/*
 * Decompiled with CFR 0.152.
 */
package jsat;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import jsat.ARFFLoader;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.ClassificationDataSet;
import jsat.datatransform.LinearTransform;
import jsat.datatransform.visualization.LargeViz;
import jsat.linear.DenseVector;
import jsat.utils.SystemInfo;
import jsat.utils.random.XORWOW;

public class MINST_TSNE {
    public static ClassificationDataSet swissRoll(int N) {
        int classes = 10;
        XORWOW rand = new XORWOW();
        double max = 14.137166941154069;
        ClassificationDataSet cds = new ClassificationDataSet(3, new CategoricalData[0], new CategoricalData(classes));
        for (int i = 0; i < N; ++i) {
            double r = ((Random)rand).nextDouble();
            double t = (r * 2.0 + 1.0) * 1.5 * Math.PI;
            cds.addDataPoint(DenseVector.toDenseVec(t * Math.cos(t), 12.6 * ((Random)rand).nextDouble(), t * Math.sin(t)), (int)Math.round(r * (double)classes));
        }
        return cds;
    }

    public static void main(String[] args) throws IOException, InterruptedException {
        String path = "/Users/eman7613/Desktop/";
        String train = "MNISTtrain.arff";
        String test = "MNISTtest.arff";
        ClassificationDataSet trainSet = new ClassificationDataSet(ARFFLoader.loadArffFile(new File(path, train)), 0);
        ClassificationDataSet testSet = new ClassificationDataSet(ARFFLoader.loadArffFile(new File(path, test)), 0);
        trainSet.applyTransform(new LinearTransform(trainSet));
        System.out.println("N: " + trainSet.getSampleSize());
        System.out.println("D: " + trainSet.getNumNumericalVars());
        System.out.println("Loaded!");
        LargeViz tsne = new LargeViz();
        tsne.setPerplexity(15.0);
        ExecutorService ex = Executors.newFixedThreadPool(SystemInfo.LogicalCores);
        ClassificationDataSet embeded = tsne.transform(trainSet, ex);
        System.out.println("Embedded!");
        embeded.applyTransform(new LinearTransform(embeded, 0.0, 1.0));
        File file = new File("tmp.py");
        BufferedWriter bw = new BufferedWriter(new FileWriter(file));
        bw.write("import numpy as np\n");
        bw.write("import matplotlib.pyplot as plt\n");
        bw.write("X = np.array(" + embeded.getDataVectors() + ")\n");
        bw.write("y = [");
        for (int i = 0; i < embeded.getSampleSize(); ++i) {
            if (i > 0) {
                bw.write(", ");
            }
            bw.write("" + embeded.getDataPointCategory(i));
        }
        bw.write("]\n");
        bw.write("plt.scatter(X.T[0], X.T[1], c=y)\n");
        bw.write("plt.show()");
        bw.close();
        System.out.println("file: " + file.getAbsolutePath());
        Process pr = Runtime.getRuntime().exec("python " + file.getAbsolutePath());
        ex.shutdownNow();
    }
}

