package weka.classifiers.functions;

import adams.core.MessageCollection;
import adams.core.SerializationHelper;
import adams.core.io.FileUtils;
import adams.core.io.PlaceholderFile;
import adams.data.conversion.DL4JJsonToModel;
import adams.data.conversion.DL4JYamlToModel;
import adams.flow.container.DL4JModelContainer;
import adams.flow.container.WekaTrainTestSetContainer;
import adams.ml.dl4j.iterationlistener.IterationListenerConfigurator;
import adams.ml.dl4j.model.Dl4jMlpClassifier;
import adams.ml.dl4j.model.ModelType;
import adams.ml.dl4j.trainstopcriterion.AbstractTrainStopCriterion;
import adams.ml.dl4j.trainstopcriterion.MaxEpoch;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import weka.classifiers.DL4JFilteredMultiLayerNetworkProvider;
import weka.classifiers.DL4JMultiLayerNetworkProvider;
import weka.classifiers.RandomSplitGenerator;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.rules.ZeroR;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.WekaOptionUtils;
import weka.dl4j.iterators.DefaultInstancesIterator;
import weka.dl4j.layers.DenseLayer;
import weka.dl4j.layers.OutputLayer;
import weka.filters.AllFilter;
import weka.filters.Filter;

/* loaded from: input_file:weka/classifiers/functions/DL4JMultiLayerNetwork.class */
public class DL4JMultiLayerNetwork extends RandomizableClassifier implements DL4JFilteredMultiLayerNetworkProvider {
    protected static final long serialVersionUID = -6363254116597574265L;
    public static final String LAYER = "layer";
    public static final String OPTIMIZATION_ALGORITHM = "optimization-algorithm";
    public static final String TRAIN_STOP = "train-stop";
    public static final String MINI_BATCH_SIZE = "mini-batch-size";
    public static final String RANDOMIZE_BETWEEN_EPOCHS = "randomize-between-epochs";
    public static final String DROP_TYPE = "drop-type";
    public static final String DROP_OUT = "drop-out";
    public static final String ITERATION_LISTENER = "iteration-listener";
    public static final String MODEL_FILE = "model-file";
    public static final String MODEL_FILE_TYPE = "model-file-type";
    public static final String FILTER = "filter";
    public static final String TEST_INTERVAL = "test-interval";
    public static final String TEST_PERCENTAGE = "test-percentage";
    protected ZeroR m_ZeroR;
    protected transient MultiLayerNetwork m_Model;
    protected Filter m_ActualFilter;
    protected Instances m_Header;
    protected boolean m_ProgrammaticInitialization;
    protected byte[] m_ModelData = null;
    protected boolean m_Trained = false;
    protected Layer[] m_Layers = getDefaultLayers();
    protected OptimizationAlgorithm m_Algorithm = getDefaultOptimizationAlgorithm();
    protected AbstractTrainStopCriterion m_TrainStop = getDefaultTrainStop();
    protected int m_MiniBatchSize = getDefaultMiniBatchSize();
    protected boolean m_RandomizeBetweenEpochs = getDefaultRandomizeBetweenEpochs();
    protected boolean m_UseRegularization = getDefaultUseRegularization();
    protected Dl4jMlpClassifier.DropType m_DropType = getDefaultDropType();
    protected double m_DropOut = getDefaultDropOut();
    protected IterationListenerConfigurator[] m_IterationListeners = getDefaultIterationListeners();
    protected PlaceholderFile m_ModelFile = getDefaultModelFile();
    protected ModelFileType m_ModelFileType = getDefaultModelFileType();
    protected Filter m_Filter = getDefaultFilter();
    protected DefaultInstancesIterator m_Iterator = null;
    protected int m_TestInterval = getDefaultTestInterval();
    protected double m_TestPercentage = getDefaultTestPercentage();

    /* loaded from: input_file:weka/classifiers/functions/DL4JMultiLayerNetwork$ModelFileType.class */
    public enum ModelFileType {
        YAML,
        JSON,
        DL4J_SERIALIZED_MODEL,
        WEKA_SERIALIZED_MODEL
    }

