/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.SimpleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;
import jsat.lossfunctions.HingeLoss;
import jsat.lossfunctions.LossC;
import jsat.lossfunctions.LossFunc;
import jsat.lossfunctions.LossMC;
import jsat.lossfunctions.LossR;
import jsat.math.decayrates.DecayRate;
import jsat.math.decayrates.PowerDecay;
import jsat.math.optimization.stochastic.GradientUpdater;
import jsat.math.optimization.stochastic.SimpleSGD;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.BaseUpdateableRegressor;
import jsat.regression.RegressionDataSet;
import jsat.regression.UpdateableRegressor;

public class LinearSGD
extends BaseUpdateableClassifier
implements UpdateableRegressor,
Parameterized,
SimpleWeightVectorModel {
    private static final long serialVersionUID = -59695592724956535L;
    private LossFunc loss;
    private GradientUpdater gradientUpdater;
    private double eta;
    private DecayRate decay;
    private Vec[] ws;
    private GradientUpdater[] gus;
    private double[] bs;
    private int time;
    private double lambda0;
    private double lambda1;
    private double l1U;
    private double[][] l1Q;
    private boolean useBias = true;

    public LinearSGD() {
        this(new HingeLoss(), 1.0E-4, 0.0);
    }

    public LinearSGD(LossFunc loss, double lambda0, double lambda1) {
        this(loss, 0.001, new PowerDecay(1.0, 0.1), lambda0, lambda1);
    }

    public LinearSGD(LossFunc loss, double eta, DecayRate decay, double lambda0, double lambda1) {
        this.setLoss(loss);
        this.setEta(eta);
        this.setEtaDecay(decay);
        this.setGradientUpdater(new SimpleSGD());
        this.setLambda0(lambda0);
        this.setLambda1(lambda1);
    }

    public LinearSGD(LinearSGD toClone) {
        int i;
        this.loss = toClone.loss.clone();
        this.eta = toClone.eta;
        this.decay = toClone.decay.clone();
        this.time = toClone.time;
        this.lambda0 = toClone.lambda0;
        this.lambda1 = toClone.lambda1;
        this.l1U = toClone.l1U;
        this.useBias = toClone.useBias;
        this.gradientUpdater = toClone.gradientUpdater;
        if (toClone.l1Q != null) {
            this.l1Q = new double[toClone.l1Q.length][];
            for (i = 0; i < toClone.l1Q.length; ++i) {
                this.l1Q[i] = Arrays.copyOf(toClone.l1Q[i], toClone.l1Q[i].length);
            }
        }
        if (toClone.ws != null) {
            this.ws = new Vec[toClone.ws.length];
            this.bs = new double[toClone.bs.length];
            this.gus = new GradientUpdater[toClone.gus.length];
            for (i = 0; i < this.ws.length; ++i) {
                this.ws[i] = toClone.ws[i].clone();
                this.bs[i] = toClone.bs[i];
                this.gus[i] = toClone.gus[i].clone();
            }
        }
    }

    public void setGradientUpdater(GradientUpdater gradientUpdater) {
        if (gradientUpdater == null) {
            throw new IllegalArgumentException("Gradient updater must be non-null");
        }
        this.gradientUpdater = gradientUpdater;
    }

    public GradientUpdater getGradientUpdater() {
        return this.gradientUpdater;
    }

    public void setEtaDecay(DecayRate decay) {
        this.decay = decay;
    }

    public DecayRate getEtaDecay() {
        return this.decay;
    }

    public void setEta(double eta) {
        if (eta <= 0.0 || Double.isNaN(eta) || Double.isInfinite(eta)) {
            throw new IllegalArgumentException("eta must be a positive constant, not " + eta);
        }
        this.eta = eta;
    }

    public double getEta() {
        return this.eta;
    }

    public void setLoss(LossFunc loss) {
        this.loss = loss;
    }

    public LossFunc getLoss() {
        return this.loss;
    }

    public void setLambda0(double lambda0) {
        if (lambda0 < 0.0 || Double.isNaN(lambda0) || Double.isInfinite(lambda0)) {
            throw new IllegalArgumentException("Lambda0 must be non-negative, not " + lambda0);
        }
        this.lambda0 = lambda0;
    }

    public double getLambda0() {
        return this.lambda0;
    }

    public void setLambda1(double lambda1) {
        if (lambda1 < 0.0 || Double.isNaN(lambda1) || Double.isInfinite(lambda1)) {
            throw new IllegalArgumentException("Lambda1 must be non-negative, not " + lambda1);
        }
        this.lambda1 = lambda1;
    }

    public double getLambda1() {
        return this.lambda1;
    }

    public void setUseBias(boolean useBias) {
        this.useBias = useBias;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    @Override
    public LinearSGD clone() {
        return new LinearSGD(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (!(this.loss instanceof LossC)) {
            throw new FailedToFitException("Loss function " + this.loss.getClass().getSimpleName() + " only supports regression");
        }
        if (predicting.getNumOfCategories() == 2) {
            this.ws = new Vec[1];
            this.bs = new double[1];
            this.gus = new GradientUpdater[1];
        } else {
            if (!(this.loss instanceof LossMC)) {
                throw new FailedToFitException("Loss function " + this.loss.getClass().getSimpleName() + " only supports binary classification");
            }
            this.ws = new Vec[predicting.getNumOfCategories()];
            this.bs = new double[predicting.getNumOfCategories()];
            this.gus = new GradientUpdater[predicting.getNumOfCategories()];
        }
        this.setUpShared(numericAttributes);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes) {
        if (!(this.loss instanceof LossR)) {
            throw new FailedToFitException("Loss function " + this.loss.getClass().getSimpleName() + "does not support regression");
        }
        this.ws = new Vec[1];
        this.bs = new double[1];
        this.gus = new GradientUpdater[1];
        this.setUpShared(numericAttributes);
    }

    private void setUpShared(int numericAttributes) {
        if (numericAttributes <= 0) {
            throw new FailedToFitException("LinearSGD requires numeric features to use");
        }
        for (int i = 0; i < this.ws.length; ++i) {
            this.ws[i] = new ScaledVector(new DenseVector(numericAttributes));
            this.gus[i] = this.gradientUpdater.clone();
            this.gus[i].setup(this.ws[i].length());
        }
        this.time = 0;
        this.l1U = 0.0;
        this.l1Q = this.lambda1 > 0.0 ? new double[this.ws.length][this.ws[0].length()] : (double[][])null;
    }

    @Override
    public void update(DataPoint dataPoint, int targetClass) {
        double eta_t = this.decay.rate(this.time++, this.eta);
        Vec x = dataPoint.getNumericalValues();
        this.applyL2Reg(eta_t);
        if (this.ws.length == 1) {
            double y = targetClass * 2 - 1;
            double lossD = ((LossC)this.loss).getDeriv(this.ws[0].dot(x) + this.bs[0], y);
            this.performGradientUpdate(0, eta_t, lossD, x);
        } else {
            DenseVector pred = new DenseVector(this.ws.length);
            for (int i = 0; i < this.ws.length; ++i) {
                ((Vec)pred).set(i, this.ws[i].dot(x) + this.bs[i]);
            }
            ((LossMC)this.loss).process(pred, pred);
            ((LossMC)this.loss).deriv(pred, pred, targetClass);
            for (IndexValue iv : pred) {
                int i = iv.getIndex();
                double lossD = iv.getValue();
                this.performGradientUpdate(i, eta_t, lossD, x);
            }
        }
        this.applyL1Reg(eta_t, x);
    }

    private void performGradientUpdate(int i, double eta_t, double lossD, Vec x) {
        ScaledVector grad = new ScaledVector(lossD, x);
        if (this.useBias) {
            int n = i;
            this.bs[n] = this.bs[n] - this.gus[i].update(this.ws[i], grad, eta_t, this.bs[i], lossD);
        } else {
            this.gus[i].update(this.ws[i], grad, eta_t);
        }
    }

    @Override
    public void update(DataPoint dataPoint, double targetValue) {
        double eta_t = this.decay.rate(this.time++, this.eta);
        Vec x = dataPoint.getNumericalValues();
        this.applyL2Reg(eta_t);
        double lossD = ((LossR)this.loss).getDeriv(this.ws[0].dot(x) + this.bs[0], targetValue);
        this.performGradientUpdate(0, eta_t, lossD, x);
        this.applyL1Reg(eta_t, x);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        Vec x = data.getNumericalValues();
        if (this.ws.length == 1) {
            return ((LossC)this.loss).getClassification(this.ws[0].dot(x) + this.bs[0]);
        }
        DenseVector pred = new DenseVector(this.ws.length);
        for (int i = 0; i < this.ws.length; ++i) {
            ((Vec)pred).set(i, this.ws[i].dot(x) + this.bs[i]);
        }
        ((LossMC)this.loss).process(pred, pred);
        return ((LossMC)this.loss).getClassification(pred);
    }

    @Override
    public double regress(DataPoint data) {
        Vec x = data.getNumericalValues();
        return ((LossR)this.loss).getRegression(this.ws[0].dot(x) + this.bs[0]);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    private void applyL2Reg(double eta_t) {
        if (this.lambda0 > 0.0) {
            for (Vec v : this.ws) {
                v.mutableMultiply(1.0 - eta_t * this.lambda0);
            }
        }
    }

    private void applyL1Reg(double eta_t, Vec x) {
        if (this.lambda1 > 0.0) {
            this.l1U += eta_t * this.lambda1;
            for (int k = 0; k < this.ws.length; ++k) {
                Vec w_k = this.ws[k];
                double[] l1Q_k = this.l1Q[k];
                for (IndexValue iv : x) {
                    int i = iv.getIndex();
                    double z = w_k.get(i);
                    double newW_i = 0.0;
                    if (z > 0.0) {
                        newW_i = Math.max(0.0, z - (this.l1U + l1Q_k[i]));
                    } else if (z < 0.0) {
                        newW_i = Math.min(0.0, z + (this.l1U - l1Q_k[i]));
                    }
                    int n = i;
                    l1Q_k[n] = l1Q_k[n] + (newW_i - z);
                    w_k.set(i, newW_i);
                }
            }
        }
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadPool) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        BaseUpdateableRegressor.trainEpochs(dataSet, this, this.getEpochs());
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    @Override
    public Vec getRawWeight(int index) {
        return this.ws[index];
    }

    @Override
    public double getBias(int index) {
        return this.bs[index];
    }

    @Override
    public int numWeightsVecs() {
        return this.ws.length;
    }

    public static Distribution guessLambda0(DataSet d) {
        return new LogUniform(1.0E-7, 0.01);
    }

    public static Distribution guessLambda1(DataSet d) {
        int N = d.getSampleSize();
        return new LogUniform(1.0E-7 / (double)N, 0.001 / (double)N);
    }
}

