package adams.ml.dl4j.modelgenerator;

import adams.core.Randomizable;
import adams.core.base.BaseBoolean;
import adams.core.base.BaseDouble;
import adams.core.base.BaseInteger;
import adams.core.option.OptionUtils;
import adams.ml.dl4j.model.Dl4jMlpClassifier;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import weka.classifiers.functions.DL4JMultiLayerNetwork;
import weka.dl4j.layers.BatchNormalization;
import weka.dl4j.layers.DenseLayer;
import weka.dl4j.layers.OutputLayer;

/* loaded from: input_file:adams/ml/dl4j/modelgenerator/RandomDl4jMlpClassifierNetworks.class */
public class RandomDl4jMlpClassifierNetworks extends AbstractModelGenerator implements Randomizable {
    private static final long serialVersionUID = -3934375675313822523L;
    protected Dl4jMlpClassifier m_DefaultNetwork;
    protected DenseLayer m_DefaultDenseLayer;
    protected OutputLayer m_DefaultOutputLayer;
    protected long m_Seed;
    protected int m_NumNetworks;
    protected BaseInteger[] m_NumLayers;
    protected BaseInteger[] m_NumNodes;
    protected BaseDouble[] m_LearningRate;
    protected BaseBoolean[] m_UseRegularization;
    protected BaseDouble[] m_L1;
    protected BaseDouble[] m_L2;
    protected Dl4jMlpClassifier.DropType[] m_DropType;
    protected BaseDouble[] m_DropOut;
    protected BaseInteger[] m_LearningRateScheduleEpochs;
    protected BaseDouble[] m_LearningRateScheduleDivisors;
    protected boolean m_InsertBatchNormLayers;

    public String globalInfo() {
        return "Generates random " + Dl4jMlpClassifier.class.getName() + " networks.\nRandomly selects items from the supplied parameter lists if more than item supplied.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("default-network", "defaultNetwork", new Dl4jMlpClassifier());
        this.m_OptionManager.add("default-dense-layer", "defaultDenseLayer", new DenseLayer());
        this.m_OptionManager.add("default-output-layer", "defaultOutputLayer", new OutputLayer());
        this.m_OptionManager.add("seed", "seed", 1L);
        this.m_OptionManager.add("num-networks", "numNetworks", 1, 1, (Number) null);
        this.m_OptionManager.add("num-layers", "numLayers", new BaseInteger[0]);
        this.m_OptionManager.add("num-nodes", "numNodes", new BaseInteger[0]);
        this.m_OptionManager.add("learning-rate", "learningRate", new BaseDouble[0]);
        this.m_OptionManager.add("use-regularization", "useRegularization", new BaseBoolean[0]);
        this.m_OptionManager.add("l1", "L1", new BaseDouble[0]);
        this.m_OptionManager.add("l2", "L2", new BaseDouble[0]);
        this.m_OptionManager.add(DL4JMultiLayerNetwork.DROP_TYPE, "dropType", new Dl4jMlpClassifier.DropType[0]);
        this.m_OptionManager.add(DL4JMultiLayerNetwork.DROP_OUT, "dropOut", new BaseDouble[0]);
        this.m_OptionManager.add("learning-rate-schedule-epochs", "learningRateScheduleEpochs", new BaseInteger[0]);
        this.m_OptionManager.add("learning-rate-schedule-divisors", "learningRateScheduleDivisors", new BaseDouble[0]);
        this.m_OptionManager.add("insert-batch-norm-layers", "insertBatchNormLayers", false);
    }

    public void setDefaultNetwork(Dl4jMlpClassifier dl4jMlpClassifier) {
        this.m_DefaultNetwork = dl4jMlpClassifier;
        reset();
    }

    public Dl4jMlpClassifier getDefaultNetwork() {
        return this.m_DefaultNetwork;
    }

    public String defaultNetworkTipText() {
        return "The default network setup to use (minus layers).";
    }

    public void setDefaultDenseLayer(DenseLayer denseLayer) {
        this.m_DefaultDenseLayer = denseLayer;
        reset();
    }

    public DenseLayer getDefaultDenseLayer() {
        return this.m_DefaultDenseLayer;
    }

    public String defaultDenseLayerTipText() {
        return "The default dense layer setup to use (minus layers).";
    }

    public void setDefaultOutputLayer(OutputLayer outputLayer) {
        this.m_DefaultOutputLayer = outputLayer;
        reset();
    }

    public OutputLayer getDefaultOutputLayer() {
        return this.m_DefaultOutputLayer;
    }

    public String defaultOutputLayerTipText() {
        return "The default output layer setup to use.";
    }

    public void setSeed(long j) {
        this.m_Seed = j;
        reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value to use for initializing the random number generator";
    }

