/*
 * Decompiled with CFR 0.152.
 */
package adams.ml.cntk.modelgenerator;

import adams.core.Randomizable;
import adams.core.Utils;
import adams.core.base.BaseBoolean;
import adams.core.base.BaseDouble;
import adams.core.base.BaseInteger;
import adams.core.base.BaseText;
import adams.ml.cntk.modelgenerator.AbstractBrainScriptModelGenerator;
import adams.ml.cntk.modelgenerator.BrainScriptHelper;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Random;

public class RandomBrainScriptGenerator
extends AbstractBrainScriptModelGenerator
implements Randomizable {
    private static final long serialVersionUID = 6117066358207451433L;
    public static final String INPUT_DIM = "inputDim";
    public static final String OUTPUT_DIM = "outputDim";
    protected BaseText m_DefaultNetwork;
    protected BaseText m_DefaultDenseLayer;
    protected BaseText m_DefaultOutputLayer;
    protected long m_Seed;
    protected int m_NumNetworks;
    protected BaseInteger[] m_NumLayers;
    protected BaseInteger[] m_NumNodes;
    protected BaseDouble[] m_LearningRate;
    protected BaseBoolean[] m_UseRegularization;
    protected BaseDouble[] m_L1;
    protected BaseDouble[] m_L2;
    protected BaseDouble[] m_DropOut;
    protected BaseInteger[] m_LearningRateScheduleEpochs;
    protected BaseDouble[] m_LearningRateScheduleDivisors;
    protected boolean m_InsertBatchNormLayers;

    public String globalInfo() {
        return "Generates random networks, just using dense layers and a linear layer as output.\nInserts the following variables for input and output dimensions:\n- input: inputDim\n- output: outputDim\n\n" + this.getBrainScriptInfo();
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("default-network", "defaultNetwork", (Object)new BaseText("SGD = {\n  #maxEpochs = 10\n  #minibatchSize = 200\n}"));
        this.m_OptionManager.add("default-dense-layer", "defaultDenseLayer", (Object)new BaseText("DenseLayer {}"));
        this.m_OptionManager.add("default-output-layer", "defaultOutputLayer", (Object)new BaseText("LinearLayer {outputDim, init=\"gaussian\" }"));
        this.m_OptionManager.add("seed", "seed", (Object)1L);
        this.m_OptionManager.add("num-networks", "numNetworks", (Object)1, (Number)1, null);
        this.m_OptionManager.add("num-layers", "numLayers", (Object)new BaseInteger[0]);
        this.m_OptionManager.add("num-nodes", "numNodes", (Object)new BaseInteger[0]);
        this.m_OptionManager.add("learning-rate", "learningRate", (Object)new BaseDouble[0]);
        this.m_OptionManager.add("use-regularization", "useRegularization", (Object)new BaseBoolean[0]);
        this.m_OptionManager.add("l1", "L1", (Object)new BaseDouble[0]);
        this.m_OptionManager.add("l2", "L2", (Object)new BaseDouble[0]);
        this.m_OptionManager.add("drop-out", "dropOut", (Object)new BaseDouble[0]);
        this.m_OptionManager.add("learning-rate-schedule-epochs", "learningRateScheduleEpochs", (Object)new BaseInteger[0]);
        this.m_OptionManager.add("learning-rate-schedule-divisors", "learningRateScheduleDivisors", (Object)new BaseDouble[0]);
        this.m_OptionManager.add("insert-batch-norm-layers", "insertBatchNormLayers", (Object)false);
    }

    public void setDefaultNetwork(BaseText value) {
        this.m_DefaultNetwork = value;
        this.reset();
    }

    public BaseText getDefaultNetwork() {
        return this.m_DefaultNetwork;
    }

    public String defaultNetworkTipText() {
        return "The default network setup to use (minus layers).";
    }

    public void setDefaultDenseLayer(BaseText value) {
        this.m_DefaultDenseLayer = value;
        this.reset();
    }

    public BaseText getDefaultDenseLayer() {
        return this.m_DefaultDenseLayer;
    }

    public String defaultDenseLayerTipText() {
        return "The default dense layer setup to use (minus layers).";
    }

    public void setDefaultOutputLayer(BaseText value) {
        this.m_DefaultOutputLayer = value;
        this.reset();
    }

    public BaseText getDefaultOutputLayer() {
        return this.m_DefaultOutputLayer;
    }

    public String defaultOutputLayerTipText() {
        return "The default output layer setup to use.";
    }

    public void setSeed(long value) {
        this.m_Seed = value;
        this.reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value to use for initializing the random number generator";
    }

    public void setNumNetworks(int value) {
        if (this.getOptionManager().isValid("numNetworks", (Number)value)) {
            this.m_NumNetworks = value;
            this.reset();
        }
    }

    public int getNumNetworks() {
        return this.m_NumNetworks;
    }

    public String numNetworksTipText() {
        return "The maximum number of networks to generate (duplicate networks will still get removed afterwards).";
    }

    public void setNumLayers(BaseInteger[] value) {
        this.m_NumLayers = value;
        this.reset();
    }

    public BaseInteger[] getNumLayers() {
        return this.m_NumLayers;
    }

    public String numLayersTipText() {
        return "The list of layers to choose from.";
    }

    public void setNumNodes(BaseInteger[] value) {
        this.m_NumNodes = value;
        this.reset();
    }

    public BaseInteger[] getNumNodes() {
        return this.m_NumNodes;
    }

    public String numNodesTipText() {
        return "The list of node counts to choose from.";
    }

    public void setLearningRate(BaseDouble[] value) {
        this.m_LearningRate = value;
        this.reset();
    }

    public BaseDouble[] getLearningRate() {
        return this.m_LearningRate;
    }

    public String learningRateTipText() {
        return "The list of learning rate values to choose from.";
    }

    public void setUseRegularization(BaseBoolean[] value) {
        this.m_UseRegularization = value;
        this.reset();
    }

    public BaseBoolean[] getUseRegularization() {
        return this.m_UseRegularization;
    }

    public String useRegularizationTipText() {
        return "The list of regularization flags to choose from.";
    }

    public void setL1(BaseDouble[] value) {
        this.m_L1 = value;
        this.reset();
    }

    public BaseDouble[] getL1() {
        return this.m_L1;
    }

    public String L1TipText() {
        return "The list of L1 values to choose from.";
    }

    public void setL2(BaseDouble[] value) {
        this.m_L2 = value;
        this.reset();
    }

    public BaseDouble[] getL2() {
        return this.m_L2;
    }

    public String L2TipText() {
        return "The list of L2 values to choose from.";
    }

    public void setDropOut(BaseDouble[] value) {
        this.m_DropOut = value;
        this.reset();
    }

    public BaseDouble[] getDropOut() {
        return this.m_DropOut;
    }

    public String dropOutTipText() {
        return "The list of drop out values to choose from.";
    }

    public void setLearningRateScheduleEpochs(BaseInteger[] value) {
        this.m_LearningRateScheduleEpochs = value;
        this.reset();
    }

    public BaseInteger[] getLearningRateScheduleEpochs() {
        return this.m_LearningRateScheduleEpochs;
    }

    public String learningRateScheduleEpochsTipText() {
        return "The list of epochs for the learning rate schedule; requires learning rate to be specified (for initial value); not used if empty.";
    }

    public void setLearningRateScheduleDivisors(BaseDouble[] value) {
        this.m_LearningRateScheduleDivisors = value;
        this.reset();
    }

    public BaseDouble[] getLearningRateScheduleDivisors() {
        return this.m_LearningRateScheduleDivisors;
    }

    public String learningRateScheduleDivisorsTipText() {
        return "The list of divisors for calculating learning rate used in the learning rate schedule to choose from; each subsequent value is calculated by dividing the previous one by the chosen divisor.";
    }

    public void setInsertBatchNormLayers(boolean value) {
        this.m_InsertBatchNormLayers = value;
        this.reset();
    }

    public boolean getInsertBatchNormLayers() {
        return this.m_InsertBatchNormLayers;
    }

    public String insertBatchNormLayersTipText() {
        return "If enabled, batchnorm layers get inserted after each layer (except output layer).";
    }

    protected Object pick(Random rand, Object array) {
        if (Array.getLength(array) == 1) {
            return Array.get(array, 0);
        }
        return Array.get(array, rand.nextInt(Array.getLength(array)));
    }

    @Override
    protected List<String> doGenerate(int numInput, int numOutput) {
        int i;
        ArrayList<String> result = new ArrayList<String>();
        Random rand = new Random(this.m_Seed);
        while (result.size() < this.m_NumNetworks) {
            int n;
            String conf = this.m_DefaultNetwork.getValue();
            if (this.m_UseRegularization.length > 0 && ((BaseBoolean)this.pick(rand, this.m_UseRegularization)).booleanValue()) {
                if (this.m_L1.length > 0) {
                    conf = BrainScriptHelper.addParam(conf, "  L1RegWeight = " + this.pick(rand, this.m_L1), true);
                }
                if (this.m_L2.length > 0) {
                    conf = BrainScriptHelper.addParam(conf, "  L2RegWeight = " + this.pick(rand, this.m_L2), true);
                }
            }
            if (this.m_DropOut.length > 0) {
                conf = BrainScriptHelper.addParam(conf, "  dropoutRate = " + this.pick(rand, this.m_DropOut), true);
            }
            if (this.m_LearningRate.length > 0) {
                if (this.m_LearningRateScheduleEpochs.length > 0) {
                    double lr = ((BaseDouble)this.pick(rand, this.m_LearningRate)).doubleValue();
                    int priorEpoch = 0;
                    ArrayList<String> lrSchedule = new ArrayList<String>();
                    for (n = 0; n < this.m_LearningRateScheduleEpochs.length; ++n) {
                        lr /= ((BaseDouble)this.pick(rand, this.m_LearningRateScheduleDivisors)).doubleValue();
                        if (n < this.m_LearningRateScheduleEpochs.length - 1) {
                            lrSchedule.add(lr + "*" + (this.m_LearningRateScheduleEpochs[n].intValue() - priorEpoch));
                        } else {
                            lrSchedule.add("" + lr);
                        }
                        priorEpoch = this.m_LearningRateScheduleEpochs[n].intValue();
                    }
                    conf = BrainScriptHelper.addParam(conf, "  learningRatesPerSample = " + Utils.flatten(lrSchedule, (String)":"), true);
                } else {
                    conf = BrainScriptHelper.addParam(conf, "  learningRatesPerSample = " + this.pick(rand, this.m_LearningRate), true);
                }
            }
            ArrayList<String> layers = new ArrayList<String>();
            ArrayList<String> names = new ArrayList<String>();
            int numLayers = ((BaseInteger)this.pick(rand, this.m_NumLayers)).intValue();
            if (this.isLoggingEnabled()) {
                this.getLogger().info("# layers: " + numLayers);
            }
            for (i = 0; i < numLayers; ++i) {
                String layer;
                if (i == numLayers - 1) {
                    layer = this.m_DefaultOutputLayer.getValue();
                    names.add("ol");
                } else {
                    layer = this.m_DefaultDenseLayer.getValue();
                    names.add("dl" + i);
                    layer = i == 0 ? BrainScriptHelper.setOutDim(layer, INPUT_DIM) : BrainScriptHelper.setOutDim(layer, ((BaseInteger)this.pick(rand, this.m_NumNodes)).intValue());
                }
                layers.add(layer);
                if (!this.m_InsertBatchNormLayers || i >= numLayers - 1) continue;
                layers.add("BatchNormalizationLayer");
                names.add("bnl" + i);
            }
            StringBuilder model = new StringBuilder();
            model.append("BrainScriptNetworkBuilder = {\n");
            model.append("  ").append(INPUT_DIM).append(" = ").append("" + numInput).append("\n");
            model.append("  ").append(OUTPUT_DIM).append(" = ").append("" + numOutput).append("\n");
            model.append("\n");
            model.append("  ").append("model (features) {\n");
            for (n = 0; n < layers.size(); ++n) {
                model.append("    ").append((String)names.get(n)).append(" = ").append((String)layers.get(n));
                if (n == 0) {
                    model.append(" (features)");
                } else {
                    model.append(" (").append((String)names.get(n - 1)).append(")");
                }
                model.append("\n");
            }
            model.append("  ").append("}.ol\n");
            model.append("\n");
            model.append("  # inputs\n");
            model.append("  features = Input {inputDim}\n");
            model.append("  labels = Input {labels}\n");
            model.append("  \n");
            model.append("  # apply model to outputs\n");
            model.append("  ol = model (features)\n");
            model.append("  \n");
            model.append("  # define regression loss\n");
            model.append("  diff = labels - ol\n");
            model.append("  sqerr = ReduceSum (diff.*diff, axis=1)\n");
            model.append("  rmse = Sqrt (sqerr)\n");
            model.append("  \n");
            model.append("  # declare special nodes\n");
            model.append("  featureNodes    = (features)\n");
            model.append("  labelNodes      = (labels)\n");
            model.append("  criterionNodes  = (rmse)\n");
            model.append("  evaluationNodes = (rmse)\n");
            model.append("  outputNodes     = (ol)\n");
            model.append("\n");
            model.append("}\n");
            model.append("\n");
            model.append(conf);
            result.add(model.toString());
        }
        HashSet generated = new HashSet();
        i = 0;
        while (i < result.size()) {
            if (generated.contains(result.get(i))) {
                result.remove(i);
                continue;
            }
            generated.add(result.get(i));
            ++i;
        }
        return result;
    }
}