    public String globalInfo() {
        return "Classification and regression with multilayer perceptrons using DeepLearning4J.";
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        WekaOptionUtils.addOption(vector, layersTipText(), Utils.arrayToString(getDefaultLayers()), LAYER);
        WekaOptionUtils.addOption(vector, optimizationAlgorithmTipText(), "" + getDefaultOptimizationAlgorithm(), OPTIMIZATION_ALGORITHM);
        WekaOptionUtils.addOption(vector, trainStopTipText(), "" + getDefaultTrainStop(), TRAIN_STOP);
        WekaOptionUtils.addOption(vector, miniBatchSizeTipText(), "" + getDefaultMiniBatchSize(), MINI_BATCH_SIZE);
        WekaOptionUtils.addOption(vector, randomizeBetweenEpochsTipText(), "" + getDefaultRandomizeBetweenEpochs(), RANDOMIZE_BETWEEN_EPOCHS);
        WekaOptionUtils.addOption(vector, dropTypeTipText(), "" + getDefaultDropType(), DROP_TYPE);
        WekaOptionUtils.addOption(vector, dropOutTipText(), "" + getDefaultDropOut(), DROP_OUT);
        WekaOptionUtils.addOption(vector, iterationListenersTipText(), Utils.arrayToString(getDefaultIterationListeners()), ITERATION_LISTENER);
        WekaOptionUtils.addOption(vector, modelFileTipText(), "" + getDefaultModelFile(), MODEL_FILE);
        WekaOptionUtils.addOption(vector, modelFileTypeTipText(), "" + getDefaultModelFileType(), MODEL_FILE_TYPE);
        WekaOptionUtils.addOption(vector, filterTipText(), getDefaultFilter(), FILTER);
        WekaOptionUtils.addOption(vector, testIntervalTipText(), "" + getDefaultTestInterval(), TEST_INTERVAL);
        WekaOptionUtils.addOption(vector, testPercentageTipText(), "" + getDefaultTestPercentage(), TEST_PERCENTAGE);
        WekaOptionUtils.add(vector, super.listOptions());
        return WekaOptionUtils.toEnumeration(vector);
    }

