package adams.ml.dl4j.model;

import adams.core.Randomizable;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/* loaded from: input_file:adams/ml/dl4j/model/AnimalsLeNet.class */
public class AnimalsLeNet extends AbstractModelConfigurator implements Randomizable {
    private static final long serialVersionUID = -4915929902612899539L;
    protected int m_NumIterations;
    protected double m_LearningRate;
    protected long m_Seed;
    protected boolean m_UseRegularization;
    protected double m_L2;
    protected Activation m_Activation;
    protected WeightInit m_WeightInit;
    protected Activation m_OutputActivation;
    protected LossFunctions.LossFunction m_OutputLossFunction;

    public String globalInfo() {
        return "Convolution lenet network for classifying animals, based on this example:\nhttps://github.com/deeplearning4j/dl4j-examples/blob/ceff9965f8ae0cd8cb5cf32f8f49894a1be511fe/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/AnimalsClassification.java";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("num-iterations", "numIterations", 1, 1, (Number) null);
        this.m_OptionManager.add("learning-rate", "learningRate", Double.valueOf(1.0E-4d), Double.valueOf(0.0d), (Number) null);
        this.m_OptionManager.add("seed", "seed", 6L);
        this.m_OptionManager.add("use-regularization", "useRegularization", false);
        this.m_OptionManager.add("l2", "l2", Double.valueOf(0.005d), Double.valueOf(0.0d), (Number) null);
        this.m_OptionManager.add("activation", "activation", Activation.RELU);
        this.m_OptionManager.add("weight-init", "weightInit", WeightInit.XAVIER);
        this.m_OptionManager.add("output-activation", "outputActivation", Activation.SOFTMAX);
        this.m_OptionManager.add("output-loss-function", "outputLossFunction", LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD);
    }

    public void setNumIterations(int i) {
        if (getOptionManager().isValid("numIterations", Integer.valueOf(i))) {
            this.m_NumIterations = i;
            reset();
        }
    }

    public int getNumIterations() {
        return this.m_NumIterations;
    }

    public String numIterationsTipText() {
        return "The number of iterations to perform.";
    }

    public void setLearningRate(double d) {
        if (getOptionManager().isValid("learningRate", Double.valueOf(d))) {
            this.m_LearningRate = d;
            reset();
        }
    }

    public double getLearningRate() {
        return this.m_LearningRate;
    }

    public String learningRateTipText() {
        return "The learning rate to use.";
    }

    public void setSeed(long j) {
        if (getOptionManager().isValid("seed", Long.valueOf(j))) {
            this.m_Seed = j;
            reset();
        }
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value for the weight initialization.";
    }

    public void setUseRegularization(boolean z) {
        this.m_UseRegularization = z;
        reset();
    }

    public boolean getUseRegularization() {
        return this.m_UseRegularization;
    }

    public String useRegularizationTipText() {
        return "If enabled, regularization is used.";
    }

    public void setL2(double d) {
        if (getOptionManager().isValid("l2", Double.valueOf(d))) {
            this.m_L2 = d;
            reset();
        }
    }

    public double getL2() {
        return this.m_L2;
    }

    public String l2TipText() {
        return "The L2 value.";
    }

    public void setActivation(Activation activation) {
        this.m_Activation = activation;
        reset();
    }

    public Activation getActivation() {
        return this.m_Activation;
    }

    public String activationTipText() {
        return "The activation to use.";
    }

    public void setWeightInit(WeightInit weightInit) {
        this.m_WeightInit = weightInit;
        reset();
    }

    public WeightInit getWeightInit() {
        return this.m_WeightInit;
    }

    public String weightInitTipText() {
        return "The weight init to use.";
    }

    public void setOutputActivation(Activation activation) {
        this.m_OutputActivation = activation;
        reset();
    }

    public Activation getOutputActivation() {
        return this.m_OutputActivation;
    }

    public String outputActivationTipText() {
        return "The activation to use for the output layer.";
    }

    public void setOutputLossFunction(LossFunctions.LossFunction lossFunction) {
        this.m_OutputLossFunction = lossFunction;
        reset();
    }

    public LossFunctions.LossFunction getOutputLossFunction() {
        return this.m_OutputLossFunction;
    }

    public String outputLossFunctionTipText() {
        return "The loss function to use for the output layer.";
    }

    protected ConvolutionLayer convInit(String str, int i, int i2, int[] iArr, int[] iArr2, int[] iArr3, double d) {
        return new ConvolutionLayer.Builder(iArr, iArr2, iArr3).name(str).nIn(i).nOut(i2).biasInit(d).build();
    }

    protected ConvolutionLayer conv5x5(String str, int i, int[] iArr, int[] iArr2, double d) {
        return new ConvolutionLayer.Builder(new int[]{5, 5}, iArr, iArr2).name(str).nOut(i).biasInit(d).build();
    }

    protected SubsamplingLayer maxPool(String str, int[] iArr) {
        return new SubsamplingLayer.Builder(iArr, new int[]{2, 2}).name(str).build();
    }

    @Override // adams.ml.dl4j.model.AbstractModelConfigurator
    protected Model doConfigureModel(int i, int i2) {
        return new MultiLayerNetwork(new NeuralNetConfiguration.Builder().seed(this.m_Seed).iterations(this.m_NumIterations).regularization(this.m_UseRegularization).l2(this.m_L2).activation(this.m_Activation).learningRate(this.m_LearningRate).weightInit(this.m_WeightInit).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.RMSPROP).list().layer(0, convInit("cnn1", 3, 50, new int[]{5, 5}, new int[]{1, 1}, new int[]{0, 0}, 0.0d)).layer(1, maxPool("maxpool1", new int[]{2, 2})).layer(2, conv5x5("cnn2", 100, new int[]{5, 5}, new int[]{1, 1}, 0.0d)).layer(3, maxPool("maxool2", new int[]{2, 2})).layer(4, new DenseLayer.Builder().nOut(500).build()).layer(5, new OutputLayer.Builder(this.m_OutputLossFunction).nOut(i2).activation(this.m_OutputActivation).build()).backprop(true).pretrain(false).setInputType(InputType.convolutional(100, 100, 3)).build());
    }
}
