/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.hierarchical;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.KClustererBase;
import jsat.clustering.dissimilarity.LanceWilliamsDissimilarity;
import jsat.clustering.dissimilarity.WardsDissimilarity;
import jsat.clustering.hierarchical.PriorityHAC;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.math.OnLineStatistics;
import jsat.utils.FakeExecutor;
import jsat.utils.IndexTable;
import jsat.utils.IntDoubleMap;
import jsat.utils.IntDoubleMapArray;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDouble;

public class NNChainHAC
extends KClustererBase {
    private LanceWilliamsDissimilarity distMeasure;
    private DistanceMetric dm;
    private int[] merges;

    public NNChainHAC() {
        this(new WardsDissimilarity());
    }

    public NNChainHAC(LanceWilliamsDissimilarity distMeasure) {
        this(distMeasure, new EuclideanDistance());
    }

    public NNChainHAC(LanceWilliamsDissimilarity distMeasure, DistanceMetric distance) {
        this.distMeasure = distMeasure;
        this.dm = distance;
    }

    protected NNChainHAC(NNChainHAC toCopy) {
        this.distMeasure = toCopy.distMeasure.clone();
        this.dm = toCopy.dm.clone();
        if (toCopy.merges != null) {
            this.merges = Arrays.copyOf(toCopy.merges, toCopy.merges.length);
        }
    }

    @Override
    public NNChainHAC clone() {
        return new NNChainHAC(this);
    }

    @Override
    public int[] cluster(DataSet dataSet, int[] designations) {
        return this.cluster(dataSet, new FakeExecutor(), designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, ExecutorService threadpool, int[] designations) {
        return this.cluster(dataSet, 2, (int)Math.sqrt(dataSet.getSampleSize()), threadpool, designations);
    }

    private double getDist(int a, int j, int[] size, List<Vec> vecs, List<Double> cache, List<Map<Integer, Double>> d_xk) {
        if (size[j] == 1 && size[a] == 1) {
            return this.dm.dist(a, j, vecs, cache);
        }
        if (d_xk.get(a) != null) {
            Double tmp = d_xk.get(a).get(j);
            if (tmp != null) {
                return tmp;
            }
            return d_xk.get(j).get(a);
        }
        return d_xk.get(j).get(a);
    }

    public int[] getClusterDesignations(int[] designations, int clusters) {
        if (this.merges == null) {
            return null;
        }
        return PriorityHAC.assignClusterDesignations(designations, clusters, this.merges);
    }

    public List<List<DataPoint>> getClusterDesignations(int clusters, DataSet data) {
        if (this.merges == null || (this.merges.length + 2) / 2 != data.getSampleSize()) {
            return null;
        }
        int[] assignments = new int[data.getSampleSize()];
        assignments = this.getClusterDesignations(assignments, clusters);
        return NNChainHAC.createClusterListFromAssignmentArray(assignments, data);
    }

    @Override
    public int[] cluster(DataSet dataSet, int clusters, ExecutorService threadpool, int[] designations) {
        return this.cluster(dataSet, clusters, clusters, threadpool, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int clusters, int[] designations) {
        return this.cluster(dataSet, clusters, new FakeExecutor(), designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, ExecutorService threadpool, int[] designations) {
        if (designations == null) {
            designations = new int[dataSet.getSampleSize()];
        }
        final int N = dataSet.getSampleSize();
        this.merges = new int[N * 2 - 2];
        IntList merge_removed = new IntList(N);
        IntList merge_kept = new IntList(N);
        final int[] size = new int[N];
        Arrays.fill(size, 1);
        double[] mergedDistance = new double[N - 1];
        int L_pos = 0;
        final IntList S = new IntList(N);
        ListUtils.addRange(S, 0, N, 1);
        final ArrayList<Map<Integer, Double>> dist_map = new ArrayList<Map<Integer, Double>>(N);
        for (int i = 0; i < N; ++i) {
            dist_map.add(null);
        }
        final List<Vec> vecs = dataSet.getDataVectors();
        final List<Double> cache = this.dm.getAccelerationCache(vecs, threadpool);
        int[] chain = new int[N];
        int chainPos = 0;
        while (S.size() > 1) {
            boolean singleThread;
            double dist_ab;
            int b;
            int a;
            if (chainPos <= 3) {
                a = S.getI(0);
                chainPos = 0;
                chain[chainPos++] = a;
                b = S.getI(1);
            } else {
                a = chain[chainPos - 4];
                b = chain[chainPos - 3];
                chainPos -= 3;
            }
            do {
                int c = b;
                double minDist = this.getDist(a, c, size, vecs, cache, dist_map);
                if (threadpool == null || threadpool instanceof FakeExecutor || S.size() < SystemInfo.LogicalCores * 2) {
                    Iterator iterator = S.iterator();
                    while (iterator.hasNext()) {
                        double dist;
                        int j = (Integer)iterator.next();
                        if (j == a || j == c || !((dist = this.getDist(a, j, size, vecs, cache, dist_map)) < minDist)) continue;
                        minDist = dist;
                        c = j;
                    }
                } else {
                    final AtomicInteger c_ = new AtomicInteger(b);
                    final AtomicDouble minDist_ = new AtomicDouble(minDist);
                    final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
                    int id = 0;
                    while (id < SystemInfo.LogicalCores) {
                        final int ID = id++;
                        final int A = a;
                        final int B = b;
                        threadpool.submit(new Runnable(){

                            /*
                             * WARNING - Removed try catching itself - possible behaviour change.
                             */
                            @Override
                            public void run() {
                                double localMinDist = Double.POSITIVE_INFINITY;
                                int localC = -1;
                                for (int indx = ID; indx < S.size(); indx += SystemInfo.LogicalCores) {
                                    double dist;
                                    int j = S.getI(indx);
                                    if (j == A || j == B || !((dist = NNChainHAC.this.getDist(A, j, size, vecs, cache, dist_map)) < localMinDist)) continue;
                                    localMinDist = dist;
                                    localC = j;
                                }
                                if (localMinDist < minDist_.get()) {
                                    IntList intList = S;
                                    synchronized (intList) {
                                        if (localMinDist < minDist_.get()) {
                                            minDist_.set(localMinDist);
                                            c_.set(localC);
                                        }
                                    }
                                }
                                latch.countDown();
                            }
                        });
                    }
                    try {
                        latch.await();
                        c = c_.get();
                        minDist = minDist_.get();
                    }
                    catch (InterruptedException ex) {
                        Logger.getLogger(NNChainHAC.class.getName()).log(Level.SEVERE, null, ex);
                    }
                }
                dist_ab = minDist;
                b = a;
                a = c;
                chain[chainPos++] = a;
            } while (chainPos < 3 || a != chain[chainPos - 3]);
            final int n = Math.min(a, b);
            int removed = Math.max(a, b);
            merge_removed.add(removed);
            merge_kept.add(n);
            mergedDistance[L_pos] = dist_ab;
            ++L_pos;
            S.removeAll(Arrays.asList(a, b));
            for (int i = Math.max(0, chainPos - 5); i < chainPos; ++i) {
                if (chain[i] != removed) continue;
                chain[i] = n;
            }
            final int size_a = size[a];
            final int size_b = size[b];
            boolean bl = singleThread = threadpool == null || threadpool instanceof FakeExecutor || S.size() <= SystemInfo.LogicalCores * 10;
            final AbstractMap map_n = S.isEmpty() ? null : (S.size() * 100 >= N || !singleThread ? new IntDoubleMapArray(N) : new IntDoubleMap(S.size()));
            if (singleThread) {
                Iterator ex = S.iterator();
                while (ex.hasNext()) {
                    int x = (Integer)ex.next();
                    double d_ax = this.getDist(a, x, size, vecs, cache, dist_map);
                    double d_bx = this.getDist(b, x, size, vecs, cache, dist_map);
                    double d_xn = this.distMeasure.dissimilarity(size_a, size_b, size[x], dist_ab, d_ax, d_bx);
                    Map dist_map_x = (Map)dist_map.get(x);
                    if (dist_map_x != null) {
                        dist_map_x.remove(b);
                        dist_map_x.put(n, d_xn);
                        if (dist_map_x.size() * 50 < N && !(dist_map_x instanceof IntDoubleMap)) {
                            dist_map.set(x, new IntDoubleMap(dist_map_x));
                        }
                    }
                    map_n.put(x, d_xn);
                }
            } else {
                final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
                final int A = a;
                final int B = b;
                final double dist_AB = dist_ab;
                int id = 0;
                while (id < SystemInfo.LogicalCores) {
                    final int ID = id++;
                    threadpool.submit(new Runnable(){

                        @Override
                        public void run() {
                            for (int indx = ID; indx < S.size(); indx += SystemInfo.LogicalCores) {
                                int x = S.getI(indx);
                                double d_ax = NNChainHAC.this.getDist(A, x, size, vecs, cache, dist_map);
                                double d_bx = NNChainHAC.this.getDist(B, x, size, vecs, cache, dist_map);
                                double d_xn = NNChainHAC.this.distMeasure.dissimilarity(size_a, size_b, size[x], dist_AB, d_ax, d_bx);
                                Map dist_map_x = (Map)dist_map.get(x);
                                if (dist_map_x != null) {
                                    dist_map_x.remove(B);
                                    dist_map_x.put(n, d_xn);
                                    if (dist_map_x.size() * 50 < N && !(dist_map_x instanceof IntDoubleMap)) {
                                        dist_map.set(x, new IntDoubleMap(dist_map_x));
                                    }
                                }
                                map_n.put(x, d_xn);
                            }
                            latch.countDown();
                        }
                    });
                }
                try {
                    latch.await();
                }
                catch (InterruptedException ex) {
                    Logger.getLogger(NNChainHAC.class.getName()).log(Level.SEVERE, null, ex);
                }
            }
            dist_map.set(removed, null);
            dist_map.set(n, map_n);
            size[n] = size_a + size_b;
            S.add(n);
        }
        this.fixMergeOrderAndAssign(mergedDistance, merge_kept, merge_removed, lowK, N, highK, designations);
        return designations;
    }

    private void fixMergeOrderAndAssign(double[] mergedDistance, IntList merge_kept, IntList merge_removed, int lowK, int N, int highK, int[] designations) {
        IndexTable it = new IndexTable(mergedDistance);
        it.apply(merge_kept);
        it.apply(merge_removed);
        it.apply(mergedDistance);
        for (int i = 0; i < it.length(); ++i) {
            this.merges[this.merges.length - i * 2 - 1] = merge_removed.get(i);
            this.merges[this.merges.length - i * 2 - 2] = merge_kept.get(i);
        }
        OnLineStatistics distChange = new OnLineStatistics();
        double maxStndDevs = Double.MIN_VALUE;
        int clusterSize = lowK;
        for (int i = 0; i < mergedDistance.length; ++i) {
            double stndDevs;
            distChange.add(mergedDistance[i]);
            int curK = N - i;
            if (curK < lowK || curK > highK || !((stndDevs = (mergedDistance[i] - distChange.getMean()) / distChange.getStandardDeviation()) > maxStndDevs)) continue;
            maxStndDevs = stndDevs;
            clusterSize = curK;
        }
        PriorityHAC.assignClusterDesignations(designations, clusterSize, this.merges);
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, int[] designations) {
        return this.cluster(dataSet, lowK, highK, new FakeExecutor(), designations);
    }
}

