/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.clustering.ClusterFailureException;
import jsat.clustering.ClustererBase;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollectionUtils;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.FakeExecutor;
import jsat.utils.IntSet;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDoubleArray;

public class FLAME
extends ClustererBase
implements Parameterized {
    private static final long serialVersionUID = 2393091020100706517L;
    private DistanceMetric dm;
    private int k;
    private int maxIterations;
    private VectorCollectionFactory<VecPaired<Vec, Integer>> vectorCollectionFactory = new DefaultVectorCollectionFactory<VecPaired<Vec, Integer>>();
    private double stndDevs = 2.5;
    private double eps = 1.0E-6;

    public FLAME(DistanceMetric dm, int k, int maxIterations) {
        this.setDistanceMetric(dm);
        this.setK(k);
        this.setMaxIterations(maxIterations);
    }

    public FLAME(FLAME toCopy) {
        this.dm = toCopy.dm.clone();
        this.maxIterations = toCopy.maxIterations;
        this.vectorCollectionFactory = toCopy.vectorCollectionFactory;
        this.k = toCopy.k;
        this.stndDevs = toCopy.stndDevs;
        this.eps = toCopy.eps;
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations < 1) {
            throw new IllegalArgumentException("Must perform a positive number of iterations, not " + maxIterations);
        }
        this.maxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setK(int k) {
        this.k = k;
    }

    public int getK() {
        return this.k;
    }

    public void setEps(double eps) {
        if (Double.isNaN(eps)) {
            throw new IllegalArgumentException("Eps can not be NaN");
        }
        this.eps = eps;
    }

    public double getEps() {
        return this.eps;
    }

    public void setStndDevs(double stndDevs) {
        if (stndDevs < 0.0 || Double.isInfinite(stndDevs) || Double.isNaN(stndDevs)) {
            throw new IllegalArgumentException("Standard Deviations must be non negative");
        }
        this.stndDevs = stndDevs;
    }

    public double getStndDevs() {
        return this.stndDevs;
    }

    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setVectorCollectionFactory(VectorCollectionFactory<VecPaired<Vec, Integer>> vectorCollectionFactory) {
        this.vectorCollectionFactory = vectorCollectionFactory;
    }

    @Override
    public int[] cluster(DataSet dataSet, int[] designations) {
        return this.cluster(dataSet, new FakeExecutor(), designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, ExecutorService threadpool, int[] designations) {
        try {
            int j;
            List<List<VecPaired<VecPaired<Vec, Integer>, Double>>> allNNs;
            VectorCollection<VecPaired<Vec, Integer>> vc;
            int n = dataSet.getSampleSize();
            if (designations == null || designations.length < dataSet.getSampleSize()) {
                designations = new int[n];
            }
            ArrayList<VecPaired<Vec, Integer>> vecs = new ArrayList<VecPaired<Vec, Integer>>(n);
            for (int i = 0; i < dataSet.getSampleSize(); ++i) {
                vecs.add(new VecPaired<Vec, Integer>(dataSet.getDataPoint(i).getNumericalValues(), i));
            }
            TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, threadpool);
            if (threadpool instanceof FakeExecutor) {
                vc = this.vectorCollectionFactory.getVectorCollection(vecs, this.dm);
                allNNs = VectorCollectionUtils.allNearestNeighbors(vc, vecs, this.k + 1);
            } else {
                vc = this.vectorCollectionFactory.getVectorCollection(vecs, this.dm, threadpool);
                allNNs = VectorCollectionUtils.allNearestNeighbors(vc, vecs, this.k + 1, threadpool);
            }
            double[] density = new double[vecs.size()];
            final double[][] weights = new double[n][this.k];
            OnLineStatistics densityStats = new OnLineStatistics();
            for (int i = 0; i < density.length; ++i) {
                int j2;
                List<VecPaired<VecPaired<Vec, Integer>, Double>> knns = allNNs.get(i);
                for (int j3 = 1; j3 < knns.size(); ++j3) {
                    int n2 = i;
                    double d = knns.get(j3).getPair();
                    weights[i][j3 - 1] = d;
                    density[n2] = density[n2] + d;
                }
                densityStats.add(density[i]);
                double sum = 0.0;
                for (j2 = 0; j2 < this.k; ++j2) {
                    double d = Math.min(1.0 / Math.pow(weights[i][j2], 2.0), Double.MAX_VALUE / (double)(this.k + 1));
                    weights[i][j2] = d;
                    sum += d;
                }
                j2 = 0;
                while (j2 < this.k) {
                    double[] dArray = weights[i];
                    int n3 = j2++;
                    dArray[n3] = dArray[n3] / sum;
                }
            }
            final HashMap<Integer, Integer> CSOs = new HashMap<Integer, Integer>();
            final IntSet outliers = new IntSet();
            Arrays.fill(designations, -1);
            double threshold = densityStats.getMean() + densityStats.getStandardDeviation() * this.stndDevs;
            for (int i = 0; i < density.length; ++i) {
                List<VecPaired<VecPaired<Vec, Integer>, Double>> knns = allNNs.get(i);
                boolean lowest = true;
                boolean highest = true;
                for (j = 1; j < knns.size() && (highest || lowest); ++j) {
                    int jNN = knns.get(j).getVector().getPair();
                    if (density[i] > density[jNN]) {
                        lowest = false;
                        continue;
                    }
                    highest = false;
                }
                if (lowest) {
                    CSOs.put(i, CSOs.size());
                    continue;
                }
                if (!highest || !(density[i] > threshold)) continue;
                outliers.add(Integer.valueOf(i));
            }
            int origSize = CSOs.size();
            Iterator iter = CSOs.keySet().iterator();
            block9: while (iter.hasNext()) {
                int i = (Integer)iter.next();
                List<VecPaired<VecPaired<Vec, Integer>, Double>> knns = allNNs.get(i);
                for (j = 1; j < knns.size(); ++j) {
                    if (!outliers.contains(knns.get(j).getVector().getPair())) continue;
                    iter.remove();
                    continue block9;
                }
            }
            if (origSize != CSOs.size()) {
                IntSet keys = new IntSet(CSOs.keySet());
                CSOs.clear();
                Iterator knns = keys.iterator();
                while (knns.hasNext()) {
                    int i = (Integer)knns.next();
                    CSOs.put(i, CSOs.size());
                }
            }
            Iterator keys = CSOs.keySet().iterator();
            while (keys.hasNext()) {
                int i = (Integer)keys.next();
                designations[i] = (Integer)CSOs.get(i);
            }
            double[][] fuzzy = new double[n][CSOs.size() + 1];
            for (int i = 0; i < n; ++i) {
                if (CSOs.containsKey(i)) {
                    fuzzy[i][((Integer)CSOs.get((Object)Integer.valueOf((int)i))).intValue()] = 1.0;
                    continue;
                }
                if (outliers.contains((Object)i)) {
                    fuzzy[i][CSOs.size()] = 1.0;
                    continue;
                }
                Arrays.fill(fuzzy[i], 1.0 / (double)(CSOs.size() + 1));
            }
            double[][] fuzzy2 = new double[n][CSOs.size() + 1];
            double prevScore = Double.POSITIVE_INFINITY;
            for (int iter2 = 0; iter2 < this.maxIterations; ++iter2) {
                final double[][] FROM = fuzzy;
                final double[][] TO = fuzzy2;
                final AtomicDoubleArray score = new AtomicDoubleArray(1);
                final CountDownLatch cdl = new CountDownLatch(SystemInfo.LogicalCores);
                int id = 0;
                while (id < SystemInfo.LogicalCores) {
                    final int ID = id++;
                    threadpool.submit(new Runnable(){

                        @Override
                        public void run() {
                            double localScore = 0.0;
                            for (int i = ID; i < FROM.length; i += SystemInfo.LogicalCores) {
                                int z;
                                if (outliers.contains(i) || CSOs.containsKey(i)) continue;
                                double[] fuzzy2_i = TO[i];
                                Arrays.fill(fuzzy2_i, 0.0);
                                List knns = (List)allNNs.get(i);
                                double sum = 0.0;
                                for (int j = 1; j < weights[i].length; ++j) {
                                    int jNN = (Integer)((VecPaired)((VecPaired)knns.get(j)).getVector()).getPair();
                                    double[] fuzzy_jNN = FROM[jNN];
                                    double weight = weights[i][j - 1];
                                    for (int z2 = 0; z2 < FROM[jNN].length; ++z2) {
                                        int n = z2;
                                        fuzzy2_i[n] = fuzzy2_i[n] + weight * fuzzy_jNN[z2];
                                    }
                                }
                                for (z = 0; z < fuzzy2_i.length; ++z) {
                                    sum += fuzzy2_i[z];
                                }
                                for (z = 0; z < fuzzy2_i.length; ++z) {
                                    int n = z;
                                    fuzzy2_i[n] = fuzzy2_i[n] / (sum + 1.0E-6);
                                    localScore += Math.abs(FROM[i][z] - fuzzy2_i[z]);
                                }
                            }
                            score.addAndGet(0, localScore);
                            cdl.countDown();
                        }
                    });
                }
                cdl.await();
                if (Math.abs(prevScore - score.get(0)) < this.eps) break;
                prevScore = score.get(0);
                double[][] tmp = fuzzy;
                fuzzy = fuzzy2;
                fuzzy2 = tmp;
            }
            int[] clusterCounts = new int[n];
            for (int i = 0; i < fuzzy.length; ++i) {
                int pos = -1;
                double maxVal = 0.0;
                for (int j4 = 0; j4 < fuzzy[i].length; ++j4) {
                    if (!(fuzzy[i][j4] > maxVal)) continue;
                    maxVal = fuzzy[i][j4];
                    pos = j4;
                }
                if (pos == -1) {
                    pos = CSOs.size();
                }
                int n4 = pos;
                clusterCounts[n4] = clusterCounts[n4] + 1;
                if (pos == CSOs.size()) {
                    pos = -1;
                }
                designations[i] = pos;
            }
            int newCCount = 0;
            for (int i = 0; i < clusterCounts.length; ++i) {
                clusterCounts[i] = clusterCounts[i] > 1 ? newCCount++ : -1;
            }
            if (newCCount != clusterCounts.length) {
                double[] tmp = new double[CSOs.size() + 1];
                for (int i = 0; i < fuzzy.length; ++i) {
                    int d = designations[i];
                    if (d <= 0) continue;
                    if (clusterCounts[d] == -1) {
                        List<VecPaired<VecPaired<Vec, Integer>, Double>> knns = allNNs.get(i);
                        for (int j5 = 1; j5 < weights[i].length; ++j5) {
                            int jNN = knns.get(j5).getVector().getPair();
                            double[] fuzzy_jNN = fuzzy[jNN];
                            double weight = weights[i][j5 - 1];
                            for (int z = 0; z < fuzzy[jNN].length; ++z) {
                                int n5 = z;
                                tmp[n5] = tmp[n5] + weight * fuzzy_jNN[z];
                            }
                        }
                        double maxVal = -1.0;
                        int maxIndx = -1;
                        for (int z = 0; z < tmp.length; ++z) {
                            if (!(tmp[z] > maxVal)) continue;
                            maxVal = tmp[z];
                            maxIndx = z;
                        }
                        if (maxIndx == CSOs.size()) {
                            designations[i] = -1;
                            continue;
                        }
                        designations[i] = clusterCounts[maxIndx];
                        continue;
                    }
                    designations[i] = clusterCounts[d];
                }
            }
            return designations;
        }
        catch (InterruptedException interruptedException) {
            throw new ClusterFailureException();
        }
    }

    @Override
    public FLAME clone() {
        return new FLAME(this);
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

