/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.clustering.KClustererBase;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.random.RandomUtil;

public class MiniBatchKMeans
extends KClustererBase {
    private static final long serialVersionUID = 412553399508594014L;
    private int batchSize;
    private int iterations;
    private DistanceMetric dm;
    private SeedSelectionMethods.SeedSelection seedSelection;
    private boolean storeMeans = true;
    private List<Vec> means;

    public MiniBatchKMeans(int batchSize, int iterations) {
        this(new EuclideanDistance(), batchSize, iterations);
    }

    public MiniBatchKMeans(DistanceMetric dm, int batchSize, int iterations) {
        this(dm, batchSize, iterations, SeedSelectionMethods.SeedSelection.KPP);
    }

    public MiniBatchKMeans(DistanceMetric dm, int batchSize, int iterations, SeedSelectionMethods.SeedSelection seedSelection) {
        this.setBatchSize(batchSize);
        this.setIterations(iterations);
        this.setDistanceMetric(dm);
        this.setSeedSelection(seedSelection);
    }

    public MiniBatchKMeans(MiniBatchKMeans toCopy) {
        this.batchSize = toCopy.batchSize;
        this.iterations = toCopy.iterations;
        this.dm = toCopy.dm.clone();
        this.seedSelection = toCopy.seedSelection;
        this.storeMeans = toCopy.storeMeans;
        if (toCopy.means != null) {
            this.means = new ArrayList<Vec>();
            for (Vec v : toCopy.means) {
                this.means.add(v.clone());
            }
        }
    }

    public void setStoreMeans(boolean storeMeans) {
        this.storeMeans = storeMeans;
    }

    public List<Vec> getMeans() {
        return this.means;
    }

    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setBatchSize(int batchSize) {
        if (batchSize < 1) {
            throw new ArithmeticException("Batch size must be a positive value, not " + batchSize);
        }
        this.batchSize = batchSize;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setIterations(int iterations) {
        if (iterations < 1) {
            throw new ArithmeticException("Iterations must be a positive value, not " + iterations);
        }
        this.iterations = iterations;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setSeedSelection(SeedSelectionMethods.SeedSelection seedSelection) {
        this.seedSelection = seedSelection;
    }

    public SeedSelectionMethods.SeedSelection getSeedSelection() {
        return this.seedSelection;
    }

    @Override
    public int[] cluster(DataSet dataSet, int[] designations) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public int[] cluster(DataSet dataSet, ExecutorService threadpool, int[] designations) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public int[] cluster(DataSet dataSet, int clusters, ExecutorService threadpool, int[] designations) {
        int end;
        if (designations == null) {
            designations = new int[dataSet.getSampleSize()];
        }
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, threadpool);
        final List<Vec> source = dataSet.getDataVectors();
        final List<Double> distCache = threadpool == null || threadpool instanceof FakeExecutor ? this.dm.getAccelerationCache(source) : this.dm.getAccelerationCache(source, threadpool);
        this.means = SeedSelectionMethods.selectIntialPoints(dataSet, clusters, this.dm, distCache, RandomUtil.getRandom(), this.seedSelection, threadpool);
        final ArrayList<List<Double>> meanQIs = new ArrayList<List<Double>>(this.means.size());
        for (int i = 0; i < this.means.size(); ++i) {
            if (this.dm.supportsAcceleration()) {
                meanQIs.add(this.dm.getQueryInfo(this.means.get(i)));
                continue;
            }
            meanQIs.add(Collections.EMPTY_LIST);
        }
        int[] v = new int[this.means.size()];
        int usedBatchSize = Math.min(this.batchSize, dataSet.getSampleSize());
        final IntList M = new IntList(usedBatchSize);
        IntList allIndx = new IntList(source.size());
        ListUtils.addRange(allIndx, 0, source.size(), 1);
        final int[] nearestCenter = new int[usedBatchSize];
        for (int iter = 0; iter < this.iterations; ++iter) {
            M.clear();
            ListUtils.randomSample(allIndx, M, usedBatchSize);
            final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
            int blockSize = usedBatchSize / SystemInfo.LogicalCores;
            int extra = usedBatchSize % SystemInfo.LogicalCores;
            int start = 0;
            while (start < usedBatchSize) {
                final int s = start;
                start = end = start + blockSize + (extra-- > 0 ? 1 : 0);
                threadpool.submit(new Runnable(){

                    @Override
                    public void run() {
                        for (int i = s; i < end; ++i) {
                            double minDist = Double.POSITIVE_INFINITY;
                            int min = -1;
                            for (int j = 0; j < MiniBatchKMeans.this.means.size(); ++j) {
                                double tmp = MiniBatchKMeans.this.dm.dist((Integer)M.get(i), (Vec)MiniBatchKMeans.this.means.get(j), (List)meanQIs.get(j), source, distCache);
                                if (!(tmp < minDist)) continue;
                                minDist = tmp;
                                min = j;
                            }
                            nearestCenter[i] = min;
                        }
                        latch.countDown();
                    }
                });
            }
            try {
                latch.await();
            }
            catch (InterruptedException ex) {
                Logger.getLogger(MiniBatchKMeans.class.getName()).log(Level.SEVERE, null, ex);
            }
            for (int j = 0; j < M.size(); ++j) {
                int c_i;
                int n = c_i = nearestCenter[j];
                int n2 = v[n] + 1;
                v[n] = n2;
                double eta = 1.0 / (double)n2;
                Vec c = this.means.get(c_i);
                c.mutableMultiply(1.0 - eta);
                c.mutableAdd(eta, source.get((Integer)M.get(j)));
            }
            if (!this.dm.supportsAcceleration()) continue;
            for (int i = 0; i < this.means.size(); ++i) {
                meanQIs.set(i, this.dm.getQueryInfo(this.means.get(i)));
            }
        }
        ArrayList<Future<Double>> futures = new ArrayList<Future<Double>>(SystemInfo.LogicalCores);
        int blockSize = dataSet.getSampleSize() / SystemInfo.LogicalCores;
        int extra = dataSet.getSampleSize() % SystemInfo.LogicalCores;
        int start = 0;
        final int[] des = designations;
        while (start < dataSet.getSampleSize()) {
            final int s = start;
            start = end = start + blockSize + (extra-- > 0 ? 1 : 0);
            futures.add(threadpool.submit(new Callable<Double>(){

                @Override
                public Double call() throws Exception {
                    double dists = 0.0;
                    for (int i = s; i < end; ++i) {
                        double minDist = Double.POSITIVE_INFINITY;
                        int min = -1;
                        for (int j = 0; j < MiniBatchKMeans.this.means.size(); ++j) {
                            double tmp = MiniBatchKMeans.this.dm.dist(i, (Vec)MiniBatchKMeans.this.means.get(j), (List)meanQIs.get(j), source, distCache);
                            if (!(tmp < minDist)) continue;
                            minDist = tmp;
                            min = j;
                        }
                        des[i] = min;
                        dists += minDist * minDist;
                    }
                    return dists;
                }
            }));
        }
        double sumErr = 0.0;
        try {
            for (Future future : futures) {
                sumErr += ((Double)future.get()).doubleValue();
            }
        }
        catch (InterruptedException ex) {
            Logger.getLogger(MiniBatchKMeans.class.getName()).log(Level.SEVERE, null, ex);
        }
        catch (ExecutionException ex) {
            Logger.getLogger(MiniBatchKMeans.class.getName()).log(Level.SEVERE, null, ex);
        }
        if (!this.storeMeans) {
            this.means = null;
        }
        return des;
    }

    @Override
    public int[] cluster(DataSet dataSet, int clusters, int[] designations) {
        return this.cluster(dataSet, clusters, new FakeExecutor(), designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, ExecutorService threadpool, int[] designations) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, int[] designations) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public MiniBatchKMeans clone() {
        return new MiniBatchKMeans(this);
    }
}

