/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear.kernelized;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.SubMatrix;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;

public class Projectron
extends BaseUpdateableClassifier
implements BinaryScoreClassifier,
Parameterized {
    private static final long serialVersionUID = -4025799790045954359L;
    @Parameter.ParameterHolder
    private KernelTrick k;
    private double eta;
    private DoubleList alpha;
    private List<Vec> S;
    private List<Double> cacheAccel;
    private Matrix InvK;
    private Matrix InvKExpanded;
    private double[] k_raw;
    private boolean useMarginUpdates;

    public Projectron(KernelTrick k) {
        this(k, 0.1);
    }

    public Projectron(KernelTrick k, double eta) {
        this(k, eta, true);
    }

    public Projectron(KernelTrick k, double eta, boolean useMarginUpdates) {
        this.setKernel(k);
        this.setEta(eta);
        this.setUseMarginUpdates(useMarginUpdates);
    }

    protected Projectron(Projectron toCopy) {
        this.k = toCopy.k.clone();
        this.eta = toCopy.eta;
        if (toCopy.S != null) {
            this.alpha = new DoubleList(toCopy.alpha);
            this.S = new ArrayList<Vec>(toCopy.S);
            this.cacheAccel = new DoubleList(toCopy.cacheAccel);
            this.InvKExpanded = toCopy.InvKExpanded.clone();
            this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, toCopy.InvK.rows(), toCopy.InvK.cols());
            this.k_raw = Arrays.copyOf(toCopy.k_raw, toCopy.k_raw.length);
        }
    }

    public void setKernel(KernelTrick k) {
        this.k = k;
    }

    public KernelTrick getKernel() {
        return this.k;
    }

    public void setEta(double eta) {
        if (Double.isNaN(eta) || Double.isInfinite(eta) || eta < 0.0) {
            throw new IllegalArgumentException("eta must be in the range [0, Infity), not " + eta);
        }
        this.eta = eta;
    }

    public double getEta() {
        return this.eta;
    }

    public void setUseMarginUpdates(boolean useMarginUpdates) {
        this.useMarginUpdates = useMarginUpdates;
    }

    public boolean isUseMarginUpdates() {
        return this.useMarginUpdates;
    }

    @Override
    public Projectron clone() {
        return new Projectron(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (numericAttributes < 1) {
            throw new IllegalArgumentException("Projectrion requires numeric features");
        }
        if (predicting.getNumOfCategories() != 2) {
            throw new FailedToFitException("Projectrion only supports binary classification");
        }
        int initSize = 50;
        this.alpha = new DoubleList(50);
        this.cacheAccel = new DoubleList(50);
        this.S = new ArrayList<Vec>(50);
        this.InvKExpanded = new DenseMatrix(50, 50);
        this.k_raw = new double[50];
    }

    @Override
    public void update(DataPoint dataPoint, int targetClass) {
        block14: {
            double delta;
            double k_t_d;
            Vec d;
            double y_t;
            double score;
            block13: {
                Vec x_t = dataPoint.getNumericalValues();
                List<Double> qi = this.k.getQueryInfo(x_t);
                score = this.getScore(x_t, qi, this.k_raw);
                y_t = targetClass * 2 - 1;
                if (this.S.isEmpty()) {
                    this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, 1, 1);
                    this.InvK.set(0, 0, 1.0);
                    this.S.add(x_t);
                    this.alpha.add(y_t);
                    this.cacheAccel.addAll(qi);
                    return;
                }
                if (y_t * score > 1.0) {
                    return;
                }
                if (y_t * score < 1.0 && y_t * score > 0.0 && !this.useMarginUpdates) {
                    return;
                }
                DenseVector k_t = new DenseVector(this.k_raw, 0, this.S.size());
                d = this.InvK.multiply(k_t);
                double k_xt = this.k.eval(0, 0, Arrays.asList(x_t), qi);
                k_t_d = k_t.dot(d);
                double deltaSqrd = Math.max(k_xt - k_t_d, 0.0);
                delta = Math.sqrt(deltaSqrd);
                if (Math.signum(score) == y_t) break block13;
                if (delta < this.eta) {
                    for (int i = 0; i < this.S.size(); ++i) {
                        this.alpha.set(i, this.alpha.get(i) + y_t * d.get(i));
                    }
                } else {
                    if (this.S.size() == this.InvKExpanded.rows()) {
                        this.InvKExpanded = new DenseMatrix(this.S.size() * 2, this.S.size() * 2);
                        for (int i = 0; i < this.InvK.rows(); ++i) {
                            for (int j = 0; j < this.InvK.cols(); ++j) {
                                this.InvKExpanded.set(i, j, this.InvK.get(i, j));
                            }
                        }
                        this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, this.S.size(), this.S.size());
                        this.k_raw = Arrays.copyOf(this.k_raw, this.S.size() * 2);
                    }
                    this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, this.S.size() + 1, this.S.size() + 1);
                    DenseVector dExp = new DenseVector(this.S.size() + 1);
                    for (int i = 0; i < d.length(); ++i) {
                        ((Vec)dExp).set(i, d.get(i));
                    }
                    ((Vec)dExp).set(this.S.size(), -1.0);
                    if (deltaSqrd > 0.0) {
                        Matrix.OuterProductUpdate(this.InvK, dExp, dExp, 1.0 / deltaSqrd);
                    }
                    this.S.add(x_t);
                    this.alpha.add(y_t);
                    this.cacheAccel.addAll(qi);
                }
                break block14;
            }
            if (!(y_t * score <= 1.0)) break block14;
            double loss = 1.0 - y_t * score;
            if (loss < delta / this.eta) {
                return;
            }
            double tau = Math.max(Math.max(loss / k_t_d, 2.0 * (loss - delta / this.eta) / k_t_d), 1.0);
            for (int i = 0; i < this.S.size(); ++i) {
                this.alpha.set(i, this.alpha.get(i) + y_t * tau * d.get(i));
            }
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        if (this.getScore(data) < 0.0) {
            cr.setProb(0, 1.0);
        } else {
            cr.setProb(1, 1.0);
        }
        return cr;
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    private double getScore(Vec x, List<Double> qi, double[] kStore) {
        double score = 0.0;
        for (int i = 0; i < this.S.size(); ++i) {
            double tmp = this.k.eval(i, x, qi, this.S, this.cacheAccel);
            if (kStore != null) {
                kStore[i] = tmp;
            }
            score += this.alpha.get(i) * tmp;
        }
        return score;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.k.evalSum(this.S, this.cacheAccel, this.alpha.getBackingArray(), dp.getNumericalValues(), 0, this.S.size());
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

