/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import adams.core.MessageCollection;
import adams.core.SerializationHelper;
import adams.core.io.FileUtils;
import adams.core.io.PlaceholderFile;
import adams.core.option.OptionHandler;
import adams.data.conversion.DL4JJsonToModel;
import adams.data.conversion.DL4JYamlToModel;
import adams.flow.container.DL4JModelContainer;
import adams.flow.container.WekaTrainTestSetContainer;
import adams.ml.dl4j.iterationlistener.IterationListenerConfigurator;
import adams.ml.dl4j.model.Dl4jMlpClassifier;
import adams.ml.dl4j.model.ModelType;
import adams.ml.dl4j.trainstopcriterion.AbstractTrainStopCriterion;
import adams.ml.dl4j.trainstopcriterion.MaxEpoch;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import weka.classifiers.Classifier;
import weka.classifiers.DL4JFilteredMultiLayerNetworkProvider;
import weka.classifiers.DL4JMultiLayerNetworkProvider;
import weka.classifiers.RandomSplitGenerator;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.functions.dl4j.Utils;
import weka.classifiers.rules.ZeroR;
import weka.core.Capabilities;
import weka.core.CapabilitiesHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.WekaOptionUtils;
import weka.dl4j.iterators.DefaultInstancesIterator;
import weka.dl4j.layers.DenseLayer;
import weka.dl4j.layers.OutputLayer;
import weka.filters.AllFilter;
import weka.filters.Filter;

