/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.ClusterFailureException;
import jsat.clustering.KClustererBase;
import jsat.clustering.PAM;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.SystemInfo;
import jsat.utils.random.XORWOW;

public abstract class KMeans
extends KClustererBase
implements Parameterized {
    private static final long serialVersionUID = 8730927112084289722L;
    public static final SeedSelectionMethods.SeedSelection DEFAULT_SEED_SELECTION = SeedSelectionMethods.SeedSelection.KPP;
    @Parameter.ParameterHolder
    protected DistanceMetric dm;
    protected SeedSelectionMethods.SeedSelection seedSelection;
    protected Random rand;
    protected boolean storeMeans = true;
    protected boolean saveCentroidDistance = true;
    protected double[] nearestCentroidDist;
    protected List<Vec> means;
    protected int MaxIterLimit = Integer.MAX_VALUE;

    public KMeans(DistanceMetric dm, SeedSelectionMethods.SeedSelection seedSelection, Random rand) {
        this.dm = dm;
        this.setSeedSelection(seedSelection);
        this.rand = rand;
    }

    public KMeans(KMeans toCopy) {
        this.dm = toCopy.dm.clone();
        this.seedSelection = toCopy.seedSelection;
        this.rand = new XORWOW();
        if (toCopy.nearestCentroidDist != null) {
            this.nearestCentroidDist = Arrays.copyOf(toCopy.nearestCentroidDist, toCopy.nearestCentroidDist.length);
        }
        if (toCopy.means != null) {
            this.means = new ArrayList<Vec>(toCopy.means.size());
            for (Vec v : toCopy.means) {
                this.means.add(v.clone());
            }
        }
    }

    public void setIterationLimit(int iterLimit) {
        if (iterLimit < 1) {
            throw new IllegalArgumentException("Iterations must be a positive value, not " + iterLimit);
        }
        this.MaxIterLimit = iterLimit;
    }

    public int getIterationLimit() {
        return this.MaxIterLimit;
    }

    public void setStoreMeans(boolean storeMeans) {
        this.storeMeans = storeMeans;
    }

    public List<Vec> getMeans() {
        return this.means;
    }

    public void setSeedSelection(SeedSelectionMethods.SeedSelection seedSelection) {
        this.seedSelection = seedSelection;
    }

    public SeedSelectionMethods.SeedSelection getSeedSelection() {
        return this.seedSelection;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    protected abstract double cluster(DataSet var1, List<Double> var2, int var3, List<Vec> var4, int[] var5, boolean var6, ExecutorService var7, boolean var8, Vec var9);

    protected static List<List<DataPoint>> getListOfLists(int k) {
        ArrayList<List<DataPoint>> ks = new ArrayList<List<DataPoint>>(k);
        for (int i = 0; i < k; ++i) {
            ks.add(new ArrayList());
        }
        return ks;
    }

    @Override
    public int[] cluster(DataSet dataSet, int[] designations) {
        return this.cluster(dataSet, 2, (int)Math.sqrt(dataSet.getSampleSize() / 2), designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, ExecutorService threadpool, int[] designations) {
        return this.cluster(dataSet, 2, (int)Math.sqrt(dataSet.getSampleSize() / 2), threadpool, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int clusters, ExecutorService threadpool, int[] designations) {
        if (designations == null) {
            designations = new int[dataSet.getSampleSize()];
        }
        if (dataSet.getSampleSize() < clusters) {
            throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
        }
        this.means = new ArrayList<Vec>(clusters);
        this.cluster(dataSet, null, clusters, this.means, designations, false, threadpool, false, null);
        if (!this.storeMeans) {
            this.means = null;
        }
        return designations;
    }

    @Override
    public int[] cluster(DataSet dataSet, int clusters, int[] designations) {
        if (designations == null) {
            designations = new int[dataSet.getSampleSize()];
        }
        if (dataSet.getSampleSize() < clusters) {
            throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
        }
        this.means = new ArrayList<Vec>(clusters);
        this.cluster(dataSet, null, clusters, this.means, designations, false, null, false, null);
        if (!this.storeMeans) {
            this.means = null;
        }
        return designations;
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, ExecutorService threadpool, int[] designations) {
        if (dataSet.getSampleSize() < highK) {
            throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
        }
        double[] totDistances = new double[highK - lowK + 1];
        ArrayBlockingQueue<ClusterWorker> workerQue = new ArrayBlockingQueue<ClusterWorker>(SystemInfo.LogicalCores);
        for (int i = 0; i < SystemInfo.LogicalCores; ++i) {
            workerQue.add(new ClusterWorker(dataSet, workerQue));
        }
        int k = lowK;
        int received = 0;
        while (received < totDistances.length) {
            try {
                ClusterWorker worker = (ClusterWorker)workerQue.take();
                if (worker.getResult() != -1.0) {
                    totDistances[worker.getK() - lowK] = worker.getResult();
                    ++received;
                }
                if (k > highK) continue;
                worker.setK(k++);
                threadpool.submit(worker);
            }
            catch (InterruptedException ex) {
                Logger.getLogger(PAM.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        return this.findK(lowK, highK, totDistances, dataSet, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, int[] designations) {
        if (designations == null) {
            designations = new int[dataSet.getSampleSize()];
        }
        if (dataSet.getSampleSize() < highK) {
            throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
        }
        double[] totDistances = new double[highK - lowK + 1];
        for (int i = lowK; i <= highK; ++i) {
            totDistances[i - lowK] = this.cluster(dataSet, null, i, new ArrayList<Vec>(), designations, true, null, true, null);
        }
        return this.findK(lowK, highK, totDistances, dataSet, designations);
    }

    private int[] findK(int lowK, int highK, double[] totDistances, DataSet dataSet, int[] designations) {
        OnLineStatistics stats = new OnLineStatistics();
        double maxChange = Double.MIN_VALUE;
        int maxChangeK = lowK;
        for (int i = lowK; i <= highK; ++i) {
            double totDist = totDistances[i - lowK];
            if (i <= lowK) continue;
            double change = Math.abs(totDist - totDistances[i - lowK - 1]);
            stats.add(change);
            if (!(change > maxChange)) continue;
            maxChange = change;
            maxChangeK = i;
        }
        double changeMean = stats.getMean();
        double changeDev = stats.getStandardDeviation();
        if (maxChange < changeDev * 2.0 + changeMean) {
            maxChangeK = lowK;
        } else {
            for (int i = 1; i < totDistances.length; ++i) {
                double d;
                double tmp = Math.abs(totDistances[i] - totDistances[i - 1]);
                if (!(d < maxChange)) continue;
                maxChange = tmp;
                maxChangeK = i + lowK;
                break;
            }
        }
        return this.cluster(dataSet, maxChangeK, designations);
    }

    @Override
    public abstract KMeans clone();

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    private class ClusterWorker
    implements Runnable {
        private DataSet dataSet;
        private int k;
        int[] clusterIDs;
        private Random rand;
        private volatile double result = -1.0;
        private volatile BlockingQueue<ClusterWorker> putSelf;

        public ClusterWorker(DataSet dataSet, int k, BlockingQueue<ClusterWorker> que) {
            this.dataSet = dataSet;
            this.k = k;
            this.putSelf = que;
            this.clusterIDs = new int[dataSet.getSampleSize()];
            this.rand = new Random();
        }

        public ClusterWorker(DataSet dataSet, BlockingQueue<ClusterWorker> que) {
            this(dataSet, 2, que);
        }

        public void setK(int k) {
            this.k = k;
        }

        public int getK() {
            return this.k;
        }

        public double getResult() {
            return this.result;
        }

        @Override
        public void run() {
            this.result = KMeans.this.cluster(this.dataSet, null, this.k, new ArrayList<Vec>(), this.clusterIDs, true, null, true, null);
            this.putSelf.add(this);
        }
    }
}

