/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.linear.StochasticSTLinearL1;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.regression.RegressionDataSet;

public class SMIDAS
extends StochasticSTLinearL1 {
    private static final long serialVersionUID = -4888083541600164597L;
    private double eta;

    public SMIDAS(double eta) {
        this(eta, 1000, 1.0E-14, DEFAULT_LOSS);
    }

    public SMIDAS(double eta, int epochs, double lambda, StochasticSTLinearL1.Loss loss) {
        this(eta, epochs, lambda, loss, true);
    }

    public SMIDAS(double eta, int epochs, double lambda, StochasticSTLinearL1.Loss loss, boolean reScale) {
        this.setEta(eta);
        this.setEpochs(epochs);
        this.setLambda(lambda);
        this.setLoss(loss);
        this.setReScale(reScale);
    }

    public void setEta(double eta) {
        if (Double.isNaN(eta) || Double.isInfinite(eta) || eta <= 0.0) {
            throw new ArithmeticException("convergence parameter must be a positive value");
        }
        this.eta = eta;
    }

    public double getEta() {
        return this.eta;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.w == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        Vec x = data.getNumericalValues();
        return this.loss.classify(this.wDot(x));
    }

    @Override
    public double regress(DataPoint data) {
        if (this.w == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        Vec x = data.getNumericalValues();
        return this.loss.regress(this.wDot(x));
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.trainC(dataSet);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        if (dataSet.getNumNumericalVars() < 3) {
            throw new FailedToFitException("SMIDAS requires at least 3 features");
        }
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("SMIDAS only supports binary classification problems");
        }
        Vec[] x = this.setUpVecs(dataSet);
        DenseVector obvMinV = DenseVector.toDenseVec(this.obvMin);
        DenseVector obvMaxV = DenseVector.toDenseVec(this.obvMax);
        DenseVector multitpliers = new DenseVector(((Vec)obvMaxV).length());
        ((Vec)multitpliers).mutableAdd(this.maxScaled - this.minScaled);
        ((Vec)multitpliers).mutablePairwiseDivide(obvMaxV.subtract(obvMinV));
        boolean allZeroMins = true;
        for (double min : this.obvMin) {
            if (min == 0.0) continue;
            allZeroMins = false;
        }
        double[] target = new double[x.length];
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            if (allZeroMins && this.minScaled == 0.0) {
                x[i].mutablePairwiseMultiply(multitpliers);
            } else {
                x[i] = x[i].subtract(obvMinV);
                x[i].mutablePairwiseMultiply(multitpliers);
                x[i].mutableAdd(this.minScaled);
            }
            target[i] = dataSet.getDataPointCategory(i) * 2 - 1;
        }
        this.train(x, target);
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadPool) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        if (dataSet.getNumNumericalVars() < 3) {
            throw new FailedToFitException("SMIDAS requires at least 3 features");
        }
        Vec[] x = this.setUpVecs(dataSet);
        DenseVector obvMinV = DenseVector.toDenseVec(this.obvMin);
        DenseVector obvMaxV = DenseVector.toDenseVec(this.obvMax);
        DenseVector multitpliers = new DenseVector(((Vec)obvMaxV).length());
        ((Vec)multitpliers).mutableAdd(this.maxScaled - this.minScaled);
        ((Vec)multitpliers).mutablePairwiseDivide(obvMaxV.subtract(obvMinV));
        boolean allZeroMins = true;
        for (double min : this.obvMin) {
            if (min == 0.0) continue;
            allZeroMins = false;
        }
        double[] target = new double[x.length];
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            if (allZeroMins && this.minScaled == 0.0) {
                x[i].mutablePairwiseMultiply(multitpliers);
            } else {
                x[i] = x[i].subtract(obvMinV);
                x[i].mutablePairwiseMultiply(multitpliers);
                x[i].mutableAdd(this.minScaled);
            }
            target[i] = dataSet.getTargetValue(i);
        }
        this.train(x, target);
    }

    private void train(Vec[] x, double[] y) {
        int m = x.length;
        int d = x[0].length();
        double p = 2.0 * Math.log(d);
        DenseVector theta = new DenseVector(d);
        double theta_bias = 0.0;
        double lossScore = 0.0;
        this.w = new DenseVector(d);
        Random rand = new Random();
        for (int t = 0; t < this.epochs; ++t) {
            int i = rand.nextInt(m);
            lossScore = this.loss.deriv(this.w.dot(x[i]) + this.bias, y[i]);
            theta.mutableSubtract(this.eta * lossScore, x[i]);
            theta_bias -= this.eta * lossScore;
            for (IndexValue iv : theta) {
                int j = iv.getIndex();
                double theta_j = iv.getValue();
                ((Vec)theta).set(j, Math.signum(theta_j) * Math.max(0.0, Math.abs(theta_j) - this.eta * this.lambda));
            }
            theta_bias = Math.signum(theta_bias) * Math.max(0.0, Math.abs(theta_bias) - this.eta * this.lambda);
            double thetaNorm = ((Vec)theta).pNorm(p);
            if (thetaNorm > 0.0) {
                double logThetaNorm = Math.log(thetaNorm);
                for (int j = 0; j < this.w.length(); ++j) {
                    double theta_j = ((Vec)theta).get(j);
                    this.w.set(j, Math.signum(theta_j) * Math.exp((p - 1.0) * Math.log(Math.abs(theta_j)) - (p - 2.0) * logThetaNorm));
                }
                this.bias = Math.signum(theta_bias) * Math.exp((p - 1.0) * Math.log(Math.abs(theta_bias)) - (p - 2.0) * logThetaNorm);
                continue;
            }
            theta.zeroOut();
            theta_bias = 0.0;
            this.w.zeroOut();
            this.bias = 0.0;
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public SMIDAS clone() {
        SMIDAS clone = new SMIDAS(this.eta, this.epochs, this.lambda, this.loss, this.reScale);
        if (this.w != null) {
            clone.w = this.w.clone();
        }
        clone.bias = this.bias;
        clone.minScaled = this.minScaled;
        clone.maxScaled = this.maxScaled;
        if (this.obvMin != null) {
            clone.obvMin = Arrays.copyOf(this.obvMin, this.obvMin.length);
        }
        if (this.obvMax != null) {
            clone.obvMax = Arrays.copyOf(this.obvMax, this.obvMax.length);
        }
        return clone;
    }

    private Vec[] setUpVecs(DataSet dataSet) {
        int i;
        this.obvMin = new double[dataSet.getNumNumericalVars()];
        Arrays.fill(this.obvMin, Double.POSITIVE_INFINITY);
        this.obvMax = new double[dataSet.getNumNumericalVars()];
        Arrays.fill(this.obvMax, Double.NEGATIVE_INFINITY);
        Vec[] x = new Vec[dataSet.getSampleSize()];
        for (i = 0; i < dataSet.getSampleSize(); ++i) {
            x[i] = dataSet.getDataPoint(i).getNumericalValues();
            for (IndexValue iv : x[i]) {
                int j = iv.getIndex();
                double v = iv.getValue();
                this.obvMin[j] = Math.min(this.obvMin[j], v);
                this.obvMax[j] = Math.max(this.obvMax[j], v);
            }
        }
        if (x[0].isSparse()) {
            for (i = 0; i < this.obvMin.length; ++i) {
                this.obvMin[i] = Math.min(this.obvMin[i], 0.0);
            }
        }
        if (!this.reScale) {
            for (double min : this.obvMin) {
                if (!(min < -1.0)) continue;
                throw new FailedToFitException("Values must be in the range [-1,1], " + min + " violation encountered");
            }
            for (double max : this.obvMax) {
                if (!(max > 1.0)) continue;
                throw new FailedToFitException("Values must be in the range [-1,1], " + max + " violation encountered");
            }
        }
        return x;
    }
}

