/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.node.ArrayNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultiLayerConfiguration
implements Serializable,
Cloneable {
    private static final Logger log = LoggerFactory.getLogger(MultiLayerConfiguration.class);
    private static final AtomicBoolean defaultChangeWarningPrinted = new AtomicBoolean(false);
    protected List<NeuralNetConfiguration> confs;
    protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<Integer, InputPreProcessor>();
    protected boolean pretrain = false;
    protected boolean backprop = true;
    protected BackpropType backpropType = BackpropType.Standard;
    protected int tbpttFwdLength = 20;
    protected int tbpttBackLength = 20;

    public String toYaml() {
        ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
        try {
            return mapper.writeValueAsString((Object)this);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    public static MultiLayerConfiguration fromYaml(String json) {
        ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
        try {
            return (MultiLayerConfiguration)mapper.readValue(json, MultiLayerConfiguration.class);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public String toJson() {
        ObjectMapper mapper = NeuralNetConfiguration.mapper();
        try {
            return mapper.writeValueAsString((Object)this);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    public static MultiLayerConfiguration fromJson(String json) {
        MultiLayerConfiguration conf;
        ObjectMapper mapper = NeuralNetConfiguration.mapper();
        try {
            conf = (MultiLayerConfiguration)mapper.readValue(json, MultiLayerConfiguration.class);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        int layerCount = 0;
        for (NeuralNetConfiguration nnc : conf.getConfs()) {
            block19: {
                Layer l = nnc.getLayer();
                if (l instanceof BaseOutputLayer && ((BaseOutputLayer)l).getLossFn() == null) {
                    BaseOutputLayer ol = (BaseOutputLayer)l;
                    try {
                        JsonNode jsonNode = mapper.readTree(json);
                        JsonNode confs = jsonNode.get("confs");
                        if (confs instanceof ArrayNode) {
                            ArrayNode layerConfs = (ArrayNode)confs;
                            JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
                            if (outputLayerNNCNode == null) {
                                return conf;
                            }
                            JsonNode outputLayerNode = outputLayerNNCNode.get("layer");
                            JsonNode lossFunctionNode = null;
                            if (outputLayerNode.has("output")) {
                                lossFunctionNode = outputLayerNode.get("output").get("lossFunction");
                            } else if (outputLayerNode.has("rnnoutput")) {
                                lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction");
                            }
                            if (lossFunctionNode == null) break block19;
                            String lossFunctionEnumStr = lossFunctionNode.asText();
                            LossFunctions.LossFunction lossFunction = null;
                            try {
                                lossFunction = LossFunctions.LossFunction.valueOf((String)lossFunctionEnumStr);
                            }
                            catch (Exception e) {
                                log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", (Throwable)e);
                            }
                            if (lossFunction == null) break block19;
                            switch (lossFunction) {
                                case MSE: {
                                    ol.setLossFn((ILossFunction)new LossMSE());
                                    break;
                                }
                                case XENT: {
                                    ol.setLossFn((ILossFunction)new LossBinaryXENT());
                                    break;
                                }
                                case NEGATIVELOGLIKELIHOOD: {
                                    ol.setLossFn((ILossFunction)new LossNegativeLogLikelihood());
                                    break;
                                }
                                case MCXENT: {
                                    ol.setLossFn((ILossFunction)new LossMCXENT());
                                    break;
                                }
                                default: {
                                    log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}", (Object)lossFunction);
                                }
                            }
                            break block19;
                        }
                        log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})", confs != null ? confs.getClass() : null);
                    }
                    catch (IOException e) {
                        log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", (Throwable)e);
                        break;
                    }
                }
            }
            ++layerCount;
        }
        return conf;
    }

    public String toString() {
        return this.toJson();
    }

    public NeuralNetConfiguration getConf(int i) {
        return this.confs.get(i);
    }

    public MultiLayerConfiguration clone() {
        try {
            MultiLayerConfiguration clone = (MultiLayerConfiguration)super.clone();
            if (clone.confs != null) {
                ArrayList<NeuralNetConfiguration> list = new ArrayList<NeuralNetConfiguration>();
                for (NeuralNetConfiguration neuralNetConfiguration : clone.confs) {
                    list.add(neuralNetConfiguration.clone());
                }
                clone.confs = list;
            }
            if (clone.inputPreProcessors != null) {
                HashMap<Integer, InputPreProcessor> map = new HashMap<Integer, InputPreProcessor>();
                for (Map.Entry entry : clone.inputPreProcessors.entrySet()) {
                    map.put((Integer)entry.getKey(), ((InputPreProcessor)entry.getValue()).clone());
                }
                clone.inputPreProcessors = map;
            }
            return clone;
        }
        catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    public InputPreProcessor getInputPreProcess(int curr) {
        return this.inputPreProcessors.get(curr);
    }

    public List<NeuralNetConfiguration> getConfs() {
        return this.confs;
    }

    public Map<Integer, InputPreProcessor> getInputPreProcessors() {
        return this.inputPreProcessors;
    }

    public boolean isPretrain() {
        return this.pretrain;
    }

    public boolean isBackprop() {
        return this.backprop;
    }

    public BackpropType getBackpropType() {
        return this.backpropType;
    }

    public int getTbpttFwdLength() {
        return this.tbpttFwdLength;
    }

    public int getTbpttBackLength() {
        return this.tbpttBackLength;
    }

    public void setConfs(List<NeuralNetConfiguration> confs) {
        this.confs = confs;
    }

    public void setInputPreProcessors(Map<Integer, InputPreProcessor> inputPreProcessors) {
        this.inputPreProcessors = inputPreProcessors;
    }

    public void setPretrain(boolean pretrain) {
        this.pretrain = pretrain;
    }

    public void setBackprop(boolean backprop) {
        this.backprop = backprop;
    }

    public void setBackpropType(BackpropType backpropType) {
        this.backpropType = backpropType;
    }

    public void setTbpttFwdLength(int tbpttFwdLength) {
        this.tbpttFwdLength = tbpttFwdLength;
    }

    public void setTbpttBackLength(int tbpttBackLength) {
        this.tbpttBackLength = tbpttBackLength;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MultiLayerConfiguration)) {
            return false;
        }
        MultiLayerConfiguration other = (MultiLayerConfiguration)o;
        if (!other.canEqual(this)) {
            return false;
        }
        List<NeuralNetConfiguration> this$confs = this.getConfs();
        List<NeuralNetConfiguration> other$confs = other.getConfs();
        if (this$confs == null ? other$confs != null : !((Object)this$confs).equals(other$confs)) {
            return false;
        }
        Map<Integer, InputPreProcessor> this$inputPreProcessors = this.getInputPreProcessors();
        Map<Integer, InputPreProcessor> other$inputPreProcessors = other.getInputPreProcessors();
        if (this$inputPreProcessors == null ? other$inputPreProcessors != null : !((Object)this$inputPreProcessors).equals(other$inputPreProcessors)) {
            return false;
        }
        if (this.isPretrain() != other.isPretrain()) {
            return false;
        }
        if (this.isBackprop() != other.isBackprop()) {
            return false;
        }
        BackpropType this$backpropType = this.getBackpropType();
        BackpropType other$backpropType = other.getBackpropType();
        if (this$backpropType == null ? other$backpropType != null : !((Object)((Object)this$backpropType)).equals((Object)other$backpropType)) {
            return false;
        }
        if (this.getTbpttFwdLength() != other.getTbpttFwdLength()) {
            return false;
        }
        return this.getTbpttBackLength() == other.getTbpttBackLength();
    }

    protected boolean canEqual(Object other) {
        return other instanceof MultiLayerConfiguration;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        List<NeuralNetConfiguration> $confs = this.getConfs();
        result = result * 59 + ($confs == null ? 43 : ((Object)$confs).hashCode());
        Map<Integer, InputPreProcessor> $inputPreProcessors = this.getInputPreProcessors();
        result = result * 59 + ($inputPreProcessors == null ? 43 : ((Object)$inputPreProcessors).hashCode());
        result = result * 59 + (this.isPretrain() ? 79 : 97);
        result = result * 59 + (this.isBackprop() ? 79 : 97);
        BackpropType $backpropType = this.getBackpropType();
        result = result * 59 + ($backpropType == null ? 43 : ((Object)((Object)$backpropType)).hashCode());
        result = result * 59 + this.getTbpttFwdLength();
        result = result * 59 + this.getTbpttBackLength();
        return result;
    }

    private MultiLayerConfiguration(List<NeuralNetConfiguration> confs, Map<Integer, InputPreProcessor> inputPreProcessors, boolean pretrain, boolean backprop, BackpropType backpropType, int tbpttFwdLength, int tbpttBackLength) {
        this.confs = confs;
        this.inputPreProcessors = inputPreProcessors;
        this.pretrain = pretrain;
        this.backprop = backprop;
        this.backpropType = backpropType;
        this.tbpttFwdLength = tbpttFwdLength;
        this.tbpttBackLength = tbpttBackLength;
    }

    public MultiLayerConfiguration() {
    }

    public static class Builder {
        protected List<NeuralNetConfiguration> confs = new ArrayList<NeuralNetConfiguration>();
        protected double dampingFactor = 100.0;
        protected Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<Integer, InputPreProcessor>();
        protected boolean pretrain = false;
        protected boolean backprop = true;
        protected BackpropType backpropType = BackpropType.Standard;
        protected int tbpttFwdLength = 20;
        protected int tbpttBackLength = 20;
        protected InputType inputType;
        @Deprecated
        protected int[] cnnInputSize;

        public Builder inputPreProcessor(Integer layer, InputPreProcessor processor) {
            this.inputPreProcessors.put(layer, processor);
            return this;
        }

        public Builder inputPreProcessors(Map<Integer, InputPreProcessor> processors) {
            this.inputPreProcessors = processors;
            return this;
        }

        public Builder backprop(boolean backprop) {
            this.backprop = backprop;
            return this;
        }

        public Builder backpropType(BackpropType type) {
            this.backpropType = type;
            return this;
        }

        public Builder tBPTTForwardLength(int forwardLength) {
            this.tbpttFwdLength = forwardLength;
            return this;
        }

        public Builder tBPTTBackwardLength(int backwardLength) {
            this.tbpttBackLength = backwardLength;
            return this;
        }

        public Builder pretrain(boolean pretrain) {
            this.pretrain = pretrain;
            return this;
        }

        public Builder confs(List<NeuralNetConfiguration> confs) {
            this.confs = confs;
            return this;
        }

        @Deprecated
        public Builder cnnInputSize(int height, int width, int depth) {
            this.cnnInputSize = new int[]{height, width, depth};
            return this;
        }

        @Deprecated
        public Builder cnnInputSize(int[] cnnInputSize) {
            if (cnnInputSize != null) {
                this.cnnInputSize = cnnInputSize;
            }
            return this;
        }

        public Builder setInputType(InputType inputType) {
            this.inputType = inputType;
            return this;
        }

        private void validate() {
            boolean printed = false;
            if (this.pretrain && !this.backprop && !defaultChangeWarningPrinted.get()) {
                log.warn("Warning: pretrain is set to true and if finetune is needed set backprop to true.");
                printed = true;
            } else if (!this.pretrain && !defaultChangeWarningPrinted.get()) {
                log.warn("Warning: new network default sets pretrain to false.");
                printed = true;
            }
            if (this.backprop && !defaultChangeWarningPrinted.get()) {
                log.warn("Warning: new network default sets backprop to true.");
                printed = true;
            }
            if (printed) {
                defaultChangeWarningPrinted.set(true);
            }
        }

        public MultiLayerConfiguration build() {
            if (this.cnnInputSize != null) {
                new ConvolutionLayerSetup(this, this.cnnInputSize[0], this.cnnInputSize[1], this.cnnInputSize[2]);
            } else if (this.inputType == null && this.inputPreProcessors.get(0) == null) {
                FeedForwardLayer ffl;
                int nIn;
                Layer firstLayer = this.confs.get(0).getLayer();
                if (firstLayer instanceof BaseRecurrentLayer) {
                    BaseRecurrentLayer brl = (BaseRecurrentLayer)firstLayer;
                    nIn = brl.getNIn();
                    if (nIn > 0) {
                        this.inputType = InputType.recurrent(nIn);
                    }
                } else if ((firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer || firstLayer instanceof OutputLayer) && (nIn = (ffl = (FeedForwardLayer)firstLayer).getNIn()) > 0) {
                    this.inputType = InputType.feedForward(nIn);
                }
            }
            if (this.inputType != null) {
                InputType currentInputType = this.inputType;
                for (int i = 0; i < this.confs.size(); ++i) {
                    InputPreProcessor inputPreProcessor;
                    Layer l = this.confs.get(i).getLayer();
                    if (this.inputPreProcessors.get(i) == null && (inputPreProcessor = l.getPreProcessorForInputType(currentInputType)) != null) {
                        this.inputPreProcessors.put(i, inputPreProcessor);
                    }
                    if ((inputPreProcessor = this.inputPreProcessors.get(i)) != null) {
                        currentInputType = inputPreProcessor.getOutputType(currentInputType);
                    }
                    l.setNIn(currentInputType, false);
                    currentInputType = l.getOutputType(i, currentInputType);
                }
            }
            if (this.isPretrain()) {
                for (int j = 0; j < this.confs.size(); ++j) {
                    Layer l = this.confs.get(j).getLayer();
                    if (!(l instanceof BasePretrainNetwork)) continue;
                    this.confs.get(j).setPretrain(this.pretrain);
                }
            }
            MultiLayerConfiguration conf = new MultiLayerConfiguration();
            conf.confs = this.confs;
            conf.pretrain = this.pretrain;
            conf.backprop = this.backprop;
            this.validate();
            conf.inputPreProcessors = this.inputPreProcessors;
            conf.backpropType = this.backpropType;
            conf.tbpttFwdLength = this.tbpttFwdLength;
            conf.tbpttBackLength = this.tbpttBackLength;
            Nd4j.getRandom().setSeed(conf.getConf(0).getSeed());
            return conf;
        }

        public List<NeuralNetConfiguration> getConfs() {
            return this.confs;
        }

        public double getDampingFactor() {
            return this.dampingFactor;
        }

        public Map<Integer, InputPreProcessor> getInputPreProcessors() {
            return this.inputPreProcessors;
        }

        public boolean isPretrain() {
            return this.pretrain;
        }

        public boolean isBackprop() {
            return this.backprop;
        }

        public BackpropType getBackpropType() {
            return this.backpropType;
        }

        public int getTbpttFwdLength() {
            return this.tbpttFwdLength;
        }

        public int getTbpttBackLength() {
            return this.tbpttBackLength;
        }

        public InputType getInputType() {
            return this.inputType;
        }

        @Deprecated
        public int[] getCnnInputSize() {
            return this.cnnInputSize;
        }

        public void setConfs(List<NeuralNetConfiguration> confs) {
            this.confs = confs;
        }

        public void setDampingFactor(double dampingFactor) {
            this.dampingFactor = dampingFactor;
        }

        public void setInputPreProcessors(Map<Integer, InputPreProcessor> inputPreProcessors) {
            this.inputPreProcessors = inputPreProcessors;
        }

        public void setPretrain(boolean pretrain) {
            this.pretrain = pretrain;
        }

        public void setBackprop(boolean backprop) {
            this.backprop = backprop;
        }

        public void setBackpropType(BackpropType backpropType) {
            this.backpropType = backpropType;
        }

        public void setTbpttFwdLength(int tbpttFwdLength) {
            this.tbpttFwdLength = tbpttFwdLength;
        }

        public void setTbpttBackLength(int tbpttBackLength) {
            this.tbpttBackLength = tbpttBackLength;
        }

        @Deprecated
        public void setCnnInputSize(int[] cnnInputSize) {
            this.cnnInputSize = cnnInputSize;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Builder)) {
                return false;
            }
            Builder other = (Builder)o;
            if (!other.canEqual(this)) {
                return false;
            }
            List<NeuralNetConfiguration> this$confs = this.getConfs();
            List<NeuralNetConfiguration> other$confs = other.getConfs();
            if (this$confs == null ? other$confs != null : !((Object)this$confs).equals(other$confs)) {
                return false;
            }
            if (Double.compare(this.getDampingFactor(), other.getDampingFactor()) != 0) {
                return false;
            }
            Map<Integer, InputPreProcessor> this$inputPreProcessors = this.getInputPreProcessors();
            Map<Integer, InputPreProcessor> other$inputPreProcessors = other.getInputPreProcessors();
            if (this$inputPreProcessors == null ? other$inputPreProcessors != null : !((Object)this$inputPreProcessors).equals(other$inputPreProcessors)) {
                return false;
            }
            if (this.isPretrain() != other.isPretrain()) {
                return false;
            }
            if (this.isBackprop() != other.isBackprop()) {
                return false;
            }
            BackpropType this$backpropType = this.getBackpropType();
            BackpropType other$backpropType = other.getBackpropType();
            if (this$backpropType == null ? other$backpropType != null : !((Object)((Object)this$backpropType)).equals((Object)other$backpropType)) {
                return false;
            }
            if (this.getTbpttFwdLength() != other.getTbpttFwdLength()) {
                return false;
            }
            if (this.getTbpttBackLength() != other.getTbpttBackLength()) {
                return false;
            }
            InputType this$inputType = this.getInputType();
            InputType other$inputType = other.getInputType();
            if (this$inputType == null ? other$inputType != null : !this$inputType.equals(other$inputType)) {
                return false;
            }
            return Arrays.equals(this.getCnnInputSize(), other.getCnnInputSize());
        }

        protected boolean canEqual(Object other) {
            return other instanceof Builder;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            List<NeuralNetConfiguration> $confs = this.getConfs();
            result = result * 59 + ($confs == null ? 43 : ((Object)$confs).hashCode());
            long $dampingFactor = Double.doubleToLongBits(this.getDampingFactor());
            result = result * 59 + (int)($dampingFactor >>> 32 ^ $dampingFactor);
            Map<Integer, InputPreProcessor> $inputPreProcessors = this.getInputPreProcessors();
            result = result * 59 + ($inputPreProcessors == null ? 43 : ((Object)$inputPreProcessors).hashCode());
            result = result * 59 + (this.isPretrain() ? 79 : 97);
            result = result * 59 + (this.isBackprop() ? 79 : 97);
            BackpropType $backpropType = this.getBackpropType();
            result = result * 59 + ($backpropType == null ? 43 : ((Object)((Object)$backpropType)).hashCode());
            result = result * 59 + this.getTbpttFwdLength();
            result = result * 59 + this.getTbpttBackLength();
            InputType $inputType = this.getInputType();
            result = result * 59 + ($inputType == null ? 43 : $inputType.hashCode());
            result = result * 59 + Arrays.hashCode(this.getCnnInputSize());
            return result;
        }

        public String toString() {
            return "MultiLayerConfiguration.Builder(confs=" + this.getConfs() + ", dampingFactor=" + this.getDampingFactor() + ", inputPreProcessors=" + this.getInputPreProcessors() + ", pretrain=" + this.isPretrain() + ", backprop=" + this.isBackprop() + ", backpropType=" + (Object)((Object)this.getBackpropType()) + ", tbpttFwdLength=" + this.getTbpttFwdLength() + ", tbpttBackLength=" + this.getTbpttBackLength() + ", inputType=" + this.getInputType() + ", cnnInputSize=" + Arrays.toString(this.getCnnInputSize()) + ")";
        }
    }
}

