/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.linear.MatrixStatistics;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.utils.FakeExecutor;
import jsat.utils.IndexTable;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

public class SeedSelectionMethods {
    private SeedSelectionMethods() {
    }

    public static List<Vec> selectIntialPoints(DataSet d, int k, DistanceMetric dm, Random rand, SeedSelection selectionMethod) {
        return SeedSelectionMethods.selectIntialPoints(d, k, dm, null, rand, selectionMethod);
    }

    public static List<Vec> selectIntialPoints(DataSet d, int k, DistanceMetric dm, List<Double> accelCache, Random rand, SeedSelection selectionMethod) {
        int[] indicies = new int[k];
        SeedSelectionMethods.selectIntialPoints(d, indicies, dm, accelCache, rand, selectionMethod, null);
        ArrayList<Vec> vecs = new ArrayList<Vec>(k);
        int[] nArray = indicies;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            Integer i2 = nArray[i];
            vecs.add(d.getDataPoint(i2).getNumericalValues().clone());
        }
        return vecs;
    }

    public static List<Vec> selectIntialPoints(DataSet d, int k, DistanceMetric dm, Random rand, SeedSelection selectionMethod, ExecutorService threadpool) {
        return SeedSelectionMethods.selectIntialPoints(d, k, dm, null, rand, selectionMethod, threadpool);
    }

    public static List<Vec> selectIntialPoints(DataSet d, int k, DistanceMetric dm, List<Double> accelCache, Random rand, SeedSelection selectionMethod, ExecutorService threadpool) {
        int[] indicies = new int[k];
        SeedSelectionMethods.selectIntialPoints(d, indicies, dm, accelCache, rand, selectionMethod, threadpool);
        ArrayList<Vec> vecs = new ArrayList<Vec>(k);
        int[] nArray = indicies;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            Integer i2 = nArray[i];
            vecs.add(d.getDataPoint(i2).getNumericalValues().clone());
        }
        return vecs;
    }

    public static void selectIntialPoints(DataSet d, int[] indices, DistanceMetric dm, Random rand, SeedSelection selectionMethod) {
        SeedSelectionMethods.selectIntialPoints(d, indices, dm, null, rand, selectionMethod);
    }

    public static void selectIntialPoints(DataSet d, int[] indices, DistanceMetric dm, List<Double> accelCache, Random rand, SeedSelection selectionMethod) {
        SeedSelectionMethods.selectIntialPoints(d, indices, dm, accelCache, rand, selectionMethod, null);
    }

    public static void selectIntialPoints(DataSet d, int[] indices, DistanceMetric dm, Random rand, SeedSelection selectionMethod, ExecutorService threadpool) {
        SeedSelectionMethods.selectIntialPoints(d, indices, dm, null, rand, selectionMethod, threadpool);
    }

    public static void selectIntialPoints(DataSet d, int[] indices, DistanceMetric dm, List<Double> accelCache, Random rand, SeedSelection selectionMethod, ExecutorService threadpool) {
        try {
            int k = indices.length;
            if (selectionMethod == SeedSelection.RANDOM) {
                IntSet indecies = new IntSet(k);
                while (indecies.size() != k) {
                    indecies.add(rand.nextInt(d.getSampleSize()));
                }
                int j = 0;
                for (Integer i : indecies) {
                    indices[j++] = i;
                }
            } else if (selectionMethod == SeedSelection.KPP) {
                if (threadpool == null || threadpool instanceof FakeExecutor) {
                    SeedSelectionMethods.kppSelection(indices, rand, d, k, dm, accelCache);
                } else {
                    SeedSelectionMethods.kppSelection(indices, rand, d, k, dm, accelCache, threadpool);
                }
            } else if (selectionMethod == SeedSelection.FARTHEST_FIRST) {
                if (threadpool == null) {
                    SeedSelectionMethods.ffSelection(indices, rand, d, k, dm, accelCache, new FakeExecutor());
                } else {
                    SeedSelectionMethods.ffSelection(indices, rand, d, k, dm, accelCache, threadpool);
                }
            } else if (selectionMethod == SeedSelection.MEAN_QUANTILES) {
                if (threadpool == null) {
                    SeedSelectionMethods.mqSelection(indices, d, k, dm, accelCache, new FakeExecutor());
                } else {
                    SeedSelectionMethods.mqSelection(indices, d, k, dm, accelCache, threadpool);
                }
            }
        }
        catch (InterruptedException ex) {
            Logger.getLogger(SeedSelectionMethods.class.getName()).log(Level.SEVERE, null, ex);
        }
        catch (ExecutionException ex) {
            Logger.getLogger(SeedSelectionMethods.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    private static void kppSelection(int[] indices, Random rand, DataSet d, int k, DistanceMetric dm, List<Double> accelCache) {
        indices[0] = rand.nextInt(d.getSampleSize());
        double[] closestDist = new double[d.getSampleSize()];
        double sqrdDistSum = 0.0;
        List<Vec> vecs = d.getDataVectors();
        for (int j = 1; j < k; ++j) {
            int newMeanIndx = indices[j - 1];
            for (int i = 0; i < d.getSampleSize(); ++i) {
                double newDist = dm.dist(newMeanIndx, i, vecs, accelCache);
                if (!((newDist *= newDist) < closestDist[i]) && j != 1) continue;
                sqrdDistSum -= closestDist[i];
                sqrdDistSum += newDist;
                closestDist[i] = newDist;
            }
            if (sqrdDistSum <= 1.0E-6) {
                IntSet ind = new IntSet();
                for (int i = 0; i < j; ++i) {
                    ind.add(indices[i]);
                }
                while (ind.size() < k) {
                    ind.add(rand.nextInt(closestDist.length));
                }
                int pos = 0;
                Iterator iterator = ind.iterator();
                while (iterator.hasNext()) {
                    int i = (Integer)iterator.next();
                    indices[pos++] = i;
                }
                return;
            }
            double rndX = rand.nextDouble() * sqrdDistSum;
            int i = 0;
            for (double searchSum = closestDist[0]; searchSum < rndX && i < d.getSampleSize() - 1; searchSum += closestDist[++i]) {
            }
            indices[j] = i;
        }
    }

    private static void kppSelection(int[] indices, Random rand, DataSet d, int k, final DistanceMetric dm, final List<Double> accelCache, ExecutorService threadpool) throws InterruptedException, ExecutionException {
        indices[0] = rand.nextInt(d.getSampleSize());
        final double[] closestDist = new double[d.getSampleSize()];
        double sqrdDistSum = 0.0;
        final List<Vec> X = d.getDataVectors();
        ArrayList futureChanges = new ArrayList(SystemInfo.LogicalCores);
        for (int j = 1; j < k; ++j) {
            final int newMeanIndx = indices[j - 1];
            futureChanges.clear();
            for (int id = 0; id < SystemInfo.LogicalCores; ++id) {
                final int from = ParallelUtils.getStartBlock(X.size(), id, SystemInfo.LogicalCores);
                final int to = ParallelUtils.getEndBlock(X.size(), id, SystemInfo.LogicalCores);
                final boolean forceCompute = j == 1;
                Future<Double> future = threadpool.submit(new Callable<Double>(){

                    @Override
                    public Double call() throws Exception {
                        double sqrdDistChanges = 0.0;
                        for (int i = from; i < to; ++i) {
                            double newDist = dm.dist(newMeanIndx, i, (List<? extends Vec>)X, (List<Double>)accelCache);
                            if (!((newDist *= newDist) < closestDist[i]) && !forceCompute) continue;
                            sqrdDistChanges -= closestDist[i];
                            sqrdDistChanges += newDist;
                            closestDist[i] = newDist;
                        }
                        return sqrdDistChanges;
                    }
                });
                futureChanges.add(future);
            }
            for (Double change : ListUtils.collectFutures(futureChanges)) {
                sqrdDistSum += change.doubleValue();
            }
            if (sqrdDistSum <= 1.0E-6) {
                IntSet ind = new IntSet();
                for (int i = 0; i < j; ++i) {
                    ind.add(indices[i]);
                }
                while (ind.size() < k) {
                    ind.add(rand.nextInt(closestDist.length));
                }
                int pos = 0;
                Iterator to = ind.iterator();
                while (to.hasNext()) {
                    int i = (Integer)to.next();
                    indices[pos++] = i;
                }
                return;
            }
            double rndX = rand.nextDouble() * sqrdDistSum;
            int i = 0;
            for (double searchSum = closestDist[0]; searchSum < rndX && i < d.getSampleSize() - 1; searchSum += closestDist[++i]) {
            }
            indices[j] = i;
        }
    }

    private static void ffSelection(int[] indices, Random rand, DataSet d, int k, final DistanceMetric dm, final List<Double> accelCache, ExecutorService threadpool) throws InterruptedException, ExecutionException {
        indices[0] = rand.nextInt(d.getSampleSize());
        final double[] closestDist = new double[d.getSampleSize()];
        Arrays.fill(closestDist, Double.POSITIVE_INFINITY);
        final List<Vec> X = d.getDataVectors();
        ArrayList futures = new ArrayList(SystemInfo.LogicalCores);
        for (int j = 1; j < k; ++j) {
            final int newMeanIndx = indices[j - 1];
            futures.clear();
            int blockSize = d.getSampleSize() / SystemInfo.LogicalCores;
            int extra = d.getSampleSize() % SystemInfo.LogicalCores;
            int pos = 0;
            while (pos < d.getSampleSize()) {
                int to;
                final int from = pos;
                pos = to = Math.min(pos + blockSize + (extra-- > 0 ? 1 : 0), d.getSampleSize());
                Future<Integer> future = threadpool.submit(new Callable<Integer>(){

                    @Override
                    public Integer call() throws Exception {
                        double maxDist = Double.NEGATIVE_INFINITY;
                        int max = -1;
                        for (int i = from; i < to; ++i) {
                            double newDist = dm.dist(newMeanIndx, i, (List<? extends Vec>)X, (List<Double>)accelCache);
                            closestDist[i] = Math.min(newDist, closestDist[i]);
                            if (!(closestDist[i] > maxDist)) continue;
                            maxDist = closestDist[i];
                            max = i;
                        }
                        return max;
                    }
                });
                futures.add(future);
            }
            int max = -1;
            double maxDist = Double.NEGATIVE_INFINITY;
            for (Integer localMax : ListUtils.collectFutures(futures)) {
                if (!(closestDist[localMax] > maxDist)) continue;
                max = localMax;
                maxDist = closestDist[localMax];
            }
            indices[j] = max;
        }
    }

    private static void mqSelection(int[] indices, DataSet d, int k, final DistanceMetric dm, final List<Double> accelCache, ExecutorService threadpool) throws InterruptedException, ExecutionException {
        final double[] meanDist = new double[d.getSampleSize()];
        final Vec newMean = MatrixStatistics.meanVector(d);
        final List<Double> meanQI = dm.getQueryInfo(newMean);
        final List<Vec> X = d.getDataVectors();
        final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
        int blockSize = d.getSampleSize() / SystemInfo.LogicalCores;
        int extra = d.getSampleSize() % SystemInfo.LogicalCores;
        int pos = 0;
        while (pos < d.getSampleSize()) {
            int to;
            final int from = pos;
            pos = to = Math.min(pos + blockSize + (extra-- > 0 ? 1 : 0), d.getSampleSize());
            threadpool.submit(new Runnable(){

                @Override
                public void run() {
                    for (int i = from; i < to; ++i) {
                        meanDist[i] = dm.dist(i, newMean, meanQI, X, accelCache);
                    }
                    latch.countDown();
                }
            });
        }
        latch.await();
        IndexTable indxTbl = new IndexTable(meanDist);
        for (int l = 0; l < k; ++l) {
            indices[l] = indxTbl.index(l * d.getSampleSize() / k);
        }
    }

    public static enum SeedSelection {
        RANDOM,
        KPP,
        FARTHEST_FIRST,
        MEAN_QUANTILES;

    }
}

