/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.math;

import edu.berkeley.nlp.mapper.AsynchronousMapper;
import edu.berkeley.nlp.mapper.SimpleMapper;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.ObjectiveItemDifferentiableFunction;
import edu.berkeley.nlp.math.Regularizer;
import edu.berkeley.nlp.util.CallbackFunction;
import edu.berkeley.nlp.util.CollectionUtils;
import edu.berkeley.nlp.util.Logger;
import edu.berkeley.nlp.util.Option;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class StochasticObjectiveOptimizer<I> {
    Collection<I> items;
    List<? extends ObjectiveItemDifferentiableFunction<I>> itemFns;
    Regularizer regularizer;
    double initAlpha = 0.5;
    double upAlphaMult = 1.1;
    double downAlphaMult = 0.5;
    Object weightLock = new Object();
    double[] weights;
    double alpha;
    CallbackFunction iterDoneCallback;
    boolean printProgress = true;
    Random rand;
    @Option
    public int randSeed = 0;
    @Option
    public boolean doAveraging = false;
    @Option
    public boolean shuffleData = false;
    double[] sumWeightVector;
    int numUpdates;

    public StochasticObjectiveOptimizer(double initAlpha, double upAlphaMult, double downAlphaMult) {
        this(initAlpha, upAlphaMult, downAlphaMult, true);
    }

    public StochasticObjectiveOptimizer(double initAlpha, double upAlphaMult, double downAlphaMult, boolean printProgress) {
        this.initAlpha = initAlpha;
        this.upAlphaMult = upAlphaMult;
        this.downAlphaMult = downAlphaMult;
        this.printProgress = printProgress;
        this.rand = new Random(this.randSeed);
    }

    public void setIterationCallback(CallbackFunction iterDoneCallback) {
        this.iterDoneCallback = iterDoneCallback;
    }

    private double doIter() {
        ArrayList<GradMapper> gradMappers = new ArrayList<GradMapper>();
        for (ObjectiveItemDifferentiableFunction<I> itemFn : this.itemFns) {
            gradMappers.add(new GradMapper(itemFn));
        }
        List<I> shuffledItems = this.shuffleData ? CollectionUtils.shuffle(this.items, this.rand) : new ArrayList<I>(this.items);
        AsynchronousMapper.doMapping(shuffledItems, gradMappers);
        double val = 0.0;
        for (GradMapper mapper : gradMappers) {
            val += mapper.val;
        }
        return val;
    }

    public double[] minimize(double[] initWeights, int numIters, Collection<I> items, List<? extends ObjectiveItemDifferentiableFunction<I>> itemFns, Regularizer regularizer) {
        this.items = items;
        this.itemFns = itemFns;
        this.numUpdates = 0;
        this.regularizer = regularizer;
        this.alpha = this.initAlpha;
        this.weights = DoubleArrays.clone(initWeights);
        this.sumWeightVector = DoubleArrays.constantArray(0.0, this.weights.length);
        double lastVal = Double.POSITIVE_INFINITY;
        for (int iter = 0; iter < numIters; ++iter) {
            double val = this.doIter();
            double alphaMult = val < lastVal ? this.upAlphaMult : this.downAlphaMult;
            this.alpha *= alphaMult;
            lastVal = val;
            if (this.printProgress) {
                Logger.logs("[StochasticObjectiveOptimizer] Ended Iteration %d with value %.5f", iter + 1, val);
                Logger.logs("[StochasticObjectiveOptimizer] New Alpha: %.5f (scaled by %.5f)", this.alpha, alphaMult);
            }
            if (this.iterDoneCallback != null) {
                this.iterDoneCallback.callback(iter, this.doAveraging ? this.avgWeightVector() : this.weights, val, this.alpha);
            }
            if (!(this.alpha < this.initAlpha * Math.pow(10.0, -2.0))) continue;
            Logger.logs("[StochasticObjectiveOptimizer] alpha %.5f below tolerance %.5f, saying converged", this.alpha, this.initAlpha * Math.pow(10.0, -2.0));
            break;
        }
        return this.doAveraging ? this.avgWeightVector() : this.weights;
    }

    private double[] avgWeightVector() {
        double[] avgWeights = DoubleArrays.clone(this.sumWeightVector);
        DoubleArrays.scale(avgWeights, 1.0 / (double)this.numUpdates);
        return avgWeights;
    }

    public int dimension() {
        return this.itemFns.get(0).dimension();
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    class ValMapper
    implements SimpleMapper<I> {
        double val = 0.0;
        ObjectiveItemDifferentiableFunction<I> itemFn;

        ValMapper(ObjectiveItemDifferentiableFunction<I> itemFn) {
            this.itemFn = itemFn;
        }

        @Override
        public void map(I elem) {
            this.val += this.itemFn.update(elem, null);
            this.val += StochasticObjectiveOptimizer.this.regularizer.val(StochasticObjectiveOptimizer.this.weights, 1.0 / (double)StochasticObjectiveOptimizer.this.items.size());
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    class GradMapper
    implements SimpleMapper<I> {
        double val = 0.0;
        ObjectiveItemDifferentiableFunction<I> itemFn;

        GradMapper(ObjectiveItemDifferentiableFunction<I> itemFn) {
            this.itemFn = itemFn;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void map(I elem) {
            double[] localWeights;
            Object object = StochasticObjectiveOptimizer.this.weightLock;
            synchronized (object) {
                localWeights = DoubleArrays.clone(StochasticObjectiveOptimizer.this.weights);
            }
            double[] localGrad = new double[StochasticObjectiveOptimizer.this.dimension()];
            this.itemFn.setWeights(localWeights);
            this.val += this.itemFn.update(elem, localGrad);
            if (StochasticObjectiveOptimizer.this.regularizer != null) {
                this.val += StochasticObjectiveOptimizer.this.regularizer.update(localWeights, localGrad, 1.0 / (double)StochasticObjectiveOptimizer.this.items.size());
            }
            Object object2 = StochasticObjectiveOptimizer.this.weightLock;
            synchronized (object2) {
                DoubleArrays.addInPlace(StochasticObjectiveOptimizer.this.weights, localGrad, -StochasticObjectiveOptimizer.this.alpha);
                DoubleArrays.addInPlace(StochasticObjectiveOptimizer.this.sumWeightVector, StochasticObjectiveOptimizer.this.weights);
                ++StochasticObjectiveOptimizer.this.numUpdates;
            }
        }
    }
}

