/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.lda;

import com.google.common.base.Charsets;
import com.google.common.io.Files;
import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.io.Writer;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.IntPairWritable;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.utils.vectors.VectorHelper;

public final class LDAPrintTopics {
    private LDAPrintTopics() {
    }

    private static void ensureQueueSize(Collection<PriorityQueue<StringDoublePair>> queues, int k) {
        for (int i = queues.size(); i <= k; ++i) {
            queues.add(new PriorityQueue());
        }
    }

    public static void main(String[] args) throws Exception {
        DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
        ArgumentBuilder abuilder = new ArgumentBuilder();
        GroupBuilder gbuilder = new GroupBuilder();
        DefaultOption inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription("Path to an LDA output (a state)").withShortName("i").create();
        DefaultOption dictOpt = obuilder.withLongName("dict").withRequired(true).withArgument(abuilder.withName("dict").withMinimum(1).withMaximum(1).create()).withDescription("Dictionary to read in, in the same format as one created by org.apache.mahout.utils.vectors.lucene.Driver").withShortName("d").create();
        DefaultOption outOpt = obuilder.withLongName("output").withRequired(false).withArgument(abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription("Output directory to write top words").withShortName("o").create();
        DefaultOption wordOpt = obuilder.withLongName("words").withRequired(false).withArgument(abuilder.withName("words").withMinimum(0).withMaximum(1).withDefault((Object)"20").create()).withDescription("Number of words to print").withShortName("w").create();
        DefaultOption dictTypeOpt = obuilder.withLongName("dictionaryType").withRequired(false).withArgument(abuilder.withName("dictionaryType").withMinimum(1).withMaximum(1).create()).withDescription("The dictionary file type (text|sequencefile)").withShortName("dt").create();
        DefaultOption helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
        Group group = gbuilder.withName("Options").withOption((Option)dictOpt).withOption((Option)outOpt).withOption((Option)wordOpt).withOption((Option)inputOpt).withOption((Option)dictTypeOpt).create();
        try {
            List<String> wordList;
            Parser parser = new Parser();
            parser.setGroup(group);
            CommandLine cmdLine = parser.parse(args);
            if (cmdLine.hasOption((Option)helpOpt)) {
                CommandLineUtil.printHelp((Group)group);
                return;
            }
            String input = cmdLine.getValue((Option)inputOpt).toString();
            String dictFile = cmdLine.getValue((Option)dictOpt).toString();
            int numWords = 20;
            if (cmdLine.hasOption((Option)wordOpt)) {
                numWords = Integer.parseInt(cmdLine.getValue((Option)wordOpt).toString());
            }
            Configuration config = new Configuration();
            String dictionaryType = "text";
            if (cmdLine.hasOption((Option)dictTypeOpt)) {
                dictionaryType = cmdLine.getValue((Option)dictTypeOpt).toString();
            }
            if ("text".equals(dictionaryType)) {
                wordList = Arrays.asList(VectorHelper.loadTermDictionary(new File(dictFile)));
            } else if ("sequencefile".equals(dictionaryType)) {
                wordList = Arrays.asList(VectorHelper.loadTermDictionary(config, dictFile));
            } else {
                throw new IllegalArgumentException("Invalid dictionary format");
            }
            List<List<String>> topWords = LDAPrintTopics.topWordsForTopics(input, config, wordList, numWords);
            if (cmdLine.hasOption((Option)outOpt)) {
                File output = new File(cmdLine.getValue((Option)outOpt).toString());
                if (!output.exists() && !output.mkdirs()) {
                    throw new IOException("Could not create directory: " + output);
                }
                LDAPrintTopics.writeTopWords(topWords, output);
            } else {
                LDAPrintTopics.printTopWords(topWords);
            }
        }
        catch (OptionException e) {
            CommandLineUtil.printHelp((Group)group);
            throw e;
        }
    }

    private static void maybeEnqueue(Queue<StringDoublePair> q, String word, double score, int numWordsToPrint) {
        if (q.size() >= numWordsToPrint && score > q.peek().score) {
            q.poll();
        }
        if (q.size() < numWordsToPrint) {
            q.add(new StringDoublePair(score, word));
        }
    }

    private static void printTopWords(List<List<String>> topWords) {
        for (int i = 0; i < topWords.size(); ++i) {
            List<String> topK = topWords.get(i);
            System.out.println("Topic " + i);
            System.out.println("===========");
            for (String word : topK) {
                System.out.println(word);
            }
        }
    }

    private static List<List<String>> topWordsForTopics(String dir, Configuration job, List<String> wordList, int numWordsToPrint) {
        ArrayList<PriorityQueue<StringDoublePair>> queues = new ArrayList<PriorityQueue<StringDoublePair>>();
        for (Pair record : new SequenceFileDirIterable(new Path(dir, "part-*"), PathType.GLOB, null, null, true, job)) {
            IntPairWritable key = (IntPairWritable)record.getFirst();
            int topic = key.getFirst();
            int word = key.getSecond();
            LDAPrintTopics.ensureQueueSize(queues, topic);
            if (word < 0 || topic < 0) continue;
            double score = ((DoubleWritable)record.getSecond()).get();
            String realWord = wordList.get(word);
            LDAPrintTopics.maybeEnqueue((Queue)queues.get(topic), realWord, score, numWordsToPrint);
        }
        ArrayList<List<String>> result = new ArrayList<List<String>>();
        for (int i = 0; i < queues.size(); ++i) {
            result.add(i, new LinkedList());
            for (StringDoublePair sdp : (PriorityQueue)queues.get(i)) {
                ((List)result.get(i)).add(0, sdp.word);
            }
        }
        return result;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void writeTopWords(List<List<String>> topWords, File output) throws IOException {
        for (int i = 0; i < topWords.size(); ++i) {
            List<String> topK = topWords.get(i);
            BufferedWriter writer = Files.newWriter((File)new File(output, "topic-" + i), (Charset)Charsets.UTF_8);
            try {
                writer.write("Topic " + i + '\n');
                writer.write("===========\n");
                for (String word : topK) {
                    writer.write(word + '\n');
                }
                continue;
            }
            finally {
                ((Writer)writer).close();
            }
        }
    }

    private static class StringDoublePair
    implements Comparable<StringDoublePair> {
        private final double score;
        private final String word;

        StringDoublePair(double score, String word) {
            this.score = score;
            this.word = word;
        }

        @Override
        public int compareTo(StringDoublePair other) {
            return Double.compare(this.score, other.score);
        }

        public boolean equals(Object o) {
            if (!(o instanceof StringDoublePair)) {
                return false;
            }
            StringDoublePair other = (StringDoublePair)o;
            return this.score == other.score && this.word.equals(other.word);
        }

        public int hashCode() {
            return (int)Double.doubleToLongBits(this.score) ^ this.word.hashCode();
        }
    }
}

