/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.svm;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.FakeExecutor;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;

public class PegasosK
extends SupportVectorLearner
implements BinaryScoreClassifier,
Parameterized {
    private static final long serialVersionUID = 5405460830472328107L;
    private double regularization;
    private int iterations;

    public PegasosK(double regularization, int iterations, KernelTrick kernel) {
        this(regularization, iterations, kernel, SupportVectorLearner.CacheMode.NONE);
    }

    public PegasosK(double regularization, int iterations, KernelTrick kernel, SupportVectorLearner.CacheMode cacheMode) {
        super(kernel, cacheMode);
        this.setRegularization(regularization);
        this.setIterations(iterations);
    }

    public void setIterations(int iterations) {
        this.iterations = iterations;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setRegularization(double regularization) {
        if (Double.isNaN(regularization) || Double.isInfinite(regularization) || regularization <= 0.0) {
            throw new ArithmeticException("Regularization must be a positive constant, not " + regularization);
        }
        this.regularization = regularization;
    }

    public double getRegularization() {
        return this.regularization;
    }

    @Override
    public PegasosK clone() {
        PegasosK clone = new PegasosK(this.regularization, this.iterations, this.getKernel().clone(), this.getCacheMode());
        if (this.vecs != null) {
            clone.vecs = new ArrayList(this.vecs);
            clone.alphas = new double[this.alphas.length];
            for (int i = 0; i < this.vecs.size(); ++i) {
                clone.vecs.set(i, ((Vec)this.vecs.get(i)).clone());
                clone.alphas[i] = this.alphas[i];
            }
        }
        return clone;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.alphas == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        CategoricalResults cr = new CategoricalResults(2);
        double sum = this.getScore(data);
        if (sum > 0.0) {
            cr.setProb(1, 1.0);
        } else {
            cr.setProb(0, 1.0);
        }
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.kEvalSum(dp.getNumericalValues());
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("Pegasos only supports binary classification problems");
        }
        try {
            int i;
            Random rand = new Random();
            int m = dataSet.getSampleSize();
            this.alphas = new double[m];
            int[] sign = new int[m];
            this.vecs = new ArrayList(m);
            for (int i2 = 0; i2 < dataSet.getSampleSize(); ++i2) {
                this.vecs.add(dataSet.getDataPoint(i2).getNumericalValues());
                sign[i2] = dataSet.getDataPointCategory(i2) == 1 ? 1 : -1;
            }
            ArrayList<Future<Double>> futures = new ArrayList<Future<Double>>(SystemInfo.LogicalCores);
            int blockSize = m / SystemInfo.LogicalCores + (m % SystemInfo.LogicalCores == 0 ? 0 : 1);
            this.setCacheMode(this.getCacheMode());
            for (int t = 1; t <= this.iterations; ++t) {
                i = rand.nextInt(m);
                double sign_i = sign[i];
                double val = 0.0;
                futures.clear();
                for (int start = 0; start < m; start += blockSize) {
                    futures.add(threadPool.submit(new PredictPart(i, start, Math.min(start + blockSize, m), sign)));
                }
                for (Future future : futures) {
                    val += ((Double)future.get()).doubleValue();
                }
                if (!((val *= sign_i / (this.regularization * (double)t)) < 1.0)) continue;
                int n = i;
                this.alphas[n] = this.alphas[n] + 1.0;
            }
            int pos = 0;
            for (i = 0; i < this.alphas.length; ++i) {
                if (this.alphas[i] == 0.0) continue;
                this.alphas[pos] = this.alphas[i] * (double)sign[i];
                ListUtils.swap(this.vecs, pos, i);
                ++pos;
            }
            this.alphas = Arrays.copyOf(this.alphas, pos);
            this.vecs = new ArrayList(this.vecs.subList(0, pos));
            this.setCacheMode(null);
            this.setAlphas(this.alphas);
        }
        catch (ExecutionException ex) {
            throw new FailedToFitException(ex);
        }
        catch (InterruptedException ex) {
            throw new FailedToFitException(ex);
        }
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, new FakeExecutor());
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    private class PredictPart
    implements Callable<Double> {
        int i;
        int start;
        int end;
        int[] sign;

        public PredictPart(int i, int start, int end, int[] sign) {
            this.i = i;
            this.start = start;
            this.end = end;
            this.sign = sign;
        }

        @Override
        public Double call() throws Exception {
            double sign_i = this.sign[this.i];
            double val = 0.0;
            for (int j = this.start; j < this.end; ++j) {
                if (j == this.i || PegasosK.this.alphas[j] == 0.0) continue;
                val += PegasosK.this.alphas[j] * sign_i * PegasosK.this.kEval(this.i, j);
            }
            return val;
        }
    }
}

