package weka.classifiers.functions;

import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Random;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.rules.ZeroR;
import weka.core.BatchPredictor;
import weka.core.Capabilities;
import weka.core.CapabilitiesHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionMetadata;
import weka.core.Utils;
import weka.dl4j.FileIterationListener;
import weka.dl4j.iterators.AbstractDataSetIterator;
import weka.dl4j.iterators.ConvolutionalInstancesIterator;
import weka.dl4j.iterators.DefaultInstancesIterator;
import weka.dl4j.iterators.ImageDataSetIterator;
import weka.dl4j.layers.DenseLayer;
import weka.dl4j.layers.OutputLayer;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.attribute.Standardize;

/* loaded from: input_file:weka/classifiers/functions/Dl4jMlpClassifier.class */
public class Dl4jMlpClassifier extends RandomizableClassifier implements BatchPredictor, CapabilitiesHandler {
    protected static final long serialVersionUID = -6363254116597574265L;
    protected ReplaceMissingValues m_replaceMissing;
    protected Filter m_normalize;
    protected NominalToBinary m_nominalToBinary;
    protected ZeroR m_zeroR;
    protected transient MultiLayerNetwork m_model;
    protected final Logger m_log = LoggerFactory.getLogger(Dl4jMlpClassifier.class);
    protected File m_logFile = new File(System.getProperty("user.dir"));
    protected Layer[] m_layers = {new OutputLayer()};
    protected OptimizationAlgorithm m_algo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
    protected int m_numEpochs = 10;
    protected AbstractDataSetIterator m_iterator = new DefaultInstancesIterator();
    protected boolean m_standardizeInsteadOfNormalize = true;
    protected double m_x1 = 1.0d;
    protected double m_x0 = 0.0d;

    public static void main(String[] strArr) {
        runClassifier(new Dl4jMlpClassifier(), strArr);
    }

    public String globalInfo() {
        return "Classification and regression with multilayer perceptrons using DeepLearning4J.";
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        if (getDataSetIterator() instanceof ImageDataSetIterator) {
            capabilities.enable(Capabilities.Capability.STRING_ATTRIBUTES);
        } else {
            capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
            capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
            capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
            capabilities.enable(Capabilities.Capability.MISSING_VALUES);
            capabilities.enableDependency(Capabilities.Capability.STRING_ATTRIBUTES);
        }
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.DATE_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return capabilities;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
        if (this.m_replaceMissing != null) {
            ModelSerializer.writeModel(this.m_model, objectOutputStream, false);
        }
    }

    private void readObject(ObjectInputStream objectInputStream) throws ClassNotFoundException, IOException {
        objectInputStream.defaultReadObject();
        if (this.m_replaceMissing != null) {
            ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
            try {
                Thread.currentThread().setContextClassLoader(getClass().getClassLoader());
                this.m_model = ModelSerializer.restoreMultiLayerNetwork(objectInputStream, false);
                Thread.currentThread().setContextClassLoader(contextClassLoader);
            } catch (Throwable th) {
                Thread.currentThread().setContextClassLoader(contextClassLoader);
                throw th;
            }
        }
    }

    public File getLogFile() {
        return this.m_logFile;
    }

    @OptionMetadata(displayName = "log file", description = "The name of the log file to write loss information to (default = no log file).", commandLineParamName = "logFile", commandLineParamSynopsis = "-logFile <string>", displayOrder = 1)
    public void setLogFile(File file) {
        this.m_logFile = file;
    }

    public Layer[] getLayers() {
        return this.m_layers;
    }

    @OptionMetadata(displayName = "layer specification", description = "The specification of the layers.", commandLineParamName = "layers", commandLineParamSynopsis = "-layers <string>", displayOrder = 2)
    public void setLayers(Layer[] layerArr) {
        this.m_layers = layerArr;
    }

    public int getNumEpochs() {
        return this.m_numEpochs;
    }

    @OptionMetadata(description = "The number of epochs to perform", displayName = "number of epochs", commandLineParamName = "numEpochs", commandLineParamSynopsis = "-numEpochs <int>", displayOrder = 4)
    public void setNumEpochs(int i) {
        this.m_numEpochs = i;
    }

