/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.svm;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.svm.SVMnoBias;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.clustering.kmeans.ElkanKernelKMeans;
import jsat.clustering.kmeans.KernelKMeans;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.RBFKernel;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;

public class DCSVM
extends SupportVectorLearner
implements Classifier,
Parameterized {
    private double C = 1.0;
    private double tolerance = 0.001;
    private KernelKMeans clusters;
    private int m = 2000;
    private int l_max = 4;
    private int l_early = 3;
    private int k = 4;
    private Map<Integer, SVMnoBias> early_models;
    private long cache_size = 0L;

    public DCSVM(KernelTrick k) {
        super(k, SupportVectorLearner.CacheMode.ROWS);
        this.cache_size = Runtime.getRuntime().freeMemory() / 2L;
    }

    public DCSVM() {
        this(new RBFKernel());
    }

    public DCSVM(DCSVM toCopy) {
        super(toCopy);
        this.C = toCopy.C;
        this.tolerance = toCopy.tolerance;
        if (toCopy.clusters != null) {
            this.clusters = toCopy.clusters.clone();
        }
        this.cache_size = toCopy.cache_size;
        this.m = toCopy.m;
        this.l_early = toCopy.l_early;
        this.l_max = toCopy.l_max;
        this.k = toCopy.k;
        if (toCopy.early_models != null) {
            this.early_models = new ConcurrentHashMap<Integer, SVMnoBias>();
            for (Map.Entry<Integer, SVMnoBias> x : toCopy.early_models.entrySet()) {
                this.early_models.put(x.getKey(), x.getValue().clone());
            }
        }
    }

    public void setStartLevel(int l_max) {
        if (l_max < 0) {
            throw new IllegalArgumentException("l_max must be a non-negative integer, not " + l_max);
        }
        this.l_max = l_max;
    }

    public int getStartLevel() {
        return this.l_max;
    }

    public void setEndLevel(int l_early) {
        if (l_early < 0) {
            throw new IllegalArgumentException("l_early must be a non-negative integer, not " + l_early);
        }
        this.l_early = l_early;
    }

    public int getEndLevel() {
        return this.l_early;
    }

    public void setClusterSampleSize(int m) {
        if (m <= 0) {
            throw new IllegalArgumentException("Cluster Sample Size must be a positive integer, not " + m);
        }
        this.m = m;
    }

    public int getClusterSampleSize() {
        return this.m;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        double sum = this.getScore(data);
        if (sum > 0.0) {
            cr.setProb(1, 1.0);
        } else {
            cr.setProb(0, 1.0);
        }
        return cr;
    }

    public double getScore(DataPoint dp) {
        if (this.vecs == null) {
            throw new UntrainedModelException("Classifier has yet to be trained");
        }
        Vec x = dp.getNumericalValues();
        int c = this.early_models.size() > 1 ? this.clusters.findClosestCluster(x, this.getKernel().getQueryInfo(x)) : 0;
        return this.early_models.get(c).getScore(dp);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        final int threads_to_use = threadPool instanceof FakeExecutor ? 1 : SystemInfo.LogicalCores;
        final int N = dataSet.getSampleSize();
        this.vecs = dataSet.getDataVectors();
        this.early_models = new ConcurrentHashMap<Integer, SVMnoBias>();
        this.setCacheMode(SupportVectorLearner.CacheMode.NONE);
        this.alphas = new double[N];
        final int[] group = new int[N];
        IntList indicies = new IntList();
        for (int l = this.l_max; l >= this.l_early; --l) {
            int[] sub_results;
            int i;
            this.early_models.clear();
            ClassificationDataSet toCluster = new ClassificationDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories(), dataSet.getPredicting());
            int k_l = (int)Math.pow(this.k, l);
            int M = N / k_l < 7 ? k_l * 7 : this.m;
            if (l == this.l_max) {
                ListUtils.addRange(indicies, 0, N, 1);
                Collections.shuffle(indicies);
                for (i = 0; i < Math.min(M, N); ++i) {
                    toCluster.addDataPoint(dataSet.getDataPoint(i), dataSet.getDataPointCategory(i));
                }
            } else {
                indicies.clear();
                for (i = 0; i < N; ++i) {
                    if (this.alphas[i] == 0.0) continue;
                    indicies.add(i);
                }
                Collections.shuffle(indicies);
                for (i = 0; i < Math.min(M, indicies.size()); ++i) {
                    toCluster.addDataPoint(dataSet.getDataPoint(i), dataSet.getDataPointCategory(i));
                }
            }
            this.clusters = new ElkanKernelKMeans(this.getKernel());
            this.clusters.setMaximumIterations(100);
            k_l = Math.min(k_l, toCluster.getSampleSize() / 2);
            if (k_l <= 1) {
                sub_results = new int[N];
                indicies.clear();
                ListUtils.addRange(indicies, 0, N, 1);
            } else {
                sub_results = this.clusters.cluster((DataSet)toCluster, k_l, threadPool, (int[])null);
            }
            Arrays.fill(group, -1);
            HashSet<Integer> found_clusters = new HashSet<Integer>(k_l);
            for (int i2 = 0; i2 < sub_results.length; ++i2) {
                group[indicies.get((int)i2).intValue()] = sub_results[i2];
                found_clusters.add(sub_results[i2]);
            }
            final CountDownLatch latch = new CountDownLatch(threads_to_use);
            int id = 0;
            while (id < threads_to_use) {
                final int ID = id++;
                threadPool.submit(new Runnable(){

                    @Override
                    public void run() {
                        for (int i = ID; i < N; i += threads_to_use) {
                            if (group[i] >= 0) continue;
                            List qi = null;
                            if (DCSVM.this.accelCache != null) {
                                int multiplier = DCSVM.this.accelCache.size() / N;
                                qi = DCSVM.this.accelCache.subList(i * multiplier, i * multiplier + multiplier);
                            }
                            group[i] = DCSVM.this.clusters.findClosestCluster((Vec)DCSVM.this.vecs.get(i), qi);
                        }
                        latch.countDown();
                    }
                });
            }
            try {
                latch.await();
            }
            catch (InterruptedException ex) {
                throw new FailedToFitException(ex);
            }
            Iterator iterator = found_clusters.iterator();
            while (iterator.hasNext()) {
                int c = (Integer)iterator.next();
                ClassificationDataSet V_c = new ClassificationDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories(), dataSet.getPredicting());
                DoubleList V_alphas = new DoubleList();
                IntList orig_index = new IntList();
                for (int i3 = 0; i3 < N; ++i3) {
                    if (group[i3] != c) continue;
                    V_c.addDataPoint(dataSet.getDataPoint(i3), dataSet.getDataPointCategory(i3));
                    V_alphas.add(Math.abs(this.alphas[i3]));
                    orig_index.add(i3);
                }
                SVMnoBias svm = new SVMnoBias(this.getKernel());
                if (this.cache_size > 0L) {
                    svm.setCacheSize(V_alphas.size(), this.cache_size);
                } else {
                    svm.setCacheMode(SupportVectorLearner.CacheMode.NONE);
                }
                if (l == this.l_max) {
                    svm.trainC(V_c, threadPool);
                } else {
                    svm.trainC(V_c, V_alphas.getBackingArray(), threadPool);
                }
                this.early_models.put(c, svm);
                for (int i4 = 0; i4 < orig_index.size(); ++i4) {
                    this.alphas[orig_index.get((int)i4).intValue()] = svm.alphas[i4];
                }
            }
        }
        if (this.l_early == 0) {
            SVMnoBias svm = new SVMnoBias(this.getKernel());
            if (this.cache_size > 0L) {
                svm.setCacheSize(dataSet.getSampleSize(), this.cache_size);
            } else {
                svm.setCacheMode(SupportVectorLearner.CacheMode.NONE);
            }
            svm.trainC(dataSet, Arrays.copyOf(this.alphas, this.alphas.length), threadPool);
            this.early_models.clear();
            this.early_models.put(0, svm);
            for (int i = 0; i < N; ++i) {
                this.alphas[i] = svm.alphas[i];
            }
        }
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, new FakeExecutor());
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public DCSVM clone() {
        return new DCSVM(this);
    }

    @Parameter.WarmParameter(prefLowToHigh=true)
    public void setC(double C) {
        if (C <= 0.0) {
            throw new ArithmeticException("C must be a positive constant");
        }
        this.C = C;
    }

    public double getC() {
        return this.C;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

