/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform.visualization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.visualization.TSNE;
import jsat.datatransform.visualization.VisualizationTransform;
import jsat.distributions.Uniform;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;
import jsat.utils.random.XORWOW;

public class LargeViz
implements VisualizationTransform {
    private DistanceMetric dm_source = new EuclideanDistance();
    private DistanceMetric dm_embed = new EuclideanDistance();
    private double perplexity = 50.0;
    private int dt = 2;
    private int M = 5;
    private double gamma = 7.0;

    public void setPerplexity(double perplexity) {
        if (perplexity <= 0.0 || Double.isNaN(perplexity) || Double.isInfinite(perplexity)) {
            throw new IllegalArgumentException("perplexity must be positive, not " + perplexity);
        }
        this.perplexity = perplexity;
    }

    public double getPerplexity() {
        return this.perplexity;
    }

    public void setDistanceMetricSource(DistanceMetric dm) {
        this.dm_source = dm;
    }

    public void setDistanceMetricEmbedding(DistanceMetric dm) {
        this.dm_embed = dm;
    }

    public void setNegativeSamples(int M) {
        if (M < 1) {
            throw new IllegalArgumentException("Number of negative samples must be positive, not " + M);
        }
        this.M = M;
    }

    public int getNegativeSamples() {
        return this.M;
    }

    public void setGamma(double gamma) {
        if (Double.isInfinite(gamma) || Double.isNaN(gamma) || gamma <= 0.0) {
            throw new IllegalArgumentException("Gamma must be positive, not " + gamma);
        }
        this.gamma = gamma;
    }

    public double getGamma() {
        return this.gamma;
    }

    @Override
    public int getTargetDimension() {
        return this.dt;
    }

    @Override
    public boolean setTargetDimension(int target) {
        if (target < 2) {
            return false;
        }
        this.dt = target;
        return true;
    }

    @Override
    public <Type extends DataSet> Type transform(DataSet<Type> d) {
        return this.transform(d, new FakeExecutor());
    }

    @Override
    public <Type extends DataSet> Type transform(DataSet<Type> d, ExecutorService ex) {
        int i;
        XORWOW rand = new XORWOW();
        final ThreadLocal<Random> local_rand = new ThreadLocal<Random>(){

            @Override
            protected Random initialValue() {
                return new XORWOW();
            }
        };
        final int N = d.getSampleSize();
        final int knn = (int)Math.min(Math.floor(3.0 * this.perplexity), (double)(N - 1));
        double[][] nearMePij = new double[N][knn];
        final int[][] nearMe = new int[N][knn];
        TSNE.computeP(d, ex, rand, knn, nearMe, nearMePij, this.dm_source, this.perplexity);
        final double[][] nearMeSample = new double[N][knn];
        final double[] negSampleWeight = new double[N];
        double negSum = 0.0;
        for (i = 0; i < N; ++i) {
            int j;
            double sum = DenseVector.toDenseVec(nearMePij[i]).sum();
            negSampleWeight[i] = sum += (double)nearMePij[i].length * Double.MIN_VALUE;
            nearMeSample[i][0] = nearMePij[i][0];
            for (j = 1; j < knn; ++j) {
                nearMeSample[i][j] = Math.ulp(nearMePij[i][j]) + nearMePij[i][j] + nearMeSample[i][j - 1];
            }
            j = 1;
            while (j < knn) {
                double[] dArray = nearMeSample[i];
                int n = j++;
                dArray[n] = dArray[n] / sum;
            }
            negSampleWeight[i] = Math.pow(negSampleWeight[i], 0.75);
            negSum += negSampleWeight[i];
            if (i <= 0) continue;
            int n = i;
            negSampleWeight[n] = negSampleWeight[n] + negSampleWeight[i - 1];
        }
        i = 0;
        while (i < N) {
            int n = i++;
            negSampleWeight[n] = negSampleWeight[n] / negSum;
        }
        final ArrayList<DenseVector> embeded = new ArrayList<DenseVector>();
        Uniform initDistribution = new Uniform(-5.0E-5 / (double)this.dt, 5.0E-5 / (double)this.dt);
        for (int i2 = 0; i2 < N; ++i2) {
            embeded.add(initDistribution.sampleVec(this.dt, rand));
        }
        final int threads_to_use = Math.max(Math.min(N / (200 * this.M), SystemInfo.LogicalCores), 1);
        final CountDownLatch latch = new CountDownLatch(threads_to_use);
        double eta_0 = 1.0;
        final long iterations = 1000L * (long)N;
        final ThreadLocal<Vec> local_grad_i = new ThreadLocal<Vec>(){

            @Override
            protected Vec initialValue() {
                return new DenseVector(LargeViz.this.dt);
            }
        };
        final ThreadLocal<Vec> local_grad_j = new ThreadLocal<Vec>(){

            @Override
            protected Vec initialValue() {
                return new DenseVector(LargeViz.this.dt);
            }
        };
        final ThreadLocal<Vec> local_grad_k = new ThreadLocal<Vec>(){

            @Override
            protected Vec initialValue() {
                return new DenseVector(LargeViz.this.dt);
            }
        };
        for (int id = 0; id < threads_to_use; ++id) {
            ex.submit(new Runnable(){

                @Override
                public void run() {
                    Random l_rand = (Random)local_rand.get();
                    for (long iteration = 0L; iteration < iterations; iteration += (long)threads_to_use) {
                        double eta = 1.0 * (1.0 - (double)iteration / (double)iterations);
                        eta = Math.max(eta, 1.0E-4);
                        int i = l_rand.nextInt(N);
                        int j = Arrays.binarySearch(nearMeSample[i], l_rand.nextDouble());
                        if (j < 0) {
                            j = -j - 1;
                        }
                        if (j >= knn) {
                            j = l_rand.nextInt(knn);
                        }
                        j = nearMe[i][j];
                        Vec y_i = (Vec)embeded.get(i);
                        Vec y_j = (Vec)embeded.get(j);
                        double dist_ij = LargeViz.this.dm_embed.dist(i, j, (List<? extends Vec>)embeded, null);
                        double dist_ij_sqrd = dist_ij * dist_ij;
                        if (dist_ij <= 0.0) continue;
                        Vec grad_i = (Vec)local_grad_i.get();
                        Vec grad_j = (Vec)local_grad_j.get();
                        Vec grad_k = (Vec)local_grad_k.get();
                        y_i.copyTo(grad_j);
                        grad_j.mutableSubtract(y_j);
                        grad_j.mutableMultiply(-2.0 * dist_ij / (dist_ij_sqrd + 1.0));
                        grad_j.copyTo(grad_i);
                        for (int k = 0; k < LargeViz.this.M; ++k) {
                            int jk = -1;
                            block2: do {
                                if ((jk = Arrays.binarySearch(negSampleWeight, l_rand.nextDouble())) < 0) {
                                    jk = -jk - 1;
                                }
                                if (jk == i || jk == j) {
                                    jk = -1;
                                }
                                for (int search = 0; search < nearMe[i].length; ++search) {
                                    if (nearMe[i][search] != jk || !(nearMeSample[i][search] < 0.98)) continue;
                                    jk = -1;
                                    continue block2;
                                }
                            } while (jk < 0);
                            Vec y_k = (Vec)embeded.get(jk);
                            double dist_ik = LargeViz.this.dm_embed.dist(i, jk, (List<? extends Vec>)embeded, null);
                            double dist_ik_sqrd = dist_ik * dist_ik;
                            if (dist_ik < 1.0E-12) continue;
                            y_i.copyTo(grad_k);
                            grad_k.mutableSubtract(y_k);
                            grad_k.mutableMultiply(2.0 * LargeViz.this.gamma / (dist_ik * (dist_ik_sqrd + 1.0)));
                            grad_i.mutableAdd(grad_k);
                            y_k.mutableSubtract(eta, grad_k);
                        }
                        y_i.mutableAdd(eta, grad_i);
                        y_j.mutableAdd(-eta, grad_j);
                    }
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex1) {
            Logger.getLogger(LargeViz.class.getName()).log(Level.SEVERE, null, ex1);
        }
        DataSet<Type> toRet = d.shallowClone();
        final IdentityHashMap<DataPoint, Integer> indexMap = new IdentityHashMap<DataPoint, Integer>(N);
        for (int i3 = 0; i3 < N; ++i3) {
            indexMap.put(d.getDataPoint(i3), i3);
        }
        toRet.applyTransform(new DataTransform(){

            @Override
            public DataPoint transform(DataPoint dp) {
                int i = (Integer)indexMap.get(dp);
                return new DataPoint((Vec)embeded.get(i), dp.getCategoricalValues(), dp.getCategoricalData(), dp.getWeight());
            }

            @Override
            public void fit(DataSet data) {
            }

            @Override
            public DataTransform clone() {
                return this;
            }
        });
        return (Type)toRet;
    }
}

