/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.crf;

import edu.berkeley.nlp.classify.Encoding;
import edu.berkeley.nlp.classify.FeatureExtractor;
import edu.berkeley.nlp.classify.IndexLinearizer;
import edu.berkeley.nlp.crf.Counts;
import edu.berkeley.nlp.crf.LabeledInstanceSequence;
import edu.berkeley.nlp.math.DifferentiableFunction;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Pair;
import java.util.List;
import java.util.Map;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class CRFObjectiveFunction<V, E, F, L>
implements DifferentiableFunction {
    private final List<? extends LabeledInstanceSequence<V, E, L>> trainingData;
    private final Encoding<F, L> encoding;
    private final Counts<V, E, F, L> counts;
    private final IndexLinearizer il;
    private final double sigma;
    double lastValue;
    double[] lastDerivative;
    double[] lastX;

    public CRFObjectiveFunction(List<? extends LabeledInstanceSequence<V, E, L>> trainingData, Encoding<F, L> encoding, FeatureExtractor<V, F> vertexExtractor, FeatureExtractor<E, F> edgeExtractor, double sigma) {
        this.trainingData = trainingData;
        this.encoding = encoding;
        this.counts = new Counts<V, E, F, L>(encoding, vertexExtractor, edgeExtractor);
        this.il = new IndexLinearizer(encoding.getNumFeatures(), encoding.getNumLabels());
        this.sigma = sigma;
    }

    @Override
    public int dimension() {
        return this.il.getNumLinearIndexes();
    }

    @Override
    public double valueAt(double[] x) {
        this.ensureCache(x);
        return this.lastValue;
    }

    @Override
    public double[] derivativeAt(double[] x) {
        this.ensureCache(x);
        return this.lastDerivative;
    }

    private void ensureCache(double[] x) {
        if (this.requiresUpdate(this.lastX, x)) {
            Pair<Double, double[]> currentValueAndDerivative = this.calculate(x);
            this.lastValue = currentValueAndDerivative.getFirst();
            this.lastDerivative = currentValueAndDerivative.getSecond();
            this.lastX = x;
        }
    }

    private boolean requiresUpdate(double[] lastX, double[] x) {
        if (lastX == null) {
            return true;
        }
        for (int i = 0; i < x.length; ++i) {
            if (lastX[i] == x[i]) continue;
            return true;
        }
        return false;
    }

    private Pair<Double, double[]> calculate(double[] x) {
        double objective = 0.0;
        double[] derivatives = new double[this.dimension()];
        List<Counter<F>> empiricalCounts = this.counts.getEmpiricalCounts(this.trainingData);
        for (int l = 0; l < empiricalCounts.size(); ++l) {
            for (Map.Entry<F, Double> entry : empiricalCounts.get(l).entrySet()) {
                int index = this.il.getLinearIndex(this.encoding.getFeatureIndex(entry.getKey()), l);
                objective -= entry.getValue() * x[index];
                int n = index;
                derivatives[n] = derivatives[n] - entry.getValue();
            }
        }
        Pair<Double, List<Counter<F>>> results = this.counts.getLogNormalizationAndExpectedCounts(this.trainingData, x);
        objective += results.getFirst().doubleValue();
        List<Counter<F>> expectedCounts = results.getSecond();
        for (int l = 0; l < expectedCounts.size(); ++l) {
            for (Map.Entry<F, Double> entry : expectedCounts.get(l).entrySet()) {
                int index;
                int n = index = this.il.getLinearIndex(this.encoding.getFeatureIndex(entry.getKey()), l);
                derivatives[n] = derivatives[n] + entry.getValue();
            }
        }
        int i = 0;
        while (i < x.length) {
            double weight = x[i];
            objective += weight * weight / (2.0 * this.sigma * this.sigma);
            int n = i++;
            derivatives[n] = derivatives[n] + weight / (this.sigma * this.sigma);
        }
        return Pair.makePair(objective, derivatives);
    }
}