    public void setOptions(String[] strArr) throws Exception {
        setLayers((Layer[]) WekaOptionUtils.parse(strArr, LAYER, getDefaultLayers(), Layer.class));
        setOptimizationAlgorithm((OptimizationAlgorithm) WekaOptionUtils.parse(strArr, OPTIMIZATION_ALGORITHM, getDefaultOptimizationAlgorithm()));
        setTrainStop((AbstractTrainStopCriterion) WekaOptionUtils.parse(strArr, TRAIN_STOP, getDefaultTrainStop()));
        setMiniBatchSize(WekaOptionUtils.parse(strArr, MINI_BATCH_SIZE, getDefaultMiniBatchSize()));
        setRandomizeBetweenEpochs(Utils.getFlag(RANDOMIZE_BETWEEN_EPOCHS, strArr));
        setDropType((Dl4jMlpClassifier.DropType) WekaOptionUtils.parse(strArr, DROP_TYPE, getDefaultDropType()));
        setDropOut(WekaOptionUtils.parse(strArr, DROP_OUT, getDefaultDropOut()));
        setIterationListeners((IterationListenerConfigurator[]) WekaOptionUtils.parse(strArr, ITERATION_LISTENER, getDefaultIterationListeners(), IterationListenerConfigurator.class));
        setModelFile(WekaOptionUtils.parse(strArr, MODEL_FILE, getDefaultModelFile()));
        setModelFileType((ModelFileType) WekaOptionUtils.parse(strArr, MODEL_FILE_TYPE, getDefaultModelFileType()));
        setFilter((Filter) WekaOptionUtils.parse(strArr, FILTER, getDefaultFilter()));
        setTestInterval(WekaOptionUtils.parse(strArr, TEST_INTERVAL, getDefaultTestInterval()));
        setTestPercentage(WekaOptionUtils.parse(strArr, TEST_PERCENTAGE, getDefaultTestPercentage()));
        super.setOptions(strArr);
    }

    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        if (!this.m_ProgrammaticInitialization) {
            if (getModelFile().isDirectory()) {
                WekaOptionUtils.add(arrayList, LAYER, getLayers());
                WekaOptionUtils.add(arrayList, OPTIMIZATION_ALGORITHM, getOptimizationAlgorithm());
                WekaOptionUtils.add(arrayList, DROP_TYPE, getDropType());
                WekaOptionUtils.add(arrayList, DROP_OUT, getDropOut());
            } else {
                WekaOptionUtils.add(arrayList, MODEL_FILE, getModelFile());
                WekaOptionUtils.add(arrayList, MODEL_FILE_TYPE, getModelFileType());
            }
            WekaOptionUtils.add(arrayList, TRAIN_STOP, getTrainStop());
            WekaOptionUtils.add(arrayList, MINI_BATCH_SIZE, getMiniBatchSize());
            WekaOptionUtils.add(arrayList, RANDOMIZE_BETWEEN_EPOCHS, getRandomizeBetweenEpochs());
            WekaOptionUtils.add(arrayList, ITERATION_LISTENER, getIterationListeners());
            WekaOptionUtils.add(arrayList, FILTER, getFilter());
            WekaOptionUtils.add(arrayList, TEST_PERCENTAGE, getTestPercentage());
            if (getTestPercentage() > 0.0d) {
                WekaOptionUtils.add(arrayList, TEST_INTERVAL, getTestInterval());
            }
        }
        WekaOptionUtils.add(arrayList, super.getOptions());
        return WekaOptionUtils.toArray(arrayList);
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities;
        if (this.m_Filter instanceof AllFilter) {
            capabilities = super.getCapabilities();
            capabilities.disableAll();
            capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
            capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
            capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
            capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
            capabilities.enable(Capabilities.Capability.DATE_CLASS);
            capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        } else {
            capabilities = this.m_Filter.getCapabilities();
            capabilities.setOwner(this);
        }
        return capabilities;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        if (this.m_Trained) {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            ModelSerializer.writeModel(this.m_Model, byteArrayOutputStream, false);
            this.m_ModelData = byteArrayOutputStream.toByteArray();
        }
        objectOutputStream.defaultWriteObject();
        this.m_ModelData = null;
    }

    private void readObject(ObjectInputStream objectInputStream) throws ClassNotFoundException, IOException {
        objectInputStream.defaultReadObject();
        if (this.m_ModelData != null) {
            this.m_Model = ModelSerializer.restoreMultiLayerNetwork(new ByteArrayInputStream(this.m_ModelData), false);
            this.m_ModelData = null;
        }
    }

    protected Layer[] getDefaultLayers() {
        return new Layer[]{new OutputLayer()};
    }

    public void setLayers(Layer[] layerArr) {
        this.m_Layers = layerArr;
    }

    public Layer[] getLayers() {
        return this.m_Layers;
    }

    public String layersTipText() {
        return "The layers for the network (last one must be an output layer).";
    }

    protected AbstractTrainStopCriterion getDefaultTrainStop() {
        return new MaxEpoch();
    }

    public void setTrainStop(AbstractTrainStopCriterion abstractTrainStopCriterion) {
        this.m_TrainStop = abstractTrainStopCriterion;
    }

    public AbstractTrainStopCriterion getTrainStop() {
        return this.m_TrainStop;
    }

    public String trainStopTipText() {
        return "The criterion for stopping training.";
    }

    protected int getDefaultMiniBatchSize() {
        return 100;
    }

    public void setMiniBatchSize(int i) {
        this.m_MiniBatchSize = i;
    }

    public int getMiniBatchSize() {
        return this.m_MiniBatchSize;
    }

    public String miniBatchSizeTipText() {
        return "The size to use for mini-batches.";
    }

    protected boolean getDefaultRandomizeBetweenEpochs() {
        return false;
    }

    public void setRandomizeBetweenEpochs(boolean z) {
        this.m_RandomizeBetweenEpochs = z;
    }

    public boolean getRandomizeBetweenEpochs() {
        return this.m_RandomizeBetweenEpochs;
    }

    public String randomizeBetweenEpochsTipText() {
        return "If enabled, the data gets randomized between epochs.";
    }

    protected OptimizationAlgorithm getDefaultOptimizationAlgorithm() {
        return OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
    }

    public void setOptimizationAlgorithm(OptimizationAlgorithm optimizationAlgorithm) {
        this.m_Algorithm = optimizationAlgorithm;
    }

    public OptimizationAlgorithm getOptimizationAlgorithm() {
        return this.m_Algorithm;
    }

    public String optimizationAlgorithmTipText() {
        return "The optimization algorithm to use.";
    }

    protected boolean getDefaultUseRegularization() {
        return false;
    }

    public void setUseRegularization(boolean z) {
        this.m_UseRegularization = z;
    }

    public boolean getUseRegularization() {
        return this.m_UseRegularization;
    }

    public String useRegularizationTipText() {
        return "If enabled, regularization is used.";
    }

    protected Dl4jMlpClassifier.DropType getDefaultDropType() {
        return Dl4jMlpClassifier.DropType.NONE;
    }

    public void setDropType(Dl4jMlpClassifier.DropType dropType) {
        this.m_DropType = dropType;
    }

    public Dl4jMlpClassifier.DropType getDropType() {
        return this.m_DropType;
    }

    public String dropTypeTipText() {
        return "The type of drop to use.";
    }

    protected double getDefaultDropOut() {
        return 0.0d;
    }

    public void setDropOut(double d) {
        if (d < 0.0d || d > 1.0d) {
            return;
        }
        this.m_DropOut = d;
    }

    public double getDropOut() {
        return this.m_DropOut;
    }

    public String dropOutTipText() {
        return "The drop-out value.";
    }

    protected IterationListenerConfigurator[] getDefaultIterationListeners() {
        return new IterationListenerConfigurator[0];
    }

    public void setIterationListeners(IterationListenerConfigurator[] iterationListenerConfiguratorArr) {
        this.m_IterationListeners = iterationListenerConfiguratorArr;
    }

    public IterationListenerConfigurator[] getIterationListeners() {
        return this.m_IterationListeners;
    }

    public String iterationListenersTipText() {
        return "The configurators for iteration listeners to use.";
    }

    protected Filter getDefaultFilter() {
        return new AllFilter();
    }

    public void setFilter(Filter filter) {
        this.m_Filter = filter;
    }

    public Filter getFilter() {
        return this.m_Filter;
    }

    public String filterTipText() {
        return "The filter to apply to the data, ignored if " + AllFilter.class.getName() + ".";
    }

    protected PlaceholderFile getDefaultModelFile() {
        return new PlaceholderFile();
    }

    public void setModelFile(PlaceholderFile placeholderFile) {
        this.m_ModelFile = placeholderFile;
    }

    public PlaceholderFile getModelFile() {
        return this.m_ModelFile;
    }

    public String modelFileTipText() {
        return "The (optional) file to load a model from (trained or just structure), which will be used as template for the network structure; ignored if pointing to a directory.";
    }

    protected ModelFileType getDefaultModelFileType() {
        return ModelFileType.YAML;
    }

    public void setModelFileType(ModelFileType modelFileType) {
        this.m_ModelFileType = modelFileType;
    }

    public ModelFileType getModelFileType() {
        return this.m_ModelFileType;
    }

    public String modelFileTypeTipText() {
        return "The type of the model file; in case of '" + ModelFileType.WEKA_SERIALIZED_MODEL + "', the object must implement '" + DL4JMultiLayerNetworkProvider.class.getName() + "'.";
    }

    protected int getDefaultTestInterval() {
        return 100;
    }

    public void setTestInterval(int i) {
        if (i > 0) {
            this.m_TestInterval = i;
        }
    }

    public int getTestInterval() {
        return this.m_TestInterval;
    }

    public String testIntervalTipText() {
        return "The interval (of epochs) to test the model, if a test percentage is specified.";
    }

    protected double getDefaultTestPercentage() {
        return 0.0d;
    }

    public void setTestPercentage(double d) {
        if (d < 0.0d || d >= 1.0d) {
            return;
        }
        this.m_TestPercentage = d;
    }

    public double getTestPercentage() {
        return this.m_TestPercentage;
    }

    public String testPercentageTipText() {
        return "The percentage (0-1) of the training data to set aside for evaluating the model; no testing performed if 0.";
    }

    protected int getNumUnits(Layer layer) {
        if (layer instanceof DenseLayer) {
            return ((DenseLayer) layer).getNOut();
        }
        if (layer instanceof OutputLayer) {
            return ((OutputLayer) layer).getNOut();
        }
        return -1;
    }

    protected void setNumIncoming(Layer layer, int i) {
        if (layer instanceof DenseLayer) {
            ((DenseLayer) layer).setNIn(i);
        } else if (layer instanceof OutputLayer) {
            ((OutputLayer) layer).setNIn(i);
        }
    }

    protected MultiLayerNetwork loadModel() throws Exception {
        MultiLayerNetwork restoreMultiLayerNetwork;
        if (getDebug()) {
            System.out.println("Loading model (" + this.m_ModelFileType + "): " + this.m_ModelFile);
        }
        switch (this.m_ModelFileType) {
            case YAML:
                String flatten = adams.core.Utils.flatten(FileUtils.loadFromFile(this.m_ModelFile), "\n");
                DL4JYamlToModel dL4JYamlToModel = new DL4JYamlToModel();
                dL4JYamlToModel.setType(ModelType.MULTI_LAYER_NETWORK);
                dL4JYamlToModel.setInput(flatten);
                String convert = dL4JYamlToModel.convert();
                if (convert == null) {
                    restoreMultiLayerNetwork = (MultiLayerNetwork) dL4JYamlToModel.getOutput();
                    break;
                } else {
                    throw new Exception("Failed to convert YAML model loaded from '" + this.m_ModelFile + "': " + convert);
                }
            case JSON:
                String flatten2 = adams.core.Utils.flatten(FileUtils.loadFromFile(this.m_ModelFile), "\n");
                DL4JJsonToModel dL4JJsonToModel = new DL4JJsonToModel();
                dL4JJsonToModel.setType(ModelType.MULTI_LAYER_NETWORK);
                dL4JJsonToModel.setInput(flatten2);
                String convert2 = dL4JJsonToModel.convert();
                if (convert2 == null) {
                    restoreMultiLayerNetwork = (MultiLayerNetwork) dL4JJsonToModel.getOutput();
                    break;
                } else {
                    throw new Exception("Failed to convert JSON model loaded from '" + this.m_ModelFile + "': " + convert2);
                }
            case WEKA_SERIALIZED_MODEL:
                restoreMultiLayerNetwork = ((DL4JMultiLayerNetworkProvider) SerializationHelper.read(this.m_ModelFile.getAbsolutePath())).getMultiLayerNetwork();
                break;
            case DL4J_SERIALIZED_MODEL:
                restoreMultiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(this.m_ModelFile.getAbsoluteFile());
                break;
            default:
                throw new IllegalStateException("Unhandled model file type: " + this.m_ModelFileType);
        }
        return restoreMultiLayerNetwork;
    }

    protected MultiLayerNetwork generateModel(Instances instances) throws Exception {
        if (getDebug()) {
            System.out.println("Generating model...");
        }
        if (this.m_Layers.length == 0) {
            throw new Exception("No layers have been defined!");
        }
        if (!(this.m_Layers[this.m_Layers.length - 1] instanceof OutputLayer)) {
            throw new Exception("Last layer in network must be an output layer!");
        }
        NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
        builder.setOptimizationAlgo(getOptimizationAlgorithm());
        builder.setSeed(getSeed());
        if (this.m_UseRegularization) {
            builder.setUseRegularization(true);
        }
        switch (this.m_DropType) {
            case NONE:
                builder.setDropOut(0.0d);
                break;
            case DROP_OUT:
                builder.setDropOut(this.m_DropOut);
                break;
            case DROP_CONNECT:
                builder.setUseDropConnect(true);
                builder.setDropOut(this.m_DropOut);
                break;
            default:
                throw new IllegalStateException("Unhandled drop type: " + this.m_DropType);
        }
        NeuralNetConfiguration.ListBuilder list = builder.list(getLayers());
        int numAttributes = this.m_Iterator.getNumAttributes(instances);
        for (int i = 0; i < this.m_Layers.length; i++) {
            if (i == 0) {
                setNumIncoming(this.m_Layers[i], numAttributes);
            } else {
                setNumIncoming(this.m_Layers[i], getNumUnits(this.m_Layers[i - 1]));
            }
            if (i == this.m_Layers.length - 1) {
                this.m_Layers[i].setNOut(instances.numClasses());
            }
            list = list.layer(i, this.m_Layers[i]);
        }
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(list.pretrain(false).backprop(true).build());
        multiLayerNetwork.init();
        return multiLayerNetwork;
    }

    public void buildClassifier(Instances instances) throws Exception {
        Instances instances2;
        Instances instances3;
        boolean checkStopping;
        getCapabilities().testWithFail(instances);
        this.m_Iterator = new DefaultInstancesIterator();
        Instances instances4 = new Instances(instances);
        instances4.deleteWithMissingClass();
        this.m_ZeroR = null;
        if (instances4.numInstances() == 0 || instances4.numAttributes() < 2) {
            System.err.println("Not enough data, using ZeroR model!");
            this.m_ZeroR = new ZeroR();
            this.m_ZeroR.buildClassifier(instances4);
            return;
        }
        this.m_ActualFilter = null;
        if (!(this.m_Filter instanceof AllFilter)) {
            this.m_ActualFilter = Filter.makeCopy(this.m_Filter);
            this.m_ActualFilter.setInputFormat(instances4);
            instances4 = Filter.useFilter(instances4, this.m_ActualFilter);
        }
        this.m_Header = new Instances(instances4, 0);
        boolean isNominal = this.m_Header.classAttribute().isNominal();
        if (!this.m_ModelFile.exists() || this.m_ModelFile.isDirectory()) {
            this.m_Model = generateModel(instances4);
        } else {
            this.m_Model = loadModel();
        }
        this.m_Model.init();
        if (getDebug()) {
            System.out.println("Network:\n" + this.m_Model.getLayerWiseConfigurations().toYaml());
        }
        ArrayList arrayList = new ArrayList();
        for (IterationListenerConfigurator iterationListenerConfigurator : this.m_IterationListeners) {
            arrayList.addAll(iterationListenerConfigurator.configureIterationListeners());
        }
        this.m_Model.setListeners(arrayList);
        if (this.m_TestPercentage > 0.0d) {
            WekaTrainTestSetContainer next = new RandomSplitGenerator(instances4, this.m_Seed, 1.0d - this.m_TestPercentage).next();
            instances2 = (Instances) next.getValue("Train");
            instances3 = (Instances) next.getValue("Test");
        } else {
            instances2 = instances4;
            instances3 = null;
        }
        DataSet instancesToDataSet = weka.classifiers.functions.dl4j.Utils.instancesToDataSet(instances2);
        DataSet instancesToDataSet2 = instances3 != null ? weka.classifiers.functions.dl4j.Utils.instancesToDataSet(instances3) : null;
        Random random = new Random(getSeed());
        int seed = getSeed();
        MessageCollection messageCollection = new MessageCollection();
        int i = 0;
        do {
            if (this.m_RandomizeBetweenEpochs) {
                seed = random.nextInt();
            }
            this.m_Model.fit(this.m_MiniBatchSize < 1 ? this.m_Iterator.getIterator(instances2, seed) : this.m_Iterator.getIterator(instances2, seed, this.m_MiniBatchSize));
            if (getDebug() && (i + 1) % 100 == 0) {
                System.out.println("Epoch #" + (i + 1) + " finished");
            }
            Evaluation evaluation = null;
            RegressionEvaluation regressionEvaluation = null;
            if (this.m_TestPercentage > 0.0d && i % this.m_TestInterval == 0 && instancesToDataSet2 != null) {
                if (getDebug()) {
                    System.out.println("Evaluating on test set...");
                }
                if (isNominal) {
                    evaluation = new Evaluation(instancesToDataSet.numOutcomes());
                    evaluation.eval(instancesToDataSet2.getLabels(), this.m_Model.output(instancesToDataSet2.getFeatureMatrix(), Layer.TrainingMode.TEST));
                } else {
                    regressionEvaluation = new RegressionEvaluation(instancesToDataSet.numOutcomes());
                    regressionEvaluation.eval(instancesToDataSet2.getLabels(), this.m_Model.output(instancesToDataSet2.getFeatureMatrix(), Layer.TrainingMode.TEST));
                }
            }
            checkStopping = this.m_TrainStop.checkStopping(evaluation != null ? new DL4JModelContainer(this.m_Model, instancesToDataSet, Integer.valueOf(i), evaluation) : regressionEvaluation != null ? new DL4JModelContainer(this.m_Model, instancesToDataSet, Integer.valueOf(i), regressionEvaluation) : new DL4JModelContainer(this.m_Model, (DataSet) null, Integer.valueOf(i)), messageCollection);
            if (checkStopping && getDebug()) {
                System.out.println("Training stopped:\n" + messageCollection);
            }
            i++;
        } while (!checkStopping);
        this.m_Trained = true;
        this.m_ProgrammaticInitialization = false;
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_ZeroR != null) {
            return this.m_ZeroR.distributionForInstance(instance);
        }
        if (this.m_ActualFilter != null) {
            this.m_ActualFilter.input(instance);
            this.m_ActualFilter.batchFinished();
            instance = this.m_ActualFilter.output();
        }
        Instances instances = new Instances(instance.dataset(), 0);
        instances.add(instance);
        INDArray row = this.m_Model.output(((DataSet) this.m_Iterator.getIterator(instances, getSeed(), 1).next()).getFeatureMatrix(), false).getRow(0);
        double[] dArr = new double[instance.numClasses()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = row.getDouble(i);
        }
        if (dArr.length > 1) {
            Utils.normalize(dArr);
        }
        return dArr;
    }

    @Override // weka.classifiers.DL4JMultiLayerNetworkProvider
    public MultiLayerNetwork getMultiLayerNetwork() {
        return this.m_Model;
    }

    @Override // weka.classifiers.DL4JMultiLayerNetworkProvider
    public void setTrainedMultiLayerNetwork(MultiLayerNetwork multiLayerNetwork) {
        this.m_Model = multiLayerNetwork;
        this.m_ProgrammaticInitialization = true;
        this.m_Trained = this.m_Model != null;
        this.m_Iterator = new DefaultInstancesIterator();
    }

    @Override // weka.classifiers.DL4JMultiLayerNetworkProvider
    public void setTrainingData(Instances instances) {
        this.m_Header = new Instances(instances, 0);
    }

    @Override // weka.classifiers.DL4JFilteredMultiLayerNetworkProvider
    public Filter getPreFilter() {
        return this.m_ActualFilter;
    }

    @Override // weka.classifiers.DL4JFilteredMultiLayerNetworkProvider
    public void setTrainedPreFilter(Filter filter) {
        this.m_ActualFilter = filter;
    }

    public String toString() {
        return this.m_Model.getLayerWiseConfigurations().toYaml();
    }

    public static void main(String[] strArr) {
        runClassifier(new DL4JMultiLayerNetwork(), strArr);
    }
}
