/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.crf;

import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.crf.ChainCrfFeatureExtractor;
import com.aliasi.crf.ChainCrfFeatures;
import com.aliasi.crf.ForwardBackwardTagLattice;
import com.aliasi.features.Features;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Matrices;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.symbol.MapSymbolTable;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.tag.MarginalTagger;
import com.aliasi.tag.NBestTagger;
import com.aliasi.tag.ScoredTagging;
import com.aliasi.tag.TagLattice;
import com.aliasi.tag.Tagger;
import com.aliasi.tag.Tagging;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.BoundedPriorityQueue;
import com.aliasi.util.Iterators;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.Scored;
import com.aliasi.util.ScoredObject;
import com.aliasi.util.Strings;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Formatter;
import java.util.IllegalFormatException;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;

public class ChainCrf<E>
implements Tagger<E>,
NBestTagger<E>,
MarginalTagger<E>,
Serializable {
    static final long serialVersionUID = -4868542587460878290L;
    private final List<String> mTagList;
    private final boolean[] mLegalTagStarts;
    private final boolean[] mLegalTagEnds;
    private final boolean[][] mLegalTagTransitions;
    private final Vector[] mCoefficients;
    private final SymbolTable mFeatureSymbolTable;
    private final ChainCrfFeatureExtractor<E> mFeatureExtractor;
    private final boolean mAddInterceptFeature;
    private final int mNumDimensions;
    static final String INTERCEPT_FEATURE_NAME = "*&^INTERCEPT%$^&**";
    static final double[][] EMPTY_DOUBLE_2D_ARRAY = new double[0][];
    static final double[][][] EMPTY_DOUBLE_3D_ARRAY = new double[0][][];

    public ChainCrf(String[] tags, Vector[] coefficients, SymbolTable featureSymbolTable, ChainCrfFeatureExtractor<E> featureExtractor, boolean addInterceptFeature) {
        this(tags, ChainCrf.trueArray(tags.length), ChainCrf.trueArray(tags.length), ChainCrf.trueArray(tags.length, tags.length), coefficients, featureSymbolTable, featureExtractor, addInterceptFeature);
    }

    public ChainCrf(String[] tags, boolean[] legalTagStarts, boolean[] legalTagEnds, boolean[][] legalTagTransitions, Vector[] coefficients, SymbolTable featureSymbolTable, ChainCrfFeatureExtractor<E> featureExtractor, boolean addInterceptFeature) {
        String msg;
        if (tags.length < 1) {
            String string = "Require at least one tag.";
        }
        if (tags.length != coefficients.length) {
            msg = "Require tags and coefficients to be same length. Found tags.length=" + tags.length + " coefficients.length=" + coefficients.length;
            throw new IllegalArgumentException(msg);
        }
        if (tags.length != legalTagStarts.length) {
            msg = "Tags and starts must be same length. Found tags.length=" + tags.length + " legalTagStarts.length=" + legalTagStarts.length;
            throw new IllegalArgumentException(msg);
        }
        if (tags.length != legalTagEnds.length) {
            msg = "Tags and starts must be same length. Found tags.length=" + tags.length + " legalTagStarts.length=" + legalTagStarts.length;
            throw new IllegalArgumentException(msg);
        }
        if (tags.length != legalTagTransitions.length) {
            msg = "Tags and transitions must be same length. Found tags.length=" + tags.length + " legalTagTransitions.length=" + legalTagTransitions.length;
            throw new IllegalArgumentException(msg);
        }
        int i = 0;
        while (i < legalTagTransitions.length) {
            if (tags.length != legalTagTransitions[i].length) {
                String msg2 = "Tags and transition rows must be same length. Found tags.length=" + tags.length + " legalTagTransitions[" + i + "].length=" + legalTagTransitions[i].length;
                throw new IllegalArgumentException(msg2);
            }
            ++i;
        }
        int k = 1;
        while (k < coefficients.length) {
            if (coefficients[0].numDimensions() != coefficients[k].numDimensions()) {
                String msg3 = "All coefficients must be same length. Found coefficents[0].numDimensions()=" + coefficients[0].numDimensions() + " coefficients[" + k + "].numDimensions()=" + coefficients[k].numDimensions();
                throw new IllegalArgumentException(msg3);
            }
            ++k;
        }
        this.mTagList = Arrays.asList(tags);
        this.mLegalTagStarts = legalTagStarts;
        this.mLegalTagEnds = legalTagEnds;
        this.mLegalTagTransitions = legalTagTransitions;
        this.mCoefficients = coefficients;
        this.mNumDimensions = coefficients[0].numDimensions();
        this.mFeatureSymbolTable = featureSymbolTable;
        this.mFeatureExtractor = featureExtractor;
        this.mAddInterceptFeature = addInterceptFeature;
    }

    public List<String> tags() {
        return Collections.unmodifiableList(this.mTagList);
    }

    public String tag(int k) {
        return this.mTagList.get(k);
    }

    public Vector[] coefficients() {
        Vector[] result = new Vector[this.mCoefficients.length];
        int k = 0;
        while (k < result.length) {
            result[k] = Matrices.unmodifiableVector(this.mCoefficients[k]);
            ++k;
        }
        return result;
    }

    public SymbolTable featureSymbolTable() {
        return MapSymbolTable.unmodifiableView(this.mFeatureSymbolTable);
    }

    public ChainCrfFeatureExtractor<E> featureExtractor() {
        return this.mFeatureExtractor;
    }

    public boolean addInterceptFeature() {
        return this.mAddInterceptFeature;
    }

    @Override
    public Tagging<E> tag(List<E> tokens) {
        int numTokens = tokens.size();
        if (numTokens == 0) {
            return new Tagging<E>(tokens, Collections.emptyList());
        }
        int numTags = this.mTagList.size();
        int numDimensions = this.mFeatureSymbolTable.numSymbols();
        double[][] bestScores = new double[numTokens][numTags];
        int[][] backPointers = new int[numTokens - 1][numTags];
        ChainCrfFeatures<E> features = this.mFeatureExtractor.extract(tokens, this.mTagList);
        Vector nodeVector0 = this.nodeFeatures(0, features);
        int k = 0;
        while (k < numTags) {
            bestScores[0][k] = this.mLegalTagStarts[k] ? nodeVector0.dotProduct(this.mCoefficients[k]) : Double.NEGATIVE_INFINITY;
            ++k;
        }
        Vector[] edgeVectors = new Vector[numTags];
        int n = 1;
        while (n < numTokens) {
            Vector nodeVector = this.nodeFeatures(n, features);
            int kMinus1 = 0;
            while (kMinus1 < numTags) {
                edgeVectors[kMinus1] = this.edgeFeatures(n, kMinus1, features);
                ++kMinus1;
            }
            int k2 = 0;
            while (k2 < numTags) {
                if (n == numTokens - 1 && !this.mLegalTagEnds[k2]) {
                    bestScores[n][k2] = Double.NEGATIVE_INFINITY;
                    backPointers[n - 1][k2] = -1;
                } else {
                    double bestScore = Double.NEGATIVE_INFINITY;
                    int backPtr = -1;
                    double nodeScore = nodeVector.dotProduct(this.mCoefficients[k2]);
                    int kMinus12 = 0;
                    while (kMinus12 < numTags) {
                        double score;
                        if (this.mLegalTagTransitions[kMinus12][k2] && (score = nodeScore + edgeVectors[kMinus12].dotProduct(this.mCoefficients[k2]) + bestScores[n - 1][kMinus12]) > bestScore) {
                            bestScore = score;
                            backPtr = kMinus12;
                        }
                        ++kMinus12;
                    }
                    bestScores[n][k2] = bestScore;
                    backPointers[n - 1][k2] = backPtr;
                }
                ++k2;
            }
            ++n;
        }
        double bestScore = Double.NEGATIVE_INFINITY;
        int bestFinalTag = -1;
        int k3 = 0;
        while (k3 < numTags) {
            if (bestScores[numTokens - 1][k3] > bestScore) {
                bestScore = bestScores[numTokens - 1][k3];
                bestFinalTag = k3;
            }
            ++k3;
        }
        ArrayList<String> tags = new ArrayList<String>(numTokens);
        int bestPreviousTag = bestFinalTag;
        tags.add(this.mTagList.get(bestFinalTag));
        int n2 = numTokens - 1;
        while (--n2 >= 0) {
            bestPreviousTag = backPointers[n2][bestPreviousTag];
            tags.add(this.mTagList.get(bestPreviousTag));
        }
        Collections.reverse(tags);
        return new Tagging<E>(tokens, tags);
    }

    @Override
    public Iterator<ScoredTagging<E>> tagNBest(List<E> tokens, int maxResults) {
        if (tokens.size() == 0) {
            ScoredTagging<E> scoredTagging = new ScoredTagging<E>(tokens, Collections.emptyList(), 0.0);
            return Iterators.singleton(scoredTagging);
        }
        return new NBestIterator(tokens, false, maxResults);
    }

    @Override
    public Iterator<ScoredTagging<E>> tagNBestConditional(List<E> tokens, int maxResults) {
        if (tokens.size() == 0) {
            ScoredTagging<E> scoredTagging = new ScoredTagging<E>(tokens, Collections.emptyList(), 0.0);
            return Iterators.singleton(scoredTagging);
        }
        return new NBestIterator(tokens, true, maxResults);
    }

    @Override
    public TagLattice<E> tagMarginal(List<E> tokens) {
        if (tokens.size() == 0) {
            return new ForwardBackwardTagLattice<E>(tokens, this.mTagList, EMPTY_DOUBLE_2D_ARRAY, EMPTY_DOUBLE_2D_ARRAY, EMPTY_DOUBLE_3D_ARRAY, 0.0);
        }
        FeatureVectors features = this.features(tokens);
        TagLattice<E> lattice = this.forwardBackward(tokens, features);
        return lattice;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Feature Extractor=" + this.featureExtractor());
        sb.append("\n");
        sb.append("Add intercept=" + this.addInterceptFeature());
        sb.append("\n");
        List<String> tags = this.tags();
        sb.append("Tags=" + tags);
        sb.append("\n");
        Vector[] coeffs = this.coefficients();
        SymbolTable symTab = this.featureSymbolTable();
        sb.append("Coefficients=\n");
        int i = 0;
        while (i < coeffs.length) {
            sb.append(tags.get(i));
            sb.append("  ");
            int[] nzDims = coeffs[i].nonZeroDimensions();
            int k = 0;
            while (k < nzDims.length) {
                if (k > 0) {
                    sb.append(", ");
                }
                int d = nzDims[k];
                sb.append(symTab.idToSymbol(d));
                sb.append("=");
                sb.append(coeffs[i].value(d));
                ++k;
            }
            sb.append("\n");
            ++i;
        }
        return sb.toString();
    }

    Object writeReplace() {
        return new Serializer(this);
    }

    private Vector nodeFeatures(int position, ChainCrfFeatures<E> features) {
        return Features.toVector(features.nodeFeatures(position), this.mFeatureSymbolTable, this.mNumDimensions, this.mAddInterceptFeature);
    }

    private Vector edgeFeatures(int position, int lastTagIndex, ChainCrfFeatures<E> features) {
        return Features.toVector(features.edgeFeatures(position, lastTagIndex), this.mFeatureSymbolTable, this.mNumDimensions, this.mAddInterceptFeature);
    }

    private FeatureVectors features(List<E> tokens) {
        int numTags = this.mTagList.size();
        int numDimensions = this.mFeatureSymbolTable.numSymbols();
        if (tokens.size() == 0) {
            return null;
        }
        ChainCrfFeatures<E> features = this.mFeatureExtractor.extract(tokens, this.mTagList);
        Vector[] nodeFeatureVectors = new Vector[tokens.size()];
        int n = 0;
        while (n < tokens.size()) {
            nodeFeatureVectors[n] = this.nodeFeatures(n, features);
            ++n;
        }
        Vector[][] edgeFeatureVectorss = new Vector[tokens.size() - 1][this.mTagList.size()];
        int n2 = 1;
        while (n2 < tokens.size()) {
            int k = 0;
            while (k < numTags) {
                edgeFeatureVectorss[n2 - 1][k] = this.edgeFeatures(n2, k, features);
                ++k;
            }
            ++n2;
        }
        return new FeatureVectors(nodeFeatureVectors, edgeFeatureVectorss);
    }

    TagLattice<E> forwardBackward(List<E> tokens, FeatureVectors features) {
        double[][][] logPotentials;
        int numTokens = tokens.size();
        int numTags = this.mTagList.size();
        double[] logPotentials0Begin = new double[numTags];
        int kTo = 0;
        while (kTo < numTags) {
            logPotentials0Begin[kTo] = this.mLegalTagStarts[kTo] ? features.mNodeFeatureVectors[0].dotProduct(this.mCoefficients[kTo]) : Double.NEGATIVE_INFINITY;
            ++kTo;
        }
        double[][][] dArray = logPotentials = new double[numTokens - 1][numTags][numTags];
        int n = logPotentials.length;
        int n2 = 0;
        while (n2 < n) {
            double[][] logPotentials2;
            double[][] dArray2 = logPotentials2 = dArray[n2];
            int n3 = logPotentials2.length;
            int n4 = 0;
            while (n4 < n3) {
                double[] logPotentials3 = dArray2[n4];
                Arrays.fill(logPotentials3, Double.NEGATIVE_INFINITY);
                ++n4;
            }
            ++n2;
        }
        int nTo = 1;
        while (nTo < numTokens) {
            int kTo2 = 0;
            while (kTo2 < numTags) {
                if (nTo != numTokens - 1 || this.mLegalTagEnds[kTo2]) {
                    double nodePotentialKTo = features.mNodeFeatureVectors[nTo].dotProduct(this.mCoefficients[kTo2]);
                    int kFrom = 0;
                    while (kFrom < numTags) {
                        if (this.mLegalTagTransitions[kFrom][kTo2]) {
                            logPotentials[nTo - 1][kFrom][kTo2] = features.mEdgeFeatureVectorss[nTo - 1][kFrom].dotProduct(this.mCoefficients[kTo2]) + nodePotentialKTo;
                        }
                        ++kFrom;
                    }
                }
                ++kTo2;
            }
            ++nTo;
        }
        double[] buf = new double[numTags];
        double[][] logForwards = new double[numTokens][numTags];
        int kTo3 = 0;
        while (kTo3 < numTags) {
            logForwards[0][kTo3] = logPotentials0Begin[kTo3];
            ++kTo3;
        }
        int nTo2 = 1;
        while (nTo2 < numTokens) {
            int kTo4 = 0;
            while (kTo4 < numTags) {
                int kFrom = 0;
                while (kFrom < numTags) {
                    buf[kFrom] = logForwards[nTo2 - 1][kFrom] + logPotentials[nTo2 - 1][kFrom][kTo4];
                    ++kFrom;
                }
                logForwards[nTo2][kTo4] = com.aliasi.util.Math.logSumOfExponentials(buf);
                ++kTo4;
            }
            ++nTo2;
        }
        double[][] logBackwards = new double[numTokens][numTags];
        int nFrom = numTokens - 1;
        while (--nFrom >= 0) {
            int kFrom = 0;
            while (kFrom < numTags) {
                int kTo5 = 0;
                while (kTo5 < numTags) {
                    buf[kTo5] = logBackwards[nFrom + 1][kTo5] + logPotentials[nFrom][kFrom][kTo5];
                    ++kTo5;
                }
                logBackwards[nFrom][kFrom] = com.aliasi.util.Math.logSumOfExponentials(buf);
                ++kFrom;
            }
        }
        double logZ = com.aliasi.util.Math.logSumOfExponentials(logForwards[numTokens - 1]);
        return new ForwardBackwardTagLattice<E>(tokens, this.mTagList, logForwards, logBackwards, logPotentials, logZ, false);
    }

    static boolean[] legalStarts(int[][] tagIdss, int numTags) {
        boolean[] legalStarts = new boolean[numTags];
        int[][] nArray = tagIdss;
        int n = tagIdss.length;
        int n2 = 0;
        while (n2 < n) {
            int[] tagIds = nArray[n2];
            if (tagIds.length > 0) {
                legalStarts[tagIds[0]] = true;
            }
            ++n2;
        }
        return legalStarts;
    }

    static boolean[] legalEnds(int[][] tagIdss, int numTags) {
        boolean[] legalEnds = new boolean[numTags];
        int[][] nArray = tagIdss;
        int n = tagIdss.length;
        int n2 = 0;
        while (n2 < n) {
            int[] tagIds = nArray[n2];
            if (tagIds.length > 0) {
                legalEnds[tagIds[tagIds.length - 1]] = true;
            }
            ++n2;
        }
        return legalEnds;
    }

    static boolean[][] legalTransitions(int[][] tagIdss, int numTags) {
        boolean[][] legalTransitions = new boolean[numTags][numTags];
        int[][] nArray = tagIdss;
        int n = tagIdss.length;
        int n2 = 0;
        while (n2 < n) {
            int[] tagIds = nArray[n2];
            int i = 1;
            while (i < tagIds.length) {
                legalTransitions[tagIds[i - 1]][tagIds[i]] = true;
                ++i;
            }
            ++n2;
        }
        return legalTransitions;
    }

    static boolean[] trueArray(int m) {
        boolean[] result = new boolean[m];
        Arrays.fill(result, true);
        return result;
    }

    static boolean[][] trueArray(int m, int n) {
        boolean[][] result;
        boolean[][] blArray = result = new boolean[m][n];
        int n2 = result.length;
        int n3 = 0;
        while (n3 < n2) {
            boolean[] row = blArray[n3];
            Arrays.fill(row, true);
            ++n3;
        }
        return result;
    }

    public static <F> ChainCrf<F> estimate(Corpus<ObjectHandler<Tagging<F>>> corpus, ChainCrfFeatureExtractor<F> featureExtractor, boolean addInterceptFeature, int minFeatureCount, boolean cacheFeatureVectors, boolean allowUnseenTransitions, RegressionPrior prior, int priorBlockSize, AnnealingSchedule annealingSchedule, double minImprovement, int minEpochs, int maxEpochs, Reporter reporter) throws IOException {
        FeatureVectors[] featureVectorsCache;
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        reporter.info("ChainCrf.estimate Parameters");
        reporter.info("featureExtractor=" + featureExtractor);
        reporter.info("addInterceptFeature=" + addInterceptFeature);
        reporter.info("minFeatureCount=" + minFeatureCount);
        reporter.info("cacheFeatureVectors=" + cacheFeatureVectors);
        reporter.info("allowUnseenTransitions=" + allowUnseenTransitions);
        reporter.info("prior=" + prior);
        reporter.info("annealingSchedule=" + annealingSchedule);
        reporter.info("minImprovement=" + minImprovement);
        reporter.info("minEpochs=" + minEpochs);
        reporter.info("maxEpochs=" + maxEpochs);
        reporter.info("priorBlockSize=" + priorBlockSize);
        reporter.info("Computing corpus tokens and features");
        List<List<F>> tokenss = ChainCrf.corpusTokens(corpus);
        String[][] tagss = ChainCrf.corpusTags(corpus);
        int numTrainingInstances = tagss.length;
        int longestInput = ChainCrf.longestInput(tagss);
        long numTrainingTokens = 0L;
        String[][] stringArray = tagss;
        int n = tagss.length;
        int n2 = 0;
        while (n2 < n) {
            String[] tags = stringArray[n2];
            numTrainingTokens += (long)tags.length;
            ++n2;
        }
        int[][] tagIdss = new int[tagss.length][];
        MapSymbolTable tagSymbolTable = ChainCrf.tagSymbolTable(tagss, tagIdss);
        MapSymbolTable featureSymbolTable = ChainCrf.featureSymbolTable(tagss, tokenss, addInterceptFeature, featureExtractor, minFeatureCount);
        int numTags = tagSymbolTable.numSymbols();
        String[] allTags = new String[numTags];
        int n3 = 0;
        while (n3 < numTags) {
            allTags[n3] = tagSymbolTable.idToSymbol(n3);
            ++n3;
        }
        boolean[] legalTagStarts = allowUnseenTransitions ? ChainCrf.trueArray(numTags) : ChainCrf.legalStarts(tagIdss, numTags);
        boolean[] legalTagEnds = allowUnseenTransitions ? ChainCrf.trueArray(numTags) : ChainCrf.legalEnds(tagIdss, numTags);
        boolean[][] legalTagTransitions = allowUnseenTransitions ? ChainCrf.trueArray(numTags, numTags) : ChainCrf.legalTransitions(tagIdss, numTags);
        int numDimensions = featureSymbolTable.numSymbols();
        Vector[] weightVectors = new DenseVector[numTags];
        int i = 0;
        while (i < weightVectors.length) {
            weightVectors[i] = new DenseVector(numDimensions);
            ++i;
        }
        reporter.info("Corpus Statistics");
        reporter.info("Num Training Instances=" + numTrainingInstances);
        reporter.info("Num Training Tokens=" + numTrainingTokens);
        reporter.info("Num Dimensions After Pruning=" + numDimensions);
        reporter.info("Tags=" + tagSymbolTable);
        ChainCrf<F> crf = new ChainCrf<F>(allTags, legalTagStarts, legalTagEnds, legalTagTransitions, weightVectors, featureSymbolTable, featureExtractor, addInterceptFeature);
        FeatureVectors[] featureVectorsArray = featureVectorsCache = cacheFeatureVectors ? new FeatureVectors[numTrainingInstances] : null;
        if (cacheFeatureVectors) {
            reporter.info("Caching Feature Vectors");
            int j = 0;
            while (j < numTrainingInstances) {
                featureVectorsCache[j] = super.features(tokenss.get(j));
                ++j;
            }
        }
        double lastLog2LikelihoodAndPrior = -8.988465674311579E307;
        double rollingAverageRelativeDiff = 1.0;
        double bestLog2LikelihoodAndPrior = Double.NEGATIVE_INFINITY;
        long cumFeatureExtractionMs = 0L;
        long cumForwardBackwardMs = 0L;
        long cumUpdateMs = 0L;
        long cumLossMs = 0L;
        long cumPriorUpdateMs = 0L;
        int epoch = 0;
        while (epoch < maxEpochs) {
            block32: {
                int instancesSinceLastPriorUpdate = 0;
                double learningRate = annealingSchedule.learningRate(epoch);
                double learningRatePerTrainingInstance = learningRate / (double)numTrainingInstances;
                int j = 0;
                while (j < numTrainingInstances) {
                    int[] tagIds = tagIdss[j];
                    List<F> tokens = tokenss.get(j);
                    int numTokens = tokens.size();
                    if (numTokens >= 1) {
                        long startMs = System.currentTimeMillis();
                        FeatureVectors features = cacheFeatureVectors ? featureVectorsCache[j] : super.features(tokens);
                        long featsMs = System.currentTimeMillis();
                        cumFeatureExtractionMs += featsMs - startMs;
                        TagLattice<F> lattice = crf.forwardBackward(tokens, features);
                        long fwdBkMs = System.currentTimeMillis();
                        cumForwardBackwardMs += fwdBkMs - featsMs;
                        int nTo = 0;
                        while (nTo < numTokens) {
                            ((DenseVector)weightVectors[tagIds[nTo]]).increment(learningRate, features.mNodeFeatureVectors[nTo]);
                            ++nTo;
                        }
                        nTo = 1;
                        while (nTo < numTokens) {
                            ((DenseVector)weightVectors[tagIds[nTo]]).increment(learningRate, features.mEdgeFeatureVectorss[nTo - 1][tagIds[nTo - 1]]);
                            ++nTo;
                        }
                        nTo = 0;
                        while (nTo < numTokens) {
                            int kTo = 0;
                            while (kTo < numTags) {
                                double logP = lattice.logProbability(nTo, kTo);
                                if (!(logP < -400.0)) {
                                    double p = Math.exp(logP);
                                    ((DenseVector)weightVectors[kTo]).increment(-p * learningRate, features.mNodeFeatureVectors[nTo]);
                                }
                                ++kTo;
                            }
                            ++nTo;
                        }
                        nTo = 1;
                        while (nTo < numTokens) {
                            int kFrom = 0;
                            while (kFrom < numTags) {
                                int kTo = 0;
                                while (kTo < numTags) {
                                    double logP = lattice.logProbability(nTo, kFrom, kTo);
                                    if (!(logP < -400.0)) {
                                        double p = Math.exp(logP);
                                        ((DenseVector)weightVectors[kTo]).increment(-p * learningRate, features.mEdgeFeatureVectorss[nTo - 1][kFrom]);
                                    }
                                    ++kTo;
                                }
                                ++kFrom;
                            }
                            ++nTo;
                        }
                        long updateMs = System.currentTimeMillis();
                        cumUpdateMs += updateMs - fwdBkMs;
                        if (++instancesSinceLastPriorUpdate == priorBlockSize) {
                            ChainCrf.adjustWeightsWithPrior((DenseVector[])weightVectors, prior, (double)instancesSinceLastPriorUpdate * learningRatePerTrainingInstance);
                            instancesSinceLastPriorUpdate = 0;
                        }
                        long priorMs = System.currentTimeMillis();
                        cumPriorUpdateMs += priorMs - updateMs;
                    }
                    ++j;
                }
                long finalPriorStartMs = System.currentTimeMillis();
                ChainCrf.adjustWeightsWithPrior((DenseVector[])weightVectors, prior, (double)instancesSinceLastPriorUpdate * learningRatePerTrainingInstance);
                long finalPriorEndMs = System.currentTimeMillis();
                cumPriorUpdateMs += finalPriorEndMs - finalPriorStartMs;
                long lossStartMs = System.currentTimeMillis();
                double log2Likelihood = 0.0;
                int j2 = 0;
                while (j2 < numTrainingInstances) {
                    if (tokenss.get(j2).size() >= 1) {
                        FeatureVectors features = cacheFeatureVectors ? featureVectorsCache[j2] : super.features(tokenss.get(j2));
                        TagLattice<F> lattice = crf.forwardBackward(tokenss.get(j2), features);
                        log2Likelihood += lattice.logProbability(0, tagIdss[j2]);
                    }
                    ++j2;
                }
                double log2Prior = prior == null ? 0.0 : prior.log2Prior(weightVectors);
                double log2LikelihoodAndPrior = log2Likelihood + log2Prior;
                double relativeDiff = com.aliasi.util.Math.relativeAbsoluteDifference(lastLog2LikelihoodAndPrior, log2LikelihoodAndPrior);
                rollingAverageRelativeDiff = (9.0 * rollingAverageRelativeDiff + relativeDiff) / 10.0;
                lastLog2LikelihoodAndPrior = log2LikelihoodAndPrior;
                if (log2LikelihoodAndPrior > bestLog2LikelihoodAndPrior) {
                    bestLog2LikelihoodAndPrior = log2LikelihoodAndPrior;
                }
                long lossMs = System.currentTimeMillis();
                cumLossMs += lossMs - lossStartMs;
                if (reporter.isDebugEnabled()) {
                    Formatter formatter = null;
                    try {
                        try {
                            formatter = new Formatter(Locale.ENGLISH);
                            formatter.format("epoch=%5d lr=%11.9f ll=%11.4f lp=%11.4f llp=%11.4f llp*=%11.4f", epoch, learningRate, log2Likelihood, log2Prior, log2LikelihoodAndPrior, bestLog2LikelihoodAndPrior);
                            reporter.debug(formatter.toString());
                        }
                        catch (IllegalFormatException e) {
                            reporter.warn("Illegal format in Logistic Regression");
                            if (formatter != null) {
                                formatter.close();
                            }
                            break block32;
                        }
                    }
                    catch (Throwable throwable) {
                        if (formatter != null) {
                            formatter.close();
                        }
                        throw throwable;
                    }
                    if (formatter != null) {
                        formatter.close();
                    }
                }
            }
            if (rollingAverageRelativeDiff < minImprovement) {
                reporter.info("Converged with rollingAverageRelativeDiff=" + rollingAverageRelativeDiff);
                break;
            }
            ++epoch;
        }
        reporter.info("Feat Extraction Time=" + Strings.msToString(cumFeatureExtractionMs));
        reporter.info("Forward Backward Time=" + Strings.msToString(cumForwardBackwardMs));
        reporter.info("Update Time=" + Strings.msToString(cumUpdateMs));
        reporter.info("Prior Update Time=" + Strings.msToString(cumPriorUpdateMs));
        reporter.info("Loss Time=" + Strings.msToString(cumLossMs));
        return crf;
    }

    static void adjustWeightsWithPrior(DenseVector[] weightVectors, RegressionPrior prior, double learningRateDividedByNumTrainingInstances) {
        if (prior.isUniform()) {
            return;
        }
        DenseVector[] denseVectorArray = weightVectors;
        int n = weightVectors.length;
        int n2 = 0;
        while (n2 < n) {
            DenseVector weightVectorsK = denseVectorArray[n2];
            int dim = 0;
            while (dim < weightVectorsK.numDimensions()) {
                double priorMode;
                double weightVectorsKDim = weightVectorsK.value(dim);
                if (weightVectorsKDim != (priorMode = prior.mode(dim))) {
                    double priorGradient = prior.gradient(weightVectorsKDim, dim);
                    double delta = priorGradient * learningRateDividedByNumTrainingInstances;
                    double newVal = weightVectorsKDim > priorMode ? Math.max(priorMode, weightVectorsKDim - delta) : Math.min(priorMode, weightVectorsKDim - delta);
                    weightVectorsK.setValue(dim, newVal);
                }
                ++dim;
            }
            ++n2;
        }
    }

    static MapSymbolTable tagSymbolTable(String[][] tagss, int[][] tagIdss) {
        MapSymbolTable tagSymbolTable = new MapSymbolTable();
        int j = 0;
        while (j < tagss.length) {
            tagIdss[j] = new int[tagss[j].length];
            int n = 0;
            while (n < tagIdss[j].length) {
                tagIdss[j][n] = tagSymbolTable.getOrAddSymbol(tagss[j][n]);
                ++n;
            }
            ++j;
        }
        return tagSymbolTable;
    }

    static <F> MapSymbolTable featureSymbolTable(String[][] tagss, List<List<F>> tokenss, boolean addInterceptFeature, ChainCrfFeatureExtractor<F> featureExtractor, int minFeatureCount) {
        ObjectToCounterMap<String> featureCounter = new ObjectToCounterMap<String>();
        int j = 0;
        while (j < tagss.length) {
            String[] tags = tagss[j];
            List<String> tagList = Arrays.asList(tags);
            List<F> tokens = tokenss.get(j);
            ChainCrfFeatures<F> features = featureExtractor.extract(tokens, tagList);
            int n = 0;
            while (n < tags.length) {
                for (String feature : features.nodeFeatures(n).keySet()) {
                    featureCounter.increment(feature);
                }
                ++n;
            }
            int k = 1;
            while (k < tags.length) {
                for (String feature : features.edgeFeatures(k, k - 1).keySet()) {
                    featureCounter.increment(feature);
                }
                ++k;
            }
            ++j;
        }
        featureCounter.prune(minFeatureCount);
        MapSymbolTable featureSymbolTable = new MapSymbolTable();
        if (addInterceptFeature) {
            featureSymbolTable.getOrAddSymbol(INTERCEPT_FEATURE_NAME);
        }
        for (String feature : featureCounter.keySet()) {
            featureSymbolTable.getOrAddSymbol(feature);
        }
        return featureSymbolTable;
    }

    static <F> List<List<F>> corpusTokens(Corpus<ObjectHandler<Tagging<F>>> corpus) throws IOException {
        final ArrayList<List<F>> corpusTokenList = new ArrayList<List<F>>();
        corpus.visitTrain(new ObjectHandler<Tagging<F>>(){

            @Override
            public void handle(Tagging<F> tagging) {
                corpusTokenList.add(tagging.tokens());
            }
        });
        return corpusTokenList;
    }

    static <F> String[][] corpusTags(Corpus<ObjectHandler<Tagging<F>>> corpus) throws IOException {
        final ArrayList corpusTagList = new ArrayList(1024);
        corpus.visitTrain(new ObjectHandler<Tagging<F>>(){

            @Override
            public void handle(Tagging<F> tagging) {
                corpusTagList.add(tagging.tags().toArray(Strings.EMPTY_STRING_ARRAY));
            }
        });
        return (String[][])corpusTagList.toArray((T[])Strings.EMPTY_STRING_2D_ARRAY);
    }

    static DenseVector[] copy(DenseVector[] xs) {
        DenseVector[] result = new DenseVector[xs.length];
        int k = 0;
        while (k < xs.length) {
            result[k] = new DenseVector(xs[k]);
            ++k;
        }
        return result;
    }

    static int longestInput(String[][] tagss) {
        int longest = 0;
        String[][] stringArray = tagss;
        int n = tagss.length;
        int n2 = 0;
        while (n2 < n) {
            String[] tags = stringArray[n2];
            if (tags.length > longest) {
                longest = tags.length;
            }
            ++n2;
        }
        return longest;
    }

    static class FeatureVectors {
        final Vector[] mNodeFeatureVectors;
        final Vector[][] mEdgeFeatureVectorss;

        FeatureVectors(Vector[] nodeFeatureVectors, Vector[][] edgeFeatureVectorss) {
            this.mNodeFeatureVectors = nodeFeatureVectors;
            this.mEdgeFeatureVectorss = edgeFeatureVectorss;
        }
    }

    static class ForwardPointer {
        final int mK;
        final ForwardPointer mPointer;
        final double mScore;

        ForwardPointer(int k, ForwardPointer pointer, double score) {
            this.mK = k;
            this.mPointer = pointer;
            this.mScore = score;
        }
    }

    class NBestIterator
    extends Iterators.Buffered<ScoredTagging<E>> {
        final List<E> mTokens;
        final double mLogZ;
        final double[][][] mTransitionScores;
        final double[][] mViterbiScores;
        final int[][] mBackPointers;
        final BoundedPriorityQueue<NBestState> mPriorityQueue;

        NBestIterator(List<E> tokens, boolean normToConditional, int maxResults) {
            this.mPriorityQueue = new BoundedPriorityQueue(ScoredObject.comparator(), maxResults);
            this.mTokens = tokens;
            int numTokens = tokens.size();
            int numTags = ChainCrf.this.mTagList.size();
            Object object = this.mTransitionScores = new double[numTokens - 1][numTags][numTags];
            int n = this.mTransitionScores.length;
            int n2 = 0;
            while (n2 < n) {
                double[][] xss;
                double[][] dArray = xss = object[n2];
                int n3 = xss.length;
                int n4 = 0;
                while (n4 < n3) {
                    double[] xs = dArray[n4];
                    Arrays.fill(xs, Double.NEGATIVE_INFINITY);
                    ++n4;
                }
                ++n2;
            }
            this.mViterbiScores = new double[numTokens][numTags];
            object = this.mViterbiScores;
            n = this.mViterbiScores.length;
            n2 = 0;
            while (n2 < n) {
                double[][] xs = object[n2];
                Arrays.fill((double[])xs, Double.NEGATIVE_INFINITY);
                ++n2;
            }
            this.mBackPointers = new int[numTokens - 1][numTags];
            object = this.mBackPointers;
            n = this.mBackPointers.length;
            n2 = 0;
            while (n2 < n) {
                double[][] ptrs = object[n2];
                Arrays.fill((int[])ptrs, -1);
                ++n2;
            }
            Vector[] edgeVectors = new Vector[numTags];
            ChainCrfFeatures features = ChainCrf.this.mFeatureExtractor.extract(tokens, ChainCrf.this.mTagList);
            int n5 = 1;
            while (n5 < numTokens) {
                Vector nodeVector = ChainCrf.this.nodeFeatures(n5, features);
                int kMinus1 = 0;
                while (kMinus1 < numTags) {
                    if (n5 != 1 || ChainCrf.this.mLegalTagStarts[kMinus1]) {
                        edgeVectors[kMinus1] = ChainCrf.this.edgeFeatures(n5, kMinus1, features);
                    }
                    ++kMinus1;
                }
                int k = 0;
                while (k < numTags) {
                    if (n5 != numTokens - 1 || ChainCrf.this.mLegalTagEnds[k]) {
                        double nodeScore = nodeVector.dotProduct(ChainCrf.this.mCoefficients[k]);
                        int kMinus12 = 0;
                        while (kMinus12 < numTags) {
                            if (ChainCrf.this.mLegalTagTransitions[kMinus12][k] && (n5 != 1 || ChainCrf.this.mLegalTagStarts[kMinus12])) {
                                this.mTransitionScores[n5 - 1][kMinus12][k] = nodeScore + edgeVectors[kMinus12].dotProduct(ChainCrf.this.mCoefficients[k]);
                            }
                            ++kMinus12;
                        }
                    }
                    ++k;
                }
                ++n5;
            }
            Vector nodeVector0 = ChainCrf.this.nodeFeatures(0, features);
            int k = 0;
            while (k < numTags) {
                if (ChainCrf.this.mLegalTagStarts[k]) {
                    this.mViterbiScores[0][k] = nodeVector0.dotProduct(ChainCrf.this.mCoefficients[k]);
                }
                ++k;
            }
            int n6 = 1;
            while (n6 < numTokens) {
                int k2 = 0;
                while (k2 < numTags) {
                    if (n6 != numTokens - 1 || ChainCrf.this.mLegalTagEnds[k2]) {
                        double bestScore = Double.NEGATIVE_INFINITY;
                        int backPtr = -1;
                        int kMinus1 = 0;
                        while (kMinus1 < numTags) {
                            double score;
                            if (ChainCrf.this.mLegalTagTransitions[kMinus1][k2] && (score = this.mViterbiScores[n6 - 1][kMinus1] + this.mTransitionScores[n6 - 1][kMinus1][k2]) > bestScore) {
                                bestScore = score;
                                backPtr = kMinus1;
                            }
                            ++kMinus1;
                        }
                        this.mViterbiScores[n6][k2] = bestScore;
                        this.mBackPointers[n6 - 1][k2] = backPtr;
                    }
                    ++k2;
                }
                ++n6;
            }
            this.mLogZ = normToConditional ? this.logZ() : 0.0;
            k = 0;
            while (k < numTags) {
                this.offer(this.mViterbiScores[numTokens - 1][k], null, numTokens - 1, k);
                ++k;
            }
        }

        double logZ() {
            double[] forwards = (double[])this.mViterbiScores[0].clone();
            int numTags = forwards.length;
            double[] previousForwards = new double[numTags];
            double[] exps = new double[numTags];
            int n = 0;
            while (n < this.mTransitionScores.length) {
                double[] temp = previousForwards;
                previousForwards = forwards;
                forwards = temp;
                int k = 0;
                while (k < numTags) {
                    int kMinus1 = 0;
                    while (kMinus1 < numTags) {
                        exps[kMinus1] = previousForwards[kMinus1] + this.mTransitionScores[n][kMinus1][k];
                        ++kMinus1;
                    }
                    forwards[k] = com.aliasi.util.Math.logSumOfExponentials(exps);
                    ++k;
                }
                ++n;
            }
            double logZ = com.aliasi.util.Math.logSumOfExponentials(forwards);
            return logZ;
        }

        void offer(double score, ForwardPointer pointer, int n, int k) {
            if (score == Double.NEGATIVE_INFINITY) {
                return;
            }
            if (pointer != null && pointer.mScore == Double.NEGATIVE_INFINITY) {
                return;
            }
            NBestState state = new NBestState(score, pointer, n, k);
            this.mPriorityQueue.offer(state);
        }

        @Override
        public ScoredTagging<E> bufferNext() {
            NBestState resultState = this.mPriorityQueue.poll();
            if (resultState == null) {
                return null;
            }
            int n = resultState.mN - 1;
            int k = resultState.mK;
            ForwardPointer fwdPointer = resultState.mForwardPointer;
            while (n >= 0) {
                this.addAlternatives(n, k, fwdPointer);
                int kMinus1 = this.mBackPointers[n][k];
                double fwdScore = this.mTransitionScores[n][kMinus1][k];
                if (fwdPointer != null) {
                    fwdScore += fwdPointer.mScore;
                }
                fwdPointer = new ForwardPointer(k, fwdPointer, fwdScore);
                k = kMinus1;
                --n;
            }
            ScoredTagging scoredTagging = this.toScoredTagging(resultState);
            return scoredTagging;
        }

        void addAlternatives(int n, int k, ForwardPointer fwdPointer) {
            int numTags = ChainCrf.this.mTagList.size();
            int kMinus1 = 0;
            while (kMinus1 < numTags) {
                if (kMinus1 != this.mBackPointers[n][k]) {
                    double score = this.mViterbiScores[n][kMinus1];
                    double fwdScore = this.mTransitionScores[n][kMinus1][k];
                    if (fwdPointer != null) {
                        fwdScore += fwdPointer.mScore;
                    }
                    ForwardPointer pointer = new ForwardPointer(k, fwdPointer, fwdScore);
                    this.offer(score, pointer, n, kMinus1);
                }
                ++kMinus1;
            }
        }

        public ScoredTagging<E> toScoredTagging(NBestState state) {
            ArrayList<String> tags = new ArrayList<String>(this.mTokens.size());
            int k = state.mK;
            tags.add((String)ChainCrf.this.mTagList.get(k));
            int n = state.mN;
            while (n > 0) {
                k = this.mBackPointers[n - 1][k];
                tags.add((String)ChainCrf.this.mTagList.get(k));
                --n;
            }
            Collections.reverse(tags);
            ForwardPointer pointer = state.mForwardPointer;
            while (pointer != null) {
                tags.add((String)ChainCrf.this.mTagList.get(pointer.mK));
                pointer = pointer.mPointer;
            }
            return new ScoredTagging(this.mTokens, tags, state.score() - this.mLogZ);
        }
    }

    static class NBestState
    implements Scored {
        final double mScore;
        final ForwardPointer mForwardPointer;
        final int mN;
        final int mK;

        NBestState(double score, ForwardPointer forwardPointer, int n, int k) {
            this.mScore = score;
            this.mForwardPointer = forwardPointer;
            this.mN = n;
            this.mK = k;
        }

        @Override
        public double score() {
            return this.mForwardPointer != null ? this.mScore + this.mForwardPointer.mScore : this.mScore;
        }
    }

    static class Serializer<F>
    extends AbstractExternalizable {
        static final long serialVersionUID = -4140295941325870709L;
        final ChainCrf<F> mCrf;

        public Serializer(ChainCrf<F> crf) {
            this.mCrf = crf;
        }

        public Serializer() {
            this(null);
        }

        @Override
        public void writeExternal(ObjectOutput out) throws IOException {
            int numTags = ((ChainCrf)this.mCrf).mTagList.size();
            out.writeInt(numTags);
            for (String tag : ((ChainCrf)this.mCrf).mTagList) {
                out.writeUTF(tag);
            }
            int i = 0;
            while (i < numTags) {
                out.writeBoolean(((ChainCrf)this.mCrf).mLegalTagStarts[i]);
                ++i;
            }
            i = 0;
            while (i < numTags) {
                out.writeBoolean(((ChainCrf)this.mCrf).mLegalTagEnds[i]);
                ++i;
            }
            i = 0;
            while (i < numTags) {
                int j = 0;
                while (j < numTags) {
                    out.writeBoolean(((ChainCrf)this.mCrf).mLegalTagTransitions[i][j]);
                    ++j;
                }
                ++i;
            }
            Vector[] vectorArray = ((ChainCrf)this.mCrf).mCoefficients;
            int n = vectorArray.length;
            int n2 = 0;
            while (n2 < n) {
                Vector v = vectorArray[n2];
                out.writeObject(v);
                ++n2;
            }
            out.writeObject(((ChainCrf)this.mCrf).mFeatureSymbolTable);
            out.writeObject(((ChainCrf)this.mCrf).mFeatureExtractor);
            out.writeBoolean(((ChainCrf)this.mCrf).mAddInterceptFeature);
        }

        @Override
        public Object read(ObjectInput in) throws ClassNotFoundException, IOException {
            int numTags = in.readInt();
            String[] tags = new String[numTags];
            int i = 0;
            while (i < tags.length) {
                tags[i] = in.readUTF();
                ++i;
            }
            boolean[] legalTagStarts = new boolean[numTags];
            int i2 = 0;
            while (i2 < numTags) {
                legalTagStarts[i2] = in.readBoolean();
                ++i2;
            }
            boolean[] legalTagEnds = new boolean[numTags];
            int i3 = 0;
            while (i3 < numTags) {
                legalTagEnds[i3] = in.readBoolean();
                ++i3;
            }
            boolean[][] legalTagTransitions = new boolean[numTags][numTags];
            int i4 = 0;
            while (i4 < numTags) {
                int j = 0;
                while (j < numTags) {
                    legalTagTransitions[i4][j] = in.readBoolean();
                    ++j;
                }
                ++i4;
            }
            Vector[] coefficients = new Vector[numTags];
            int i5 = 0;
            while (i5 < tags.length) {
                coefficients[i5] = (Vector)in.readObject();
                ++i5;
            }
            SymbolTable featureSymbolTable = (SymbolTable)in.readObject();
            ChainCrfFeatureExtractor featureExtractor = (ChainCrfFeatureExtractor)in.readObject();
            boolean addInterceptFeature = in.readBoolean();
            return new ChainCrf(tags, legalTagStarts, legalTagEnds, legalTagTransitions, coefficients, featureSymbolTable, featureExtractor, addInterceptFeature);
        }
    }
}

