/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.PCFGLA;

import edu.berkeley.nlp.PCFGLA.ArrayParser;
import edu.berkeley.nlp.PCFGLA.Binarization;
import edu.berkeley.nlp.PCFGLA.Corpus;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.GrammarTrainer;
import edu.berkeley.nlp.PCFGLA.Lexicon;
import edu.berkeley.nlp.PCFGLA.ParserData;
import edu.berkeley.nlp.PCFGLA.SophisticatedLexicon;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.CommandLineUtils;
import edu.berkeley.nlp.util.Numberer;
import edu.berkeley.nlp.util.PriorityQueue;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

public class GrammarMerger {
    public static void main(String[] args) {
        int minIterations;
        if (args.length < 1) {
            System.out.println("usage: java GrammarMerger \n\t\t  -i       Input File for Grammar (Required)\n\t\t  -o       Output File for Merged Grammar (Required)\n\t\t  -p       Merging percentage (Default: 0.5)\n\t\t  -2p      Merging percentage for non-siblings (Default: 0.0)\n\t\t  -top     Keep top N substates, overrides -p!               -path  Path to Corpus (Default: null)\n\t\t  -chsh    If this is enabled, then we train on a short segment of\n\t\t           the Chinese treebank (Default: false)\t\t  -trfr    The fraction of the training corpus to keep (Default: 1.0)\n\t\t  -maxIt   Maximum number of EM iterations (Default: 100)\t\t  -minIt   Minimum number of EM iterations (Default: 5)\t\t\t -f\t\t    Filter rules with prob under f (Default: -1)\t\t  -dL      Delete labels? (true/false) (Default: false)\t\t  -ent \t  Use Entropic prior (Default: false)\t\t  -maxL \t  Maximum sentence length (Default: 10000)\t\t\t -sep\t    Set merging threshold for grammar and lexicon separately (Default: false)");
            System.exit(2);
        }
        System.out.print("Running with arguments:  ");
        for (String arg : args) {
            System.out.print(" '" + arg + "'");
        }
        System.out.println("");
        Map<String, String> input = CommandLineUtils.simpleCommandLineParser(args);
        double mergingPercentage = Double.parseDouble(CommandLineUtils.getValueOrUseDefault(input, "-p", "0.5"));
        double mergingPercentage2 = Double.parseDouble(CommandLineUtils.getValueOrUseDefault(input, "-2p", "0.0"));
        String outFileName = CommandLineUtils.getValueOrUseDefault(input, "-o", null);
        String inFileName = CommandLineUtils.getValueOrUseDefault(input, "-i", null);
        System.out.println("Loading grammar from " + inFileName + ".");
        ParserData pData = ParserData.Load(inFileName);
        if (pData == null) {
            System.out.println("Failed to load grammar from file" + inFileName + ".");
            System.exit(1);
        }
        if ((minIterations = Integer.parseInt(CommandLineUtils.getValueOrUseDefault(input, "-minIt", "0"))) > 0) {
            System.out.println("I will do at least " + minIterations + " iterations.");
        }
        boolean separateMerge = CommandLineUtils.getValueOrUseDefault(input, "-sep", "").equals("true");
        int maxIterations = Integer.parseInt(CommandLineUtils.getValueOrUseDefault(input, "-maxIt", "100"));
        if (maxIterations > 0) {
            System.out.println("But at most " + maxIterations + " iterations.");
        }
        boolean deleteLabels = CommandLineUtils.getValueOrUseDefault(input, "-dL", "").equals("true");
        boolean useEntropicPrior = CommandLineUtils.getValueOrUseDefault(input, "-ent", "").equals("true");
        int maxSentenceLength = Integer.parseInt(CommandLineUtils.getValueOrUseDefault(input, "-maxL", "10000"));
        System.out.println("Will remove sentences with more than " + maxSentenceLength + " words.");
        String path = CommandLineUtils.getValueOrUseDefault(input, "-path", null);
        boolean chineseShort = Boolean.parseBoolean(CommandLineUtils.getValueOrUseDefault(input, "-chsh", "false"));
        double trainingFractionToKeep = Double.parseDouble(CommandLineUtils.getValueOrUseDefault(input, "-trfr", "1.0"));
        Grammar grammar = pData.getGrammar();
        Lexicon lexicon = pData.getLexicon();
        Numberer.setNumberers(pData.getNumbs());
        int h_markov = pData.h_markov;
        int v_markov = pData.v_markov;
        Binarization bin = pData.bin;
        short[] numSubStatesArray = pData.numSubStatesArray;
        Numberer tagNumberer = Numberer.getGlobalNumberer("tags");
        double filter = Double.parseDouble(CommandLineUtils.getValueOrUseDefault(input, "-f", "-1"));
        if (filter > 0.0) {
            System.out.println("Will remove rules with prob under " + filter);
        }
        Corpus corpus = new Corpus(path, Corpus.TreeBankType.WSJ, trainingFractionToKeep, false);
        List<Tree<String>> trainTrees = Corpus.binarizeAndFilterTrees(corpus.getTrainTrees(), v_markov, h_markov, maxSentenceLength, bin, false, false);
        List<Tree<String>> validationTrees = Corpus.binarizeAndFilterTrees(corpus.getValidationTrees(), v_markov, h_markov, maxSentenceLength, bin, false, false);
        int nTrees = trainTrees.size();
        System.out.println("There are " + nTrees + " trees in the training set.");
        StateSetTreeList trainStateSetTrees = new StateSetTreeList(trainTrees, numSubStatesArray, false, tagNumberer);
        StateSetTreeList validationStateSetTrees = new StateSetTreeList(validationTrees, numSubStatesArray, false, tagNumberer);
        double[][] mergeWeights = GrammarMerger.computeMergeWeights(grammar, lexicon, trainStateSetTrees);
        double[][][] deltas = GrammarMerger.computeDeltas(grammar, lexicon, mergeWeights, trainStateSetTrees);
        boolean[][][] mergeThesePairs = GrammarMerger.determineMergePairs(deltas, separateMerge, mergingPercentage, grammar);
        Grammar newGrammar = GrammarMerger.doTheMerges(grammar, lexicon, mergeThesePairs, mergeWeights);
        GrammarMerger.printMergingStatistics(grammar, newGrammar);
        short[] newNumSubStatesArray = newGrammar.numSubStates;
        trainStateSetTrees = new StateSetTreeList(trainTrees, newNumSubStatesArray, false, tagNumberer);
        validationStateSetTrees = new StateSetTreeList(validationTrees, newNumSubStatesArray, false, tagNumberer);
        System.out.println("completing lexicon merge");
        ArrayParser newParser = new ArrayParser(newGrammar, lexicon);
        SophisticatedLexicon newLexicon = new SophisticatedLexicon(newNumSubStatesArray, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, lexicon.getSmoothingParams(), lexicon.getSmoother(), filter);
        boolean updateOnlyLexicon = true;
        double trainingLikelihood = GrammarTrainer.doOneEStep(newGrammar, lexicon, null, newLexicon, trainStateSetTrees, updateOnlyLexicon, 4);
        System.out.println("The training LL is " + trainingLikelihood);
        newLexicon.optimize();
        SophisticatedLexicon previousLexicon = null;
        Grammar previousGrammar = null;
        System.out.println("Doing some iterations of EM to clean things up...");
        double maxLikelihood = Double.NEGATIVE_INFINITY;
        int droppingIter = 0;
        int iter = 0;
        while (droppingIter < 2 && iter < maxIterations) {
            ++iter;
            previousLexicon = newLexicon;
            previousGrammar = newGrammar;
            boolean noSmoothing = false;
            boolean debugOutput = false;
            newParser = new ArrayParser(previousGrammar, previousLexicon);
            newLexicon = new SophisticatedLexicon(newNumSubStatesArray, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, lexicon.getSmoothingParams(), lexicon.getSmoother(), filter);
            newGrammar = new Grammar(newNumSubStatesArray, grammar.findClosedPaths, grammar.smoother, grammar, filter);
            if (useEntropicPrior) {
                grammar.useEntropicPrior = true;
            }
            int n = 0;
            trainingLikelihood = 0.0;
            for (Tree<StateSet> stateSetTree : trainStateSetTrees) {
                boolean secondHalf = (double)n++ > (double)nTrees / 2.0;
                newParser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput);
                double ll = stateSetTree.getLabel().getIScore(0);
                ll = Math.log(ll) + (double)(100 * stateSetTree.getLabel().getIScale());
                if (Double.isInfinite(ll) || Double.isNaN(ll)) {
                    System.out.println("Training sentence " + n + " is given " + ll + " log likelihood!");
                    GrammarTrainer.printBadLLReason(stateSetTree, previousLexicon);
                    continue;
                }
                trainingLikelihood += ll;
                newGrammar.tallyStateSetTree(stateSetTree, previousGrammar);
                newLexicon.trainTree(stateSetTree, -1.0, previousLexicon, secondHalf, false, 4);
            }
            System.out.println("The training LL is " + trainingLikelihood);
            newLexicon.optimize();
            newGrammar.optimize(0.0);
            newParser = new ArrayParser(newGrammar, newLexicon);
            double validationLikelihood = 0.0;
            n = 0;
            for (Tree<StateSet> stateSetTree : validationStateSetTrees) {
                ++n;
                newParser.doInsideScores(stateSetTree, false, false, null);
                double ll = stateSetTree.getLabel().getIScore(0);
                ll = Math.log(ll) + (double)(100 * stateSetTree.getLabel().getIScale());
                if (Double.isInfinite(ll) || Double.isNaN(ll)) {
                    System.out.println("Validation sentence " + n + " is given -inf log likelihood!");
                    continue;
                }
                validationLikelihood += ll;
            }
            System.out.println("The validation LL after merging and " + (iter + 1) + " iterations is " + validationLikelihood);
            if (iter < minIterations) {
                maxLikelihood = Math.max(validationLikelihood, maxLikelihood);
                grammar = newGrammar;
                lexicon = newLexicon;
                droppingIter = 0;
            } else if (validationLikelihood > maxLikelihood) {
                maxLikelihood = validationLikelihood;
                grammar = newGrammar;
                lexicon = newLexicon;
                droppingIter = 0;
            } else {
                ++droppingIter;
            }
            if (iter <= 0 || iter % 5 != 0) continue;
            pData = new ParserData(newLexicon, newGrammar, null, Numberer.getNumberers(), newNumSubStatesArray, v_markov, h_markov, bin);
            System.out.println("Saving grammar to " + outFileName + "-it-" + iter + ".");
            System.out.println("It gives a validation data log likelihood of: " + maxLikelihood);
            if (pData.Save(outFileName + "-it-" + iter)) {
                System.out.println("Saving successful");
            } else {
                System.out.println("Saving failed!");
            }
            pData = null;
        }
        System.out.println("Saving grammar to " + outFileName + ".");
        System.out.println("It gives a validation data log likelihood of: " + maxLikelihood);
        ParserData newPData = new ParserData(lexicon, grammar, null, Numberer.getNumberers(), newNumSubStatesArray, v_markov, h_markov, bin);
        if (newPData.Save(outFileName)) {
            System.out.println("Saving successful.");
        } else {
            System.out.println("Saving failed!");
        }
        System.exit(0);
    }

    public static void printMergingStatistics(Grammar grammar, Grammar newGrammar) {
        PriorityQueue<String> lexiconStates = new PriorityQueue<String>();
        PriorityQueue<String> grammarStates = new PriorityQueue<String>();
        short[] numSubStatesArray = grammar.numSubStates;
        short[] newNumSubStatesArray = newGrammar.numSubStates;
        Numberer tagNumberer = grammar.tagNumberer;
        for (int state = 0; state < numSubStatesArray.length; state = (int)((short)(state + 1))) {
            System.out.print("\nState " + tagNumberer.object(state) + " had " + numSubStatesArray[state] + " substates and now has " + newNumSubStatesArray[state] + ".");
            if (!grammar.isGrammarTag(state)) {
                lexiconStates.add((String)tagNumberer.object(state), newNumSubStatesArray[state]);
                continue;
            }
            grammarStates.add((String)tagNumberer.object(state), newNumSubStatesArray[state]);
        }
        System.out.print("\n");
        System.out.println("Lexicon: " + lexiconStates.toString());
        System.out.println("Grammar: " + grammarStates.toString());
    }

    public static Grammar doTheMerges(Grammar grammar, Lexicon lexicon, boolean[][][] mergeThesePairs, double[][] mergeWeights) {
        short[] numSubStatesArray = grammar.numSubStates;
        short[] newNumSubStatesArray = grammar.numSubStates;
        Grammar newGrammar = null;
        while (true) {
            int j;
            int tag;
            boolean somethingToMerge = false;
            for (int tag2 = 0; tag2 < numSubStatesArray.length; ++tag2) {
                for (int i = 0; i < newNumSubStatesArray[tag2]; ++i) {
                    for (int j2 = 0; j2 < newNumSubStatesArray[tag2]; ++j2) {
                        somethingToMerge = somethingToMerge || mergeThesePairs[tag2][i][j2];
                    }
                }
            }
            if (!somethingToMerge) break;
            boolean[][][] mergeThisIteration = new boolean[newNumSubStatesArray.length][][];
            for (tag = 0; tag < numSubStatesArray.length; ++tag) {
                mergeThisIteration[tag] = new boolean[mergeThesePairs[tag].length][mergeThesePairs[tag].length];
                for (int i = 0; i < mergeThesePairs[tag].length; ++i) {
                    for (j = 0; j < mergeThesePairs[tag].length; ++j) {
                        mergeThisIteration[tag][i][j] = mergeThesePairs[tag][i][j];
                    }
                }
            }
            for (tag = 0; tag < numSubStatesArray.length; ++tag) {
                boolean[] alreadyDecidedToMerge = new boolean[mergeThesePairs[tag].length];
                for (int i = 0; i < mergeThesePairs[tag].length; ++i) {
                    for (int j3 = 0; j3 < mergeThesePairs[tag].length; ++j3) {
                        if (alreadyDecidedToMerge[i] || alreadyDecidedToMerge[j3]) {
                            mergeThisIteration[tag][i][j3] = false;
                        }
                        alreadyDecidedToMerge[i] = alreadyDecidedToMerge[i] || mergeThesePairs[tag][i][j3];
                        alreadyDecidedToMerge[j3] = alreadyDecidedToMerge[j3] || mergeThesePairs[tag][i][j3];
                    }
                }
            }
            for (tag = 0; tag < numSubStatesArray.length; ++tag) {
                for (int i = 0; i < mergeThesePairs[tag].length; ++i) {
                    for (j = 0; j < mergeThesePairs[tag].length; ++j) {
                        mergeThesePairs[tag][i][j] = mergeThesePairs[tag][i][j] && !mergeThisIteration[tag][i][j];
                    }
                }
            }
            newGrammar = grammar.mergeStates(mergeThisIteration, mergeWeights);
            lexicon.mergeStates(mergeThisIteration, mergeWeights);
            grammar.fixMergeWeightsEtc(mergeThesePairs, mergeWeights, mergeThisIteration);
            grammar = newGrammar;
            newNumSubStatesArray = grammar.numSubStates;
        }
        grammar.makeCRArrays();
        return grammar;
    }

    public static double[][][] computeDeltas(Grammar grammar, Lexicon lexicon, double[][] mergeWeights, StateSetTreeList trainStateSetTrees) {
        ArrayParser parser = new ArrayParser(grammar, lexicon);
        double[][][] deltas = new double[grammar.numSubStates.length][mergeWeights[0].length][mergeWeights[0].length];
        boolean noSmoothing = false;
        boolean debugOutput = false;
        for (Tree<StateSet> stateSetTree : trainStateSetTrees) {
            parser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput);
            double ll = stateSetTree.getLabel().getIScore(0);
            if (Double.isInfinite(ll = Math.log(ll) + (double)(100 * stateSetTree.getLabel().getIScale()))) continue;
            grammar.tallyMergeScores(stateSetTree, deltas, mergeWeights);
        }
        return deltas;
    }

    public static double[][] computeMergeWeights(Grammar grammar, Lexicon lexicon, StateSetTreeList trainStateSetTrees) {
        double[][] mergeWeights = new double[grammar.numSubStates.length][(int)ArrayUtil.max(grammar.numSubStates)];
        double trainingLikelihood = 0.0;
        ArrayParser parser = new ArrayParser(grammar, lexicon);
        boolean noSmoothing = false;
        boolean debugOutput = false;
        int n = 0;
        for (Tree<StateSet> stateSetTree : trainStateSetTrees) {
            parser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput);
            double ll = stateSetTree.getLabel().getIScore(0);
            ll = Math.log(ll) + (double)(100 * stateSetTree.getLabel().getIScale());
            if (Double.isInfinite(ll)) {
                System.out.println("Training sentence " + n + " is given -inf log likelihood!");
            } else {
                trainingLikelihood += ll;
                grammar.tallyMergeWeights(stateSetTree, mergeWeights);
            }
            ++n;
        }
        System.out.println("The trainings LL before merging is " + trainingLikelihood);
        grammar.normalizeMergeWeights(mergeWeights);
        return mergeWeights;
    }

    public static boolean[][][] determineMergePairs(double[][][] deltas, boolean separateMerge, double mergingPercentage, Grammar grammar) {
        int i;
        int state;
        boolean[][][] mergeThesePairs = new boolean[grammar.numSubStates.length][][];
        short[] numSubStatesArray = grammar.numSubStates;
        ArrayList<Double> deltaSiblings = new ArrayList<Double>();
        ArrayList<Double> deltaPairs = new ArrayList<Double>();
        ArrayList<Double> deltaLexicon = new ArrayList<Double>();
        ArrayList<Double> deltaGrammar = new ArrayList<Double>();
        int nSiblings = 0;
        int nPairs = 0;
        int nSiblingsGr = 0;
        int nSiblingsLex = 0;
        for (int state2 = 0; state2 < mergeThesePairs.length; ++state2) {
            for (int sub1 = 0; sub1 < numSubStatesArray[state2] - 1; ++sub1) {
                if (sub1 % 2 == 0 && deltas[state2][sub1][sub1 + 1] != 0.0) {
                    deltaSiblings.add(deltas[state2][sub1][sub1 + 1]);
                    if (separateMerge) {
                        if (grammar.isGrammarTag(state2)) {
                            deltaGrammar.add(deltas[state2][sub1][sub1 + 1]);
                            ++nSiblingsGr;
                        } else {
                            deltaLexicon.add(deltas[state2][sub1][sub1 + 1]);
                            ++nSiblingsLex;
                        }
                    }
                    ++nSiblings;
                }
                for (int sub2 = sub1 + 1; sub2 < numSubStatesArray[state2]; ++sub2) {
                    if (sub2 != sub1 + 1 && sub1 % 2 != 0 || deltas[state2][sub1][sub2] == 0.0) continue;
                    deltaPairs.add(deltas[state2][sub1][sub2]);
                    ++nPairs;
                }
            }
        }
        double threshold = -1.0;
        double threshold2 = -1.0;
        double thresholdGr = -1.0;
        double thresholdLex = -1.0;
        if (separateMerge) {
            System.out.println("Going to merge " + (int)(mergingPercentage * 100.0) + "% of the substates siblings.");
            System.out.println("Setting the merging threshold for lexicon and grammar separately.");
            Collections.sort(deltaGrammar);
            Collections.sort(deltaLexicon);
            thresholdGr = (Double)deltaGrammar.get((int)((double)nSiblingsGr * mergingPercentage));
            thresholdLex = (Double)deltaLexicon.get((int)((double)nSiblingsLex * mergingPercentage * 1.5));
            System.out.println("Setting the threshold for lexical siblings to " + thresholdLex);
            System.out.println("Setting the threshold for grammatical siblings to " + thresholdGr);
        } else {
            Collections.sort(deltaSiblings);
            System.out.println("Going to merge " + (int)(mergingPercentage * 100.0) + "% of the substates siblings.");
            threshold = (Double)deltaSiblings.get((int)((double)nSiblings * mergingPercentage));
            System.out.println("Setting the threshold for siblings to " + threshold + ".");
        }
        int mergePair = 0;
        int mergeSiblings = 0;
        for (state = 0; state < mergeThesePairs.length; ++state) {
            mergeThesePairs[state] = new boolean[numSubStatesArray[state]][numSubStatesArray[state]];
            for (i = 0; i < numSubStatesArray[state] - 1; ++i) {
                if (i % 2 != 0 || deltas[state][i][i + 1] == 0.0) continue;
                if (separateMerge) {
                    mergeThesePairs[state][i][i + 1] = grammar.isGrammarTag(state) ? deltas[state][i][i + 1] <= thresholdGr : deltas[state][i][i + 1] <= thresholdLex;
                } else {
                    boolean bl = mergeThesePairs[state][i][i + 1] = deltas[state][i][i + 1] <= threshold;
                }
                if (!mergeThesePairs[state][i][i + 1]) continue;
                ++mergeSiblings;
            }
        }
        System.out.println("Merging " + mergeSiblings + " siblings and " + mergePair + " other pairs.");
        for (state = 0; state < deltas.length; state = (int)((short)(state + 1))) {
            System.out.print("State " + grammar.tagNumberer.object(state));
            for (i = 0; i < numSubStatesArray[state]; ++i) {
                for (int j = i + 1; j < numSubStatesArray[state]; ++j) {
                    if (!mergeThesePairs[state][i][j]) continue;
                    System.out.print(". Merging pair (" + i + "," + j + ") at cost " + deltas[state][i][j]);
                }
            }
            System.out.print(".\n");
        }
        return mergeThesePairs;
    }
}

