/*
 * Decompiled with CFR 0.152.
 */
package jsat.math.optimization;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.FunctionP;
import jsat.math.FunctionVec;
import jsat.math.optimization.BacktrackingArmijoLineSearch;
import jsat.math.optimization.LineSearch;
import jsat.math.optimization.Optimizer2;
import jsat.utils.DoubleList;

public class LBFGS
implements Optimizer2 {
    private int m;
    private int maxIterations;
    private LineSearch lineSearch;
    private boolean inftNormCriterion = true;

    public LBFGS() {
        this(10);
    }

    public LBFGS(int m) {
        this(m, 500, new BacktrackingArmijoLineSearch());
    }

    public LBFGS(int m, int maxIterations, LineSearch lineSearch) {
        this.setM(m);
        this.setMaximumIterations(maxIterations);
        this.setLineSearch(lineSearch);
    }

    public static void twoLoopHp(Vec x_grad, List<Double> rho, List<Vec> s, List<Vec> y, Vec q, double[] alphas) {
        int i;
        x_grad.copyTo(q);
        if (s.isEmpty()) {
            return;
        }
        for (i = 0; i < s.size(); ++i) {
            Vec s_i = s.get(i);
            Vec y_i = y.get(i);
            double alpha_i = alphas[i] = rho.get(i) * s_i.dot(q);
            q.mutableSubtract(alpha_i, y_i);
        }
        q.mutableMultiply(s.get(0).dot(y.get(0)) / y.get(0).dot(y.get(0)));
        for (i = s.size() - 1; i >= 0; --i) {
            double beta = rho.get(i) * y.get(i).dot(q);
            q.mutableAdd(alphas[i] - beta, s.get(i));
        }
    }

    @Override
    public void optimize(double tolerance, Vec w, Vec x0, Function f, FunctionVec fp, FunctionVec fpp) {
        this.optimize(tolerance, w, x0, f, fp, fpp, null);
    }

    @Override
    public void optimize(double tolerance, Vec w, Vec x0, Function f, FunctionVec fp, FunctionVec fpp, ExecutorService ex) {
        LineSearch search = this.lineSearch.clone();
        double[] f_xVal = new double[1];
        DoubleList Rho = new DoubleList(this.m);
        ArrayList<Vec> S = new ArrayList<Vec>(this.m);
        ArrayList<Vec> Y = new ArrayList<Vec>(this.m);
        Vec x_prev = x0.clone();
        Vec x_cur = x0.clone();
        f_xVal[0] = ex != null && f instanceof FunctionP ? ((FunctionP)f).f(x_prev, ex) : f.f(x_prev);
        Vec x_grad = x0.clone();
        x_grad.zeroOut();
        Vec x_gradPrev = x_grad.clone();
        Vec p_k = x_grad.clone();
        Vec s_k = x_grad.clone();
        Vec y_k = x_grad.clone();
        x_grad = ex != null ? fp.f(x_cur, x_grad, ex) : fp.f(x_cur, x_grad);
        double[] alphas = new double[this.m];
        for (int iter = 0; this.gradConvgHelper(x_grad) > tolerance && iter < this.maxIterations; ++iter) {
            LBFGS.twoLoopHp(x_grad, Rho, S, Y, p_k, alphas);
            p_k.mutableMultiply(-1.0);
            x_cur.copyTo(x_prev);
            x_grad.copyTo(x_gradPrev);
            double alpha_k = search.lineSearch(1.0, x_prev, x_gradPrev, p_k, f, fp, f_xVal[0], x_gradPrev.dot(p_k), x_cur, f_xVal, x_grad, ex);
            if (alpha_k < 1.0E-12) break;
            if (!search.updatesGrad()) {
                if (ex != null) {
                    fp.f(x_cur, x_grad, ex);
                } else {
                    fp.f(x_cur, x_grad);
                }
            }
            x_cur.copyTo(s_k);
            s_k.mutableSubtract(x_prev);
            S.add(0, s_k.clone());
            x_grad.copyTo(y_k);
            y_k.mutableSubtract(x_gradPrev);
            Y.add(0, y_k.clone());
            Rho.add(0, Double.valueOf(1.0 / s_k.dot(y_k)));
            if (Double.isInfinite((Double)Rho.get(0)) || Double.isNaN((Double)Rho.get(0))) {
                Rho.clear();
                S.clear();
                Y.clear();
            }
            while (Rho.size() > this.m) {
                Rho.remove(this.m);
                S.remove(this.m);
                Y.remove(this.m);
            }
        }
        x_cur.copyTo(w);
    }

    public void setInftNormCriterion(boolean inftNormCriterion) {
        this.inftNormCriterion = inftNormCriterion;
    }

    public boolean isInftNormCriterion() {
        return this.inftNormCriterion;
    }

    private double gradConvgHelper(Vec grad) {
        if (!this.inftNormCriterion) {
            return grad.pNorm(2.0);
        }
        double max = 0.0;
        for (IndexValue iv : grad) {
            max = Math.max(max, Math.abs(iv.getValue()));
        }
        return max;
    }

    public void setM(int m) {
        if (m < 1) {
            throw new IllegalArgumentException("m must be positive, not " + m);
        }
        this.m = m;
    }

    public int getM() {
        return this.m;
    }

    public void setLineSearch(LineSearch lineSearch) {
        this.lineSearch = lineSearch;
    }

    public LineSearch getLineSearch() {
        return this.lineSearch;
    }

    @Override
    public void setMaximumIterations(int iterations) {
        if (iterations < 1) {
            throw new IllegalArgumentException("Number of iterations must be positive, not " + iterations);
        }
        this.maxIterations = iterations;
    }

    @Override
    public int getMaximumIterations() {
        return this.maxIterations;
    }

    @Override
    public LBFGS clone() {
        return new LBFGS(this.m, this.maxIterations, this.lineSearch.clone());
    }
}

