/*
 * Decompiled with CFR 0.152.
 */
package jsat.math.optimization.stochastic;

import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;
import jsat.math.optimization.stochastic.GradientUpdater;

public class Adam
implements GradientUpdater {
    private static final long serialVersionUID = 5352504067435579553L;
    private Vec m;
    private Vec v;
    private long t;
    private double alpha;
    private double beta_1;
    private double beta_2;
    private double eps;
    private double lambda;
    private double vBias;
    private double mBias;
    public static final double DEFAULT_ALPHA = 2.0E-4;
    public static final double DEFAULT_BETA_1 = 0.1;
    public static final double DEFAULT_BETA_2 = 0.001;
    public static final double DEFAULT_EPS = 1.0E-8;
    public static final double DEFAULT_LAMBDA = 1.0E-8;

    public Adam() {
        this(2.0E-4, 0.1, 0.001, 1.0E-8, 1.0E-8);
    }

    public Adam(double alpha, double beta_1, double beta_2, double eps, double lambda) {
        if (alpha <= 0.0 || Double.isInfinite(alpha) || Double.isNaN(alpha)) {
            throw new IllegalArgumentException("alpha must be a positive value, not " + alpha);
        }
        if (beta_1 <= 0.0 || beta_1 > 1.0 || Double.isInfinite(beta_1) || Double.isNaN(beta_1)) {
            throw new IllegalArgumentException("beta_1 must be in (0, 1], not " + beta_1);
        }
        if (beta_2 <= 0.0 || beta_2 > 1.0 || Double.isInfinite(beta_2) || Double.isNaN(beta_2)) {
            throw new IllegalArgumentException("beta_2 must be in (0, 1], not " + beta_2);
        }
        if (Math.pow(1.0 - beta_1, 2.0) / Math.sqrt(1.0 - beta_2) >= 1.0) {
            throw new IllegalArgumentException("the required property (1-beta_1)^2 / sqrt(1-beta_2) < 1, is not held by beta_1=" + beta_1 + " and beta_2=" + beta_2);
        }
        if (lambda <= 0.0 || lambda >= 1.0 || Double.isInfinite(lambda) || Double.isNaN(lambda)) {
            throw new IllegalArgumentException("lambda must be in (0, 1), not " + lambda);
        }
        this.alpha = alpha;
        this.beta_1 = beta_1;
        this.beta_2 = beta_2;
        this.eps = eps;
        this.lambda = lambda;
    }

    public Adam(Adam toCopy) {
        this.alpha = toCopy.alpha;
        this.beta_1 = toCopy.beta_1;
        this.beta_2 = toCopy.beta_2;
        this.eps = toCopy.eps;
        this.lambda = toCopy.lambda;
        this.t = toCopy.t;
        this.mBias = toCopy.mBias;
        this.vBias = toCopy.vBias;
        if (toCopy.m != null) {
            this.m = toCopy.m.clone();
            this.v = toCopy.v.clone();
        }
    }

    @Override
    public void update(Vec x, Vec grad, double eta) {
        this.update(x, grad, eta, 0.0, 0.0);
    }

    @Override
    public double update(Vec x, Vec grad, double eta, double bias, double biasGrad) {
        ++this.t;
        double beta_1t = 1.0 - (1.0 - this.beta_1) * Math.pow(this.lambda, this.t - 1L);
        this.m.mutableMultiply(1.0 - beta_1t);
        this.m.mutableAdd(beta_1t, grad);
        this.mBias = 1.0 - beta_1t + beta_1t * biasGrad;
        this.v.mutableMultiply(1.0 - this.beta_2);
        this.vBias = (1.0 - this.beta_2) * this.vBias + this.beta_2 * biasGrad * biasGrad;
        for (IndexValue iv : grad) {
            double g_i = iv.getValue();
            this.v.increment(iv.getIndex(), this.beta_2 * (g_i * g_i));
        }
        double cnst = eta * this.alpha * Math.sqrt(1.0 - Math.pow(1.0 - this.beta_2, this.t)) / (1.0 - Math.pow(1.0 - this.beta_1, this.t));
        for (int i = 0; i < this.m.length(); ++i) {
            x.increment(i, -cnst * this.m.get(i) / (Math.sqrt(this.v.get(i)) + this.eps));
        }
        return cnst * this.mBias / (Math.sqrt(this.vBias) + this.eps);
    }

    @Override
    public Adam clone() {
        return new Adam(this);
    }

    @Override
    public void setup(int d) {
        this.t = 0L;
        this.m = new ScaledVector(new DenseVector(d));
        this.v = new ScaledVector(new DenseVector(d));
        this.mBias = 0.0;
        this.vBias = 0.0;
    }
}

