/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.knn;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.distributions.Distribution;
import jsat.distributions.discrete.UniformDiscrete;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.EigenValueDecomposition;
import jsat.linear.Matrix;
import jsat.linear.RowColumnOps;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.VecPairedComparable;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.vectorcollection.DefaultVectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionFactory;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.BoundedSortedList;
import jsat.utils.FakeExecutor;

public class DANN
implements Classifier,
Parameterized {
    private static final long serialVersionUID = -272865942127664672L;
    public static final int DEFAULT_KN = 40;
    public static final int DEFAULT_K = 1;
    public static final double DEFAULT_EPS = 1.0;
    public static final int DEFAULT_ITERATIONS = 1;
    private int kn;
    private int k;
    private int maxIterations;
    private double eps;
    private VectorCollectionFactory<VecPaired<Vec, Integer>> vcf;
    private CategoricalData predicting;
    private VectorCollection<VecPaired<Vec, Integer>> vc;
    private List<VecPaired<Vec, Integer>> vecList;

    public DANN() {
        this(40, 1);
    }

    public DANN(int kn, int k) {
        this(kn, k, 1.0);
    }

    public DANN(int kn, int k, double eps) {
        this(kn, k, eps, new DefaultVectorCollectionFactory<VecPaired<Vec, Integer>>());
    }

    public DANN(int kn, int k, double eps, VectorCollectionFactory<VecPaired<Vec, Integer>> vcf) {
        this(kn, k, eps, 1, vcf);
    }

    public DANN(int kn, int k, double eps, int maxIterations, VectorCollectionFactory<VecPaired<Vec, Integer>> vcf) {
        this.setK(k);
        this.setKn(kn);
        this.setEpsilon(eps);
        this.setMaxIterations(maxIterations);
        this.vcf = vcf;
    }

    public void setK(int k) {
        if (k < 1) {
            throw new ArithmeticException("Number of neighbors must be positive");
        }
        this.k = k;
    }

    public int getK() {
        return this.k;
    }

    public void setKn(int kn) {
        if (kn < 2) {
            throw new ArithmeticException("At least 2 neighbors are needed to adapat the metric");
        }
        this.kn = kn;
    }

    public int getKn() {
        return this.kn;
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations < 1) {
            throw new RuntimeException("At least one iteration must be performed");
        }
        this.maxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setEpsilon(double eps) {
        if (eps < 0.0 || Double.isInfinite(eps) || Double.isNaN(eps)) {
            throw new ArithmeticException("Regularization must be a positive value");
        }
        this.eps = eps;
    }

    public double getEpsilon() {
        return this.eps;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(this.predicting.getNumOfCategories());
        int n = data.numNumericalValues();
        DenseMatrix sigma = Matrix.eye(n);
        DenseMatrix B = new DenseMatrix(n, n);
        DenseMatrix W = new DenseMatrix(n, n);
        Vec query = data.getNumericalValues();
        DenseVector scratch0 = new DenseVector(n);
        double[] weights = new double[this.kn];
        double[] priors = new double[this.predicting.getNumOfCategories()];
        int[] classCount = new int[priors.length];
        DenseVector mean = new DenseVector(((Matrix)sigma).rows());
        Vec[] classMeans = new Vec[this.predicting.getNumOfCategories()];
        for (int i = 0; i < classMeans.length; ++i) {
            classMeans[i] = new DenseVector(((Vec)mean).length());
        }
        for (int iter = 0; iter < this.maxIterations; ++iter) {
            int i;
            mean.zeroOut();
            Arrays.fill(priors, 0.0);
            Arrays.fill(classCount, 0);
            for (int i2 = 0; i2 < classMeans.length; ++i2) {
                classMeans[i2].zeroOut();
            }
            double sumOfWeights = 0.0;
            ((Matrix)B).zeroOut();
            ((Matrix)W).zeroOut();
            List<VecPaired<VecPaired<Vec, Integer>, Double>> vecs = iter == 0 ? this.vc.search(query, this.kn) : this.brute(query, sigma, this.kn);
            double d = vecs.get(vecs.size() - 1).getPair();
            for (i = 0; i < vecs.size(); ++i) {
                int j;
                VecPaired<VecPaired<Vec, Integer>, Double> vec = vecs.get(i);
                weights[i] = Math.pow(Math.pow(1.0 - Math.pow(vec.getPair(), 2.0) / d, 3.0), 3.0);
                sumOfWeights += weights[i];
                mean.mutableAdd(vec);
                int n2 = j = vec.getVector().getPair().intValue();
                priors[n2] = priors[n2] + weights[i];
                classMeans[j].mutableAdd(vec);
                int n3 = j;
                classCount[n3] = classCount[n3] + 1;
            }
            ((Vec)mean).mutableDivide(this.kn);
            i = 0;
            while (i < classMeans.length) {
                if ((double)classCount[i] != 0.0) {
                    classMeans[i].mutableDivide(classCount[i]);
                }
                int n4 = i++;
                priors[n4] = priors[n4] / sumOfWeights;
            }
            for (int j = 0; j < classMeans.length; ++j) {
                if (!(priors[j] > 0.0)) continue;
                classMeans[j].copyTo(scratch0);
                scratch0.mutableSubtract(mean);
                Matrix.OuterProductUpdate(B, scratch0, scratch0, priors[j]);
                for (int i3 = 0; i3 < vecs.size(); ++i3) {
                    VecPaired<VecPaired<Vec, Integer>, Double> x = vecs.get(i3);
                    if (x.getVector().getPair() != j) continue;
                    x.copyTo(scratch0);
                    scratch0.mutableSubtract(classMeans[j]);
                    Matrix.OuterProductUpdate(W, scratch0, scratch0, weights[i3]);
                }
            }
            ((Matrix)W).mutableMultiply(1.0 / sumOfWeights);
            RowColumnOps.addDiag(B, 0, ((Matrix)B).rows(), this.eps);
            for (i = 0; i < priors.length; ++i) {
                if (priors[i] != 1.0) continue;
                cr.setProb(i, 1.0);
                return cr;
            }
            EigenValueDecomposition evd = new EigenValueDecomposition(W);
            Matrix D = evd.getD();
            for (int i4 = 0; i4 < D.rows(); ++i4) {
                D.set(i4, i4, Math.pow(D.get(i4, i4), -0.5));
            }
            Matrix VT = evd.getVT();
            Matrix WW = VT.transposeMultiply(D).multiply(VT);
            ((Matrix)sigma).zeroOut();
            WW.multiply(B).multiply(WW, sigma);
        }
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> knn = this.brute(query, sigma, this.k);
        for (VecPaired<VecPaired<Vec, Integer>, Double> vecPaired : knn) {
            cr.incProb(vecPaired.getVector().getPair(), 1.0);
        }
        cr.normalize();
        return cr;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.predicting = dataSet.getPredicting();
        this.vecList = new ArrayList<VecPaired<Vec, Integer>>(dataSet.getSampleSize());
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            this.vecList.add(new VecPaired<Vec, Integer>(dataSet.getDataPoint(i).getNumericalValues(), dataSet.getDataPointCategory(i)));
        }
        this.vc = threadPool == null || threadPool instanceof FakeExecutor ? this.vcf.getVectorCollection(this.vecList, new EuclideanDistance()) : this.vcf.getVectorCollection(this.vecList, new EuclideanDistance(), threadPool);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, null);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public Classifier clone() {
        DANN clone = new DANN(this.kn, this.k, this.maxIterations, this.vcf.clone());
        if (this.predicting != null) {
            clone.predicting = this.predicting.clone();
        }
        if (this.vc != null) {
            clone.vc = this.vc.clone();
        }
        if (this.vecList != null) {
            clone.vecList = new ArrayList<VecPaired<Vec, Integer>>(this.vecList);
        }
        return clone;
    }

    private double dist(Matrix sigma, Vec query, Vec mean, Vec scratch0, Vec scartch1) {
        query.copyTo(scratch0);
        scratch0.mutableSubtract(mean);
        scartch1.zeroOut();
        sigma.multiply(scratch0, 1.0, scartch1);
        return scratch0.dot(scartch1);
    }

    private List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> brute(Vec query, Matrix sigma, int num) {
        DenseVector scartch0 = new DenseVector(query.length());
        DenseVector scartch1 = new DenseVector(query.length());
        BoundedSortedList<VecPairedComparable<VecPaired<Vec, Integer>, Double>> knn = new BoundedSortedList<VecPairedComparable<VecPaired<Vec, Integer>, Double>>(num, num);
        for (VecPaired<Vec, Integer> v : this.vecList) {
            double d = this.dist(sigma, query, v, scartch0, scartch1);
            knn.add(new VecPairedComparable<VecPaired<Vec, Integer>, Double>(v, d));
        }
        return knn;
    }

    public static Distribution guessK(DataSet d) {
        return new UniformDiscrete(1, 25);
    }

    public static Distribution guessKn(DataSet d) {
        return new UniformDiscrete(40, Math.max(d.getSampleSize() / 5, 50));
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

