/*
 * Decompiled with CFR 0.152.
 */
package adams.ml.dl4j.model;

import adams.core.Randomizable;
import adams.ml.dl4j.model.AbstractModelConfigurator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class SimpleRegressionMultiLayerNetwork
extends AbstractModelConfigurator
implements Randomizable {
    private static final long serialVersionUID = -4915929902612899539L;
    protected int m_NumIterations;
    protected double m_LearningRate;
    protected long m_Seed;
    protected int m_HiddenNodes;
    protected Activation m_Activation;
    protected Activation m_OutputActivation;
    protected WeightInit m_WeightInit;
    protected OptimizationAlgorithm m_OptimizationAlgorithm;
    protected Updater m_Updater;
    protected LossFunctions.LossFunction m_LossFunction;

    public String globalInfo() {
        return "A simple multilayer network, adapted from this regression example:\nhttps://github.com/deeplearning4j/dl4j-examples/blob/bde80477139bbf74bea729f66e6dcd59944933ee/dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/regression/SingleTimestepRegressionExample.java";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("num-iterations", "numIterations", (Object)1, (Number)1, null);
        this.m_OptionManager.add("learning-rate", "learningRate", (Object)0.0015, (Number)0.0, null);
        this.m_OptionManager.add("seed", "seed", (Object)140L);
        this.m_OptionManager.add("hidden-nodes", "hiddenNodes", (Object)10, (Number)1, null);
        this.m_OptionManager.add("activation", "activation", (Object)Activation.TANH);
        this.m_OptionManager.add("output-activation", "outputActivation", (Object)Activation.IDENTITY);
        this.m_OptionManager.add("weight-init", "weightInit", (Object)WeightInit.XAVIER);
        this.m_OptionManager.add("optimization-algorithm", "optimizationAlgorithm", (Object)OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
        this.m_OptionManager.add("updater", "updater", (Object)Updater.NESTEROVS);
        this.m_OptionManager.add("loss-function", "lossFunction", (Object)LossFunctions.LossFunction.MSE);
    }

    public void setNumIterations(int value) {
        if (this.getOptionManager().isValid("numIterations", (Number)value)) {
            this.m_NumIterations = value;
            this.reset();
        }
    }

    public int getNumIterations() {
        return this.m_NumIterations;
    }

    public String numIterationsTipText() {
        return "The number of iterations to perform.";
    }

    public void setLearningRate(double value) {
        if (this.getOptionManager().isValid("learningRate", (Number)value)) {
            this.m_LearningRate = value;
            this.reset();
        }
    }

    public double getLearningRate() {
        return this.m_LearningRate;
    }

    public String learningRateTipText() {
        return "The learning rate to use.";
    }

    public void setSeed(long value) {
        if (this.getOptionManager().isValid("seed", (Number)value)) {
            this.m_Seed = value;
            this.reset();
        }
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value for the weight initialization.";
    }

    public void setHiddenNodes(int value) {
        this.m_HiddenNodes = value;
        this.reset();
    }

    public int getHiddenNodes() {
        return this.m_HiddenNodes;
    }

    public String hiddenNodesTipText() {
        return "The number of hidden nodes.";
    }

    public void setActivation(Activation value) {
        this.m_Activation = value;
        this.reset();
    }

    public Activation getActivation() {
        return this.m_Activation;
    }

    public String activationTipText() {
        return "The activation to use.";
    }

    public void setOutputActivation(Activation value) {
        this.m_OutputActivation = value;
        this.reset();
    }

    public Activation getOutputActivation() {
        return this.m_OutputActivation;
    }

    public String outputActivationTipText() {
        return "The activation to use for the output layer.";
    }

    public void setWeightInit(WeightInit value) {
        this.m_WeightInit = value;
        this.reset();
    }

    public WeightInit getWeightInit() {
        return this.m_WeightInit;
    }

    public String weightInitTipText() {
        return "The weight init to use.";
    }

    public void setUpdater(Updater value) {
        this.m_Updater = value;
        this.reset();
    }

    public Updater getUpdater() {
        return this.m_Updater;
    }

    public String updaterTipText() {
        return "The updater to use.";
    }

    public void setLossFunction(LossFunctions.LossFunction value) {
        this.m_LossFunction = value;
        this.reset();
    }

    public LossFunctions.LossFunction getLossFunction() {
        return this.m_LossFunction;
    }

    public String lossFunctionTipText() {
        return "The loss function to use.";
    }

    public void setOptimizationAlgorithm(OptimizationAlgorithm value) {
        this.m_OptimizationAlgorithm = value;
        this.reset();
    }

    public OptimizationAlgorithm getOptimizationAlgorithm() {
        return this.m_OptimizationAlgorithm;
    }

    public String optimizationAlgorithmTipText() {
        return "The optimization algorithm to use.";
    }

    @Override
    protected Model doConfigureModel(int numInput, int numOutput) {
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(this.m_Seed).optimizationAlgo(this.m_OptimizationAlgorithm).iterations(this.m_NumIterations).list().layer(0, (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(numInput)).nOut(this.m_HiddenNodes)).biasLearningRate(0.01)).activation(this.m_Activation)).learningRate(this.m_LearningRate)).weightInit(this.m_WeightInit)).updater(this.m_Updater)).build()).layer(1, (Layer)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder(this.m_LossFunction).nIn(this.m_HiddenNodes)).nOut(1)).biasLearningRate(0.01)).activation(this.m_OutputActivation)).learningRate(this.m_LearningRate)).weightInit(this.m_WeightInit)).updater(this.m_Updater)).build()).build();
        return new MultiLayerNetwork(conf);
    }
}

