/*
 * Decompiled with CFR 0.152.
 */
package jsat.math.optimization;

import java.util.ArrayList;
import java.util.concurrent.ExecutorService;
import jsat.linear.ConstantVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.FunctionP;
import jsat.math.FunctionVec;
import jsat.math.optimization.LBFGS;
import jsat.math.optimization.Optimizer2;
import jsat.utils.DoubleList;

public class ModifiedOWLQN
implements Optimizer2 {
    private int m = 10;
    private double lambda;
    private Vec lambdaMultipler = null;
    private static final double DEFAULT_EPS = 1.0E-12;
    private static final double DEFAULT_ALPHA_0 = 1.0;
    private static final double DEFAULT_BETA = 0.2;
    private static final double DEFAULT_GAMMA = 0.01;
    private double eps = 1.0E-12;
    private double alpha_0 = 1.0;
    private double beta = 0.2;
    private double gamma = 0.01;
    private int maxIterations = 500;

    public ModifiedOWLQN() {
        this(0.0);
    }

    public ModifiedOWLQN(double lambda) {
        this.setLambda(lambda);
    }

    protected ModifiedOWLQN(ModifiedOWLQN toCopy) {
        this(toCopy.lambda);
        if (toCopy.lambdaMultipler != null) {
            this.lambdaMultipler = toCopy.lambdaMultipler.clone();
        }
        this.eps = toCopy.eps;
        this.m = toCopy.m;
        this.alpha_0 = toCopy.alpha_0;
        this.beta = toCopy.beta;
        this.gamma = toCopy.gamma;
        this.maxIterations = toCopy.maxIterations;
    }

    public void setLambda(double lambda) {
        if (lambda < 0.0 || Double.isInfinite(lambda) || Double.isNaN(lambda)) {
            throw new IllegalArgumentException("lambda must be non-negative, not " + lambda);
        }
        this.lambda = lambda;
    }

    public void setLambdaMultipler(Vec lambdaMultipler) {
        this.lambdaMultipler = lambdaMultipler;
    }

    public Vec getLambdaMultipler() {
        return this.lambdaMultipler;
    }

    public void setM(int m) {
        if (m < 1) {
            throw new IllegalArgumentException("m must be positive, not " + m);
        }
        this.m = m;
    }

    public int getM() {
        return this.m;
    }

    public void setEps(double eps) {
        if (eps < 0.0 || Double.isInfinite(eps) || Double.isNaN(eps)) {
            throw new IllegalArgumentException("eps must be non-negative, not " + eps);
        }
        this.eps = eps;
    }

    public double getEps() {
        return this.eps;
    }

    public void setBeta(double beta) {
        if (beta <= 0.0 || beta >= 1.0 || Double.isNaN(beta)) {
            throw new IllegalArgumentException("shrinkage term must be in (0, 1), not " + beta);
        }
        this.beta = beta;
    }

    public double getBeta() {
        return this.beta;
    }

    @Override
    public void optimize(double tolerance, Vec w, Vec x0, Function f, FunctionVec fp, FunctionVec fpp) {
        this.optimize(tolerance, w, x0, f, fp, fpp, null);
    }

    @Override
    public void optimize(double tolerance, Vec w, Vec x0, Function f, FunctionVec fp, FunctionVec fpp, ExecutorService ex) {
        Vec lambdaMul = this.lambdaMultipler;
        if (this.lambdaMultipler == null) {
            lambdaMul = new ConstantVector(1.0, x0.length());
        }
        Vec x_cur = x0.clone();
        Vec x_grad = x0.clone();
        Vec x_gradNext = x0.clone();
        Vec x_grad_diff = x0.clone();
        Vec v_k = x0.clone();
        Vec d_k = x0.clone();
        Vec p_k = x0.clone();
        Vec x_alpha = x0.clone();
        Vec x_diff = x0.clone();
        DoubleList Rho = new DoubleList(this.m);
        ArrayList<Vec> S = new ArrayList<Vec>(this.m);
        ArrayList<Vec> Y = new ArrayList<Vec>(this.m);
        double[] alphas = new double[this.m];
        double f_x = ex != null && f instanceof FunctionP ? ((FunctionP)f).f(x_cur, ex) : f.f(x_cur);
        f_x += this.getL1Penalty(x_cur, lambdaMul);
        x_grad = ex != null ? fp.f(x_cur, x_grad, ex) : fp.f(x_cur, x_grad);
        for (int k = 0; k < this.maxIterations; ++k) {
            int i;
            double v_k_norm = 0.0;
            for (int i2 = 0; i2 < x_grad.length(); ++i2) {
                double x_i = x_cur.get(i2);
                double l_i = x_grad.get(i2);
                double lambda_i = this.lambda * lambdaMul.get(i2);
                double newVal = x_i > 0.0 ? l_i + lambda_i : (x_i < 0.0 ? l_i - lambda_i : (l_i + lambda_i < 0.0 ? l_i + lambda_i : (l_i - lambda_i > 0.0 ? l_i - lambda_i : 0.0)));
                v_k.set(i2, -newVal);
                v_k_norm += newVal * newVal;
            }
            v_k_norm = Math.sqrt(v_k_norm);
            double eps_k = Math.min(v_k_norm, this.eps);
            boolean doGDstep = false;
            for (int i3 = 0; i3 < v_k.length() && !doGDstep; ++i3) {
                boolean isInI;
                double x_i = x_cur.get(i3);
                double v_i = v_k.get(i3);
                boolean bl = isInI = 0.0 < Math.abs(x_i) && Math.abs(x_i) < eps_k && x_i * v_i < 0.0;
                if (!isInI) continue;
                doGDstep = true;
            }
            double alpha = this.alpha_0;
            double f_x_alpha = 0.0;
            if (!doGDstep) {
                LBFGS.twoLoopHp(v_k, Rho, S, Y, d_k, alphas);
                for (int i4 = 0; i4 < p_k.length(); ++i4) {
                    if (Math.signum(d_k.get(i4)) == Math.signum(v_k.get(i4))) {
                        p_k.set(i4, d_k.get(i4));
                        continue;
                    }
                    p_k.set(i4, 0.0);
                }
                double rightSideMainTerm = this.gamma * v_k.dot(d_k);
                alpha /= this.beta;
                do {
                    x_cur.copyTo(x_alpha);
                    x_alpha.mutableSubtract(-(alpha *= this.beta), p_k);
                    for (i = 0; i < p_k.length(); ++i) {
                        double toUse;
                        double x_i = x_cur.get(i);
                        double v_i = v_k.get(i);
                        double d = toUse = x_i != 0.0 ? x_i : v_i;
                        if (Math.signum(x_alpha.get(i)) == Math.signum(toUse)) continue;
                        x_alpha.set(i, 0.0);
                    }
                    double d = f_x_alpha = ex != null && f instanceof FunctionP ? ((FunctionP)f).f(x_alpha, ex) : f.f(x_alpha);
                } while ((f_x_alpha += this.getL1Penalty(x_alpha, lambdaMul)) > f_x - alpha * rightSideMainTerm);
                x_alpha.copyTo(x_diff);
                x_diff.mutableSubtract(x_cur);
            } else {
                alpha /= this.beta;
                do {
                    x_grad.copyTo(x_alpha);
                    x_alpha.mutableMultiply(-(alpha *= this.beta));
                    x_alpha.mutableAdd(x_cur);
                    for (int i5 = 0; i5 < x_alpha.length(); ++i5) {
                        double u_i = x_alpha.get(i5);
                        double lambda_i = this.lambda * lambdaMul.get(i5);
                        x_alpha.set(i5, Math.signum(u_i) * Math.max(0.0, Math.abs(u_i) - lambda_i * alpha));
                    }
                    x_alpha.copyTo(x_diff);
                    x_diff.mutableSubtract(x_cur);
                    double d = f_x_alpha = ex != null && f instanceof FunctionP ? ((FunctionP)f).f(x_alpha, ex) : f.f(x_alpha);
                } while ((f_x_alpha += this.getL1Penalty(x_alpha, lambdaMul)) > f_x - this.gamma / (2.0 * alpha) * x_diff.dot(x_diff));
            }
            S.add(0, x_diff.clone());
            x_gradNext = ex != null ? fp.f(x_alpha, x_gradNext, ex) : fp.f(x_alpha, x_gradNext);
            double maxGrad = 0.0;
            for (i = 0; i < x_gradNext.length(); ++i) {
                maxGrad = Math.max(maxGrad, Math.abs(x_gradNext.get(i)));
            }
            if (maxGrad < tolerance || f_x < tolerance || x_diff.pNorm(1.0) < tolerance) break;
            x_gradNext.copyTo(x_grad_diff);
            x_grad_diff.mutableSubtract(x_grad);
            Y.add(0, x_grad_diff.clone());
            Rho.add(0, Double.valueOf(1.0 / x_diff.dot(x_grad_diff)));
            if (Double.isInfinite((Double)Rho.get(0)) || Double.isNaN((Double)Rho.get(0))) {
                Rho.clear();
                S.clear();
                Y.clear();
            }
            while (Rho.size() > this.m) {
                Rho.remove(this.m);
                S.remove(this.m);
                Y.remove(this.m);
            }
            f_x = f_x_alpha;
            x_alpha.copyTo(x_cur);
            x_gradNext.copyTo(x_grad);
        }
        x_cur.copyTo(w);
    }

    private double getL1Penalty(Vec w, Vec lambdaMul) {
        if (this.lambda <= 0.0) {
            return 0.0;
        }
        double pen = 0.0;
        for (IndexValue iv : w) {
            pen += this.lambda * lambdaMul.get(iv.getIndex()) * Math.abs(iv.getValue());
        }
        return pen;
    }

    @Override
    public void setMaximumIterations(int iterations) {
        if (iterations < 1) {
            throw new IllegalArgumentException("Number of iterations must be positive, not " + iterations);
        }
        this.maxIterations = iterations;
    }

    @Override
    public int getMaximumIterations() {
        return this.maxIterations;
    }

    @Override
    public ModifiedOWLQN clone() {
        return new ModifiedOWLQN(this);
    }
}

