/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.functions;

import moa.classifiers.AbstractClassifier;
import moa.classifiers.Regressor;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.StringUtils;
import moa.options.FloatOption;
import moa.options.MultiChoiceOption;
import weka.core.Instance;
import weka.core.Utils;

public class SGDMultiClass
extends AbstractClassifier
implements Regressor {
    private static final long serialVersionUID = -3732968666673530290L;
    protected double m_lambda = 1.0E-4;
    public FloatOption lambdaRegularizationOption = new FloatOption("lambdaRegularization", 'l', "Lambda regularization parameter .", 1.0E-4, 0.0, 2.147483647E9);
    protected double m_learningRate = 0.01;
    public FloatOption learningRateOption = new FloatOption("learningRate", 'r', "Learning rate parameter.", 1.0E-4, 0.0, 2.147483647E9);
    protected DoubleVector[] m_weights;
    protected double[] m_bias;
    protected double m_t;
    protected double m_numInstances;
    protected static final int HINGE = 0;
    protected static final int LOGLOSS = 1;
    protected static final int SQUAREDLOSS = 2;
    protected int m_loss = 0;
    public MultiChoiceOption lossFunctionOption = new MultiChoiceOption("lossFunction", 'o', "The loss function to use.", new String[]{"HINGE", "LOGLOSS", "SQUAREDLOSS"}, new String[]{"Hinge loss (SVM)", "Log loss (logistic regression)", "Squared loss (regression)"}, 0);

    @Override
    public String getPurposeString() {
        return "AStochastic gradient descent for learning various linear models (binary class SVM, binary class logistic regression and linear regression).";
    }

    public void setLambda(double lambda) {
        this.m_lambda = lambda;
    }

    public double getLambda() {
        return this.m_lambda;
    }

    public void setLossFunction(int function) {
        this.m_loss = function;
    }

    public int getLossFunction() {
        return this.m_loss;
    }

    public void setLearningRate(double lr) {
        this.m_learningRate = lr;
    }

    public double getLearningRate() {
        return this.m_learningRate;
    }

    public void reset() {
        this.m_t = 1.0;
        this.m_weights = null;
        this.m_bias = null;
    }

    protected double dloss(double z) {
        if (this.m_loss == 0) {
            return z < 1.0 ? 1.0 : 0.0;
        }
        if (this.m_loss == 1) {
            if (z < 0.0) {
                return 1.0 / (Math.exp(z) + 1.0);
            }
            double t = Math.exp(-z);
            return t / (t + 1.0);
        }
        return z;
    }

    protected static double dotProd(Instance inst1, DoubleVector weights, int classIndex) {
        double result = 0.0;
        int n1 = inst1.numValues();
        int n2 = weights.numValues();
        int p1 = 0;
        int p2 = 0;
        while (p1 < n1 && p2 < n2) {
            int ind2;
            int ind1 = inst1.index(p1);
            if (ind1 == (ind2 = p2++)) {
                if (ind1 != classIndex && !inst1.isMissingSparse(p1)) {
                    result += inst1.valueSparse(p1) * weights.getValue(p2);
                }
                ++p1;
                ++p2;
                continue;
            }
            if (ind1 > ind2) continue;
            ++p1;
        }
        return result;
    }

    @Override
    public void resetLearningImpl() {
        this.reset();
        this.setLambda(this.lambdaRegularizationOption.getValue());
        this.setLearningRate(this.learningRateOption.getValue());
        this.setLossFunction(this.lossFunctionOption.getChosenIndex());
    }

    @Override
    public void trainOnInstanceImpl(Instance instance) {
        if (this.m_weights == null) {
            int length = instance.classAttribute().isNominal() ? instance.numClasses() : 1;
            this.m_weights = new DoubleVector[length];
            this.m_bias = new double[length];
            for (int i = 0; i < this.m_weights.length; ++i) {
                this.m_weights[i] = new DoubleVector();
                this.m_bias[i] = 0.0;
            }
        }
        for (int i = 0; i < this.m_weights.length; ++i) {
            this.trainOnInstanceImpl(instance, i);
        }
        this.m_t += 1.0;
    }

    public void trainOnInstanceImpl(Instance instance, int classLabel) {
        if (!instance.classIsMissing()) {
            double z;
            double y;
            double wx = SGDMultiClass.dotProd(instance, this.m_weights[classLabel], instance.classIndex());
            if (instance.classAttribute().isNominal()) {
                y = instance.classValue() != (double)classLabel ? -1.0 : 1.0;
                z = y * (wx + this.m_bias[classLabel]);
            } else {
                y = instance.classValue();
                z = y - (wx + this.m_bias[classLabel]);
                y = 1.0;
            }
            double multiplier = 1.0;
            multiplier = this.m_numInstances == 0.0 ? 1.0 - this.m_learningRate * this.m_lambda / this.m_t : 1.0 - this.m_learningRate * this.m_lambda / this.m_numInstances;
            for (int i = 0; i < this.m_weights[classLabel].numValues(); ++i) {
                this.m_weights[classLabel].setValue(i, this.m_weights[classLabel].getValue(i) * multiplier);
            }
            if (this.m_loss != 0 || z < 1.0) {
                double factor = this.m_learningRate * y * this.dloss(z);
                int n1 = instance.numValues();
                for (int p1 = 0; p1 < n1; ++p1) {
                    int indS = instance.index(p1);
                    if (indS == instance.classIndex() || instance.isMissingSparse(p1)) continue;
                    this.m_weights[classLabel].addToValue(indS, factor * instance.valueSparse(p1));
                }
                int n = classLabel;
                this.m_bias[n] = this.m_bias[n] + factor;
            }
        }
    }

    @Override
    public double[] getVotesForInstance(Instance inst) {
        double[] result;
        if (this.m_weights == null) {
            return new double[inst.numClasses()];
        }
        double[] dArray = result = inst.classAttribute().isNominal() ? new double[inst.numClasses()] : new double[1];
        if (inst.classAttribute().isNumeric()) {
            double z;
            double wx = SGDMultiClass.dotProd(inst, this.m_weights[0], inst.classIndex());
            result[0] = z = wx + this.m_bias[0];
            return result;
        }
        for (int i = 0; i < this.m_weights.length; ++i) {
            double wx = SGDMultiClass.dotProd(inst, this.m_weights[i], inst.classIndex());
            double z = wx + this.m_bias[i];
            if (z <= 0.0) {
                if (this.m_loss == 1) {
                    result[i] = 1.0 - 1.0 / (1.0 + Math.exp(z));
                    continue;
                }
                result[i] = 0.0;
                continue;
            }
            result[i] = this.m_loss == 1 ? 1.0 / (1.0 + Math.exp(-z)) : 1.0;
        }
        return result;
    }

    @Override
    public void getModelDescription(StringBuilder result, int indent) {
        StringUtils.appendIndented(result, indent, this.toString());
        StringUtils.appendNewline(result);
    }

    @Override
    public String toString() {
        if (this.m_weights == null) {
            return "SGD: No model built yet.\n";
        }
        StringBuffer buff = new StringBuffer();
        buff.append("Loss function: ");
        if (this.m_loss == 0) {
            buff.append("Hinge loss (SVM)\n\n");
        } else if (this.m_loss == 1) {
            buff.append("Log loss (logistic regression)\n\n");
        } else {
            buff.append("Squared loss (linear regression)\n\n");
        }
        int printed = 0;
        for (int i = 0; i < this.m_weights[0].numValues(); ++i) {
            if (printed > 0) {
                buff.append(" + ");
            } else {
                buff.append("   ");
            }
            buff.append(Utils.doubleToString((double)this.m_weights[0].getValue(i), (int)12, (int)4) + " " + "\n");
            ++printed;
        }
        if (this.m_bias[0] > 0.0) {
            buff.append(" + " + Utils.doubleToString((double)this.m_bias[0], (int)12, (int)4));
        } else {
            buff.append(" - " + Utils.doubleToString((double)(-this.m_bias[0]), (int)12, (int)4));
        }
        return buff.toString();
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    @Override
    public boolean isRandomizable() {
        return false;
    }
}

