/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform.visualization;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.datatransform.visualization.VisualizationTransform;
import jsat.linear.DenseMatrix;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.math.OnLineStatistics;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDouble;
import jsat.utils.random.XORWOW;

public class MDS
implements VisualizationTransform {
    private static DistanceMetric embedMetric = new EuclideanDistance();
    private DistanceMetric dm = new EuclideanDistance();
    private double tolerance = 0.001;
    private int maxIterations = 300;
    private int targetSize = 2;

    public void setTolerance(double tolerance) {
        if (tolerance < 0.0 || Double.isInfinite(tolerance) || Double.isNaN(tolerance)) {
            throw new IllegalArgumentException("tolerance must be a non-negative value, not " + tolerance);
        }
        this.tolerance = tolerance;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setEmbeddingMetric(DistanceMetric embedMetric) {
        MDS.embedMetric = embedMetric;
    }

    public DistanceMetric getEmbeddingMetric() {
        return embedMetric;
    }

    @Override
    public <Type extends DataSet> Type transform(DataSet<Type> d) {
        return this.transform(d, (ExecutorService)new FakeExecutor());
    }

    @Override
    public <Type extends DataSet> Type transform(final DataSet<Type> d, ExecutorService ex) {
        final List<Vec> orig_vecs = d.getDataVectors();
        final List<Double> orig_distCache = this.dm.getAccelerationCache(orig_vecs, ex);
        int N = orig_vecs.size();
        final DenseMatrix delta = new DenseMatrix(N, N);
        OnLineStatistics avg = new OnLineStatistics();
        ArrayList<Future<OnLineStatistics>> futureStats = new ArrayList<Future<OnLineStatistics>>();
        int id = 0;
        while (id < SystemInfo.LogicalCores) {
            final int n = id++;
            futureStats.add(ex.submit(new Callable<OnLineStatistics>(){

                @Override
                public OnLineStatistics call() throws Exception {
                    OnLineStatistics local_avg = new OnLineStatistics();
                    for (int i = n; i < d.getSampleSize(); i += SystemInfo.LogicalCores) {
                        for (int j = i + 1; j < d.getSampleSize(); ++j) {
                            double dist = MDS.this.dm.dist(i, j, (List<? extends Vec>)orig_vecs, (List<Double>)orig_distCache);
                            local_avg.add(dist);
                            delta.set(i, j, dist);
                            delta.set(j, i, dist);
                        }
                    }
                    return local_avg;
                }
            }));
        }
        for (Future future : futureStats) {
            try {
                avg.add((OnLineStatistics)future.get());
            }
            catch (InterruptedException ex1) {
                Logger.getLogger(MDS.class.getName()).log(Level.SEVERE, null, ex1);
            }
            catch (ExecutionException ex1) {
                Logger.getLogger(MDS.class.getName()).log(Level.SEVERE, null, ex1);
            }
        }
        SimpleDataSet embeded = this.transform(delta, ex);
        DataSet<Type> dataSet = d.shallowClone();
        dataSet.replaceNumericFeatures(embeded.getDataVectors());
        return (Type)dataSet;
    }

    public SimpleDataSet transform(Matrix delta) {
        return this.transform(delta, (ExecutorService)new FakeExecutor());
    }

    public SimpleDataSet transform(final Matrix delta, ExecutorService ex) {
        int N = delta.rows();
        XORWOW rand = new XORWOW();
        DenseMatrix X = new DenseMatrix(N, this.targetSize);
        final ArrayList<Vec> X_views = new ArrayList<Vec>();
        for (int i = 0; i < N; ++i) {
            for (int j = 0; j < this.targetSize; ++j) {
                ((Matrix)X).set(i, j, ((Random)rand).nextDouble());
            }
            X_views.add(((Matrix)X).getRowView(i));
        }
        final List<Double> X_rowCache = embedMetric.getAccelerationCache(X_views, ex);
        DenseMatrix V_inv = new DenseMatrix(N, N);
        for (int i = 0; i < N; ++i) {
            for (int j = 0; j < N; ++j) {
                if (i == j) {
                    ((Matrix)V_inv).set(i, j, (1.0 - 1.0 / (double)N) / (double)N);
                    continue;
                }
                ((Matrix)V_inv).set(i, j, (0.0 - 1.0 / (double)N) / (double)N);
            }
        }
        double stressChange = Double.POSITIVE_INFINITY;
        double oldStress = MDS.stress(X_views, X_rowCache, delta, ex);
        final DenseMatrix B = new DenseMatrix(N, N);
        DenseMatrix X_new = new DenseMatrix(((Matrix)X).rows(), ((Matrix)X).cols());
        for (int iter = 0; iter < this.maxIterations && stressChange > this.tolerance; ++iter) {
            final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
            int id = 0;
            while (id < SystemInfo.LogicalCores) {
                final int ID = id++;
                ex.submit(new Runnable(){

                    @Override
                    public void run() {
                        for (int i = ID; i < B.rows(); i += SystemInfo.LogicalCores) {
                            for (int j = i + 1; j < B.rows(); ++j) {
                                double d_ij = embedMetric.dist(i, j, (List<? extends Vec>)X_views, (List<Double>)X_rowCache);
                                if (d_ij > 1.0E-5) {
                                    double b_ij = -delta.get(i, j) / d_ij;
                                    B.set(i, j, b_ij);
                                    B.set(j, i, b_ij);
                                    continue;
                                }
                                B.set(i, j, 0.0);
                                B.set(j, i, 0.0);
                            }
                        }
                        latch.countDown();
                    }
                });
            }
            ((Matrix)X_new).zeroOut();
            try {
                latch.await();
            }
            catch (InterruptedException ex1) {
                Logger.getLogger(MDS.class.getName()).log(Level.SEVERE, null, ex1);
            }
            for (int i = 0; i < ((Matrix)B).rows(); ++i) {
                ((Matrix)B).set(i, i, 0.0);
                for (int k = 0; k < ((Matrix)B).cols(); ++k) {
                    if (k == i) continue;
                    B.increment(i, i, -((Matrix)B).get(i, k));
                }
            }
            ((Matrix)B).multiply(X, X_new, ex);
            ((Matrix)X_new).mutableMultiply(1.0 / (double)N);
            X_new.copyTo(X);
            X_rowCache.clear();
            X_rowCache.addAll(embedMetric.getAccelerationCache(X_views, ex));
            double newStress = MDS.stress(X_views, X_rowCache, delta, ex);
            stressChange = Math.abs(oldStress - newStress);
            oldStress = newStress;
        }
        SimpleDataSet sds = new SimpleDataSet(new CategoricalData[0], this.targetSize);
        for (Vec v : X_views) {
            sds.add(new DataPoint(v));
        }
        return sds;
    }

    private static double stress(final List<Vec> X_views, final List<Double> X_rowCache, final Matrix delta, ExecutorService ex) {
        final AtomicDouble stress = new AtomicDouble(0.0);
        final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
        int id = 0;
        while (id < SystemInfo.LogicalCores) {
            final int ID = id++;
            ex.submit(new Runnable(){

                @Override
                public void run() {
                    double localStress = 0.0;
                    for (int i = ID; i < delta.rows(); i += SystemInfo.LogicalCores) {
                        for (int j = i + 1; j < delta.rows(); ++j) {
                            double tmp = embedMetric.dist(i, j, (List<? extends Vec>)X_views, (List<Double>)X_rowCache) - delta.get(i, j);
                            localStress += tmp * tmp;
                        }
                    }
                    stress.addAndGet(localStress);
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex1) {
            Logger.getLogger(MDS.class.getName()).log(Level.SEVERE, null, ex1);
        }
        return stress.get();
    }

    @Override
    public int getTargetDimension() {
        return this.targetSize;
    }

    @Override
    public boolean setTargetDimension(int target) {
        if (target < 1) {
            return false;
        }
        this.targetSize = target;
        return true;
    }
}

