package adams.ml.cntk.modelgenerator;

import adams.core.Randomizable;
import adams.core.Utils;
import adams.core.base.BaseBoolean;
import adams.core.base.BaseDouble;
import adams.core.base.BaseInteger;
import adams.core.base.BaseText;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Random;

/* loaded from: input_file:adams/ml/cntk/modelgenerator/RandomBrainScriptGenerator.class */
public class RandomBrainScriptGenerator extends AbstractBrainScriptModelGenerator implements Randomizable {
    private static final long serialVersionUID = 6117066358207451433L;
    public static final String INPUT_DIM = "inputDim";
    public static final String OUTPUT_DIM = "outputDim";
    protected BaseText m_DefaultNetwork;
    protected BaseText m_DefaultDenseLayer;
    protected BaseText m_DefaultOutputLayer;
    protected long m_Seed;
    protected int m_NumNetworks;
    protected BaseInteger[] m_NumLayers;
    protected BaseInteger[] m_NumNodes;
    protected BaseDouble[] m_LearningRate;
    protected BaseBoolean[] m_UseRegularization;
    protected BaseDouble[] m_L1;
    protected BaseDouble[] m_L2;
    protected BaseDouble[] m_DropOut;
    protected BaseInteger[] m_LearningRateScheduleEpochs;
    protected BaseDouble[] m_LearningRateScheduleDivisors;
    protected boolean m_InsertBatchNormLayers;

    public String globalInfo() {
        return "Generates random networks, just using dense layers and a linear layer as output.\nInserts the following variables for input and output dimensions:\n- input: inputDim\n- output: outputDim\n\n" + getBrainScriptInfo();
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("default-network", "defaultNetwork", new BaseText("SGD = {\n  #maxEpochs = 10\n  #minibatchSize = 200\n}"));
        this.m_OptionManager.add("default-dense-layer", "defaultDenseLayer", new BaseText("DenseLayer {}"));
        this.m_OptionManager.add("default-output-layer", "defaultOutputLayer", new BaseText("LinearLayer {outputDim, init=\"gaussian\" }"));
        this.m_OptionManager.add("seed", "seed", 1L);
        this.m_OptionManager.add("num-networks", "numNetworks", 1, 1, (Number) null);
        this.m_OptionManager.add("num-layers", "numLayers", new BaseInteger[0]);
        this.m_OptionManager.add("num-nodes", "numNodes", new BaseInteger[0]);
        this.m_OptionManager.add("learning-rate", "learningRate", new BaseDouble[0]);
        this.m_OptionManager.add("use-regularization", "useRegularization", new BaseBoolean[0]);
        this.m_OptionManager.add("l1", "L1", new BaseDouble[0]);
        this.m_OptionManager.add("l2", "L2", new BaseDouble[0]);
        this.m_OptionManager.add("drop-out", "dropOut", new BaseDouble[0]);
        this.m_OptionManager.add("learning-rate-schedule-epochs", "learningRateScheduleEpochs", new BaseInteger[0]);
        this.m_OptionManager.add("learning-rate-schedule-divisors", "learningRateScheduleDivisors", new BaseDouble[0]);
        this.m_OptionManager.add("insert-batch-norm-layers", "insertBatchNormLayers", false);
    }

    public void setDefaultNetwork(BaseText baseText) {
        this.m_DefaultNetwork = baseText;
        reset();
    }

    public BaseText getDefaultNetwork() {
        return this.m_DefaultNetwork;
    }

    public String defaultNetworkTipText() {
        return "The default network setup to use (minus layers).";
    }

    public void setDefaultDenseLayer(BaseText baseText) {
        this.m_DefaultDenseLayer = baseText;
        reset();
    }

    public BaseText getDefaultDenseLayer() {
        return this.m_DefaultDenseLayer;
    }

    public String defaultDenseLayerTipText() {
        return "The default dense layer setup to use (minus layers).";
    }

    public void setDefaultOutputLayer(BaseText baseText) {
        this.m_DefaultOutputLayer = baseText;
        reset();
    }

    public BaseText getDefaultOutputLayer() {
        return this.m_DefaultOutputLayer;
    }

    public String defaultOutputLayerTipText() {
        return "The default output layer setup to use.";
    }

    public void setSeed(long j) {
        this.m_Seed = j;
        reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value to use for initializing the random number generator";
    }

    public void setNumNetworks(int i) {
        if (getOptionManager().isValid("numNetworks", Integer.valueOf(i))) {
            this.m_NumNetworks = i;
            reset();
        }
    }

    public int getNumNetworks() {
        return this.m_NumNetworks;
    }