    @OptionMetadata(description = "Optimization algorithm (LINE_GRADIENT_DESCENT, CONJUGATE_GRADIENT, HESSIAN_FREE, LBFGS, STOCHASTIC_GRADIENT_DESCENT)", displayName = "optimization algorithm", commandLineParamName = "algorithm", commandLineParamSynopsis = "-algorithm <string>", displayOrder = 5)
    public OptimizationAlgorithm getOptimizationAlgorithm() {
        return this.m_algo;
    }

    public void setOptimizationAlgorithm(OptimizationAlgorithm optimizationAlgorithm) {
        this.m_algo = optimizationAlgorithm;
    }

    @OptionMetadata(description = "The dataset iterator to use", displayName = "dataset iterator", commandLineParamName = "iterator", commandLineParamSynopsis = "-iterator <string>", displayOrder = 6)
    public AbstractDataSetIterator getDataSetIterator() {
        return this.m_iterator;
    }

    public void setDataSetIterator(AbstractDataSetIterator abstractDataSetIterator) {
        this.m_iterator = abstractDataSetIterator;
    }

    protected int getNumUnits(Layer layer) {
        if (layer instanceof DenseLayer) {
            return ((DenseLayer) layer).getNOut();
        }
        if (layer instanceof OutputLayer) {
            return ((OutputLayer) layer).getNOut();
        }
        return -1;
    }

    protected void setNumIncoming(Layer layer, int i) {
        if (layer instanceof DenseLayer) {
            ((DenseLayer) layer).setNIn(i);
        } else if (layer instanceof OutputLayer) {
            ((OutputLayer) layer).setNIn(i);
        }
    }

