/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.clustering.ClusterFailureException;
import jsat.clustering.SeedSelectionMethods;
import jsat.clustering.kmeans.KMeans;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DenseSparseMetric;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDoubleArray;
import jsat.utils.random.RandomUtil;

public class ElkanKMeans
extends KMeans {
    private static final long serialVersionUID = -1629432283103273051L;
    private DenseSparseMetric dmds;
    private boolean useDenseSparse = false;

    public ElkanKMeans(DistanceMetric dm, Random rand, SeedSelectionMethods.SeedSelection seedSelection) {
        super(dm, seedSelection, rand);
        if (!dm.isSubadditive()) {
            throw new ClusterFailureException("KMeans implementation requires the triangle inequality");
        }
    }

    public ElkanKMeans(DistanceMetric dm, Random rand) {
        this(dm, rand, DEFAULT_SEED_SELECTION);
    }

    public ElkanKMeans(DistanceMetric dm) {
        this(dm, RandomUtil.getRandom());
    }

    public ElkanKMeans() {
        this(new EuclideanDistance());
    }

    public ElkanKMeans(ElkanKMeans toCopy) {
        super(toCopy);
        if (toCopy.dmds != null) {
            this.dmds = (DenseSparseMetric)toCopy.dmds.clone();
        }
        this.useDenseSparse = toCopy.useDenseSparse;
    }

    public void setUseDenseSparse(boolean useDenseSparse) {
        this.useDenseSparse = useDenseSparse;
    }

    public boolean isUseDenseSparse() {
        return this.useDenseSparse;
    }

    @Override
    protected double cluster(final DataSet dataSet, List<Double> accelCache, final int k, final List<Vec> means, final int[] assignment, boolean exactTotal, ExecutorService threadpool, boolean returnError, Vec dataPointWeights) {
        try {
            final int N = dataSet.getSampleSize();
            final int D = dataSet.getNumNumericalVars();
            if (N < k) {
                throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
            }
            final Vec W = dataPointWeights == null ? dataSet.getDataWeights() : dataPointWeights;
            TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet);
            final List<Vec> X = dataSet.getDataVectors();
            final ArrayList<List<Double>> meanQIs = new ArrayList<List<Double>>(k);
            final List<Double> distAccelCache = accelCache == null ? (threadpool == null || threadpool instanceof FakeExecutor ? this.dm.getAccelerationCache(X) : this.dm.getAccelerationCache(X, threadpool)) : accelCache;
            if (means.size() != k) {
                means.clear();
                if (threadpool == null || threadpool instanceof FakeExecutor) {
                    means.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, k, this.dm, distAccelCache, this.rand, this.seedSelection));
                } else {
                    means.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, k, this.dm, distAccelCache, this.rand, this.seedSelection, threadpool));
                }
            }
            for (int i = 0; i < means.size(); ++i) {
                if (!means.get(i).isSparse()) continue;
                means.set(i, new DenseVector(means.get(i)));
            }
            final double[][] lowerBound = new double[N][k];
            final double[] upperBound = new double[N];
            final double[][] centroidSelfDistances = new double[k][k];
            final double[] sC = new double[k];
            this.calculateCentroidDistances(k, centroidSelfDistances, means, sC, null, threadpool);
            final AtomicDoubleArray meanCount = new AtomicDoubleArray(k);
            Vec[] oldMeans = new Vec[k];
            final Vec[] meanSums = new Vec[k];
            for (int i = 0; i < k; ++i) {
                oldMeans[i] = means.get(i).clone();
                if (this.dm.supportsAcceleration()) {
                    meanQIs.add(this.dm.getQueryInfo(means.get(i)));
                } else {
                    meanQIs.add(Collections.EMPTY_LIST);
                }
                meanSums[i] = new DenseVector(D);
            }
            if (this.dm instanceof DenseSparseMetric && this.useDenseSparse) {
                this.dmds = (DenseSparseMetric)this.dm;
            }
            final double[] meanSummaryConsts = this.dmds != null ? new double[means.size()] : null;
            int atLeast = 2;
            final AtomicBoolean changeOccurred = new AtomicBoolean(true);
            final boolean[] r = new boolean[N];
            final ThreadLocal<Vec[]> localDeltas = new ThreadLocal<Vec[]>(){

                @Override
                protected Vec[] initialValue() {
                    Vec[] toRet = new Vec[k];
                    for (int i = 0; i < toRet.length; ++i) {
                        toRet[i] = new DenseVector(D);
                    }
                    return toRet;
                }
            };
            if (threadpool == null) {
                this.initialClusterSetUp(k, N, X, means, lowerBound, upperBound, centroidSelfDistances, assignment, meanCount, meanSums, distAccelCache, meanQIs, W);
            } else {
                this.initialClusterSetUp(k, N, X, means, lowerBound, upperBound, centroidSelfDistances, assignment, meanCount, meanSums, distAccelCache, meanQIs, localDeltas, threadpool, W);
            }
            int iterLimit = this.MaxIterLimit;
            while ((changeOccurred.get() || atLeast > 0) && iterLimit-- >= 0) {
                --atLeast;
                changeOccurred.set(false);
                if (iterLimit < this.MaxIterLimit - 1) {
                    this.calculateCentroidDistances(k, centroidSelfDistances, means, sC, meanSummaryConsts, threadpool);
                }
                final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
                if (threadpool == null) {
                    for (int q = 0; q < N; ++q) {
                        if (upperBound[q] <= sC[assignment[q]]) continue;
                        Vec v = X.get(q);
                        for (int c = 0; c < k; ++c) {
                            if (c == assignment[q] || !(upperBound[q] > lowerBound[q][c]) || !(upperBound[q] > centroidSelfDistances[assignment[q]][c] * 0.5)) continue;
                            this.step3aBoundsUpdate(X, r, q, v, means, assignment, upperBound, lowerBound, meanSummaryConsts, distAccelCache, meanQIs);
                            this.step3bUpdate(X, upperBound, q, lowerBound, c, centroidSelfDistances, assignment, v, means, localDeltas, meanCount, changeOccurred, meanSummaryConsts, distAccelCache, meanQIs, W);
                        }
                    }
                } else {
                    int id = 0;
                    while (id < SystemInfo.LogicalCores) {
                        final int ID = id++;
                        threadpool.submit(new Runnable(){

                            @Override
                            public void run() {
                                for (int q = ID; q < N; q += SystemInfo.LogicalCores) {
                                    if (upperBound[q] <= sC[assignment[q]]) continue;
                                    Vec v = dataSet.getDataPoint(q).getNumericalValues();
                                    for (int c = 0; c < k; ++c) {
                                        if (c == assignment[q] || !(upperBound[q] > lowerBound[q][c]) || !(upperBound[q] > centroidSelfDistances[assignment[q]][c] * 0.5)) continue;
                                        ElkanKMeans.this.step3aBoundsUpdate(X, r, q, v, means, assignment, upperBound, lowerBound, meanSummaryConsts, distAccelCache, meanQIs);
                                        ElkanKMeans.this.step3bUpdate(X, upperBound, q, lowerBound, c, centroidSelfDistances, assignment, v, means, localDeltas, meanCount, changeOccurred, meanSummaryConsts, distAccelCache, meanQIs, W);
                                    }
                                }
                                ElkanKMeans.this.step4UpdateCentroids(meanSums, localDeltas);
                                latch.countDown();
                            }
                        });
                    }
                }
                if (threadpool != null) {
                    try {
                        latch.await();
                    }
                    catch (InterruptedException ex) {
                        throw new ClusterFailureException("Clustering failed");
                    }
                } else {
                    this.step4UpdateCentroids(meanSums, localDeltas);
                }
                this.step5_6_distanceMovedBoundsUpdate(k, oldMeans, means, meanSums, meanCount, N, lowerBound, upperBound, assignment, r, meanQIs, threadpool);
            }
            double totalDistance = 0.0;
            if (returnError) {
                this.nearestCentroidDist = (double[])(this.saveCentroidDistance ? new double[N] : null);
                if (exactTotal) {
                    for (int i = 0; i < N; ++i) {
                        double dist = this.dm.dist(i, means.get(assignment[i]), (List)meanQIs.get(assignment[i]), X, distAccelCache);
                        totalDistance += Math.pow(dist, 2.0);
                        if (!this.saveCentroidDistance) continue;
                        this.nearestCentroidDist[i] = dist;
                    }
                } else {
                    for (int i = 0; i < N; ++i) {
                        totalDistance += Math.pow(upperBound[i], 2.0);
                        if (!this.saveCentroidDistance) continue;
                        this.nearestCentroidDist[i] = upperBound[i];
                    }
                }
            }
            return totalDistance;
        }
        catch (Exception ex) {
            Logger.getLogger(ElkanKMeans.class.getName()).log(Level.SEVERE, null, ex);
            return Double.MAX_VALUE;
        }
    }

    private void initialClusterSetUp(int k, int N, List<Vec> dataSet, List<Vec> means, double[][] lowerBound, double[] upperBound, double[][] centroidSelfDistances, int[] assignment, AtomicDoubleArray meanCount, Vec[] meanSums, List<Double> distAccelCache, List<List<Double>> meanQIs, Vec W) {
        boolean[] skip = new boolean[k];
        for (int q = 0; q < N; ++q) {
            Vec v = dataSet.get(q);
            double minDistance = Double.MAX_VALUE;
            int index = -1;
            Arrays.fill(skip, false);
            for (int i = 0; i < k; ++i) {
                double d;
                if (skip[i]) continue;
                lowerBound[q][i] = d = this.dm.dist(q, means.get(i), meanQIs.get(i), dataSet, distAccelCache);
                if (!(d < minDistance)) continue;
                minDistance = upperBound[q] = d;
                index = i;
                for (int z = i + 1; z < k; ++z) {
                    if (!(centroidSelfDistances[i][z] >= 2.0 * d)) continue;
                    skip[z] = true;
                }
            }
            assignment[q] = index;
            double weight = W.get(q);
            meanCount.addAndGet(index, weight);
            meanSums[index].mutableAdd(weight, v);
        }
    }

    private void initialClusterSetUp(final int k, int N, final List<Vec> dataSet, final List<Vec> means, final double[][] lowerBound, final double[] upperBound, final double[][] centroidSelfDistances, final int[] assignment, final AtomicDoubleArray meanCount, final Vec[] meanSums, final List<Double> distAccelCache, final List<List<Double>> meanQIs, final ThreadLocal<Vec[]> localDeltas, ExecutorService threadpool, final Vec W) {
        int blockSize = N / SystemInfo.LogicalCores;
        int extra = N % SystemInfo.LogicalCores;
        int pos = 0;
        final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
        while (pos < N) {
            int to;
            final int from = pos;
            pos = to = pos + blockSize + (extra-- > 0 ? 1 : 0);
            threadpool.submit(new Runnable(){

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                @Override
                public void run() {
                    Vec[] deltas = (Vec[])localDeltas.get();
                    boolean[] skip = new boolean[k];
                    for (int q = from; q < to; ++q) {
                        Vec v = (Vec)dataSet.get(q);
                        double minDistance = Double.MAX_VALUE;
                        int index = -1;
                        Arrays.fill(skip, false);
                        for (int i = 0; i < k; ++i) {
                            double d;
                            if (skip[i]) continue;
                            lowerBound[q][i] = d = ElkanKMeans.this.dm.dist(q, (Vec)means.get(i), (List)meanQIs.get(i), dataSet, distAccelCache);
                            if (!(d < minDistance)) continue;
                            minDistance = upperBound[q] = d;
                            index = i;
                            for (int z = i + 1; z < k; ++z) {
                                if (!(centroidSelfDistances[i][z] >= 2.0 * d)) continue;
                                skip[z] = true;
                            }
                        }
                        assignment[q] = index;
                        double weight = W.get(q);
                        meanCount.addAndGet(index, weight);
                        deltas[index].mutableAdd(weight, v);
                    }
                    for (int i = 0; i < deltas.length; ++i) {
                        Vec vec = meanSums[i];
                        synchronized (vec) {
                            meanSums[i].mutableAdd(deltas[i]);
                        }
                        deltas[i].zeroOut();
                    }
                    latch.countDown();
                }
            });
        }
        while (pos++ < SystemInfo.LogicalCores) {
            latch.countDown();
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            Logger.getLogger(ElkanKMeans.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void step4UpdateCentroids(Vec[] meanSums, ThreadLocal<Vec[]> localDeltas) {
        Vec[] deltas = localDeltas.get();
        for (int i = 0; i < deltas.length; ++i) {
            if (deltas[i].nnz() == 0) continue;
            Vec vec = meanSums[i];
            synchronized (vec) {
                meanSums[i].mutableAdd(deltas[i]);
            }
            deltas[i].zeroOut();
        }
    }

    private void step5_6_distanceMovedBoundsUpdate(int k, final Vec[] oldMeans, final List<Vec> means, final Vec[] meanSums, final AtomicDoubleArray meanCount, final int N, final double[][] lowerBound, final double[] upperBound, final int[] assignment, final boolean[] r, final List<List<Double>> meanQIs, ExecutorService threadpool) {
        int i;
        final double[] distancesMoved = new double[k];
        if (threadpool != null) {
            try {
                final CountDownLatch latch1 = new CountDownLatch(k);
                int i2 = 0;
                while (i2 < k) {
                    final int c = i2++;
                    threadpool.submit(new Runnable(){

                        @Override
                        public void run() {
                            ((Vec)means.get(c)).copyTo(oldMeans[c]);
                            meanSums[c].copyTo((Vec)means.get(c));
                            double count = meanCount.get(c);
                            if (count <= 1.0E-14) {
                                ((Vec)means.get(c)).zeroOut();
                            } else {
                                ((Vec)means.get(c)).mutableDivide(meanCount.get(c));
                            }
                            distancesMoved[c] = ElkanKMeans.this.dm.dist(oldMeans[c], (Vec)means.get(c));
                            if (ElkanKMeans.this.dm.supportsAcceleration()) {
                                meanQIs.set(c, ElkanKMeans.this.dm.getQueryInfo((Vec)means.get(c)));
                            }
                            for (int q = 0; q < N; ++q) {
                                lowerBound[q][c] = Math.max(lowerBound[q][c] - distancesMoved[c], 0.0);
                            }
                            latch1.countDown();
                        }
                    });
                }
                latch1.await();
                final CountDownLatch latch2 = new CountDownLatch(SystemInfo.LogicalCores);
                int blockSize = N / SystemInfo.LogicalCores;
                for (int id = 0; id < SystemInfo.LogicalCores; ++id) {
                    final int start = id * blockSize;
                    final int end = id == SystemInfo.LogicalCores - 1 ? N : start + blockSize;
                    threadpool.submit(new Runnable(){

                        @Override
                        public void run() {
                            for (int q = start; q < end; ++q) {
                                int n = q;
                                upperBound[n] = upperBound[n] + distancesMoved[assignment[q]];
                                r[q] = true;
                            }
                            latch2.countDown();
                        }
                    });
                }
                latch2.await();
                return;
            }
            catch (InterruptedException ex) {
                Logger.getLogger(ElkanKMeans.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        for (i = 0; i < k; ++i) {
            means.get(i).copyTo(oldMeans[i]);
        }
        for (i = 0; i < k; ++i) {
            meanSums[i].copyTo(means.get(i));
            double count = meanCount.get(i);
            if (count <= 1.0E-14) {
                means.get(i).zeroOut();
                continue;
            }
            means.get(i).mutableDivide(meanCount.get(i));
        }
        for (i = 0; i < k; ++i) {
            distancesMoved[i] = this.dm.dist(oldMeans[i], means.get(i));
            if (!this.dm.supportsAcceleration()) continue;
            meanQIs.set(i, this.dm.getQueryInfo(means.get(i)));
        }
        for (int c = 0; c < k; ++c) {
            for (int q = 0; q < N; ++q) {
                lowerBound[q][c] = Math.max(lowerBound[q][c] - distancesMoved[c], 0.0);
            }
        }
        for (int q = 0; q < N; ++q) {
            int n = q;
            upperBound[n] = upperBound[n] + distancesMoved[assignment[q]];
            r[q] = true;
        }
    }

    private void step3aBoundsUpdate(List<Vec> X, boolean[] r, int q, Vec v, List<Vec> means, int[] assignment, double[] upperBound, double[][] lowerBound, double[] meanSummaryConsts, List<Double> distAccelCache, List<List<Double>> meanQIs) {
        if (r[q]) {
            r[q] = false;
            int meanIndx = assignment[q];
            double d = this.dmds == null ? this.dm.dist(q, means.get(meanIndx), meanQIs.get(meanIndx), X, distAccelCache) : this.dmds.dist(meanSummaryConsts[meanIndx], means.get(meanIndx), v);
            lowerBound[q][meanIndx] = d;
            upperBound[q] = d;
        }
    }

    private void step3bUpdate(List<Vec> X, double[] upperBound, int q, double[][] lowerBound, int c, double[][] centroidSelfDistances, int[] assignment, Vec v, List<Vec> means, ThreadLocal<Vec[]> localDeltas, AtomicDoubleArray meanCount, AtomicBoolean changeOccurred, double[] meanSummaryConsts, List<Double> distAccelCache, List<List<Double>> meanQIs, Vec W) {
        if (upperBound[q] > lowerBound[q][c] || upperBound[q] > centroidSelfDistances[assignment[q]][c] / 2.0) {
            double d = this.dmds == null ? this.dm.dist(q, means.get(c), meanQIs.get(c), X, distAccelCache) : this.dmds.dist(meanSummaryConsts[c], means.get(c), v);
            lowerBound[q][c] = d;
            if (d < upperBound[q]) {
                Vec[] deltas = localDeltas.get();
                double weight = W.get(q);
                deltas[assignment[q]].mutableSubtract(weight, v);
                meanCount.addAndGet(assignment[q], -weight);
                deltas[c].mutableAdd(weight, v);
                meanCount.addAndGet(c, weight);
                assignment[q] = c;
                upperBound[q] = d;
                changeOccurred.set(true);
            }
        }
    }

    private void calculateCentroidDistances(int k, final double[][] centroidSelfDistances, final List<Vec> means, double[] sC, final double[] meanSummaryConsts, ExecutorService threadpool) {
        int i;
        List<Double> meanAccelCache;
        List<Double> list = meanAccelCache = this.dm.supportsAcceleration() ? this.dm.getAccelerationCache(means) : null;
        if (threadpool != null) {
            int jobs = (1 + k) * k / 2 - k;
            final CountDownLatch latch = new CountDownLatch(jobs);
            for (int i2 = 0; i2 < k; ++i2) {
                final int ii = i2;
                int z = i2 + 1;
                while (z < k) {
                    final int zz = z++;
                    threadpool.submit(new Runnable(){

                        @Override
                        public void run() {
                            centroidSelfDistances[ii][zz] = ElkanKMeans.this.dm.dist(ii, zz, (List<? extends Vec>)means, (List<Double>)meanAccelCache);
                            if (meanSummaryConsts != null) {
                                meanSummaryConsts[ii] = ElkanKMeans.this.dmds.getVectorConstant((Vec)means.get(ii));
                            }
                            latch.countDown();
                        }
                    });
                }
            }
            try {
                latch.await();
            }
            catch (InterruptedException ex) {
                Logger.getLogger(ElkanKMeans.class.getName()).log(Level.SEVERE, null, ex);
            }
        } else {
            for (i = 0; i < k; ++i) {
                for (int z = i + 1; z < k; ++z) {
                    double d = this.dm.dist(i, z, means, meanAccelCache);
                    centroidSelfDistances[i][z] = d;
                    centroidSelfDistances[z][i] = d;
                }
                if (meanSummaryConsts == null) continue;
                meanSummaryConsts[i] = this.dmds.getVectorConstant(means.get(i));
            }
        }
        for (i = 0; i < k; ++i) {
            double sCmin = Double.MAX_VALUE;
            for (int z = 0; z < k; ++z) {
                if (z == i) continue;
                sCmin = Math.min(sCmin, centroidSelfDistances[i][z]);
            }
            sC[i] = sCmin / 2.0;
        }
    }

    @Override
    public ElkanKMeans clone() {
        return new ElkanKMeans(this);
    }
}

