/*
 * Decompiled with CFR 0.152.
 */
package adams.ml.dl4j.model;

import adams.core.Randomizable;
import adams.ml.dl4j.model.AbstractModelConfigurator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import weka.dl4j.layers.DenseLayer;
import weka.dl4j.layers.OutputLayer;

public class Dl4jMlpClassifier
extends AbstractModelConfigurator
implements Randomizable {
    private static final long serialVersionUID = -1020010172973169754L;
    protected long m_Seed;
    protected OptimizationAlgorithm m_OptimizationAlgorithm;
    protected boolean m_UseRegularization;
    protected double m_L1;
    protected double m_L2;
    protected DropType m_DropType;
    protected double m_DropOut;
    protected Layer[] m_Layers;

    public String globalInfo() {
        return "Configures a network as used by " + weka.classifiers.functions.Dl4jMlpClassifier.class.getName() + ".";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("seed", "seed", (Object)1L);
        this.m_OptionManager.add("optimization-algorithm", "optimizationAlgorithm", (Object)OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
        this.m_OptionManager.add("use-regularization", "useRegularization", (Object)false);
        this.m_OptionManager.add("l1", "l1", (Object)0.0, (Number)0.0, null);
        this.m_OptionManager.add("l2", "l2", (Object)0.0, (Number)0.0, null);
        this.m_OptionManager.add("drop-type", "dropType", (Object)DropType.NONE);
        this.m_OptionManager.add("drop-out", "dropOut", (Object)0.0, (Number)0.0, null);
        this.m_OptionManager.add("layer", "layers", (Object)new Layer[]{new OutputLayer()});
    }

    public void setSeed(long value) {
        this.m_Seed = value;
        this.reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value for the model.";
    }

    public void setOptimizationAlgorithm(OptimizationAlgorithm value) {
        this.m_OptimizationAlgorithm = value;
        this.reset();
    }

    public OptimizationAlgorithm getOptimizationAlgorithm() {
        return this.m_OptimizationAlgorithm;
    }

    public String optimizationAlgorithmTipText() {
        return "The optimization algorithm.";
    }

    public void setUseRegularization(boolean value) {
        this.m_UseRegularization = value;
        this.reset();
    }

    public boolean getUseRegularization() {
        return this.m_UseRegularization;
    }

    public String useRegularizationTipText() {
        return "If enabled, regularization is used.";
    }

    public void setL1(double value) {
        if (this.getOptionManager().isValid("l1", (Number)value)) {
            this.m_L1 = value;
            this.reset();
        }
    }

    public double getL1() {
        return this.m_L1;
    }

    public String l1TipText() {
        return "The L1 value.";
    }

    public void setL2(double value) {
        if (this.getOptionManager().isValid("l2", (Number)value)) {
            this.m_L2 = value;
            this.reset();
        }
    }

    public double getL2() {
        return this.m_L2;
    }

    public String l2TipText() {
        return "The L2 value.";
    }

    public void setDropType(DropType value) {
        this.m_DropType = value;
        this.reset();
    }

    public DropType getDropType() {
        return this.m_DropType;
    }

    public String dropTypeTipText() {
        return "The type of drop to use.";
    }

    public void setDropOut(double value) {
        if (this.getOptionManager().isValid("dropOut", (Number)value)) {
            this.m_DropOut = value;
            this.reset();
        }
    }

    public double getDropOut() {
        return this.m_DropOut;
    }

    public String dropOutTipText() {
        return "The drop-out value.";
    }

    public void setLayers(Layer[] layers) {
        this.m_Layers = layers;
        this.reset();
    }

    public Layer[] getLayers() {
        return this.m_Layers;
    }

    public String layersTipText() {
        return "The layers specification.";
    }

    protected int getNumUnits(Layer layer) {
        if (layer instanceof DenseLayer) {
            return ((DenseLayer)layer).getNOut();
        }
        if (layer instanceof OutputLayer) {
            return ((OutputLayer)layer).getNOut();
        }
        return -1;
    }

    protected void setNumIncoming(Layer layer, int numInputs) {
        if (layer instanceof DenseLayer) {
            ((DenseLayer)layer).setNIn(numInputs);
        } else if (layer instanceof OutputLayer) {
            ((OutputLayer)layer).setNIn(numInputs);
        }
    }

    protected Model doConfigureModel(int numInput, int numOutput) {
        NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
        builder.setOptimizationAlgo(this.getOptimizationAlgorithm());
        builder.setSeed(this.m_Seed);
        if (this.m_UseRegularization) {
            builder.setUseRegularization(true);
            builder.setL1(this.m_L1);
            builder.setL2(this.m_L2);
        }
        switch (this.m_DropType) {
            case NONE: {
                builder.setDropOut(0.0);
                break;
            }
            case DROP_OUT: {
                builder.setDropOut(this.m_DropOut);
                break;
            }
            case DROP_CONNECT: {
                builder.setUseDropConnect(true);
                builder.setDropOut(this.m_DropOut);
                break;
            }
            default: {
                throw new IllegalStateException("Unhandled drop type: " + (Object)((Object)this.m_DropType));
            }
        }
        NeuralNetConfiguration.ListBuilder listbuilder = builder.list(this.getLayers());
        for (int i = 0; i < this.m_Layers.length; ++i) {
            if (i == 0) {
                this.setNumIncoming(this.m_Layers[i], numInput);
            } else {
                this.setNumIncoming(this.m_Layers[i], this.getNumUnits(this.m_Layers[i - 1]));
            }
            if (i == this.m_Layers.length - 1) {
                ((OutputLayer)this.m_Layers[i]).setNOut(numOutput);
            }
            listbuilder = listbuilder.layer(i, this.m_Layers[i]);
        }
        listbuilder = listbuilder.pretrain(false).backprop(true);
        return new MultiLayerNetwork(listbuilder.build());
    }

    public static enum DropType {
        NONE,
        DROP_CONNECT,
        DROP_OUT;

    }
}

