/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.neuralnetwork;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionFactory;
import jsat.math.decayrates.DecayRate;
import jsat.math.decayrates.ExponetialDecay;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.FakeExecutor;
import jsat.utils.random.RandomUtil;

public class LVQ
implements Classifier,
Parameterized {
    private static final long serialVersionUID = -3911765006048793222L;
    public static final int DEFAULT_ITERATIONS = 200;
    public static final double DEFAULT_LEARNING_RATE = 0.1;
    public static final double DEFAULT_EPS = 0.3;
    public static final double DEFAULT_MSCALE = 0.30000000000000004;
    public static final LVQVersion DEFAULT_LVQ_METHOD = LVQVersion.LVQ3;
    public static final int DEFAULT_REPS_PER_CLASS = 3;
    public static final double DEFAULT_STOPPING_DIST = 0.001;
    public static final SeedSelectionMethods.SeedSelection DEFAULT_SEED_SELECTION = SeedSelectionMethods.SeedSelection.KPP;
    private DecayRate learningDecay;
    private int iterations;
    private double learningRate;
    protected DistanceMetric dm;
    private LVQVersion lvqVersion;
    private double eps;
    private double mScale;
    private double stoppingDist;
    private int representativesPerClass;
    protected Vec[] weights;
    protected int[] weightClass;
    protected int[] wins;
    private SeedSelectionMethods.SeedSelection seedSelection;
    protected VectorCollection<VecPaired<Vec, Integer>> vc;
    private VectorCollectionFactory<VecPaired<Vec, Integer>> vcf;

    public LVQ(DistanceMetric dm, int iterations) {
        this(dm, iterations, 0.1, 3);
    }

    public LVQ(DistanceMetric dm, int iterations, double learningRate, int representativesPerClass) {
        this(dm, iterations, learningRate, representativesPerClass, DEFAULT_LVQ_METHOD, new ExponetialDecay());
    }

    public LVQ(DistanceMetric dm, int iterations, double learningRate, int representativesPerClass, LVQVersion lvqVersion, DecayRate learningDecay) {
        this.setLearningDecay(learningDecay);
        this.setIterations(iterations);
        this.setLearningRate(learningRate);
        this.setDistanceMetric(dm);
        this.setLVQMethod(lvqVersion);
        this.setEpsilonDistance(0.3);
        this.setMScale(0.30000000000000004);
        this.setSeedSelection(DEFAULT_SEED_SELECTION);
        this.setVecCollectionFactory(new DefaultVectorCollectionFactory<VecPaired<Vec, Integer>>());
        this.setRepresentativesPerClass(representativesPerClass);
    }

    protected LVQ(LVQ toCopy) {
        this(toCopy.dm.clone(), toCopy.iterations, toCopy.learningRate, toCopy.representativesPerClass, toCopy.lvqVersion, toCopy.learningDecay);
        if (toCopy.weights != null) {
            this.wins = Arrays.copyOf(toCopy.wins, toCopy.wins.length);
            this.weights = new Vec[toCopy.weights.length];
            this.weightClass = Arrays.copyOf(toCopy.weightClass, toCopy.weightClass.length);
            for (int i = 0; i < toCopy.weights.length; ++i) {
                this.weights[i] = toCopy.weights[i].clone();
            }
        }
        this.setEpsilonDistance(toCopy.eps);
        this.setMScale(toCopy.getMScale());
        this.setSeedSelection(toCopy.getSeedSelection());
        if (toCopy.vc != null) {
            this.vc = toCopy.vc.clone();
        }
        this.setVecCollectionFactory(toCopy.vcf.clone());
    }

    public void setMScale(double mScale) {
        if (mScale <= 0.0 || Double.isInfinite(mScale) || Double.isNaN(mScale)) {
            throw new ArithmeticException("Scale factor must be a positive constant, not " + mScale);
        }
        this.mScale = mScale;
    }

    public double getMScale() {
        return this.mScale;
    }

    public void setEpsilonDistance(double eps) {
        if (eps <= 0.0 || Double.isInfinite(eps) || Double.isNaN(eps)) {
            throw new ArithmeticException("eps factor must be a positive constant, not " + eps);
        }
        this.eps = eps;
    }

    public double getEpsilonDistance() {
        return this.eps;
    }

    public void setLearningRate(double learningRate) {
        if (learningRate <= 0.0 || Double.isInfinite(learningRate) || Double.isNaN(learningRate)) {
            throw new ArithmeticException("learning rate must be a positive constant, not " + learningRate);
        }
        this.learningRate = learningRate;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningDecay(DecayRate learningDecay) {
        this.learningDecay = learningDecay;
    }

    public DecayRate getLearningDecay() {
        return this.learningDecay;
    }

    public void setIterations(int iterations) {
        if (iterations < 0) {
            throw new ArithmeticException("Can not perform a negative number of iterations");
        }
        this.iterations = iterations;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setRepresentativesPerClass(int representativesPerClass) {
        this.representativesPerClass = representativesPerClass;
    }

    public int getRepresentativesPerClass() {
        return this.representativesPerClass;
    }

    public void setLVQMethod(LVQVersion lvqMethod) {
        this.lvqVersion = lvqMethod;
    }

    public LVQVersion getLVQMethod() {
        return this.lvqVersion;
    }

    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setStoppingDist(double stoppingDist) {
        if (stoppingDist < 0.0 || Double.isInfinite(stoppingDist) || Double.isNaN(stoppingDist)) {
            throw new ArithmeticException("stopping dist must be a zero or positive constant, not " + stoppingDist);
        }
        this.stoppingDist = stoppingDist;
    }

    public double getStoppingDist() {
        return this.stoppingDist;
    }

    public void setSeedSelection(SeedSelectionMethods.SeedSelection seedSelection) {
        this.seedSelection = seedSelection;
    }

    public SeedSelectionMethods.SeedSelection getSeedSelection() {
        return this.seedSelection;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    public void setVecCollectionFactory(VectorCollectionFactory<VecPaired<Vec, Integer>> vcf) {
        this.vcf = vcf;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(this.weightClass.length / this.representativesPerClass);
        int index = this.vc.search(data.getNumericalValues(), 1).get(0).getVector().getPair();
        cr.setProb(this.weightClass[index], 1.0);
        return cr;
    }

    protected boolean epsClose(double minDist, double minDist2) {
        return Math.min(minDist / minDist2, minDist2 / minDist) > 1.0 - this.eps && Math.max(minDist / minDist2, minDist2 / minDist) < 1.0 + this.eps;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        if (threadPool == null || threadPool instanceof FakeExecutor) {
            TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet);
        } else {
            TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, threadPool);
        }
        Random rand = RandomUtil.getRandom();
        int classCount = dataSet.getPredicting().getNumOfCategories();
        this.weights = new Vec[classCount * this.representativesPerClass];
        Vec[] weightsPrev = new Vec[this.weights.length];
        this.weightClass = new int[this.weights.length];
        this.wins = new int[this.weights.length];
        int curPos = 0;
        for (int curClass = 0; curClass < classCount; ++curClass) {
            List<DataPoint> origSubList = dataSet.getSamples(curClass);
            ArrayList<DataPointPair<Integer>> subList = new ArrayList<DataPointPair<Integer>>(origSubList.size());
            for (DataPoint dp : origSubList) {
                subList.add(new DataPointPair<Integer>(dp, curClass));
            }
            ClassificationDataSet subSet = new ClassificationDataSet(subList, dataSet.getPredicting());
            List<Vec> classSeeds = SeedSelectionMethods.selectIntialPoints((DataSet)subSet, this.representativesPerClass, this.dm, rand, this.seedSelection);
            for (Vec v : classSeeds) {
                this.weights[curPos] = v.clone();
                weightsPrev[curPos] = this.weights[curPos].clone();
                this.weightClass[curPos++] = curClass;
            }
        }
        Vec tmp = this.weights[0].clone();
        for (int iteration = 0; iteration < this.iterations; ++iteration) {
            for (int j = 0; j < this.weights.length; ++j) {
                this.weights[j].copyTo(weightsPrev[j]);
            }
            Arrays.fill(this.wins, 0);
            double alpha = this.learningDecay.rate(iteration, this.iterations, this.learningRate);
            for (int i = 0; i < dataSet.getSampleSize(); ++i) {
                Vec x = dataSet.getDataPoint(i).getNumericalValues();
                int closestClass = -1;
                int minDistIndx = 0;
                int minDistIndx2 = 0;
                double minDist = Double.POSITIVE_INFINITY;
                double minDist2 = Double.POSITIVE_INFINITY;
                for (int j = 0; j < this.weights.length; ++j) {
                    double dist = this.dm.dist(x, this.weights[j]);
                    if (!(dist < minDist)) continue;
                    if (this.lvqVersion == LVQVersion.LVQ2) {
                        minDist2 = minDist;
                        minDistIndx2 = minDistIndx;
                    }
                    minDist = dist;
                    minDistIndx = j;
                    closestClass = dataSet.getDataPointCategory(i);
                }
                if (this.lvqVersion.ordinal() >= LVQVersion.LVQ2.ordinal() && this.weightClass[minDistIndx] != this.weightClass[minDistIndx2] && closestClass == this.weightClass[minDistIndx2] && this.epsClose(minDist, minDist2)) {
                    x.copyTo(tmp);
                    tmp.mutableSubtract(this.weights[minDistIndx]);
                    this.weights[minDistIndx].mutableSubtract(alpha, tmp);
                    x.copyTo(tmp);
                    tmp.mutableSubtract(this.weights[minDistIndx2]);
                    this.weights[minDistIndx2].mutableAdd(alpha, tmp);
                    int n = minDistIndx2;
                    this.wins[n] = this.wins[n] + 1;
                    continue;
                }
                if (this.lvqVersion.ordinal() >= LVQVersion.LVQ21.ordinal() && this.weightClass[minDistIndx] != this.weightClass[minDistIndx2] && closestClass == this.weightClass[minDistIndx] && this.epsClose(minDist, minDist2)) {
                    x.copyTo(tmp);
                    tmp.mutableSubtract(this.weights[minDistIndx]);
                    this.weights[minDistIndx].mutableAdd(alpha, tmp);
                    int n = minDistIndx;
                    this.wins[n] = this.wins[n] + 1;
                    x.copyTo(tmp);
                    tmp.mutableSubtract(this.weights[minDistIndx2]);
                    this.weights[minDistIndx2].mutableSubtract(alpha, tmp);
                    continue;
                }
                if (this.lvqVersion.ordinal() >= LVQVersion.LVQ3.ordinal() && this.weightClass[minDistIndx] == this.weightClass[minDistIndx2] && Math.min(minDist / minDist2, minDist2 / minDist) > (1.0 - this.eps) * (1.0 + this.eps)) {
                    x.copyTo(tmp);
                    tmp.mutableSubtract(this.weights[minDistIndx]);
                    this.weights[minDistIndx].mutableAdd(this.mScale * alpha, tmp);
                    x.copyTo(tmp);
                    tmp.mutableSubtract(this.weights[minDistIndx2]);
                    this.weights[minDistIndx2].mutableAdd(this.mScale * alpha, tmp);
                    int n = minDistIndx;
                    this.wins[n] = this.wins[n] + 1;
                    int n2 = minDistIndx2;
                    this.wins[n2] = this.wins[n2] + 1;
                    continue;
                }
                x.copyTo(tmp);
                tmp.mutableSubtract(this.weights[minDistIndx]);
                if (closestClass == this.weightClass[minDistIndx]) {
                    int n = minDistIndx;
                    this.wins[n] = this.wins[n] + 1;
                    this.weights[minDistIndx].mutableAdd(alpha, tmp);
                    continue;
                }
                this.weights[minDistIndx].mutableSubtract(alpha, tmp);
            }
            boolean stopEarly = true;
            for (int j = 0; j < this.weights.length; ++j) {
                if (!stopEarly || !(this.dm.dist(this.weights[j], weightsPrev[j]) > this.stoppingDist)) continue;
                stopEarly = false;
            }
            if (stopEarly) break;
        }
        ArrayList<VecPaired<Vec, Integer>> finalLVs = new ArrayList<VecPaired<Vec, Integer>>(this.weights.length);
        for (int i = 0; i < this.weights.length; ++i) {
            if (this.wins[i] == 0) continue;
            finalLVs.add(new VecPaired<Vec, Integer>(this.weights[i], i));
        }
        this.vc = threadPool == null || threadPool instanceof FakeExecutor ? this.vcf.getVectorCollection(finalLVs, this.dm) : this.vcf.getVectorCollection(finalLVs, this.dm, threadPool);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, null);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public LVQ clone() {
        return new LVQ(this);
    }

    public static enum LVQVersion {
        LVQ1,
        LVQ2,
        LVQ21,
        LVQ3;

    }
}