    public void buildClassifier(Instances instances) throws Exception {
        ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
        try {
            Thread.currentThread().setContextClassLoader(getClass().getClassLoader());
            getCapabilities().testWithFail(instances);
            if (this.m_layers.length == 0) {
                throw new Exception("No layers have been added!");
            }
            if (!(this.m_layers[this.m_layers.length - 1] instanceof OutputLayer)) {
                throw new Exception("Last layer in network must be an output layer!");
            }
            Instances instances2 = new Instances(instances);
            instances2.deleteWithMissingClass();
            this.m_zeroR = null;
            if (instances2.numInstances() == 0 || instances2.numAttributes() < 2) {
                this.m_zeroR = new ZeroR();
                this.m_zeroR.buildClassifier(instances2);
                Thread.currentThread().setContextClassLoader(contextClassLoader);
                return;
            }
            this.m_replaceMissing = new ReplaceMissingValues();
            this.m_replaceMissing.setInputFormat(instances2);
            Instances useFilter = Filter.useFilter(instances2, this.m_replaceMissing);
            double classValue = useFilter.instance(0).classValue();
            int i = 1;
            while (i < useFilter.numInstances() && useFilter.instance(i).classValue() == classValue) {
                i++;
            }
            if (i == useFilter.numInstances()) {
                throw new Exception("All class values are the same. At least two class values should be different");
            }
            double classValue2 = useFilter.instance(i).classValue();
            this.m_nominalToBinary = new NominalToBinary();
            this.m_nominalToBinary.setInputFormat(useFilter);
            Instances useFilter2 = Filter.useFilter(useFilter, this.m_nominalToBinary);
            if (this.m_standardizeInsteadOfNormalize) {
                this.m_normalize = new Standardize();
                this.m_normalize.setOptions(new String[]{"-unset-class-temporarily"});
            } else {
                this.m_normalize = new Normalize();
            }
            this.m_normalize.setInputFormat(useFilter2);
            Instances useFilter3 = Filter.useFilter(useFilter2, this.m_normalize);
            double classValue3 = useFilter3.instance(0).classValue();
            this.m_x1 = (classValue - classValue2) / (classValue3 - useFilter3.instance(i).classValue());
            this.m_x0 = classValue - (this.m_x1 * classValue3);
            useFilter3.randomize(new Random(getSeed()));
            NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
            if (getOptimizationAlgorithm() == null) {
                builder.setOptimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
            } else {
                builder.setOptimizationAlgo(getOptimizationAlgorithm());
            }
            builder.setSeed(r0.nextInt());
            NeuralNetConfiguration.ListBuilder list = builder.list(getLayers());
            int numAttributes = getDataSetIterator().getNumAttributes(useFilter3);
            for (int i2 = 0; i2 < this.m_layers.length; i2++) {
                if (i2 == 0) {
                    setNumIncoming(this.m_layers[i2], numAttributes);
                } else {
                    setNumIncoming(this.m_layers[i2], getNumUnits(this.m_layers[i2 - 1]));
                }
                if (i2 == this.m_layers.length - 1) {
                    this.m_layers[i2].setNOut(useFilter3.numClasses());
                }
                list = list.layer(i2, this.m_layers[i2]);
            }
            if (getDataSetIterator() instanceof ImageDataSetIterator) {
                ImageDataSetIterator imageDataSetIterator = (ImageDataSetIterator) getDataSetIterator();
                list.setInputType(InputType.convolutionalFlat(imageDataSetIterator.getWidth(), imageDataSetIterator.getHeight(), imageDataSetIterator.getNumChannels()));
            } else if (getDataSetIterator() instanceof ConvolutionalInstancesIterator) {
                ConvolutionalInstancesIterator convolutionalInstancesIterator = (ConvolutionalInstancesIterator) getDataSetIterator();
                list.setInputType(InputType.convolutionalFlat(convolutionalInstancesIterator.getWidth(), convolutionalInstancesIterator.getHeight(), convolutionalInstancesIterator.getNumChannels()));
            }
            MultiLayerConfiguration build = list.pretrain(false).backprop(true).build();
            if (getDebug()) {
                System.err.println(build.toJson());
            }
            this.m_model = new MultiLayerNetwork(build);
            this.m_model.init();
            if (getDebug()) {
                System.err.println(this.m_model.conf().toYaml());
            }
            ArrayList arrayList = new ArrayList();
            arrayList.add(new ScoreIterationListener(useFilter3.numInstances() / getDataSetIterator().getTrainBatchSize()));
            if (getLogFile() != null && !getLogFile().isDirectory()) {
                arrayList.add(new FileIterationListener(getLogFile().getAbsolutePath(), (int) Math.ceil(useFilter3.numInstances() / getDataSetIterator().getTrainBatchSize())));
            }
            this.m_model.setListeners(arrayList);
            DataSetIterator iterator = getDataSetIterator().getIterator(useFilter3, getSeed());
            for (int i3 = 0; i3 < getNumEpochs(); i3++) {
                this.m_model.fit(iterator);
                if (getDebug()) {
                    this.m_log.info("*** Completed epoch {} ***", Integer.valueOf(i3 + 1));
                }
                iterator.reset();
            }
            Thread.currentThread().setContextClassLoader(contextClassLoader);
        } catch (Throwable th) {
            Thread.currentThread().setContextClassLoader(contextClassLoader);
            throw th;
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_zeroR != null) {
            return this.m_zeroR.distributionForInstance(instance);
        }
        this.m_replaceMissing.input(instance);
        this.m_nominalToBinary.input(this.m_replaceMissing.output());
        this.m_normalize.input(this.m_nominalToBinary.output());
        Instance output = this.m_normalize.output();
        Instances instances = new Instances(output.dataset(), 0);
        instances.add(output);
        INDArray row = this.m_model.output(((DataSet) getDataSetIterator().getIterator(instances, getSeed(), 1).next()).getFeatureMatrix(), false).getRow(0);
        double[] dArr = new double[output.numClasses()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = row.getDouble(i);
        }
        if (dArr.length > 1) {
            Utils.normalize(dArr);
        } else {
            dArr[0] = (dArr[0] * this.m_x1) + this.m_x0;
        }
        return dArr;
    }

    public String toString() {
        if (this.m_replaceMissing != null) {
            return this.m_model.getLayerWiseConfigurations().toYaml();
        }
        return null;
    }
}