    public String numNetworksTipText() {
        return "The maximum number of networks to generate (duplicate networks will still get removed afterwards).";
    }

    public void setNumLayers(BaseInteger[] baseIntegerArr) {
        this.m_NumLayers = baseIntegerArr;
        reset();
    }

    public BaseInteger[] getNumLayers() {
        return this.m_NumLayers;
    }

    public String numLayersTipText() {
        return "The list of layers to choose from.";
    }

    public void setNumNodes(BaseInteger[] baseIntegerArr) {
        this.m_NumNodes = baseIntegerArr;
        reset();
    }

    public BaseInteger[] getNumNodes() {
        return this.m_NumNodes;
    }

    public String numNodesTipText() {
        return "The list of node counts to choose from.";
    }

    public void setLearningRate(BaseDouble[] baseDoubleArr) {
        this.m_LearningRate = baseDoubleArr;
        reset();
    }

    public BaseDouble[] getLearningRate() {
        return this.m_LearningRate;
    }

    public String learningRateTipText() {
        return "The list of learning rate values to choose from.";
    }

    public void setUseRegularization(BaseBoolean[] baseBooleanArr) {
        this.m_UseRegularization = baseBooleanArr;
        reset();
    }

    public BaseBoolean[] getUseRegularization() {
        return this.m_UseRegularization;
    }

    public String useRegularizationTipText() {
        return "The list of regularization flags to choose from.";
    }

    public void setL1(BaseDouble[] baseDoubleArr) {
        this.m_L1 = baseDoubleArr;
        reset();
    }

    public BaseDouble[] getL1() {
        return this.m_L1;
    }

    public String L1TipText() {
        return "The list of L1 values to choose from.";
    }

    public void setL2(BaseDouble[] baseDoubleArr) {
        this.m_L2 = baseDoubleArr;
        reset();
    }

    public BaseDouble[] getL2() {
        return this.m_L2;
    }

    public String L2TipText() {
        return "The list of L2 values to choose from.";
    }

    public void setDropOut(BaseDouble[] baseDoubleArr) {
        this.m_DropOut = baseDoubleArr;
        reset();
    }

    public BaseDouble[] getDropOut() {
        return this.m_DropOut;
    }

    public String dropOutTipText() {
        return "The list of drop out values to choose from.";
    }

    public void setLearningRateScheduleEpochs(BaseInteger[] baseIntegerArr) {
        this.m_LearningRateScheduleEpochs = baseIntegerArr;
        reset();
    }

    public BaseInteger[] getLearningRateScheduleEpochs() {
        return this.m_LearningRateScheduleEpochs;
    }

    public String learningRateScheduleEpochsTipText() {
        return "The list of epochs for the learning rate schedule; requires learning rate to be specified (for initial value); not used if empty.";
    }

    public void setLearningRateScheduleDivisors(BaseDouble[] baseDoubleArr) {
        this.m_LearningRateScheduleDivisors = baseDoubleArr;
        reset();
    }

    public BaseDouble[] getLearningRateScheduleDivisors() {
        return this.m_LearningRateScheduleDivisors;
    }

    public String learningRateScheduleDivisorsTipText() {
        return "The list of divisors for calculating learning rate used in the learning rate schedule to choose from; each subsequent value is calculated by dividing the previous one by the chosen divisor.";
    }

    public void setInsertBatchNormLayers(boolean z) {
        this.m_InsertBatchNormLayers = z;
        reset();
    }

    public boolean getInsertBatchNormLayers() {
        return this.m_InsertBatchNormLayers;
    }

    public String insertBatchNormLayersTipText() {
        return "If enabled, batchnorm layers get inserted after each layer (except output layer).";
    }

    protected Object pick(Random random, Object obj) {
        return Array.getLength(obj) == 1 ? Array.get(obj, 0) : Array.get(obj, random.nextInt(Array.getLength(obj)));
    }

