/*
 * Decompiled with CFR 0.152.
 */
package jsat;

import java.io.IOException;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.SeedSelectionMethods;
import jsat.clustering.kmeans.HamerlyKMeans;
import jsat.clustering.kmeans.KMeans;
import jsat.clustering.kmeans.NaiveKMeans;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceCounter;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.utils.IntSet;
import jsat.utils.SystemInfo;
import jsat.utils.random.XORWOW;

public class MINST_testKMeans {
    public static ClassificationDataSet swissRoll(int N) {
        int classes = 10;
        XORWOW rand = new XORWOW();
        double max = 14.137166941154069;
        ClassificationDataSet cds = new ClassificationDataSet(3, new CategoricalData[0], new CategoricalData(classes));
        for (int i = 0; i < N; ++i) {
            double r = ((Random)rand).nextDouble();
            double t = (r * 2.0 + 1.0) * 1.5 * Math.PI;
            cds.addDataPoint(DenseVector.toDenseVec(t * Math.cos(t), 12.6 * ((Random)rand).nextDouble(), t * Math.sin(t)), (int)Math.round(r * (double)classes));
        }
        return cds;
    }

    public static void main(String[] args) throws IOException, InterruptedException {
        String path = "/Users/eman7613/Desktop/";
        String train = "MNISTtrain.arff";
        String test = "MNISTtest.arff";
        ExecutorService ex = Executors.newFixedThreadPool(SystemInfo.LogicalCores);
        SimpleDataSet uni = new SimpleDataSet(new CategoricalData[0], 300);
        for (int i = 0; i < 100000; ++i) {
            uni.add(new DataPoint(DenseVector.random(300)));
        }
        SimpleDataSet set = uni;
        DistanceCounter dm = new DistanceCounter(new EuclideanDistance());
        int k = 20;
        SeedSelectionMethods.SeedSelection seedingMethod = SeedSelectionMethods.SeedSelection.MEAN_QUANTILES;
        NaiveKMeans naive = new NaiveKMeans(dm, seedingMethod);
        System.out.println("Naive");
        System.out.println("Hamerly");
        HamerlyKMeans ham = new HamerlyKMeans(dm, seedingMethod);
        MINST_testKMeans.bench(ham, set, k, dm, ex);
        System.out.println("Elkan");
        ex.shutdown();
    }

    public static void bench(KMeans naive, DataSet testSet, int k, DistanceCounter dm, ExecutorService ex) {
        dm.resetCounter();
        long start = System.currentTimeMillis();
        naive.cluster(testSet, k, ex);
        long end = System.currentTimeMillis();
        System.out.println("Time taken: " + (double)(end - start) / 1000.0 + " seconds");
        System.out.println("Distance calculations: " + dm.getCallCount());
    }

    public static void compareMeans(List<Vec> gt, List<Vec> test, DistanceMetric dm) {
        IntSet remaining = new IntSet();
        for (int i = 0; i < test.size(); ++i) {
            remaining.add(i);
        }
        for (Vec m : gt) {
            int closest = 0;
            double cd = dm.dist(m, test.get(closest));
            for (int i = 1; i < test.size(); ++i) {
                double tmp = dm.dist(m, test.get(i));
                if (!(tmp < cd)) continue;
                closest = i;
                cd = tmp;
            }
            if (cd > 1.0E-10) {
                System.out.println("\tERROR");
                continue;
            }
            remaining.remove(closest);
        }
        if (!remaining.isEmpty()) {
            System.out.println("\tERROR2");
        }
    }
}

