/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.clustering.ClustererBase;
import jsat.distributions.empirical.kernelfunc.GaussKF;
import jsat.distributions.empirical.kernelfunc.KernelFunction;
import jsat.distributions.multivariate.MetricKDE;
import jsat.distributions.multivariate.MultivariateKDE;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.utils.PoisonRunnable;
import jsat.utils.RunnableConsumer;
import jsat.utils.SystemInfo;

public class MeanShift
extends ClustererBase {
    private static final long serialVersionUID = 4061491342362690455L;
    public static final int DefaultMaxIterations = 1000;
    public static final double DefaultScaleBandwidthFactor = 1.0;
    private MultivariateKDE mkde;
    private int maxIterations = 1000;
    private double scaleBandwidthFactor = 1.0;

    public MeanShift() {
        this(new EuclideanDistance());
    }

    public MeanShift(DistanceMetric dm) {
        this(new MetricKDE(GaussKF.getInstance(), dm));
    }

    public MeanShift(MultivariateKDE mkde) {
        this.mkde = mkde;
    }

    public MeanShift(MeanShift toCopy) {
        this.mkde = toCopy.mkde.clone();
        this.maxIterations = toCopy.maxIterations;
        this.scaleBandwidthFactor = toCopy.scaleBandwidthFactor;
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations <= 0) {
            throw new ArithmeticException("Invalid iteration count, " + maxIterations);
        }
        this.maxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setScaleBandwidthFactor(double scaleBandwidthFactor) {
        if (Double.isNaN(scaleBandwidthFactor) || Double.isInfinite(scaleBandwidthFactor)) {
            throw new ArithmeticException("Invalid scale factor, " + scaleBandwidthFactor);
        }
        this.scaleBandwidthFactor = scaleBandwidthFactor;
    }

    public double getScaleBandwidthFactor() {
        return this.scaleBandwidthFactor;
    }

    @Override
    public int[] cluster(DataSet dataSet, int[] designations) {
        return this.cluster(dataSet, null, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, ExecutorService threadpool, int[] designations) {
        try {
            if (designations == null || designations.length < dataSet.getSampleSize()) {
                designations = new int[dataSet.getSampleSize()];
            }
            boolean[] converged = new boolean[dataSet.getSampleSize()];
            Arrays.fill(converged, false);
            KernelFunction k = this.mkde.getKernelFunction();
            if (threadpool == null) {
                this.mkde.setUsingData(dataSet);
            } else {
                this.mkde.setUsingData(dataSet, threadpool);
            }
            this.mkde.scaleBandwidth(this.scaleBandwidthFactor);
            DenseVector scratch = new DenseVector(dataSet.getNumNumericalVars());
            Vec[] xit = new Vec[converged.length];
            for (int i = 0; i < xit.length; ++i) {
                xit[i] = dataSet.getDataPoint(i).getNumericalValues().clone();
            }
            if (threadpool == null) {
                this.mainLoop(converged, xit, designations, scratch, k);
            } else {
                this.mainLoop(converged, xit, designations, k, threadpool);
            }
            this.assignmentStep(converged, xit, designations);
            return designations;
        }
        catch (InterruptedException ex) {
            Logger.getLogger(MeanShift.class.getName()).log(Level.SEVERE, null, ex);
            throw new FailedToFitException(ex);
        }
        catch (BrokenBarrierException ex) {
            Logger.getLogger(MeanShift.class.getName()).log(Level.SEVERE, null, ex);
            throw new FailedToFitException(ex);
        }
    }

    private void assignmentStep(boolean[] converged, Vec[] xit, int[] designations) {
        int curClusterID = 0;
        boolean progress = true;
        while (progress) {
            int basePos;
            progress = false;
            for (basePos = 0; basePos < converged.length && !converged[basePos]; ++basePos) {
            }
            for (int i = basePos; i < converged.length; ++i) {
                if (!converged[i] || designations[i] == -1) continue;
                progress = true;
                if (!(Math.abs(xit[basePos].pNormDist(2.0, xit[i])) < 0.001)) continue;
                converged[i] = false;
                designations[i] = curClusterID;
            }
            ++curClusterID;
        }
    }

    private void mainLoop(boolean[] converged, Vec[] xit, int[] designations, Vec scratch, KernelFunction k) {
        boolean progress = true;
        int count = 0;
        while (progress && count++ < this.maxIterations) {
            progress = false;
            for (int i = 0; i < converged.length; ++i) {
                if (converged[i]) continue;
                progress = true;
                this.convergenceStep(xit, i, converged, designations, scratch, k);
            }
        }
        Arrays.fill(converged, true);
    }

    private void mainLoop(final boolean[] converged, final Vec[] xit, final int[] designations, final KernelFunction k, ExecutorService ex) throws InterruptedException, BrokenBarrierException {
        boolean progress = true;
        int count = 0;
        CyclicBarrier barrier = new CyclicBarrier(SystemInfo.LogicalCores + 1);
        ArrayBlockingQueue<Runnable> jobs = new ArrayBlockingQueue<Runnable>(SystemInfo.LogicalCores * 2);
        final ThreadLocal<Vec> localScratch = new ThreadLocal<Vec>(){

            @Override
            protected Vec initialValue() {
                return new DenseVector(xit[0].length());
            }
        };
        while (progress && count++ < this.maxIterations) {
            int i;
            progress = false;
            for (i = 0; i < SystemInfo.LogicalCores; ++i) {
                ex.submit(new RunnableConsumer(jobs));
            }
            for (i = 0; i < converged.length; ++i) {
                if (converged[i]) continue;
                progress = true;
                final int ii = i;
                jobs.put(new Runnable(){

                    @Override
                    public void run() {
                        MeanShift.this.convergenceStep(xit, ii, converged, designations, (Vec)localScratch.get(), k);
                    }
                });
            }
            for (i = 0; i < SystemInfo.LogicalCores; ++i) {
                jobs.put(new PoisonRunnable(barrier));
            }
            barrier.await();
            barrier.reset();
        }
        Arrays.fill(converged, true);
    }

    private void convergenceStep(Vec[] xit, int i, boolean[] converged, int[] designations, Vec scratch, KernelFunction k) {
        double denom = 0.0;
        Vec xCur = xit[i];
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> contrib = this.mkde.getNearbyRaw(xCur);
        if (contrib.size() == 1) {
            converged[i] = true;
            designations[i] = -1;
        } else {
            scratch.zeroOut();
            for (VecPaired<VecPaired<Vec, Integer>, Double> vecPaired : contrib) {
                double g = -k.kPrime(vecPaired.getPair());
                denom += g;
                scratch.mutableAdd(g, vecPaired);
            }
            scratch.mutableDivide(denom);
            if (Math.abs(scratch.pNormDist(2.0, xCur)) < 1.0E-5) {
                converged[i] = true;
            }
            scratch.copyTo(xCur);
        }
    }

    @Override
    public MeanShift clone() {
        return new MeanShift(this);
    }
}