    @Override // adams.ml.cntk.modelgenerator.AbstractModelGenerator
    protected List<String> doGenerate(int i, int i2) {
        String outDim;
        ArrayList arrayList = new ArrayList();
        Random random = new Random(this.m_Seed);
        while (arrayList.size() < this.m_NumNetworks) {
            String value = this.m_DefaultNetwork.getValue();
            if (this.m_UseRegularization.length > 0 && ((BaseBoolean) pick(random, this.m_UseRegularization)).booleanValue()) {
                if (this.m_L1.length > 0) {
                    value = BrainScriptHelper.addParam(value, "  L1RegWeight = " + pick(random, this.m_L1), true);
                }
                if (this.m_L2.length > 0) {
                    value = BrainScriptHelper.addParam(value, "  L2RegWeight = " + pick(random, this.m_L2), true);
                }
            }
            if (this.m_DropOut.length > 0) {
                value = BrainScriptHelper.addParam(value, "  dropoutRate = " + pick(random, this.m_DropOut), true);
            }
            if (this.m_LearningRate.length > 0) {
                if (this.m_LearningRateScheduleEpochs.length > 0) {
                    double doubleValue = ((BaseDouble) pick(random, this.m_LearningRate)).doubleValue();
                    int i3 = 0;
                    ArrayList arrayList2 = new ArrayList();
                    for (int i4 = 0; i4 < this.m_LearningRateScheduleEpochs.length; i4++) {
                        doubleValue /= ((BaseDouble) pick(random, this.m_LearningRateScheduleDivisors)).doubleValue();
                        if (i4 < this.m_LearningRateScheduleEpochs.length - 1) {
                            arrayList2.add(doubleValue + "*" + (this.m_LearningRateScheduleEpochs[i4].intValue() - i3));
                        } else {
                            arrayList2.add("" + doubleValue);
                        }
                        i3 = this.m_LearningRateScheduleEpochs[i4].intValue();
                    }
                    value = BrainScriptHelper.addParam(value, "  learningRatesPerSample = " + Utils.flatten(arrayList2, ":"), true);
                } else {
                    value = BrainScriptHelper.addParam(value, "  learningRatesPerSample = " + pick(random, this.m_LearningRate), true);
                }
            }
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            int intValue = ((BaseInteger) pick(random, this.m_NumLayers)).intValue();
            if (isLoggingEnabled()) {
                getLogger().info("# layers: " + intValue);
            }
            int i5 = 0;
            while (i5 < intValue) {
                if (i5 == intValue - 1) {
                    outDim = this.m_DefaultOutputLayer.getValue();
                    arrayList4.add("ol");
                } else {
                    String value2 = this.m_DefaultDenseLayer.getValue();
                    arrayList4.add("dl" + i5);
                    outDim = i5 == 0 ? BrainScriptHelper.setOutDim(value2, "inputDim") : BrainScriptHelper.setOutDim(value2, ((BaseInteger) pick(random, this.m_NumNodes)).intValue());
                }
                arrayList3.add(outDim);
                if (this.m_InsertBatchNormLayers && i5 < intValue - 1) {
                    arrayList3.add("BatchNormalizationLayer");
                    arrayList4.add("bnl" + i5);
                }
                i5++;
            }
            StringBuilder sb = new StringBuilder();
            sb.append("BrainScriptNetworkBuilder = {\n");
            sb.append("  ").append("inputDim").append(" = ").append("" + i).append("\n");
            sb.append("  ").append("outputDim").append(" = ").append("" + i2).append("\n");
            sb.append("\n");
            sb.append("  ").append("model (features) {\n");
            for (int i6 = 0; i6 < arrayList3.size(); i6++) {
                sb.append("    ").append((String) arrayList4.get(i6)).append(" = ").append((String) arrayList3.get(i6));
                if (i6 == 0) {
                    sb.append(" (features)");
                } else {
                    sb.append(" (").append((String) arrayList4.get(i6 - 1)).append(")");
                }
                sb.append("\n");
            }
            sb.append("  ").append("}.ol\n");
            sb.append("\n");
            sb.append("  # inputs\n");
            sb.append("  features = Input {inputDim}\n");
            sb.append("  labels = Input {labels}\n");
            sb.append("  \n");
            sb.append("  # apply model to outputs\n");
            sb.append("  ol = model (features)\n");
            sb.append("  \n");
            sb.append("  # define regression loss\n");
            sb.append("  diff = labels - ol\n");
            sb.append("  sqerr = ReduceSum (diff.*diff, axis=1)\n");
            sb.append("  rmse = Sqrt (sqerr)\n");
            sb.append("  \n");
            sb.append("  # declare special nodes\n");
            sb.append("  featureNodes    = (features)\n");
            sb.append("  labelNodes      = (labels)\n");
            sb.append("  criterionNodes  = (rmse)\n");
            sb.append("  evaluationNodes = (rmse)\n");
            sb.append("  outputNodes     = (ol)\n");
            sb.append("\n");
            sb.append("}\n");
            sb.append("\n");
            sb.append(value);
            arrayList.add(sb.toString());
        }
        HashSet hashSet = new HashSet();
        int i7 = 0;
        while (i7 < arrayList.size()) {
            if (hashSet.contains(arrayList.get(i7))) {
                arrayList.remove(i7);
            } else {
                hashSet.add(arrayList.get(i7));
                i7++;
            }
        }
        return arrayList;
    }
}