public class DL4JMultiLayerNetwork
extends RandomizableClassifier
implements DL4JFilteredMultiLayerNetworkProvider {
    protected static final long serialVersionUID = -6363254116597574265L;
    public static final String LAYER = "layer";
    public static final String OPTIMIZATION_ALGORITHM = "optimization-algorithm";
    public static final String TRAIN_STOP = "train-stop";
    public static final String MINI_BATCH_SIZE = "mini-batch-size";
    public static final String RANDOMIZE_BETWEEN_EPOCHS = "randomize-between-epochs";
    public static final String DROP_TYPE = "drop-type";
    public static final String DROP_OUT = "drop-out";
    public static final String ITERATION_LISTENER = "iteration-listener";
    public static final String MODEL_FILE = "model-file";
    public static final String MODEL_FILE_TYPE = "model-file-type";
    public static final String FILTER = "filter";
    public static final String TEST_INTERVAL = "test-interval";
    public static final String TEST_PERCENTAGE = "test-percentage";
    protected ZeroR m_ZeroR;
    protected transient MultiLayerNetwork m_Model;
    protected byte[] m_ModelData = null;
    protected boolean m_Trained = false;
    protected Layer[] m_Layers = this.getDefaultLayers();
    protected OptimizationAlgorithm m_Algorithm = this.getDefaultOptimizationAlgorithm();
    protected AbstractTrainStopCriterion m_TrainStop = this.getDefaultTrainStop();
    protected int m_MiniBatchSize = this.getDefaultMiniBatchSize();
    protected boolean m_RandomizeBetweenEpochs = this.getDefaultRandomizeBetweenEpochs();
    protected boolean m_UseRegularization = this.getDefaultUseRegularization();
    protected Dl4jMlpClassifier.DropType m_DropType = this.getDefaultDropType();
    protected double m_DropOut = this.getDefaultDropOut();
    protected IterationListenerConfigurator[] m_IterationListeners = this.getDefaultIterationListeners();
    protected PlaceholderFile m_ModelFile = this.getDefaultModelFile();
    protected ModelFileType m_ModelFileType = this.getDefaultModelFileType();
    protected Filter m_Filter = this.getDefaultFilter();
    protected Filter m_ActualFilter;
    protected DefaultInstancesIterator m_Iterator = null;
    protected int m_TestInterval = this.getDefaultTestInterval();
    protected double m_TestPercentage = this.getDefaultTestPercentage();
    protected Instances m_Header;
    protected boolean m_ProgrammaticInitialization;

    public String globalInfo() {
        return "Classification and regression with multilayer perceptrons using DeepLearning4J.";
    }

    public Enumeration listOptions() {
        Vector result = new Vector();
        WekaOptionUtils.addOption(result, (String)this.layersTipText(), (String)weka.core.Utils.arrayToString((Object)this.getDefaultLayers()), (String)LAYER);
        WekaOptionUtils.addOption(result, (String)this.optimizationAlgorithmTipText(), (String)("" + this.getDefaultOptimizationAlgorithm()), (String)OPTIMIZATION_ALGORITHM);
        WekaOptionUtils.addOption(result, (String)this.trainStopTipText(), (String)("" + this.getDefaultTrainStop()), (String)TRAIN_STOP);
        WekaOptionUtils.addOption(result, (String)this.miniBatchSizeTipText(), (String)("" + this.getDefaultMiniBatchSize()), (String)MINI_BATCH_SIZE);
        WekaOptionUtils.addOption(result, (String)this.randomizeBetweenEpochsTipText(), (String)("" + this.getDefaultRandomizeBetweenEpochs()), (String)RANDOMIZE_BETWEEN_EPOCHS);
        WekaOptionUtils.addOption(result, (String)this.dropTypeTipText(), (String)("" + (Object)((Object)this.getDefaultDropType())), (String)DROP_TYPE);
        WekaOptionUtils.addOption(result, (String)this.dropOutTipText(), (String)("" + this.getDefaultDropOut()), (String)DROP_OUT);
        WekaOptionUtils.addOption(result, (String)this.iterationListenersTipText(), (String)weka.core.Utils.arrayToString((Object)this.getDefaultIterationListeners()), (String)ITERATION_LISTENER);
        WekaOptionUtils.addOption(result, (String)this.modelFileTipText(), (String)("" + this.getDefaultModelFile()), (String)MODEL_FILE);
        WekaOptionUtils.addOption(result, (String)this.modelFileTypeTipText(), (String)("" + (Object)((Object)this.getDefaultModelFileType())), (String)MODEL_FILE_TYPE);
        WekaOptionUtils.addOption(result, (String)this.filterTipText(), (weka.core.OptionHandler)this.getDefaultFilter(), (String)FILTER);
        WekaOptionUtils.addOption(result, (String)this.testIntervalTipText(), (String)("" + this.getDefaultTestInterval()), (String)TEST_INTERVAL);
        WekaOptionUtils.addOption(result, (String)this.testPercentageTipText(), (String)("" + this.getDefaultTestPercentage()), (String)TEST_PERCENTAGE);
        WekaOptionUtils.add(result, (Enumeration)super.listOptions());
        return WekaOptionUtils.toEnumeration(result);
    }

    public void setOptions(String[] options) throws Exception {
        this.setLayers((Layer[])WekaOptionUtils.parse((String[])options, (String)LAYER, (Object)this.getDefaultLayers(), Layer.class));
        this.setOptimizationAlgorithm((OptimizationAlgorithm)WekaOptionUtils.parse((String[])options, (String)OPTIMIZATION_ALGORITHM, (Enum)this.getDefaultOptimizationAlgorithm()));
        this.setTrainStop((AbstractTrainStopCriterion)WekaOptionUtils.parse((String[])options, (String)TRAIN_STOP, (OptionHandler)this.getDefaultTrainStop()));
        this.setMiniBatchSize(WekaOptionUtils.parse((String[])options, (String)MINI_BATCH_SIZE, (int)this.getDefaultMiniBatchSize()));
        this.setRandomizeBetweenEpochs(weka.core.Utils.getFlag((String)RANDOMIZE_BETWEEN_EPOCHS, (String[])options));
        this.setDropType((Dl4jMlpClassifier.DropType)WekaOptionUtils.parse((String[])options, (String)DROP_TYPE, (Enum)this.getDefaultDropType()));
        this.setDropOut(WekaOptionUtils.parse((String[])options, (String)DROP_OUT, (double)this.getDefaultDropOut()));
        this.setIterationListeners((IterationListenerConfigurator[])WekaOptionUtils.parse((String[])options, (String)ITERATION_LISTENER, (Object)this.getDefaultIterationListeners(), IterationListenerConfigurator.class));
        this.setModelFile(WekaOptionUtils.parse((String[])options, (String)MODEL_FILE, (PlaceholderFile)this.getDefaultModelFile()));
        this.setModelFileType((ModelFileType)WekaOptionUtils.parse((String[])options, (String)MODEL_FILE_TYPE, (Enum)this.getDefaultModelFileType()));
        this.setFilter((Filter)WekaOptionUtils.parse((String[])options, (String)FILTER, (weka.core.OptionHandler)this.getDefaultFilter()));
        this.setTestInterval(WekaOptionUtils.parse((String[])options, (String)TEST_INTERVAL, (int)this.getDefaultTestInterval()));
        this.setTestPercentage(WekaOptionUtils.parse((String[])options, (String)TEST_PERCENTAGE, (double)this.getDefaultTestPercentage()));
        super.setOptions(options);
    }

    public String[] getOptions() {
        ArrayList result = new ArrayList();
        if (!this.m_ProgrammaticInitialization) {
            if (this.getModelFile().isDirectory()) {
                WekaOptionUtils.add(result, (String)LAYER, (Object)this.getLayers());
                WekaOptionUtils.add(result, (String)OPTIMIZATION_ALGORITHM, (Enum)this.getOptimizationAlgorithm());
                WekaOptionUtils.add(result, (String)DROP_TYPE, (Enum)this.getDropType());
                WekaOptionUtils.add(result, (String)DROP_OUT, (double)this.getDropOut());
            } else {
                WekaOptionUtils.add(result, (String)MODEL_FILE, (File)this.getModelFile());
                WekaOptionUtils.add(result, (String)MODEL_FILE_TYPE, (Enum)this.getModelFileType());
            }
            WekaOptionUtils.add(result, (String)TRAIN_STOP, (OptionHandler)this.getTrainStop());
            WekaOptionUtils.add(result, (String)MINI_BATCH_SIZE, (int)this.getMiniBatchSize());
            WekaOptionUtils.add(result, (String)RANDOMIZE_BETWEEN_EPOCHS, (boolean)this.getRandomizeBetweenEpochs());
            WekaOptionUtils.add(result, (String)ITERATION_LISTENER, (Object)this.getIterationListeners());
            WekaOptionUtils.add(result, (String)FILTER, (weka.core.OptionHandler)this.getFilter());
            WekaOptionUtils.add(result, (String)TEST_PERCENTAGE, (double)this.getTestPercentage());
            if (this.getTestPercentage() > 0.0) {
                WekaOptionUtils.add(result, (String)TEST_INTERVAL, (int)this.getTestInterval());
            }
        }
        WekaOptionUtils.add(result, (String[])super.getOptions());
        return WekaOptionUtils.toArray(result);
    }

    public Capabilities getCapabilities() {
        Capabilities result;
        if (this.m_Filter instanceof AllFilter) {
            result = super.getCapabilities();
            result.disableAll();
            result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
            result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
            result.enable(Capabilities.Capability.NOMINAL_CLASS);
            result.enable(Capabilities.Capability.NUMERIC_CLASS);
            result.enable(Capabilities.Capability.DATE_CLASS);
            result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        } else {
            result = this.m_Filter.getCapabilities();
            result.setOwner((CapabilitiesHandler)this);
        }
        return result;
    }

    private void writeObject(ObjectOutputStream oos) throws IOException {
        if (this.m_Trained) {
            ByteArrayOutputStream bos = new ByteArrayOutputStream();
            ModelSerializer.writeModel((Model)this.m_Model, (OutputStream)bos, (boolean)false);
            this.m_ModelData = bos.toByteArray();
        }
        oos.defaultWriteObject();
        this.m_ModelData = null;
    }

    private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
        ois.defaultReadObject();
        if (this.m_ModelData != null) {
            ByteArrayInputStream bis = new ByteArrayInputStream(this.m_ModelData);
            this.m_Model = ModelSerializer.restoreMultiLayerNetwork((InputStream)bis, (boolean)false);
            this.m_ModelData = null;
        }
    }

    protected Layer[] getDefaultLayers() {
        return new Layer[]{new OutputLayer()};
    }

    public void setLayers(Layer[] value) {
        this.m_Layers = value;
    }

    public Layer[] getLayers() {
        return this.m_Layers;
    }

    public String layersTipText() {
        return "The layers for the network (last one must be an output layer).";
    }

    protected AbstractTrainStopCriterion getDefaultTrainStop() {
        return new MaxEpoch();
    }

    public void setTrainStop(AbstractTrainStopCriterion value) {
        this.m_TrainStop = value;
    }

    public AbstractTrainStopCriterion getTrainStop() {
        return this.m_TrainStop;
    }

    public String trainStopTipText() {
        return "The criterion for stopping training.";
    }

    protected int getDefaultMiniBatchSize() {
        return 100;
    }

    public void setMiniBatchSize(int value) {
        this.m_MiniBatchSize = value;
    }

    public int getMiniBatchSize() {
        return this.m_MiniBatchSize;
    }

    public String miniBatchSizeTipText() {
        return "The size to use for mini-batches.";
    }

    protected boolean getDefaultRandomizeBetweenEpochs() {
        return false;
    }

    public void setRandomizeBetweenEpochs(boolean value) {
        this.m_RandomizeBetweenEpochs = value;
    }

    public boolean getRandomizeBetweenEpochs() {
        return this.m_RandomizeBetweenEpochs;
    }

    public String randomizeBetweenEpochsTipText() {
        return "If enabled, the data gets randomized between epochs.";
    }

    protected OptimizationAlgorithm getDefaultOptimizationAlgorithm() {
        return OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
    }

    public void setOptimizationAlgorithm(OptimizationAlgorithm value) {
        this.m_Algorithm = value;
    }

    public OptimizationAlgorithm getOptimizationAlgorithm() {
        return this.m_Algorithm;
    }

    public String optimizationAlgorithmTipText() {
        return "The optimization algorithm to use.";
    }

    protected boolean getDefaultUseRegularization() {
        return false;
    }

    public void setUseRegularization(boolean value) {
        this.m_UseRegularization = value;
    }

    public boolean getUseRegularization() {
        return this.m_UseRegularization;
    }

    public String useRegularizationTipText() {
        return "If enabled, regularization is used.";
    }

    protected Dl4jMlpClassifier.DropType getDefaultDropType() {
        return Dl4jMlpClassifier.DropType.NONE;
    }

    public void setDropType(Dl4jMlpClassifier.DropType value) {
        this.m_DropType = value;
    }

    public Dl4jMlpClassifier.DropType getDropType() {
        return this.m_DropType;
    }

    public String dropTypeTipText() {
        return "The type of drop to use.";
    }

    protected double getDefaultDropOut() {
        return 0.0;
    }

    public void setDropOut(double value) {
        if (value >= 0.0 && value <= 1.0) {
            this.m_DropOut = value;
        }
    }

    public double getDropOut() {
        return this.m_DropOut;
    }

    public String dropOutTipText() {
        return "The drop-out value.";
    }

    protected IterationListenerConfigurator[] getDefaultIterationListeners() {
        return new IterationListenerConfigurator[0];
    }

    public void setIterationListeners(IterationListenerConfigurator[] value) {
        this.m_IterationListeners = value;
    }

    public IterationListenerConfigurator[] getIterationListeners() {
        return this.m_IterationListeners;
    }

    public String iterationListenersTipText() {
        return "The configurators for iteration listeners to use.";
    }

    protected Filter getDefaultFilter() {
        return new AllFilter();
    }

    public void setFilter(Filter value) {
        this.m_Filter = value;
    }

    public Filter getFilter() {
        return this.m_Filter;
    }

    public String filterTipText() {
        return "The filter to apply to the data, ignored if " + AllFilter.class.getName() + ".";
    }

    protected PlaceholderFile getDefaultModelFile() {
        return new PlaceholderFile();
    }

    public void setModelFile(PlaceholderFile value) {
        this.m_ModelFile = value;
    }

    public PlaceholderFile getModelFile() {
        return this.m_ModelFile;
    }

    public String modelFileTipText() {
        return "The (optional) file to load a model from (trained or just structure), which will be used as template for the network structure; ignored if pointing to a directory.";
    }

    protected ModelFileType getDefaultModelFileType() {
        return ModelFileType.YAML;
    }

    public void setModelFileType(ModelFileType value) {
        this.m_ModelFileType = value;
    }

    public ModelFileType getModelFileType() {
        return this.m_ModelFileType;
    }

    public String modelFileTypeTipText() {
        return "The type of the model file; in case of '" + (Object)((Object)ModelFileType.WEKA_SERIALIZED_MODEL) + "', the object must implement '" + DL4JMultiLayerNetworkProvider.class.getName() + "'.";
    }

    protected int getDefaultTestInterval() {
        return 100;
    }

    public void setTestInterval(int value) {
        if (value > 0) {
            this.m_TestInterval = value;
        }
    }

    public int getTestInterval() {
        return this.m_TestInterval;
    }

    public String testIntervalTipText() {
        return "The interval (of epochs) to test the model, if a test percentage is specified.";
    }

    protected double getDefaultTestPercentage() {
        return 0.0;
    }

    public void setTestPercentage(double value) {
        if (value >= 0.0 && value < 1.0) {
            this.m_TestPercentage = value;
        }
    }

    public double getTestPercentage() {
        return this.m_TestPercentage;
    }

    public String testPercentageTipText() {
        return "The percentage (0-1) of the training data to set aside for evaluating the model; no testing performed if 0.";
    }

    protected int getNumUnits(Layer layer) {
        if (layer instanceof DenseLayer) {
            return ((DenseLayer)layer).getNOut();
        }
        if (layer instanceof OutputLayer) {
            return ((OutputLayer)layer).getNOut();
        }
        return -1;
    }

    protected void setNumIncoming(Layer layer, int numInputs) {
        if (layer instanceof DenseLayer) {
            ((DenseLayer)layer).setNIn(numInputs);
        } else if (layer instanceof OutputLayer) {
            ((OutputLayer)layer).setNIn(numInputs);
        }
    }

    protected MultiLayerNetwork loadModel() throws Exception {
        MultiLayerNetwork result = null;
        if (this.getDebug()) {
            System.out.println("Loading model (" + (Object)((Object)this.m_ModelFileType) + "): " + this.m_ModelFile);
        }
        switch (this.m_ModelFileType) {
            case YAML: {
                String str = adams.core.Utils.flatten((List)FileUtils.loadFromFile((File)this.m_ModelFile), (String)"\n");
                DL4JYamlToModel yaml = new DL4JYamlToModel();
                yaml.setType(ModelType.MULTI_LAYER_NETWORK);
                yaml.setInput((Object)str);
                String msg = yaml.convert();
                if (msg != null) {
                    throw new Exception("Failed to convert YAML model loaded from '" + this.m_ModelFile + "': " + msg);
                }
                result = (MultiLayerNetwork)yaml.getOutput();
                break;
            }
            case JSON: {
                String str = adams.core.Utils.flatten((List)FileUtils.loadFromFile((File)this.m_ModelFile), (String)"\n");
                DL4JJsonToModel json = new DL4JJsonToModel();
                json.setType(ModelType.MULTI_LAYER_NETWORK);
                json.setInput((Object)str);
                String msg = json.convert();
                if (msg != null) {
                    throw new Exception("Failed to convert JSON model loaded from '" + this.m_ModelFile + "': " + msg);
                }
                result = (MultiLayerNetwork)json.getOutput();
                break;
            }
            case WEKA_SERIALIZED_MODEL: {
                DL4JMultiLayerNetworkProvider provider = (DL4JMultiLayerNetworkProvider)SerializationHelper.read((String)this.m_ModelFile.getAbsolutePath());
                result = provider.getMultiLayerNetwork();
                break;
            }
            case DL4J_SERIALIZED_MODEL: {
                result = ModelSerializer.restoreMultiLayerNetwork((File)this.m_ModelFile.getAbsoluteFile());
                break;
            }
            default: {
                throw new IllegalStateException("Unhandled model file type: " + (Object)((Object)this.m_ModelFileType));
            }
        }
        return result;
    }

    protected MultiLayerNetwork generateModel(Instances data) throws Exception {
        if (this.getDebug()) {
            System.out.println("Generating model...");
        }
        if (this.m_Layers.length == 0) {
            throw new Exception("No layers have been defined!");
        }
        if (!(this.m_Layers[this.m_Layers.length - 1] instanceof OutputLayer)) {
            throw new Exception("Last layer in network must be an output layer!");
        }
        NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
        builder.setOptimizationAlgo(this.getOptimizationAlgorithm());
        builder.setSeed((long)this.getSeed());
        if (this.m_UseRegularization) {
            builder.setUseRegularization(true);
        }
        switch (this.m_DropType) {
            case NONE: {
                builder.setDropOut(0.0);
                break;
            }
            case DROP_OUT: {
                builder.setDropOut(this.m_DropOut);
                break;
            }
            case DROP_CONNECT: {
                builder.setUseDropConnect(true);
                builder.setDropOut(this.m_DropOut);
                break;
            }
            default: {
                throw new IllegalStateException("Unhandled drop type: " + (Object)((Object)this.m_DropType));
            }
        }
        NeuralNetConfiguration.ListBuilder listbuilder = builder.list(this.getLayers());
        int numInputAtts = this.m_Iterator.getNumAttributes(data);
        for (int i = 0; i < this.m_Layers.length; ++i) {
            if (i == 0) {
                this.setNumIncoming(this.m_Layers[i], numInputAtts);
            } else {
                this.setNumIncoming(this.m_Layers[i], this.getNumUnits(this.m_Layers[i - 1]));
            }
            if (i == this.m_Layers.length - 1) {
                ((OutputLayer)this.m_Layers[i]).setNOut(data.numClasses());
            }
            listbuilder = listbuilder.layer(i, this.m_Layers[i]);
        }
        listbuilder = listbuilder.pretrain(false).backprop(true);
        MultiLayerConfiguration conf = listbuilder.build();
        MultiLayerNetwork result = new MultiLayerNetwork(conf);
        result.init();
        return result;
    }

    public void buildClassifier(Instances data) throws Exception {
        boolean stop;
        Instances test;
        Instances train;
        this.getCapabilities().testWithFail(data);
        this.m_Iterator = new DefaultInstancesIterator();
        data = new Instances(data);
        data.deleteWithMissingClass();
        this.m_ZeroR = null;
        if (data.numInstances() == 0 || data.numAttributes() < 2) {
            System.err.println("Not enough data, using ZeroR model!");
            this.m_ZeroR = new ZeroR();
            this.m_ZeroR.buildClassifier(data);
            return;
        }
        this.m_ActualFilter = null;
        if (!(this.m_Filter instanceof AllFilter)) {
            this.m_ActualFilter = Filter.makeCopy((Filter)this.m_Filter);
            this.m_ActualFilter.setInputFormat(data);
            data = Filter.useFilter((Instances)data, (Filter)this.m_ActualFilter);
        }
        this.m_Header = new Instances(data, 0);
        boolean nominal = this.m_Header.classAttribute().isNominal();
        this.m_Model = this.m_ModelFile.exists() && !this.m_ModelFile.isDirectory() ? this.loadModel() : this.generateModel(data);
        this.m_Model.init();
        if (this.getDebug()) {
            System.out.println("Network:\n" + this.m_Model.getLayerWiseConfigurations().toYaml());
        }
        ArrayList listeners = new ArrayList();
        for (IterationListenerConfigurator l : this.m_IterationListeners) {
            listeners.addAll(l.configureIterationListeners());
        }
        this.m_Model.setListeners(listeners);
        if (this.m_TestPercentage > 0.0) {
            RandomSplitGenerator split = new RandomSplitGenerator(data, (long)this.m_Seed, 1.0 - this.m_TestPercentage);
            WekaTrainTestSetContainer trainTest = split.next();
            train = (Instances)trainTest.getValue("Train");
            test = (Instances)trainTest.getValue("Test");
        } else {
            train = data;
            test = null;
        }
        DataSet dtrain = Utils.instancesToDataSet(train);
        DataSet dtest = test != null ? Utils.instancesToDataSet(test) : null;
        Random rand = new Random(this.getSeed());
        int seed = this.getSeed();
        MessageCollection triggers = new MessageCollection();
        int i = 0;
        do {
            DL4JModelContainer modelCont;
            if (this.m_RandomizeBetweenEpochs) {
                seed = rand.nextInt();
            }
            DataSetIterator iter = this.m_MiniBatchSize < 1 ? this.m_Iterator.getIterator(train, seed) : this.m_Iterator.getIterator(train, seed, this.m_MiniBatchSize);
            this.m_Model.fit(iter);
            if (this.getDebug() && (i + 1) % 100 == 0) {
                System.out.println("Epoch #" + (i + 1) + " finished");
            }
            Evaluation evalCls = null;
            RegressionEvaluation evalReg = null;
            if (this.m_TestPercentage > 0.0 && i % this.m_TestInterval == 0 && dtest != null) {
                if (this.getDebug()) {
                    System.out.println("Evaluating on test set...");
                }
                if (nominal) {
                    evalCls = new Evaluation(dtrain.numOutcomes());
                    evalCls.eval(dtest.getLabels(), this.m_Model.output(dtest.getFeatureMatrix(), Layer.TrainingMode.TEST));
                } else {
                    evalReg = new RegressionEvaluation(dtrain.numOutcomes());
                    evalReg.eval(dtest.getLabels(), this.m_Model.output(dtest.getFeatureMatrix(), Layer.TrainingMode.TEST));
                }
            }
            if ((stop = this.m_TrainStop.checkStopping(modelCont = evalCls != null ? new DL4JModelContainer((Object)this.m_Model, dtrain, Integer.valueOf(i), evalCls) : (evalReg != null ? new DL4JModelContainer((Object)this.m_Model, dtrain, Integer.valueOf(i), evalReg) : new DL4JModelContainer((Object)this.m_Model, null, Integer.valueOf(i))), triggers)) && this.getDebug()) {
                System.out.println("Training stopped:\n" + triggers);
            }
            ++i;
        } while (!stop);
        this.m_Trained = true;
        this.m_ProgrammaticInitialization = false;
    }

    public double[] distributionForInstance(Instance inst) throws Exception {
        if (this.m_ZeroR != null) {
            return this.m_ZeroR.distributionForInstance(inst);
        }
        if (this.m_ActualFilter != null) {
            this.m_ActualFilter.input(inst);
            this.m_ActualFilter.batchFinished();
            inst = this.m_ActualFilter.output();
        }
        Instances insts = new Instances(inst.dataset(), 0);
        insts.add(inst);
        DataSet dataset = (DataSet)this.m_Iterator.getIterator(insts, this.getSeed(), 1).next();
        INDArray predicted = this.m_Model.output(dataset.getFeatureMatrix(), false);
        predicted = predicted.getRow(0);
        double[] preds = new double[inst.numClasses()];
        for (int i = 0; i < preds.length; ++i) {
            preds[i] = predicted.getDouble(i);
        }
        if (preds.length > 1) {
            weka.core.Utils.normalize((double[])preds);
        }
        return preds;
    }

    @Override
    public MultiLayerNetwork getMultiLayerNetwork() {
        return this.m_Model;
    }

    @Override
    public void setTrainedMultiLayerNetwork(MultiLayerNetwork model) {
        this.m_Model = model;
        this.m_ProgrammaticInitialization = true;
        this.m_Trained = this.m_Model != null;
        this.m_Iterator = new DefaultInstancesIterator();
    }

    @Override
    public void setTrainingData(Instances data) {
        this.m_Header = new Instances(data, 0);
    }

    @Override
    public Filter getPreFilter() {
        return this.m_ActualFilter;
    }

    @Override
    public void setTrainedPreFilter(Filter filter) {
        this.m_ActualFilter = filter;
    }

    public String toString() {
        return this.m_Model.getLayerWiseConfigurations().toYaml();
    }

    public static void main(String[] args) {
        DL4JMultiLayerNetwork.runClassifier((Classifier)new DL4JMultiLayerNetwork(), (String[])args);
    }

    public static enum ModelFileType {
        YAML,
        JSON,
        DL4J_SERIALIZED_MODEL,
        WEKA_SERIALIZED_MODEL;

    }
}

