/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.crf;

import edu.berkeley.nlp.classify.Encoding;
import edu.berkeley.nlp.classify.FeatureExtractor;
import edu.berkeley.nlp.crf.CRFObjectiveFunction;
import edu.berkeley.nlp.crf.Inference;
import edu.berkeley.nlp.crf.InstanceSequence;
import edu.berkeley.nlp.crf.LabeledInstanceSequence;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.LBFGSMinimizer;
import edu.berkeley.nlp.util.Indexer;
import edu.berkeley.nlp.util.Lists;
import edu.berkeley.nlp.util.Logger;
import edu.berkeley.nlp.util.Pair;
import edu.berkeley.nlp.util.PriorityQueue;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ChainCRFTagger<V, E, L>
implements Serializable {
    private static final long serialVersionUID = 9165167851374358823L;
    private final Encoding<?, L> encoding;
    private final Inference<V, E, ?, L> inf;
    private final double[] w;

    public ChainCRFTagger(Encoding<?, L> encoding, Inference<V, E, ?, L> inf, double[] w) {
        this.encoding = encoding;
        this.inf = inf;
        this.w = w;
    }

    public List<L> getViterbiLabelSequence(InstanceSequence<V, E, L> s) {
        return this.getTopKLabelSequencesAndScores(s, 1).get(0).getFirst();
    }

    public List<Pair<List<L>, Double>> getTopKLabelSequencesAndScores(InstanceSequence<V, E, L> s, int k) {
        Pair<int[][][][], double[][][]> chart = this.inf.getKBestChartAndBacktrace(s, this.w, k);
        ArrayList<Pair<List<L>, Double>> sentences = new ArrayList<Pair<List<L>, Double>>(k);
        int n = s.getSequenceLength();
        PriorityQueue<Pair<Integer, Integer>> rankedScores = this.buildRankedScoreQueue(chart.getSecond()[n - 1]);
        for (int i = 0; i < k && rankedScores.hasNext(); ++i) {
            double score = rankedScores.getPriority();
            Pair<Integer, Integer> chain = rankedScores.next();
            sentences.add(Pair.makePair(this.rebuildChain(chart.getFirst(), chain.getFirst(), chain.getSecond()), score));
        }
        return sentences;
    }

    private List<L> rebuildChain(int[][][][] backtrace, int endLabel, int endCandidate) {
        int n = backtrace.length;
        ArrayList<L> l = new ArrayList<L>(n);
        int currentLabel = endLabel;
        int currentCandidate = endCandidate;
        for (int i = n - 1; i >= 0; --i) {
            l.add(this.encoding.getLabel(currentLabel));
            int nextLabel = backtrace[i][currentLabel][currentCandidate][0];
            currentCandidate = backtrace[i][currentLabel][currentCandidate][1];
            currentLabel = nextLabel;
        }
        assert (currentLabel == -1 && currentCandidate == 0);
        Lists.reverse(l);
        return l;
    }

    private PriorityQueue<Pair<Integer, Integer>> buildRankedScoreQueue(double[][] scores) {
        PriorityQueue<Pair<Integer, Integer>> pq = new PriorityQueue<Pair<Integer, Integer>>();
        for (int l = 0; l < scores.length; ++l) {
            for (int c = 0; c < scores[l].length; ++c) {
                pq.add(Pair.makePair(l, c), scores[l][c]);
            }
        }
        return pq;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class Factory<V, E, F, L> {
        private final FeatureExtractor<V, F> vertexExtractor;
        private final FeatureExtractor<E, F> edgeExtractor;
        private final double sigma;
        private final int iterations;

        public Factory(FeatureExtractor<V, F> vertexExtractor, FeatureExtractor<E, F> edgeExtractor, double sigma, int iterations) {
            this.vertexExtractor = vertexExtractor;
            this.edgeExtractor = edgeExtractor;
            this.sigma = sigma;
            this.iterations = iterations;
        }

        public ChainCRFTagger<V, E, L> trainTagger(List<? extends LabeledInstanceSequence<V, E, L>> trainingData) {
            Encoding<F, L> encoding = this.buildEncoding(trainingData);
            CRFObjectiveFunction<V, E, F, L> objective = new CRFObjectiveFunction<V, E, F, L>(trainingData, encoding, this.vertexExtractor, this.edgeExtractor, this.sigma);
            LBFGSMinimizer minimizer = new LBFGSMinimizer(this.iterations);
            Logger.startTrack("Training with LBFGS", new Object[0]);
            double[] w = minimizer.minimize(objective, DoubleArrays.constantArray(0.0, encoding.getNumFeatures() * encoding.getNumLabels()), 1.0E-4, true);
            Logger.endTrack();
            return new ChainCRFTagger<V, E, L>(encoding, new Inference<V, E, F, L>(encoding, this.vertexExtractor, this.edgeExtractor), w);
        }

        private Encoding<F, L> buildEncoding(List<? extends LabeledInstanceSequence<V, E, L>> trainingData) {
            int i;
            Indexer<F> featureIndexer = new Indexer<F>();
            Indexer<L> labelIndexer = new Indexer<L>();
            for (LabeledInstanceSequence<V, E, L> labeledInstanceSequence : trainingData) {
                for (i = 0; i < labeledInstanceSequence.getSequenceLength(); ++i) {
                    labelIndexer.add(labeledInstanceSequence.getGoldLabel(i));
                }
            }
            for (LabeledInstanceSequence labeledInstanceSequence : trainingData) {
                for (i = 0; i < labeledInstanceSequence.getSequenceLength(); ++i) {
                    featureIndexer.addAll(this.vertexExtractor.extractFeatures(labeledInstanceSequence.getVertexInstance(i)).keySet());
                    if (i <= 0) continue;
                    for (int l = 0; l < labelIndexer.size(); ++l) {
                        featureIndexer.addAll(this.edgeExtractor.extractFeatures(labeledInstanceSequence.getEdgeInstance(i, labelIndexer.getObject(l))).keySet());
                    }
                }
            }
            return new Encoding(featureIndexer, labelIndexer);
        }
    }
}

