/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.SimpleWeightVectorModel;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.WarmClassifier;
import jsat.classifiers.linear.LinearTools;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.Uniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.lossfunctions.LogisticLoss;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

public class NewGLMNET
implements WarmClassifier,
Parameterized,
SingleWeightVectorModel {
    private static final long serialVersionUID = 4133368677783573518L;
    private static final double DEFAULT_BETA = 0.5;
    private static final double DEFAULT_V = 1.0E-12;
    private static final double DEFAULT_GAMMA = 0.0;
    private static final double DEFAULT_SIGMA = 0.01;
    public static final double DEFAULT_EPS = 0.01;
    public static final int DEFAULT_MAX_OUTER_ITER = 100;
    private Vec w;
    private double b;
    private double beta = 0.5;
    private double v = 1.0E-12;
    private double gamma = 0.0;
    private double sigma = 0.01;
    private double C;
    private double alpha;
    private int maxOuterIters = 100;
    private double e_out = 0.01;
    private boolean useBias = true;
    private int maxLineSearchSteps = 20;

    public NewGLMNET() {
        this(1.0);
    }

    public NewGLMNET(double C) {
        this(C, 1.0);
    }

    public NewGLMNET(double C, double alpha) {
        this.setC(C);
        this.setAlpha(alpha);
    }

    protected NewGLMNET(NewGLMNET toCopy) {
        if (toCopy.w != null) {
            this.w = toCopy.w.clone();
        }
        this.b = toCopy.b;
        this.beta = toCopy.beta;
        this.v = toCopy.v;
        this.gamma = toCopy.gamma;
        this.sigma = toCopy.sigma;
        this.C = toCopy.C;
        this.e_out = toCopy.e_out;
        this.maxOuterIters = toCopy.maxOuterIters;
        this.alpha = toCopy.alpha;
        this.useBias = toCopy.useBias;
    }

    @Parameter.WarmParameter(prefLowToHigh=true)
    public void setC(double C) {
        if (C <= 0.0 || Double.isInfinite(C) || Double.isNaN(C)) {
            throw new IllegalArgumentException("Regularization term C must be a positive value, not " + C);
        }
        this.C = C;
    }

    public double getC() {
        return this.C;
    }

    public void setAlpha(double alpha) {
        if (alpha < 0.0 || alpha > 1.0 || Double.isNaN(alpha)) {
            throw new IllegalArgumentException("alpha must be in [0, 1], not " + alpha);
        }
        this.alpha = alpha;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setMaxIters(int maxOuterIters) {
        if (maxOuterIters < 1) {
            throw new IllegalArgumentException("Number of training iterations must be positive, not " + maxOuterIters);
        }
        this.maxOuterIters = maxOuterIters;
    }

    public int getMaxIters() {
        return this.maxOuterIters;
    }

    public void setTolerance(double e_out) {
        if (e_out <= 0.0 || Double.isNaN(e_out)) {
            throw new IllegalArgumentException("convergence tolerance paramter must be positive, not " + e_out);
        }
        this.e_out = e_out;
    }

    public double getTolerance() {
        return this.e_out;
    }

    public void setUseBias(boolean useBias) {
        this.useBias = useBias;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        return LogisticLoss.classify(this.w.dot(data.getNumericalValues()) + this.b);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.trainC(dataSet);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, Classifier warmSolution, ExecutorService threadPool) {
        this.trainC(dataSet, warmSolution);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, Classifier warmSolution) {
        if (!(warmSolution instanceof SimpleWeightVectorModel)) {
            throw new FailedToFitException("Warm solution is not of a");
        }
        SimpleWeightVectorModel swv = (SimpleWeightVectorModel)((Object)warmSolution);
        this.train(dataSet, swv.getRawWeight(0), swv.getBias(0), true);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.train(dataSet, null, 0.0, false);
    }

    private void train(ClassificationDataSet dataSet, Vec w_init, double b_init, boolean useInit) {
        double w_norm_2;
        double w_norm_1;
        int i;
        int n = dataSet.getNumNumericalVars();
        int l = dataSet.getSampleSize();
        if (useInit) {
            this.w = new DenseVector(w_init);
            this.b = this.useBias ? b_init : 0.0;
        } else {
            this.w = new DenseVector(n);
            this.b = 0.0;
        }
        List<Vec> X = dataSet.getDataVectors();
        double first_M_bar = 0.0;
        double e_in = 1.0;
        double[] w_dot_x = new double[l];
        double[] exp_w_dot_x = new double[l];
        double[] exp_w_dot_x_plus_dx = new double[l];
        double[] d_dot_x = new double[l];
        double[] D_part = new double[l];
        double[] D = new double[l];
        double[] H = new double[n];
        double H_bias = 0.0;
        double[] delta_L = new double[n];
        double delta_L_bias = 0.0;
        float[] y = new float[l];
        if (useInit) {
            for (i = 0; i < l; ++i) {
                y[i] = dataSet.getDataPointCategory(i) * 2 - 1;
                w_dot_x[i] = this.w.dot(X.get(i)) + this.b;
                exp_w_dot_x_plus_dx[i] = exp_w_dot_x[i] = Math.exp(w_dot_x[i]);
                double tmp = exp_w_dot_x[i];
                double D_part_i = D_part[i] = 1.0 / (1.0 + tmp);
                D[i] = tmp * D_part_i * D_part_i;
            }
            w_norm_1 = this.w.pNorm(1.0);
            w_norm_2 = this.w.pNorm(2.0);
        } else {
            for (i = 0; i < l; ++i) {
                y[i] = dataSet.getDataPointCategory(i) * 2 - 1;
                w_dot_x[i] = 0.0;
                exp_w_dot_x[i] = 1.0;
                exp_w_dot_x_plus_dx[i] = 1.0;
                D_part[i] = 0.5;
                D[i] = 0.25;
            }
            w_norm_1 = this.w.pNorm(1.0);
            w_norm_2 = this.w.pNorm(2.0);
        }
        ArrayList<Vec> columnsOfX = new ArrayList<Vec>(Arrays.asList(dataSet.getNumericColumns()));
        double[] col_neg_class_sum = new double[n];
        for (int j = 0; j < n; ++j) {
            Vec vec = (Vec)columnsOfX.get(j);
            for (IndexValue iv : vec) {
                if (y[iv.getIndex()] != -1.0f) continue;
                int n2 = j;
                col_neg_class_sum[n2] = col_neg_class_sum[n2] + iv.getValue();
            }
        }
        double col_neg_class_sum_bias = 0.0;
        if (this.useBias) {
            for (int i2 = 0; i2 < l; ++i2) {
                if (y[i2] != -1.0f) continue;
                col_neg_class_sum_bias += 1.0;
            }
        }
        double l2w = 1.0 - this.alpha;
        double M_out = Double.POSITIVE_INFINITY;
        DenseVector d = new DenseVector(n);
        double d_bias = 0.0;
        boolean prevLineSearchFail = false;
        for (int k = 0; k < this.maxOuterIters; ++k) {
            IntList J = new IntList(n);
            ListUtils.addRange(J, 0, n, 1);
            double M = 0.0;
            double M_bar = 0.0;
            Iterator j_iter = J.iterator();
            while (j_iter.hasNext()) {
                int j = (Integer)j_iter.next();
                double w_j = this.w.get(j);
                double delta_j_L = 0.0;
                double deltaSqrd_L = 0.0;
                for (IndexValue x_i : (Vec)columnsOfX.get(j)) {
                    int i3 = x_i.getIndex();
                    double val = x_i.getValue();
                    delta_j_L += -val * D_part[i3];
                    deltaSqrd_L += val * val * D[i3];
                }
                delta_L[j] = delta_j_L = l2w * w_j + this.C * (delta_j_L + col_neg_class_sum[j]);
                H[j] = this.C * deltaSqrd_L + Math.max(this.v, l2w);
                double deltaS_j_fw = w_j > 0.0 ? delta_j_L + this.alpha : (w_j < 0.0 ? delta_j_L - this.alpha : Math.signum(delta_j_L) * Math.max(Math.abs(delta_j_L) - this.alpha, 0.0));
                if (w_j == 0.0 && Math.abs(delta_j_L) < this.alpha - M_out / (double)l) {
                    j_iter.remove();
                    continue;
                }
                M = Math.max(M, Math.abs(deltaS_j_fw));
                M_bar += Math.abs(deltaS_j_fw);
            }
            if (this.useBias) {
                double delta_j_L = 0.0;
                double deltaSqrd_L = 0.0;
                for (int i4 = 0; i4 < l; ++i4) {
                    delta_j_L += -D_part[i4];
                    deltaSqrd_L += D[i4];
                }
                delta_L_bias = delta_j_L = this.C * (delta_j_L + col_neg_class_sum_bias);
                H_bias = this.C * deltaSqrd_L + this.v;
                double deltaS_j_fw = delta_L_bias;
                M = Math.max(M, Math.abs(deltaS_j_fw));
                M_bar += Math.abs(deltaS_j_fw);
            }
            if (k == 0) {
                e_in = useInit ? (first_M_bar = this.getM_Bar_for_w0(n, l, columnsOfX, col_neg_class_sum, col_neg_class_sum_bias)) : (first_M_bar = M_bar);
            }
            if (M_bar <= this.e_out * first_M_bar) break;
            M_out = M;
            double M_in = Double.POSITIVE_INFINITY;
            IntList T = new IntList(J);
            d.zeroOut();
            d_bias = 0.0;
            int smallZInARow = 0;
            for (int p = 0; p < 1000; ++p) {
                double m = 0.0;
                double m_bar = 0.0;
                double max_abs_z = 0.0;
                Collections.shuffle(T);
                Iterator T_iter = T.iterator();
                double dynRange = (double)n * 5.0 / (double)T.size();
                while (T_iter.hasNext()) {
                    int j = (Integer)T_iter.next();
                    double w_j = this.w.get(j);
                    double d_j = ((Vec)d).get(j);
                    double delta_qBar_j = 0.0;
                    for (IndexValue iv : (Vec)columnsOfX.get(j)) {
                        int i5 = iv.getIndex();
                        delta_qBar_j += iv.getValue() * D[i5] * d_dot_x[i5];
                    }
                    delta_qBar_j *= this.C;
                    delta_qBar_j += delta_L[j];
                    double deltaS_q_k_j = w_j + d_j > 0.0 ? delta_qBar_j + this.alpha : (w_j + d_j < 0.0 ? delta_qBar_j - this.alpha : Math.signum(delta_qBar_j += l2w * d_j) * Math.max(Math.abs(delta_qBar_j) - this.alpha, 0.0));
                    double deltaSqrd_q_jj = H[j];
                    if (w_j + d_j == 0.0 && Math.abs(delta_qBar_j) < this.alpha - M_in / (double)l) {
                        T_iter.remove();
                        continue;
                    }
                    m = Math.max(m, Math.abs(deltaS_q_k_j));
                    m_bar += Math.abs(deltaS_q_k_j);
                    double z = delta_qBar_j + this.alpha <= deltaSqrd_q_jj * (w_j + d_j) ? -(delta_qBar_j + this.alpha) / deltaSqrd_q_jj : (delta_qBar_j - this.alpha >= deltaSqrd_q_jj * (w_j + d_j) ? -(delta_qBar_j - this.alpha) / deltaSqrd_q_jj : -(w_j + d_j));
                    if (Math.abs(z) < 1.0E-11) continue;
                    z = Math.min(Math.max(z, -dynRange), dynRange);
                    max_abs_z = Math.max(max_abs_z, Math.abs(z));
                    d.increment(j, z);
                    for (IndexValue iv : (Vec)columnsOfX.get(j)) {
                        int n3 = iv.getIndex();
                        d_dot_x[n3] = d_dot_x[n3] + z * iv.getValue();
                    }
                }
                if (this.useBias) {
                    double delta_qBar_j = 0.0;
                    for (int i6 = 0; i6 < l; ++i6) {
                        delta_qBar_j += 1.0 * D[i6] * d_dot_x[i6];
                    }
                    delta_qBar_j *= this.C;
                    double deltaS_q_k_j = delta_qBar_j += delta_L_bias;
                    double deltaSqrd_q_jj = H_bias;
                    m = Math.max(m, Math.abs(deltaS_q_k_j));
                    m_bar += Math.abs(deltaS_q_k_j);
                    double z = -delta_qBar_j / deltaSqrd_q_jj;
                    if (Math.abs(z) > 1.0E-11) {
                        z = Math.min(Math.max(z, -dynRange), dynRange);
                        max_abs_z = Math.max(max_abs_z, Math.abs(z));
                        d_bias += z;
                        int i7 = 0;
                        while (i7 < l) {
                            int n4 = i7++;
                            d_dot_x[n4] = d_dot_x[n4] + z;
                        }
                    }
                }
                boolean breakInnerLoopAnyway = false;
                if (max_abs_z == 0.0) {
                    breakInnerLoopAnyway = true;
                } else if (max_abs_z <= 1.0E-6) {
                    if (smallZInARow++ >= 3) {
                        breakInnerLoopAnyway = true;
                    }
                } else if (max_abs_z <= 0.001) {
                    if (smallZInARow++ >= 30) {
                        breakInnerLoopAnyway = true;
                    }
                } else {
                    smallZInARow = 0;
                }
                if (m_bar <= e_in || breakInnerLoopAnyway) {
                    if (T.size() == J.size()) {
                        if (p != 0) break;
                        e_in /= 4.0;
                        break;
                    }
                    T.clear();
                    T.addAll(J);
                    M_in = Double.POSITIVE_INFINITY;
                    continue;
                }
                M_in = m;
            }
            double wPd_norm_1 = w_norm_1;
            double wPd_norm_2 = w_norm_2;
            double delta_L_dot_d = 0.0;
            for (IndexValue iv : d) {
                int j = iv.getIndex();
                double w_j = this.w.get(j);
                double d_j = iv.getValue();
                wPd_norm_1 -= Math.abs(w_j);
                wPd_norm_1 += Math.abs(w_j + d_j);
                wPd_norm_2 -= w_j * w_j;
                wPd_norm_2 += (w_j + d_j) * (w_j + d_j);
                delta_L_dot_d += d_j * delta_L[j];
            }
            double breakCondition = this.sigma * ((delta_L_dot_d += d_bias * delta_L_bias) + this.alpha * (wPd_norm_1 - w_norm_1) + l2w * (wPd_norm_2 - w_norm_2));
            double lambda = 1.0;
            int t = 0;
            double wPlambda_d_norm_1 = wPd_norm_1;
            double wPlambda_d_norm_2 = wPd_norm_2;
            while (t < this.maxLineSearchSteps) {
                double newTerm = 0.0;
                for (int i8 = 0; i8 < l; ++i8) {
                    double exp_lamda_d_dot_x = Math.exp(lambda * d_dot_x[i8]);
                    exp_w_dot_x_plus_dx[i8] = exp_w_dot_x[i8] * exp_lamda_d_dot_x;
                    newTerm += Math.log((exp_w_dot_x_plus_dx[i8] + 1.0) / (exp_w_dot_x_plus_dx[i8] + exp_lamda_d_dot_x));
                    if (y[i8] != -1.0f) continue;
                    newTerm += lambda * d_dot_x[i8];
                }
                if ((newTerm = l2w * (wPlambda_d_norm_2 - w_norm_2) + this.alpha * (wPlambda_d_norm_1 - w_norm_1) + this.C * newTerm) <= lambda * breakCondition) break;
                lambda = Math.pow(this.beta, ++t);
                wPlambda_d_norm_1 = w_norm_1;
                wPlambda_d_norm_2 = w_norm_2;
                for (IndexValue iv : d) {
                    double w_j = this.w.get(iv.getIndex());
                    double lambda_d_j = lambda * iv.getValue();
                    wPlambda_d_norm_1 -= Math.abs(w_j);
                    wPlambda_d_norm_1 += Math.abs(w_j + lambda_d_j);
                    wPlambda_d_norm_2 -= w_j * w_j;
                    wPlambda_d_norm_2 += (w_j + lambda_d_j) * (w_j + lambda_d_j);
                }
            }
            if (t == this.maxLineSearchSteps) {
                if (prevLineSearchFail) break;
                prevLineSearchFail = true;
            } else {
                prevLineSearchFail = false;
            }
            this.w.mutableAdd(lambda, d);
            this.b += lambda * d_bias;
            w_norm_1 = wPlambda_d_norm_1;
            w_norm_2 = wPlambda_d_norm_2;
            System.arraycopy(exp_w_dot_x_plus_dx, 0, exp_w_dot_x, 0, l);
            for (int i9 = 0; i9 < l; ++i9) {
                int n5 = i9;
                w_dot_x[n5] = w_dot_x[n5] + lambda * d_dot_x[i9];
                double D_part_i = D_part[i9] = 1.0 / (1.0 + exp_w_dot_x[i9]);
                D[i9] = exp_w_dot_x[i9] * D_part_i * D_part_i;
            }
            Arrays.fill(d_dot_x, 0.0);
        }
    }

    private double getM_Bar_for_w0(int n, int l, List<Vec> columnsOfX, double[] col_neg_class_sum, double col_neg_class_sum_bias) {
        double D_part_i = 0.5;
        double M_bar = 0.0;
        for (int j = 0; j < n; ++j) {
            double w_j = 0.0;
            double delta_j_L = -columnsOfX.get(j).sum() * 0.5;
            delta_j_L = this.C * (delta_j_L + col_neg_class_sum[j]);
            double deltaS_j_fw = Math.signum(delta_j_L) * Math.max(Math.abs(delta_j_L) - this.alpha, 0.0);
            M_bar += Math.abs(deltaS_j_fw);
        }
        if (this.useBias) {
            double delta_j_L = 0.0;
            for (int i = 0; i < l; ++i) {
                delta_j_L += -0.5;
            }
            double deltaS_j_fw = delta_j_L = this.C * (delta_j_L + col_neg_class_sum_bias);
            M_bar += Math.abs(deltaS_j_fw);
        }
        return M_bar;
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public NewGLMNET clone() {
        return new NewGLMNET(this);
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return this.b;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public boolean warmFromSameDataOnly() {
        return false;
    }

    public static Distribution guessAlpha(DataSet d) {
        return new Uniform(0.25, 0.75);
    }

    public static Distribution guessC(DataSet d) {
        double maxLambda = LinearTools.maxLambdaLogisticL1((ClassificationDataSet)d);
        double minC = 1.0 / (2.0 * maxLambda * (double)d.getSampleSize());
        return new LogUniform(minC * 10.0, minC * 1000.0);
    }
}

