/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear.kernelized;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.linear.kernelized.CSKLR;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.random.XORWOW;

public class CSKLRBatch
extends SupportVectorLearner
implements Parameterized,
Classifier {
    private static final long serialVersionUID = -2305532659182911285L;
    private double eta;
    private double curNorm;
    private double R = 10.0;
    private int T = 0;
    private CSKLR.UpdateMode mode;
    protected double gamma = 2.0;
    private int epochs = 10;

    public CSKLRBatch(double eta, KernelTrick kernel, double R, CSKLR.UpdateMode mode, SupportVectorLearner.CacheMode cacheMode) {
        super(kernel, cacheMode);
        this.setEta(eta);
        this.setR(R);
        this.setMode(mode);
    }

    protected CSKLRBatch(CSKLRBatch toClone) {
        super(toClone);
        this.curNorm = toClone.curNorm;
        this.epochs = toClone.epochs;
        this.eta = toClone.eta;
        this.R = toClone.R;
        this.T = toClone.T;
        this.mode = toClone.mode;
        this.gamma = toClone.gamma;
    }

    @Override
    public CSKLRBatch clone() {
        return new CSKLRBatch(this);
    }

    public void setEpochs(int epochs) {
        this.epochs = epochs;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public void setEta(double eta) {
        if (eta < 0.0 || Double.isNaN(eta) || Double.isInfinite(eta)) {
            throw new IllegalArgumentException("The learning rate should be in (0, Inf), not " + eta);
        }
        this.eta = eta;
    }

    public double getEta() {
        return this.eta;
    }

    public void setR(double R) {
        if (R < 0.0 || Double.isNaN(R) || Double.isInfinite(R)) {
            throw new IllegalArgumentException("The max norm should be in (0, Inf), not " + R);
        }
        this.R = R;
    }

    public double getR() {
        return this.R;
    }

    public void setMode(CSKLR.UpdateMode mode) {
        this.mode = mode;
    }

    public CSKLR.UpdateMode getMode() {
        return this.mode;
    }

    public void setGamma(double gamma) {
        if (gamma < 0.0 || Double.isNaN(gamma) || Double.isInfinite(gamma)) {
            throw new IllegalArgumentException("Gamma must be in (0, Infity), not " + gamma);
        }
        this.gamma = gamma;
    }

    public double getGamma() {
        return this.gamma;
    }

    public static Distribution guessR(DataSet d) {
        return new LogUniform(1.0, 100000.0);
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        double p_0 = CSKLR.getScore(-1.0, this.getPreScore(data.getNumericalValues()));
        cr.setProb(0, p_0);
        cr.setProb(1, 1.0 - p_0);
        return cr;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.trainC(dataSet);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("CSKLR supports only binary classification");
        }
        int N = dataSet.getSampleSize();
        this.vecs = new ArrayList(N);
        this.alphas = new double[N];
        for (int i = 0; i < N; ++i) {
            this.vecs.add(dataSet.getDataPoint(i).getNumericalValues());
        }
        this.curNorm = 0.0;
        this.T = 0;
        XORWOW rand = new XORWOW();
        IntList sampleOrder = new IntList(N);
        ListUtils.addRange(sampleOrder, 0, N, 1);
        this.setCacheMode(this.getCacheMode());
        for (int epoch = 0; epoch < this.epochs; ++epoch) {
            Collections.shuffle(sampleOrder);
            Iterator iterator = sampleOrder.iterator();
            block5: while (iterator.hasNext()) {
                int i = (Integer)iterator.next();
                double weight = dataSet.getDataPoint(i).getWeight();
                double y_t = dataSet.getDataPointCategory(i) * 2 - 1;
                Vec x_t = (Vec)this.vecs.get(i);
                double pre = this.getPreScore(x_t);
                double score = CSKLR.getScore(y_t, pre);
                switch (this.mode) {
                    case NC: {
                        break;
                    }
                    default: {
                        double pt = this.mode.pt(y_t, score, pre, this.eta, this.gamma);
                        if (((Random)rand).nextDouble() > pt) continue block5;
                    }
                }
                double alpha_i = -this.eta * y_t * this.mode.grad(y_t, score, pre, this.gamma) * weight;
                int n = i;
                this.alphas[n] = this.alphas[n] + alpha_i;
                this.curNorm += Math.abs(alpha_i) * this.kEval(i, i);
                if (!(this.curNorm > this.R)) continue;
                double coef = this.R / this.curNorm;
                int j = 0;
                while (j < this.alphas.length) {
                    int n2 = j++;
                    this.alphas[n2] = this.alphas[n2] * coef;
                }
                this.curNorm = coef;
            }
        }
        int supportVectorCount = 0;
        for (int i = 0; i < N; ++i) {
            if (!(this.alphas[i] > 0.0) && !(this.alphas[i] < 0.0)) continue;
            ListUtils.swap(this.vecs, supportVectorCount, i);
            this.alphas[supportVectorCount++] = this.alphas[i];
        }
        this.vecs = new ArrayList(this.vecs.subList(0, supportVectorCount));
        this.alphas = Arrays.copyOfRange(this.alphas, 0, supportVectorCount);
        this.setCacheMode(null);
        this.setAlphas(this.alphas);
    }

    private double getPreScore(Vec x) {
        return this.kEvalSum(x);
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }
}

