/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Random;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.rules.ZeroR;
import weka.core.BatchPredictor;
import weka.core.Capabilities;
import weka.core.CapabilitiesHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionMetadata;
import weka.core.Utils;
import weka.dl4j.FileIterationListener;
import weka.dl4j.iterators.AbstractDataSetIterator;
import weka.dl4j.iterators.ConvolutionalInstancesIterator;
import weka.dl4j.iterators.DefaultInstancesIterator;
import weka.dl4j.iterators.ImageDataSetIterator;
import weka.dl4j.layers.DenseLayer;
import weka.dl4j.layers.OutputLayer;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.attribute.Standardize;

public class Dl4jMlpClassifier
extends RandomizableClassifier
implements BatchPredictor,
CapabilitiesHandler {
    protected static final long serialVersionUID = -6363254116597574265L;
    protected final Logger m_log = LoggerFactory.getLogger(Dl4jMlpClassifier.class);
    protected ReplaceMissingValues m_replaceMissing;
    protected Filter m_normalize;
    protected NominalToBinary m_nominalToBinary;
    protected ZeroR m_zeroR;
    protected transient MultiLayerNetwork m_model;
    protected File m_logFile = new File(System.getProperty("user.dir"));
    protected Layer[] m_layers = new Layer[]{new OutputLayer()};
    protected OptimizationAlgorithm m_algo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
    protected int m_numEpochs = 10;
    protected AbstractDataSetIterator m_iterator = new DefaultInstancesIterator();
    protected boolean m_standardizeInsteadOfNormalize = true;
    protected double m_x1 = 1.0;
    protected double m_x0 = 0.0;

    public static void main(String[] argv) {
        Dl4jMlpClassifier.runClassifier((Classifier)new Dl4jMlpClassifier(), (String[])argv);
    }

    public String globalInfo() {
        return "Classification and regression with multilayer perceptrons using DeepLearning4J.";
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        if (this.getDataSetIterator() instanceof ImageDataSetIterator) {
            result.enable(Capabilities.Capability.STRING_ATTRIBUTES);
        } else {
            result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
            result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
            result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
            result.enable(Capabilities.Capability.MISSING_VALUES);
            result.enableDependency(Capabilities.Capability.STRING_ATTRIBUTES);
        }
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.DATE_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    private void writeObject(ObjectOutputStream oos) throws IOException {
        oos.defaultWriteObject();
        if (this.m_replaceMissing != null) {
            ModelSerializer.writeModel((Model)this.m_model, (OutputStream)oos, (boolean)false);
        }
    }

    private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
        ois.defaultReadObject();
        if (this.m_replaceMissing != null) {
            ClassLoader origLoader = Thread.currentThread().getContextClassLoader();
            try {
                Thread.currentThread().setContextClassLoader(((Object)((Object)this)).getClass().getClassLoader());
                this.m_model = ModelSerializer.restoreMultiLayerNetwork((InputStream)ois, (boolean)false);
            }
            finally {
                Thread.currentThread().setContextClassLoader(origLoader);
            }
        }
    }

    public File getLogFile() {
        return this.m_logFile;
    }

    @OptionMetadata(displayName="log file", description="The name of the log file to write loss information to (default = no log file).", commandLineParamName="logFile", commandLineParamSynopsis="-logFile <string>", displayOrder=1)
    public void setLogFile(File logFile) {
        this.m_logFile = logFile;
    }

    public Layer[] getLayers() {
        return this.m_layers;
    }

    @OptionMetadata(displayName="layer specification", description="The specification of the layers.", commandLineParamName="layers", commandLineParamSynopsis="-layers <string>", displayOrder=2)
    public void setLayers(Layer[] layers) {
        this.m_layers = layers;
    }

    public int getNumEpochs() {
        return this.m_numEpochs;
    }

    @OptionMetadata(description="The number of epochs to perform", displayName="number of epochs", commandLineParamName="numEpochs", commandLineParamSynopsis="-numEpochs <int>", displayOrder=4)
    public void setNumEpochs(int numEpochs) {
        this.m_numEpochs = numEpochs;
    }

    @OptionMetadata(description="Optimization algorithm (LINE_GRADIENT_DESCENT, CONJUGATE_GRADIENT, HESSIAN_FREE, LBFGS, STOCHASTIC_GRADIENT_DESCENT)", displayName="optimization algorithm", commandLineParamName="algorithm", commandLineParamSynopsis="-algorithm <string>", displayOrder=5)
    public OptimizationAlgorithm getOptimizationAlgorithm() {
        return this.m_algo;
    }

    public void setOptimizationAlgorithm(OptimizationAlgorithm optimAlgorithm) {
        this.m_algo = optimAlgorithm;
    }

    @OptionMetadata(description="The dataset iterator to use", displayName="dataset iterator", commandLineParamName="iterator", commandLineParamSynopsis="-iterator <string>", displayOrder=6)
    public AbstractDataSetIterator getDataSetIterator() {
        return this.m_iterator;
    }

    public void setDataSetIterator(AbstractDataSetIterator iterator) {
        this.m_iterator = iterator;
    }

    protected int getNumUnits(Layer layer) {
        if (layer instanceof DenseLayer) {
            return ((DenseLayer)layer).getNOut();
        }
        if (layer instanceof OutputLayer) {
            return ((OutputLayer)layer).getNOut();
        }
        return -1;
    }

    protected void setNumIncoming(Layer layer, int numInputs) {
        if (layer instanceof DenseLayer) {
            ((DenseLayer)layer).setNIn(numInputs);
        } else if (layer instanceof OutputLayer) {
            ((OutputLayer)layer).setNIn(numInputs);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void buildClassifier(Instances data) throws Exception {
        ClassLoader orig = Thread.currentThread().getContextClassLoader();
        try {
            int index;
            Thread.currentThread().setContextClassLoader(((Object)((Object)this)).getClass().getClassLoader());
            this.getCapabilities().testWithFail(data);
            if (this.m_layers.length == 0) {
                throw new Exception("No layers have been added!");
            }
            if (!(this.m_layers[this.m_layers.length - 1] instanceof OutputLayer)) {
                throw new Exception("Last layer in network must be an output layer!");
            }
            data = new Instances(data);
            data.deleteWithMissingClass();
            this.m_zeroR = null;
            if (data.numInstances() == 0 || data.numAttributes() < 2) {
                this.m_zeroR = new ZeroR();
                this.m_zeroR.buildClassifier(data);
                return;
            }
            this.m_replaceMissing = new ReplaceMissingValues();
            this.m_replaceMissing.setInputFormat(data);
            data = Filter.useFilter((Instances)data, (Filter)this.m_replaceMissing);
            double y0 = data.instance(0).classValue();
            for (index = 1; index < data.numInstances() && data.instance(index).classValue() == y0; ++index) {
            }
            if (index == data.numInstances()) {
                throw new Exception("All class values are the same. At least two class values should be different");
            }
            double y1 = data.instance(index).classValue();
            this.m_nominalToBinary = new NominalToBinary();
            this.m_nominalToBinary.setInputFormat(data);
            data = Filter.useFilter((Instances)data, (Filter)this.m_nominalToBinary);
            if (this.m_standardizeInsteadOfNormalize) {
                this.m_normalize = new Standardize();
                this.m_normalize.setOptions(new String[]{"-unset-class-temporarily"});
            } else {
                this.m_normalize = new Normalize();
            }
            this.m_normalize.setInputFormat(data);
            data = Filter.useFilter((Instances)data, (Filter)this.m_normalize);
            double z0 = data.instance(0).classValue();
            double z1 = data.instance(index).classValue();
            this.m_x1 = (y0 - y1) / (z0 - z1);
            this.m_x0 = y0 - this.m_x1 * z0;
            Random rand = new Random(this.getSeed());
            data.randomize(rand);
            NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
            if (this.getOptimizationAlgorithm() == null) {
                builder.setOptimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
            } else {
                builder.setOptimizationAlgo(this.getOptimizationAlgorithm());
            }
            builder.setSeed((long)rand.nextInt());
            NeuralNetConfiguration.ListBuilder ip = builder.list(this.getLayers());
            int numInputAttributes = this.getDataSetIterator().getNumAttributes(data);
            for (int x = 0; x < this.m_layers.length; ++x) {
                if (x == 0) {
                    this.setNumIncoming(this.m_layers[x], numInputAttributes);
                } else {
                    this.setNumIncoming(this.m_layers[x], this.getNumUnits(this.m_layers[x - 1]));
                }
                if (x == this.m_layers.length - 1) {
                    ((OutputLayer)this.m_layers[x]).setNOut(data.numClasses());
                }
                ip = ip.layer(x, this.m_layers[x]);
            }
            if (this.getDataSetIterator() instanceof ImageDataSetIterator) {
                ImageDataSetIterator idsi = (ImageDataSetIterator)this.getDataSetIterator();
                ip.setInputType(InputType.convolutionalFlat((int)idsi.getWidth(), (int)idsi.getHeight(), (int)idsi.getNumChannels()));
            } else if (this.getDataSetIterator() instanceof ConvolutionalInstancesIterator) {
                ConvolutionalInstancesIterator cii = (ConvolutionalInstancesIterator)this.getDataSetIterator();
                ip.setInputType(InputType.convolutionalFlat((int)cii.getWidth(), (int)cii.getHeight(), (int)cii.getNumChannels()));
            }
            ip = ip.pretrain(false).backprop(true);
            MultiLayerConfiguration conf = ip.build();
            if (this.getDebug()) {
                System.err.println(conf.toJson());
            }
            this.m_model = new MultiLayerNetwork(conf);
            this.m_model.init();
            if (this.getDebug()) {
                System.err.println(this.m_model.conf().toYaml());
            }
            ArrayList<Object> listeners = new ArrayList<Object>();
            listeners.add(new ScoreIterationListener(data.numInstances() / this.getDataSetIterator().getTrainBatchSize()));
            if (this.getLogFile() != null && !this.getLogFile().isDirectory()) {
                int numMiniBatches = (int)Math.ceil((double)data.numInstances() / (double)this.getDataSetIterator().getTrainBatchSize());
                listeners.add(new FileIterationListener(this.getLogFile().getAbsolutePath(), numMiniBatches));
            }
            this.m_model.setListeners(listeners);
            DataSetIterator iter = this.getDataSetIterator().getIterator(data, this.getSeed());
            for (int i = 0; i < this.getNumEpochs(); ++i) {
                this.m_model.fit(iter);
                if (this.getDebug()) {
                    this.m_log.info("*** Completed epoch {} ***", (Object)(i + 1));
                }
                iter.reset();
            }
        }
        finally {
            Thread.currentThread().setContextClassLoader(orig);
        }
    }

    public double[] distributionForInstance(Instance inst) throws Exception {
        if (this.m_zeroR != null) {
            return this.m_zeroR.distributionForInstance(inst);
        }
        this.m_replaceMissing.input(inst);
        inst = this.m_replaceMissing.output();
        this.m_nominalToBinary.input(inst);
        inst = this.m_nominalToBinary.output();
        this.m_normalize.input(inst);
        inst = this.m_normalize.output();
        Instances insts = new Instances(inst.dataset(), 0);
        insts.add(inst);
        DataSet ds = (DataSet)this.getDataSetIterator().getIterator(insts, this.getSeed(), 1).next();
        INDArray predicted = this.m_model.output(ds.getFeatureMatrix(), false);
        predicted = predicted.getRow(0);
        double[] preds = new double[inst.numClasses()];
        for (int i = 0; i < preds.length; ++i) {
            preds[i] = predicted.getDouble(i);
        }
        if (preds.length > 1) {
            Utils.normalize((double[])preds);
        } else {
            preds[0] = preds[0] * this.m_x1 + this.m_x0;
        }
        return preds;
    }

    public String toString() {
        if (this.m_replaceMissing != null) {
            return this.m_model.getLayerWiseConfigurations().toYaml();
        }
        return null;
    }
}