    public void setNumNetworks(int i) {
        if (getOptionManager().isValid("numNetworks", Integer.valueOf(i))) {
            this.m_NumNetworks = i;
            reset();
        }
    }

    public int getNumNetworks() {
        return this.m_NumNetworks;
    }

    public String numNetworksTipText() {
        return "The maximum number of networks to generate (duplicate networks will still get removed afterwards).";
    }

    public void setNumLayers(BaseInteger[] baseIntegerArr) {
        this.m_NumLayers = baseIntegerArr;
        reset();
    }

    public BaseInteger[] getNumLayers() {
        return this.m_NumLayers;
    }

    public String numLayersTipText() {
        return "The list of layers to choose from.";
    }

    public void setNumNodes(BaseInteger[] baseIntegerArr) {
        this.m_NumNodes = baseIntegerArr;
        reset();
    }

    public BaseInteger[] getNumNodes() {
        return this.m_NumNodes;
    }

    public String numNodesTipText() {
        return "The list of node counts to choose from.";
    }

    public void setLearningRate(BaseDouble[] baseDoubleArr) {
        this.m_LearningRate = baseDoubleArr;
        reset();
    }

    public BaseDouble[] getLearningRate() {
        return this.m_LearningRate;
    }

    public String learningRateTipText() {
        return "The list of learning rate values to choose from.";
    }

    public void setUseRegularization(BaseBoolean[] baseBooleanArr) {
        this.m_UseRegularization = baseBooleanArr;
        reset();
    }

    public BaseBoolean[] getUseRegularization() {
        return this.m_UseRegularization;
    }

    public String useRegularizationTipText() {
        return "The list of regularization flags to choose from.";
    }

    public void setL1(BaseDouble[] baseDoubleArr) {
        this.m_L1 = baseDoubleArr;
        reset();
    }

    public BaseDouble[] getL1() {
        return this.m_L1;
    }

    public String L1TipText() {
        return "The list of L1 values to choose from.";
    }

    public void setL2(BaseDouble[] baseDoubleArr) {
        this.m_L2 = baseDoubleArr;
        reset();
    }

    public BaseDouble[] getL2() {
        return this.m_L2;
    }

    public String L2TipText() {
        return "The list of L2 values to choose from.";
    }

    public void setDropType(Dl4jMlpClassifier.DropType[] dropTypeArr) {
        this.m_DropType = dropTypeArr;
        reset();
    }

    public Dl4jMlpClassifier.DropType[] getDropType() {
        return this.m_DropType;
    }

    public String dropTypeTipText() {
        return "The list of drop type values to choose from.";
    }

    public void setDropOut(BaseDouble[] baseDoubleArr) {
        this.m_DropOut = baseDoubleArr;
        reset();
    }

    public BaseDouble[] getDropOut() {
        return this.m_DropOut;
    }

    public String dropOutTipText() {
        return "The list of drop out values to choose from.";
    }

    public void setLearningRateScheduleEpochs(BaseInteger[] baseIntegerArr) {
        this.m_LearningRateScheduleEpochs = baseIntegerArr;
        reset();
    }

    public BaseInteger[] getLearningRateScheduleEpochs() {
        return this.m_LearningRateScheduleEpochs;
    }

    public String learningRateScheduleEpochsTipText() {
        return "The list of epochs for the learning rate schedule; not used if empty.";
    }

    public void setLearningRateScheduleDivisors(BaseDouble[] baseDoubleArr) {
        this.m_LearningRateScheduleDivisors = baseDoubleArr;
        reset();
    }

    public BaseDouble[] getLearningRateScheduleDivisors() {
        return this.m_LearningRateScheduleDivisors;
    }

    public String learningRateScheduleDivisorsTipText() {
        return "The list of divisors for calculating learning rate used in the learning rate schedule to choose from; each subsequent value is calculated by dividing the previous one by the chosen divisor.";
    }

    public void setInsertBatchNormLayers(boolean z) {
        this.m_InsertBatchNormLayers = z;
        reset();
    }

    public boolean getInsertBatchNormLayers() {
        return this.m_InsertBatchNormLayers;
    }

    public String insertBatchNormLayersTipText() {
        return "If enabled, batchnorm layers get inserted after each layer (except output layer).";
    }

    protected String check() {
        String check = super.check();
        if (check == null) {
            if (this.m_NumLayers.length == 0) {
                check = "List of layers is empty!";
            } else if (this.m_NumNodes.length == 0) {
                check = "List of nodes is empty!";
            }
        }
        return check;
    }

    protected Object pick(Random random, Object obj) {
        return Array.getLength(obj) == 1 ? Array.get(obj, 0) : Array.get(obj, random.nextInt(Array.getLength(obj)));
    }

