/*
 * Decompiled with CFR 0.152.
 */
package adams.ml.dl4j.modelgenerator;

import adams.core.Randomizable;
import adams.core.base.BaseBoolean;
import adams.core.base.BaseDouble;
import adams.core.base.BaseInteger;
import adams.core.option.OptionUtils;
import adams.ml.dl4j.model.Dl4jMlpClassifier;
import adams.ml.dl4j.modelgenerator.AbstractModelGenerator;
import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import weka.dl4j.layers.BatchNormalization;
import weka.dl4j.layers.DenseLayer;
import weka.dl4j.layers.OutputLayer;

public class RandomDl4jMlpClassifierNetworks
extends AbstractModelGenerator
implements Randomizable {
    private static final long serialVersionUID = -3934375675313822523L;
    protected Dl4jMlpClassifier m_DefaultNetwork;
    protected DenseLayer m_DefaultDenseLayer;
    protected OutputLayer m_DefaultOutputLayer;
    protected long m_Seed;
    protected int m_NumNetworks;
    protected BaseInteger[] m_NumLayers;
    protected BaseInteger[] m_NumNodes;
    protected BaseDouble[] m_LearningRate;
    protected BaseBoolean[] m_UseRegularization;
    protected BaseDouble[] m_L1;
    protected BaseDouble[] m_L2;
    protected Dl4jMlpClassifier.DropType[] m_DropType;
    protected BaseDouble[] m_DropOut;
    protected BaseInteger[] m_LearningRateScheduleEpochs;
    protected BaseDouble[] m_LearningRateScheduleDivisors;
    protected boolean m_InsertBatchNormLayers;

    public String globalInfo() {
        return "Generates random " + Dl4jMlpClassifier.class.getName() + " networks.\nRandomly selects items from the supplied parameter lists if more than item supplied.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("default-network", "defaultNetwork", (Object)new Dl4jMlpClassifier());
        this.m_OptionManager.add("default-dense-layer", "defaultDenseLayer", (Object)new DenseLayer());
        this.m_OptionManager.add("default-output-layer", "defaultOutputLayer", (Object)new OutputLayer());
        this.m_OptionManager.add("seed", "seed", (Object)1L);
        this.m_OptionManager.add("num-networks", "numNetworks", (Object)1, (Number)1, null);
        this.m_OptionManager.add("num-layers", "numLayers", (Object)new BaseInteger[0]);
        this.m_OptionManager.add("num-nodes", "numNodes", (Object)new BaseInteger[0]);
        this.m_OptionManager.add("learning-rate", "learningRate", (Object)new BaseDouble[0]);
        this.m_OptionManager.add("use-regularization", "useRegularization", (Object)new BaseBoolean[0]);
        this.m_OptionManager.add("l1", "L1", (Object)new BaseDouble[0]);
        this.m_OptionManager.add("l2", "L2", (Object)new BaseDouble[0]);
        this.m_OptionManager.add("drop-type", "dropType", (Object)new Dl4jMlpClassifier.DropType[0]);
        this.m_OptionManager.add("drop-out", "dropOut", (Object)new BaseDouble[0]);
        this.m_OptionManager.add("learning-rate-schedule-epochs", "learningRateScheduleEpochs", (Object)new BaseInteger[0]);
        this.m_OptionManager.add("learning-rate-schedule-divisors", "learningRateScheduleDivisors", (Object)new BaseDouble[0]);
        this.m_OptionManager.add("insert-batch-norm-layers", "insertBatchNormLayers", (Object)false);
    }

    public void setDefaultNetwork(Dl4jMlpClassifier value) {
        this.m_DefaultNetwork = value;
        this.reset();
    }

    public Dl4jMlpClassifier getDefaultNetwork() {
        return this.m_DefaultNetwork;
    }

    public String defaultNetworkTipText() {
        return "The default network setup to use (minus layers).";
    }

    public void setDefaultDenseLayer(DenseLayer value) {
        this.m_DefaultDenseLayer = value;
        this.reset();
    }

    public DenseLayer getDefaultDenseLayer() {
        return this.m_DefaultDenseLayer;
    }

    public String defaultDenseLayerTipText() {
        return "The default dense layer setup to use (minus layers).";
    }

    public void setDefaultOutputLayer(OutputLayer value) {
        this.m_DefaultOutputLayer = value;
        this.reset();
    }

    public OutputLayer getDefaultOutputLayer() {
        return this.m_DefaultOutputLayer;
    }

    public String defaultOutputLayerTipText() {
        return "The default output layer setup to use.";
    }

    public void setSeed(long value) {
        this.m_Seed = value;
        this.reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value to use for initializing the random number generator";
    }

    public void setNumNetworks(int value) {
        if (this.getOptionManager().isValid("numNetworks", (Number)value)) {
            this.m_NumNetworks = value;
            this.reset();
        }
    }

    public int getNumNetworks() {
        return this.m_NumNetworks;
    }

    public String numNetworksTipText() {
        return "The maximum number of networks to generate (duplicate networks will still get removed afterwards).";
    }

    public void setNumLayers(BaseInteger[] value) {
        this.m_NumLayers = value;
        this.reset();
    }

    public BaseInteger[] getNumLayers() {
        return this.m_NumLayers;
    }

    public String numLayersTipText() {
        return "The list of layers to choose from.";
    }

    public void setNumNodes(BaseInteger[] value) {
        this.m_NumNodes = value;
        this.reset();
    }

    public BaseInteger[] getNumNodes() {
        return this.m_NumNodes;
    }

    public String numNodesTipText() {
        return "The list of node counts to choose from.";
    }

    public void setLearningRate(BaseDouble[] value) {
        this.m_LearningRate = value;
        this.reset();
    }

    public BaseDouble[] getLearningRate() {
        return this.m_LearningRate;
    }

    public String learningRateTipText() {
        return "The list of learning rate values to choose from.";
    }

    public void setUseRegularization(BaseBoolean[] value) {
        this.m_UseRegularization = value;
        this.reset();
    }

    public BaseBoolean[] getUseRegularization() {
        return this.m_UseRegularization;
    }

    public String useRegularizationTipText() {
        return "The list of regularization flags to choose from.";
    }

    public void setL1(BaseDouble[] value) {
        this.m_L1 = value;
        this.reset();
    }

    public BaseDouble[] getL1() {
        return this.m_L1;
    }

    public String L1TipText() {
        return "The list of L1 values to choose from.";
    }

    public void setL2(BaseDouble[] value) {
        this.m_L2 = value;
        this.reset();
    }

    public BaseDouble[] getL2() {
        return this.m_L2;
    }

    public String L2TipText() {
        return "The list of L2 values to choose from.";
    }

    public void setDropType(Dl4jMlpClassifier.DropType[] value) {
        this.m_DropType = value;
        this.reset();
    }

    public Dl4jMlpClassifier.DropType[] getDropType() {
        return this.m_DropType;
    }

    public String dropTypeTipText() {
        return "The list of drop type values to choose from.";
    }

    public void setDropOut(BaseDouble[] value) {
        this.m_DropOut = value;
        this.reset();
    }

    public BaseDouble[] getDropOut() {
        return this.m_DropOut;
    }

    public String dropOutTipText() {
        return "The list of drop out values to choose from.";
    }

    public void setLearningRateScheduleEpochs(BaseInteger[] value) {
        this.m_LearningRateScheduleEpochs = value;
        this.reset();
    }

    public BaseInteger[] getLearningRateScheduleEpochs() {
        return this.m_LearningRateScheduleEpochs;
    }

    public String learningRateScheduleEpochsTipText() {
        return "The list of epochs for the learning rate schedule; not used if empty.";
    }

    public void setLearningRateScheduleDivisors(BaseDouble[] value) {
        this.m_LearningRateScheduleDivisors = value;
        this.reset();
    }

    public BaseDouble[] getLearningRateScheduleDivisors() {
        return this.m_LearningRateScheduleDivisors;
    }

    public String learningRateScheduleDivisorsTipText() {
        return "The list of divisors for calculating learning rate used in the learning rate schedule to choose from; each subsequent value is calculated by dividing the previous one by the chosen divisor.";
    }

    public void setInsertBatchNormLayers(boolean value) {
        this.m_InsertBatchNormLayers = value;
        this.reset();
    }

    public boolean getInsertBatchNormLayers() {
        return this.m_InsertBatchNormLayers;
    }

    public String insertBatchNormLayersTipText() {
        return "If enabled, batchnorm layers get inserted after each layer (except output layer).";
    }

    protected String check() {
        String result = super.check();
        if (result == null) {
            if (this.m_NumLayers.length == 0) {
                result = "List of layers is empty!";
            } else if (this.m_NumNodes.length == 0) {
                result = "List of nodes is empty!";
            }
        }
        return result;
    }

    protected Object pick(Random rand, Object array) {
        if (Array.getLength(array) == 1) {
            return Array.get(array, 0);
        }
        return Array.get(array, rand.nextInt(Array.getLength(array)));
    }

    protected List<Model> doGenerate(int numInput, int numOutput) {
        int i;
        ArrayList<Model> result = new ArrayList<Model>();
        Random rand = new Random(this.m_Seed);
        while (result.size() < this.m_NumNetworks) {
            Dl4jMlpClassifier conf = (Dl4jMlpClassifier)((Object)OptionUtils.shallowCopy((Object)((Object)this.m_DefaultNetwork)));
            if (this.m_UseRegularization.length > 0) {
                conf.setUseRegularization(((BaseBoolean)this.pick(rand, this.m_UseRegularization)).booleanValue());
            }
            if (conf.getUseRegularization()) {
                if (this.m_L1.length > 0) {
                    conf.setL1(((BaseDouble)this.pick(rand, this.m_L1)).doubleValue());
                }
                if (this.m_L2.length > 0) {
                    conf.setL2(((BaseDouble)this.pick(rand, this.m_L2)).doubleValue());
                }
            }
            if (this.m_DropType.length > 0) {
                conf.setDropType((Dl4jMlpClassifier.DropType)((Object)this.pick(rand, this.m_DropType)));
            }
            if (this.m_DropOut.length > 0) {
                conf.setDropOut(((BaseDouble)this.pick(rand, this.m_DropOut)).doubleValue());
            }
            ArrayList<Serializable> layers = new ArrayList<Serializable>();
            int numLayers = ((BaseInteger)this.pick(rand, this.m_NumLayers)).intValue();
            if (this.isLoggingEnabled()) {
                this.getLogger().info("# layers: " + numLayers);
            }
            for (i = 0; i < numLayers; ++i) {
                Serializable layer;
                if (i == numLayers - 1) {
                    layer = (OutputLayer)OptionUtils.shallowCopy((Object)this.m_DefaultOutputLayer);
                    layer.setLayerName("output-" + i);
                } else {
                    layer = (DenseLayer)OptionUtils.shallowCopy((Object)this.m_DefaultDenseLayer);
                    layer.setLayerName("dense-" + i);
                    ((DenseLayer)layer).setNOut(((BaseInteger)this.pick(rand, this.m_NumNodes)).intValue());
                }
                if (this.m_LearningRate.length > 0) {
                    ((BaseLayer)layer).setLearningRate(((BaseDouble)this.pick(rand, this.m_LearningRate)).doubleValue());
                }
                if (conf.getUseRegularization()) {
                    if (this.m_L1.length > 0) {
                        ((BaseLayer)layer).setL1(((BaseDouble)this.pick(rand, this.m_L1)).doubleValue());
                    }
                    if (this.m_L2.length > 0) {
                        ((BaseLayer)layer).setL2(((BaseDouble)this.pick(rand, this.m_L2)).doubleValue());
                    }
                }
                if (this.m_DropType.length > 0) {
                    Dl4jMlpClassifier.DropType dropType = (Dl4jMlpClassifier.DropType)((Object)this.pick(rand, this.m_DropType));
                    switch (dropType) {
                        case NONE: {
                            layer.setDropOut(0.0);
                            break;
                        }
                        case DROP_OUT: {
                            layer.setDropOut(((BaseDouble)this.pick(rand, this.m_DropOut)).doubleValue());
                            break;
                        }
                        case DROP_CONNECT: {
                            layer.setDropOut(((BaseDouble)this.pick(rand, this.m_DropOut)).doubleValue());
                            break;
                        }
                        default: {
                            throw new IllegalStateException("Unhandled drop type: " + (Object)((Object)dropType));
                        }
                    }
                }
                if (this.m_LearningRateScheduleEpochs.length > 0) {
                    double lr = ((BaseLayer)layer).getLearningRate();
                    HashMap<Integer, Double> schedule = new HashMap<Integer, Double>();
                    for (int n = 0; n < this.m_LearningRateScheduleEpochs.length; ++n) {
                        schedule.put(this.m_LearningRateScheduleEpochs[n].intValue(), lr /= ((BaseDouble)this.pick(rand, this.m_LearningRateScheduleDivisors)).doubleValue());
                    }
                    ((BaseLayer)layer).setLearningRateSchedule(schedule);
                }
                layers.add(layer);
                if (!this.m_InsertBatchNormLayers || layer instanceof OutputLayer) continue;
                layers.add(new BatchNormalization());
            }
            conf.setLayers(layers.toArray(new Layer[layers.size()]));
            result.add(conf.configureModel(numInput, numOutput));
        }
        HashSet<String> generated = new HashSet<String>();
        i = 0;
        while (i < result.size()) {
            String yaml = ((MultiLayerNetwork)result.get(i)).getLayerWiseConfigurations().toYaml();
            if (generated.contains(yaml)) {
                result.remove(i);
                continue;
            }
            generated.add(yaml);
            ++i;
        }
        return result;
    }
}

