/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.clustering.SeedSelectionMethods;
import jsat.clustering.kmeans.KMeans;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.DoubleList;
import jsat.utils.IndexTable;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDoubleArray;
import jsat.utils.random.RandomUtil;

public class HamerlyKMeans
extends KMeans {
    private static final long serialVersionUID = -4960453870335145091L;

    public HamerlyKMeans(DistanceMetric dm, SeedSelectionMethods.SeedSelection seedSelection, Random rand) {
        super(dm, seedSelection, rand);
    }

    public HamerlyKMeans(DistanceMetric dm, SeedSelectionMethods.SeedSelection seedSelection) {
        this(dm, seedSelection, RandomUtil.getRandom());
    }

    public HamerlyKMeans() {
        this(new EuclideanDistance(), SeedSelectionMethods.SeedSelection.KPP);
    }

    public HamerlyKMeans(HamerlyKMeans toCopy) {
        super(toCopy);
    }

    @Override
    protected double cluster(final DataSet dataSet, List<Double> accelCache, final int k, final List<Vec> means, final int[] assignment, boolean exactTotal, ExecutorService threadpool, boolean returnError, Vec dataPointWeights) {
        final int N = dataSet.getSampleSize();
        final int D = dataSet.getNumNumericalVars();
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, threadpool);
        final Vec W = dataPointWeights == null ? dataSet.getDataWeights() : dataPointWeights;
        final List<Vec> X = dataSet.getDataVectors();
        final List<Double> distAccel = accelCache == null ? this.dm.getAccelerationCache(X, threadpool) : accelCache;
        final ArrayList<List<Double>> meanQI = new ArrayList<List<Double>>(k);
        if (means.size() != k) {
            means.clear();
            means.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, k, this.dm, distAccel, this.rand, this.seedSelection, threadpool));
        }
        Vec[] oldMeans = new Vec[means.size()];
        double[] distanceMoved = new double[means.size()];
        for (int i = 0; i < means.size(); ++i) {
            if (means.get(i).isSparse()) {
                means.set(i, new DenseVector(means.get(i)));
            }
            oldMeans[i] = new DenseVector(means.get(i));
        }
        final Vec[] cP = new Vec[k];
        Vec[] tmpVecs = new Vec[k];
        Vec[] tmpVecs2 = new Vec[k];
        for (int i = 0; i < tmpVecs2.length; ++i) {
            tmpVecs2[i] = new DenseVector(oldMeans[0].length());
        }
        final AtomicDoubleArray q = new AtomicDoubleArray(k);
        double[] p = new double[k];
        final double[] s = new double[k];
        final double[] u = new double[N];
        final double[] l = new double[N];
        final ThreadLocal<Vec[]> localDeltas = new ThreadLocal<Vec[]>(){

            @Override
            protected Vec[] initialValue() {
                Vec[] toRet = new Vec[means.size()];
                for (int i = 0; i < k; ++i) {
                    toRet[i] = new DenseVector(D);
                }
                return toRet;
            }
        };
        this.Initialize(dataSet, q, means, tmpVecs, cP, u, l, assignment, threadpool, localDeltas, X, distAccel, meanQI, W);
        for (int i = 0; i < means.size(); ++i) {
            if (!means.get(i).isSparse()) continue;
            means.set(i, new DenseVector(means.get(i)));
        }
        final AtomicInteger updates = new AtomicInteger(N);
        int iteration = 0;
        while (updates.get() > 0) {
            this.moveCenters(means, oldMeans, tmpVecs, cP, q, p, meanQI);
            updates.set(0);
            this.updateS(s, distanceMoved, means, oldMeans, threadpool, meanQI);
            double[] m = new double[means.size()];
            Arrays.fill(m, 0.0);
            for (int i = 0; i < N; ++i) {
                m[assignment[i]] = Math.max(m[assignment[i]], u[i]);
            }
            double[] updateB = new double[m.length];
            this.EnhancedUpdateBounds(means, distanceMoved, m, s, oldMeans, tmpVecs, tmpVecs2, updateB, p, assignment, u, l);
            if (threadpool == null) {
                int localUpdates = 0;
                for (int i = 0; i < N; ++i) {
                    localUpdates += this.mainLoopWork(dataSet, i, s, assignment, u, l, q, cP, X, distAccel, means, meanQI, W);
                }
                updates.set(localUpdates);
            } else {
                final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
                int id = 0;
                while (id < SystemInfo.LogicalCores) {
                    final int ID = id++;
                    threadpool.submit(new Runnable(){

                        /*
                         * WARNING - Removed try catching itself - possible behaviour change.
                         */
                        @Override
                        public void run() {
                            int i;
                            Vec[] deltas = (Vec[])localDeltas.get();
                            int localUpdates = 0;
                            for (i = ID; i < N; i += SystemInfo.LogicalCores) {
                                localUpdates += HamerlyKMeans.this.mainLoopWork(dataSet, i, s, assignment, u, l, q, deltas, X, distAccel, means, meanQI, W);
                            }
                            if (localUpdates > 0) {
                                updates.getAndAdd(localUpdates);
                                for (i = 0; i < cP.length; ++i) {
                                    Vec vec = cP[i];
                                    synchronized (vec) {
                                        cP[i].mutableAdd(deltas[i]);
                                    }
                                    deltas[i].zeroOut();
                                }
                            }
                            latch.countDown();
                        }
                    });
                }
                try {
                    latch.await();
                }
                catch (InterruptedException ex) {
                    Logger.getLogger(HamerlyKMeans.class.getName()).log(Level.SEVERE, null, ex);
                }
            }
            ++iteration;
        }
        if (returnError) {
            int i;
            double totalDistance = 0.0;
            this.nearestCentroidDist = (double[])(this.saveCentroidDistance ? new double[N] : null);
            if (exactTotal) {
                for (i = 0; i < N; ++i) {
                    double dist = this.dm.dist(i, means.get(assignment[i]), (List)meanQI.get(assignment[i]), X, distAccel);
                    totalDistance += Math.pow(dist, 2.0);
                    if (!this.saveCentroidDistance) continue;
                    this.nearestCentroidDist[i] = dist;
                }
            } else {
                for (i = 0; i < N; ++i) {
                    totalDistance += Math.pow(u[i], 2.0);
                    if (!this.saveCentroidDistance) continue;
                    this.nearestCentroidDist[i] = u[i];
                }
            }
            return totalDistance;
        }
        return 0.0;
    }

    private void EnhancedUpdateBounds(List<Vec> means1, double[] distanceMoved, double[] m, double[] s, Vec[] oldMeans, Vec[] tmpVecs, Vec[] tmpVecs2, double[] updateB, double[] p, int[] assignment, double[] u, double[] l) {
        for (int i = 0; i < means1.size(); ++i) {
            double update = Double.NEGATIVE_INFINITY;
            IndexTable c_order = new IndexTable(distanceMoved);
            c_order.reverse();
            for (int order = 0; order < c_order.length(); ++order) {
                double algo2_1_out;
                int j = c_order.index(order);
                if (j == i || 2.0 * m[j] + s[j] < distanceMoved[j]) continue;
                if (distanceMoved[j] <= update) break;
                oldMeans[i].copyTo(tmpVecs[i]);
                means1.get(j).copyTo(tmpVecs2[i]);
                tmpVecs[i].mutableSubtract(oldMeans[j]);
                tmpVecs2[i].mutableSubtract(oldMeans[j]);
                double t = tmpVecs[i].dot(tmpVecs2[i]) / (distanceMoved[j] * distanceMoved[j]);
                tmpVecs[i].mutableMultiply(-1.0);
                tmpVecs[i].mutableAdd(t, tmpVecs2[i]);
                double dist = tmpVecs2[i].pNorm(2.0);
                double c_ix = dist * 2.0 / distanceMoved[j];
                double c_iy = 1.0 - 2.0 * t;
                double r = m[i] * 2.0 / distanceMoved[j];
                if (c_ix <= r) {
                    algo2_1_out = Math.max(0.0, Math.min(2.0, 2.0 * (r - c_iy)));
                } else {
                    if (c_iy > r) {
                        c_iy -= 1.0;
                    }
                    double proj_norm_sqrd = Math.sqrt(c_ix * c_ix + c_iy * c_iy);
                    proj_norm_sqrd *= proj_norm_sqrd;
                    algo2_1_out = 2.0 * (c_ix * r - c_iy * Math.sqrt(proj_norm_sqrd - r * r)) / proj_norm_sqrd;
                }
                update = Math.max(algo2_1_out *= distanceMoved[j] / 2.0, update);
            }
            updateB[i] = update;
        }
        this.UpdateBounds(p, assignment, u, l, updateB);
    }

    private int mainLoopWork(DataSet dataSet, int i, double[] s, int[] assignment, double[] u, double[] l, AtomicDoubleArray q, Vec[] deltas, List<Vec> X, List<Double> distAccel, List<Vec> means, List<List<Double>> meanQI, Vec W) {
        int a_i = assignment[i];
        double m = Math.max(s[a_i] / 2.0, l[i]);
        if (u[i] > m) {
            int new_a_i;
            Vec x = X.get(i);
            u[i] = this.dm.dist(i, means.get(a_i), meanQI.get(a_i), X, distAccel);
            if (u[i] > m && a_i != (new_a_i = this.PointAllCtrs(x, i, means, assignment, u, l, X, distAccel, meanQI))) {
                double w = W.get(i);
                q.addAndGet(a_i, -w);
                q.addAndGet(new_a_i, w);
                deltas[a_i].mutableSubtract(w, x);
                deltas[new_a_i].mutableAdd(w, x);
                return 1;
            }
        }
        return 0;
    }

    private void updateS(final double[] s, final double[] distanceMoved, final List<Vec> means, final Vec[] oldMeans, ExecutorService threadpool, List<List<Double>> meanQIs) {
        DoubleList meanCache;
        int tasks = means.size();
        final CountDownLatch latch = new CountDownLatch(tasks);
        Arrays.fill(s, Double.MAX_VALUE);
        DoubleList doubleList = meanCache = meanQIs.get(0).isEmpty() ? null : new DoubleList(meanQIs.size());
        if (meanCache != null) {
            for (List<Double> qi : meanQIs) {
                meanCache.addAll(qi);
            }
        }
        final ThreadLocal<double[]> localS = new ThreadLocal<double[]>(){

            @Override
            protected double[] initialValue() {
                return new double[s.length];
            }
        };
        for (int j = 0; j < means.size(); ++j) {
            if (threadpool == null) {
                distanceMoved[j] = this.dm.dist(oldMeans[j], means.get(j));
                for (int jp = j + 1; jp < means.size(); ++jp) {
                    double tmp = this.dm.dist(j, jp, means, (List<Double>)meanCache);
                    s[j] = Math.min(s[j], tmp);
                    s[jp] = Math.min(s[jp], tmp);
                }
                continue;
            }
            final int J = j;
            threadpool.submit(new Runnable(){

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                @Override
                public void run() {
                    double[] sTmp = (double[])localS.get();
                    Arrays.fill(sTmp, Double.POSITIVE_INFINITY);
                    distanceMoved[J] = HamerlyKMeans.this.dm.dist(oldMeans[J], (Vec)means.get(J));
                    for (int jp = J + 1; jp < means.size(); ++jp) {
                        double tmp = HamerlyKMeans.this.dm.dist(J, jp, (List<? extends Vec>)means, (List<Double>)meanCache);
                        sTmp[J] = Math.min(sTmp[J], tmp);
                        sTmp[jp] = Math.min(sTmp[jp], tmp);
                    }
                    double[] dArray = s;
                    synchronized (s) {
                        for (int i = 0; i < s.length; ++i) {
                            s[i] = Math.min(s[i], sTmp[i]);
                        }
                        // ** MonitorExit[var4_3] (shouldn't be in output)
                        latch.countDown();
                        return;
                    }
                }
            });
        }
        if (threadpool != null) {
            try {
                latch.await();
            }
            catch (InterruptedException ex) {
                Logger.getLogger(HamerlyKMeans.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    private void Initialize(DataSet d, final AtomicDoubleArray q, final List<Vec> means, Vec[] tmp, final Vec[] cP, final double[] u, final double[] l, final int[] a, ExecutorService threadpool, final ThreadLocal<Vec[]> localDeltas, final List<Vec> X, final List<Double> distAccel, final List<List<Double>> meanQI, final Vec W) {
        for (int j = 0; j < means.size(); ++j) {
            cP[j] = new DenseVector(means.get(0).length());
            tmp[j] = cP[j].clone();
            if (this.dm.supportsAcceleration()) {
                meanQI.add(this.dm.getQueryInfo(means.get(j)));
                continue;
            }
            meanQI.add(Collections.EMPTY_LIST);
        }
        if (threadpool == null) {
            for (int i = 0; i < u.length; ++i) {
                Vec x = X.get(i);
                int j = this.PointAllCtrs(x, i, means, a, u, l, X, distAccel, meanQI);
                double w = W.get(i);
                q.addAndGet(j, w);
                cP[j].mutableAdd(w, x);
            }
        } else {
            final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
            int id = 0;
            while (id < SystemInfo.LogicalCores) {
                final int ID = id++;
                threadpool.submit(new Runnable(){

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    @Override
                    public void run() {
                        int i;
                        Vec[] deltas = (Vec[])localDeltas.get();
                        for (i = ID; i < u.length; i += SystemInfo.LogicalCores) {
                            Vec x = (Vec)X.get(i);
                            int j = HamerlyKMeans.this.PointAllCtrs(x, i, means, a, u, l, X, distAccel, meanQI);
                            double w = W.get(i);
                            q.addAndGet(j, w);
                            deltas[j].mutableAdd(w, x);
                        }
                        for (i = 0; i < cP.length; ++i) {
                            Vec vec = cP[i];
                            synchronized (vec) {
                                cP[i].mutableAdd(deltas[i]);
                            }
                            deltas[i].zeroOut();
                        }
                        latch.countDown();
                    }
                });
            }
            try {
                latch.await();
            }
            catch (InterruptedException ex) {
                Logger.getLogger(HamerlyKMeans.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    private int PointAllCtrs(Vec x, int i, List<Vec> means, int[] a, double[] u, double[] l, List<Vec> X, List<Double> distAccel, List<List<Double>> meanQI) {
        double secondLowest = Double.POSITIVE_INFINITY;
        int slIndex = -1;
        double lowest = Double.MAX_VALUE;
        int lIndex = -1;
        for (int j = 0; j < means.size(); ++j) {
            double dist = this.dm.dist(i, means.get(j), meanQI.get(j), X, distAccel);
            if (!(dist < secondLowest)) continue;
            if (dist < lowest) {
                secondLowest = lowest;
                slIndex = lIndex;
                lowest = dist;
                lIndex = j;
                continue;
            }
            secondLowest = dist;
            slIndex = j;
        }
        a[i] = lIndex;
        u[i] = lowest;
        l[i] = secondLowest;
        return lIndex;
    }

    private void moveCenters(List<Vec> means, Vec[] oldMeans, Vec[] tmpSpace, Vec[] cP, AtomicDoubleArray q, double[] p, List<List<Double>> meanQI) {
        for (int j = 0; j < means.size(); ++j) {
            double count = q.get(j);
            means.get(j).copyTo(oldMeans[j]);
            if (count > 0.0) {
                cP[j].copyTo(tmpSpace[j]);
                tmpSpace[j].mutableDivide(count);
            } else {
                cP[j].zeroOut();
                tmpSpace[j].zeroOut();
            }
            p[j] = this.dm.dist(means.get(j), tmpSpace[j]);
            tmpSpace[j].copyTo(means.get(j));
            if (!this.dm.supportsAcceleration()) continue;
            meanQI.set(j, this.dm.getQueryInfo(means.get(j)));
        }
    }

    private void UpdateBounds(double[] p, int[] a, double[] u, double[] l, double[] updateB) {
        double secondHighest = Double.NEGATIVE_INFINITY;
        int shIndex = -1;
        double highest = -1.7976931348623157E308;
        int hIndex = -1;
        for (int j = 0; j < p.length; ++j) {
            double dist = p[j];
            if (!(dist > secondHighest)) continue;
            if (dist > highest) {
                secondHighest = highest;
                shIndex = hIndex;
                highest = dist;
                hIndex = j;
                continue;
            }
            secondHighest = dist;
            shIndex = j;
        }
        int r = hIndex;
        int rP = shIndex;
        for (int i = 0; i < u.length; ++i) {
            int j = a[i];
            int n = i;
            u[n] = u[n] + p[j];
            if (r == j) {
                int n2 = i;
                l[n2] = l[n2] - Math.min(p[rP], updateB[j]);
                continue;
            }
            int n3 = i;
            l[n3] = l[n3] - Math.min(p[r], updateB[j]);
        }
    }

    @Override
    public HamerlyKMeans clone() {
        return new HamerlyKMeans(this);
    }
}