    protected List<Model> doGenerate(int i, int i2) {
        Layer layer;
        ArrayList arrayList = new ArrayList();
        Random random = new Random(this.m_Seed);
        while (arrayList.size() < this.m_NumNetworks) {
            Dl4jMlpClassifier dl4jMlpClassifier = (Dl4jMlpClassifier) OptionUtils.shallowCopy(this.m_DefaultNetwork);
            if (this.m_UseRegularization.length > 0) {
                dl4jMlpClassifier.setUseRegularization(((BaseBoolean) pick(random, this.m_UseRegularization)).booleanValue());
            }
            if (dl4jMlpClassifier.getUseRegularization()) {
                if (this.m_L1.length > 0) {
                    dl4jMlpClassifier.setL1(((BaseDouble) pick(random, this.m_L1)).doubleValue());
                }
                if (this.m_L2.length > 0) {
                    dl4jMlpClassifier.setL2(((BaseDouble) pick(random, this.m_L2)).doubleValue());
                }
            }
            if (this.m_DropType.length > 0) {
                dl4jMlpClassifier.setDropType((Dl4jMlpClassifier.DropType) pick(random, this.m_DropType));
            }
            if (this.m_DropOut.length > 0) {
                dl4jMlpClassifier.setDropOut(((BaseDouble) pick(random, this.m_DropOut)).doubleValue());
            }
            ArrayList arrayList2 = new ArrayList();
            int intValue = ((BaseInteger) pick(random, this.m_NumLayers)).intValue();
            if (isLoggingEnabled()) {
                getLogger().info("# layers: " + intValue);
            }
            for (int i3 = 0; i3 < intValue; i3++) {
                if (i3 == intValue - 1) {
                    layer = (OutputLayer) OptionUtils.shallowCopy(this.m_DefaultOutputLayer);
                    layer.setLayerName("output-" + i3);
                } else {
                    layer = (DenseLayer) OptionUtils.shallowCopy(this.m_DefaultDenseLayer);
                    layer.setLayerName("dense-" + i3);
                    ((DenseLayer) layer).setNOut(((BaseInteger) pick(random, this.m_NumNodes)).intValue());
                }
                if (this.m_LearningRate.length > 0) {
                    ((BaseLayer) layer).setLearningRate(((BaseDouble) pick(random, this.m_LearningRate)).doubleValue());
                }
                if (dl4jMlpClassifier.getUseRegularization()) {
                    if (this.m_L1.length > 0) {
                        ((BaseLayer) layer).setL1(((BaseDouble) pick(random, this.m_L1)).doubleValue());
                    }
                    if (this.m_L2.length > 0) {
                        ((BaseLayer) layer).setL2(((BaseDouble) pick(random, this.m_L2)).doubleValue());
                    }
                }
                if (this.m_DropType.length > 0) {
                    Dl4jMlpClassifier.DropType dropType = (Dl4jMlpClassifier.DropType) pick(random, this.m_DropType);
                    switch (dropType) {
                        case NONE:
                            layer.setDropOut(0.0d);
                            break;
                        case DROP_OUT:
                            layer.setDropOut(((BaseDouble) pick(random, this.m_DropOut)).doubleValue());
                            break;
                        case DROP_CONNECT:
                            layer.setDropOut(((BaseDouble) pick(random, this.m_DropOut)).doubleValue());
                            break;
                        default:
                            throw new IllegalStateException("Unhandled drop type: " + dropType);
                    }
                }
                if (this.m_LearningRateScheduleEpochs.length > 0) {
                    double learningRate = ((BaseLayer) layer).getLearningRate();
                    HashMap hashMap = new HashMap();
                    for (int i4 = 0; i4 < this.m_LearningRateScheduleEpochs.length; i4++) {
                        learningRate /= ((BaseDouble) pick(random, this.m_LearningRateScheduleDivisors)).doubleValue();
                        hashMap.put(Integer.valueOf(this.m_LearningRateScheduleEpochs[i4].intValue()), Double.valueOf(learningRate));
                    }
                    ((BaseLayer) layer).setLearningRateSchedule(hashMap);
                }
                arrayList2.add(layer);
                if (this.m_InsertBatchNormLayers && !(layer instanceof OutputLayer)) {
                    arrayList2.add(new BatchNormalization());
                }
            }
            dl4jMlpClassifier.setLayers((Layer[]) arrayList2.toArray(new Layer[arrayList2.size()]));
            arrayList.add(dl4jMlpClassifier.configureModel(i, i2));
        }
        HashSet hashSet = new HashSet();
        int i5 = 0;
        while (i5 < arrayList.size()) {
            String yaml = ((MultiLayerNetwork) arrayList.get(i5)).getLayerWiseConfigurations().toYaml();
            if (hashSet.contains(yaml)) {
                arrayList.remove(i5);
            } else {
                hashSet.add(yaml);
                i5++;
            }
        }
        return arrayList;
    }
}
