/*
 * Decompiled with CFR 0.152.
 */
package adams.ml.dl4j;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Paths;
import java.util.Random;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RBM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DBNIrisExample {
    private static Logger log = LoggerFactory.getLogger(DBNIrisExample.class);

    public static void main(String[] args) throws Exception {
        Nd4j.MAX_SLICES_TO_PRINT = -1;
        Nd4j.MAX_ELEMENTS_PER_SLICE = -1;
        int numRows = 4;
        boolean numColumns = true;
        int outputNum = 3;
        int numSamples = 150;
        int batchSize = 150;
        int iterations = 5;
        int splitTrainNum = (int)((double)batchSize * 0.8);
        int seed = 123;
        int listenerFreq = 1;
        log.info("Load data....");
        IrisDataSetIterator iter = new IrisDataSetIterator(batchSize, numSamples);
        DataSet next = (DataSet)iter.next();
        next.shuffle();
        next.normalizeZeroMeanZeroUnitVariance();
        log.info("Split data....");
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
        DataSet train = testAndTrain.getTrain();
        DataSet test = testAndTrain.getTest();
        Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations).learningRate((double)1.0E-6f).optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).l1(0.1).regularization(true).l2(2.0E-4).useDropConnect(true).list().layer(0, (Layer)((RBM.Builder)((RBM.Builder)((RBM.Builder)((RBM.Builder)((RBM.Builder)((RBM.Builder)((RBM.Builder)new RBM.Builder(RBM.HiddenUnit.RECTIFIED, RBM.VisibleUnit.GAUSSIAN).nIn(4)).nOut(3)).weightInit(WeightInit.XAVIER)).k(1).activation("relu")).lossFunction(LossFunctions.LossFunction.RMSE_XENT)).updater(Updater.ADAGRAD)).dropOut(0.5)).build()).layer(1, (Layer)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(3)).nOut(outputNum)).activation("softmax")).build()).build();
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        System.out.println(model.conf().toYaml());
        model.init();
        model.setListeners(new IterationListener[]{new ScoreIterationListener(listenerFreq)});
        log.info("Train model....");
        model.fit((org.nd4j.linalg.dataset.api.DataSet)train);
        log.info("Evaluate weights....");
        for (org.deeplearning4j.nn.api.Layer layer : model.getLayers()) {
            INDArray w = layer.getParam("W");
            log.info("Weights: " + w);
        }
        log.info("Evaluate model....");
        Evaluation eval = new Evaluation(outputNum);
        eval.eval(test.getLabels(), model.output(test.getFeatureMatrix(), Layer.TrainingMode.TEST));
        log.info(eval.stats());
        log.info("****************Example finished********************");
        OutputStream fos = Files.newOutputStream(Paths.get("coefficients.bin", new String[0]), new OpenOption[0]);
        DataOutputStream dos = new DataOutputStream(fos);
        Nd4j.write((INDArray)model.params(), (DataOutputStream)dos);
        dos.flush();
        dos.close();
        FileUtils.writeStringToFile((File)new File("conf.json"), (String)model.getLayerWiseConfigurations().toJson());
        MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson((String)FileUtils.readFileToString((File)new File("conf.json")));
        DataInputStream dis = new DataInputStream(new FileInputStream("coefficients.bin"));
        INDArray newParams = Nd4j.read((DataInputStream)dis);
        dis.close();
        MultiLayerNetwork savedNetwork = new MultiLayerNetwork(confFromJson);
        savedNetwork.init();
        savedNetwork.setParams(newParams);
        System.out.println("Original network params " + model.params());
        System.out.println(savedNetwork.params());
    }
}

