/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.ui.module.train;

import com.fasterxml.jackson.annotation.JsonIgnore;
import java.beans.ConstructorProperties;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;

public class TrainModuleUtils {
    public static GraphInfo buildGraphInfo(MultiLayerConfiguration config) {
        ArrayList<String> vertexNames = new ArrayList<String>();
        ArrayList<String> originalVertexName = new ArrayList<String>();
        ArrayList<String> layerTypes = new ArrayList<String>();
        ArrayList<List<Integer>> layerInputs = new ArrayList<List<Integer>>();
        ArrayList<Map<String, String>> layerInfo = new ArrayList<Map<String, String>>();
        vertexNames.add("Input");
        originalVertexName.add(null);
        layerTypes.add("Input");
        layerInputs.add(Collections.emptyList());
        layerInfo.add(Collections.emptyMap());
        List list = config.getConfs();
        int layerIdx = 1;
        for (NeuralNetConfiguration c : list) {
            Layer layer = c.getLayer();
            String layerName = layer.getLayerName();
            if (layerName == null) {
                layerName = "layer" + layerIdx;
            }
            vertexNames.add(layerName);
            originalVertexName.add(String.valueOf(layerIdx - 1));
            String layerType = c.getLayer().getClass().getSimpleName().replaceAll("Layer$", "");
            layerTypes.add(layerType);
            layerInputs.add(Collections.singletonList(layerIdx - 1));
            ++layerIdx;
            Map<String, String> map = TrainModuleUtils.getLayerInfo(c, layer);
            layerInfo.add(map);
        }
        return new GraphInfo(vertexNames, layerTypes, layerInputs, layerInfo, originalVertexName);
    }

    public static GraphInfo buildGraphInfo(ComputationGraphConfiguration config) {
        ArrayList<String> layerNames = new ArrayList<String>();
        ArrayList<String> layerTypes = new ArrayList<String>();
        ArrayList<List<Integer>> layerInputs = new ArrayList<List<Integer>>();
        ArrayList<Map<String, String>> layerInfo = new ArrayList<Map<String, String>>();
        Map vertices = config.getVertices();
        Map vertexInputs = config.getVertexInputs();
        List networkInputs = config.getNetworkInputs();
        ArrayList<String> originalVertexName = new ArrayList<String>();
        HashMap<String, Integer> vertexToIndexMap = new HashMap<String, Integer>();
        int vertexCount = 0;
        for (String s : networkInputs) {
            vertexToIndexMap.put(s, vertexCount++);
            layerNames.add(s);
            originalVertexName.add(s);
            layerTypes.add(s);
            layerInputs.add(Collections.emptyList());
            layerInfo.add(Collections.emptyMap());
        }
        for (String s : vertices.keySet()) {
            vertexToIndexMap.put(s, vertexCount++);
        }
        boolean layerCount = false;
        for (Map.Entry entry : vertices.entrySet()) {
            GraphVertex gv = (GraphVertex)entry.getValue();
            layerNames.add((String)entry.getKey());
            List inputsThisVertex = (List)vertexInputs.get(entry.getKey());
            ArrayList inputIndexes = new ArrayList();
            for (String s : inputsThisVertex) {
                inputIndexes.add(vertexToIndexMap.get(s));
            }
            layerInputs.add(inputIndexes);
            if (gv instanceof LayerVertex) {
                NeuralNetConfiguration c = ((LayerVertex)gv).getLayerConf();
                Layer layer = c.getLayer();
                String layerType = layer.getClass().getSimpleName().replaceAll("Layer$", "");
                layerTypes.add(layerType);
                Map<String, String> map = TrainModuleUtils.getLayerInfo(c, layer);
                layerInfo.add(map);
            } else {
                String layerType = gv.getClass().getSimpleName();
                layerTypes.add(layerType);
                Map thisVertexInfo = Collections.emptyMap();
                layerInfo.add(thisVertexInfo);
            }
            originalVertexName.add((String)entry.getKey());
        }
        return new GraphInfo(layerNames, layerTypes, layerInputs, layerInfo, originalVertexName);
    }

    public static GraphInfo buildGraphInfo(NeuralNetConfiguration config) {
        ArrayList<String> vertexNames = new ArrayList<String>();
        ArrayList<String> originalVertexName = new ArrayList<String>();
        ArrayList<String> layerTypes = new ArrayList<String>();
        ArrayList<List<Integer>> layerInputs = new ArrayList<List<Integer>>();
        ArrayList<Map<String, String>> layerInfo = new ArrayList<Map<String, String>>();
        vertexNames.add("Input");
        originalVertexName.add(null);
        layerTypes.add("Input");
        layerInputs.add(Collections.emptyList());
        layerInfo.add(Collections.emptyMap());
        if (config.getLayer() instanceof VariationalAutoencoder) {
            VariationalAutoencoder va = (VariationalAutoencoder)config.getLayer();
            int[] encLayerSizes = va.getEncoderLayerSizes();
            int[] decLayerSizes = va.getDecoderLayerSizes();
            int layerIndex = 1;
            for (int i = 0; i < encLayerSizes.length; ++i) {
                String name = "encoder_" + i;
                vertexNames.add(name);
                originalVertexName.add("e" + i);
                String layerType = "VAE-Encoder";
                layerTypes.add(layerType);
                layerInputs.add(Collections.singletonList(layerIndex - 1));
                ++layerIndex;
                LinkedHashMap<String, String> encoderInfo = new LinkedHashMap<String, String>();
                int inputSize = i == 0 ? va.getNIn() : encLayerSizes[i - 1];
                int outputSize = encLayerSizes[i];
                encoderInfo.put("Input Size", String.valueOf(inputSize));
                encoderInfo.put("Layer Size", String.valueOf(outputSize));
                encoderInfo.put("Num Parameters", String.valueOf((inputSize + 1) * outputSize));
                encoderInfo.put("Activation Function", va.getActivationFn().toString());
                layerInfo.add(encoderInfo);
            }
            vertexNames.add("z");
            originalVertexName.add("pZX");
            layerTypes.add("VAE-LatentVariable");
            layerInputs.add(Collections.singletonList(layerIndex - 1));
            ++layerIndex;
            LinkedHashMap<String, String> latentInfo = new LinkedHashMap<String, String>();
            int inputSize = encLayerSizes[encLayerSizes.length - 1];
            int outputSize = va.getNOut();
            latentInfo.put("Input Size", String.valueOf(inputSize));
            latentInfo.put("Layer Size", String.valueOf(outputSize));
            latentInfo.put("Num Parameters", String.valueOf((inputSize + 1) * outputSize * 2));
            latentInfo.put("Activation Function", va.getPzxActivationFn().toString());
            layerInfo.add(latentInfo);
            for (int i = 0; i < decLayerSizes.length; ++i) {
                String name = "decoder_" + i;
                vertexNames.add(name);
                originalVertexName.add("e" + i);
                String layerType = "VAE-Decoder";
                layerTypes.add(layerType);
                layerInputs.add(Collections.singletonList(layerIndex - 1));
                ++layerIndex;
                LinkedHashMap<String, String> decoderInfo = new LinkedHashMap<String, String>();
                inputSize = i == 0 ? va.getNOut() : decLayerSizes[i - 1];
                outputSize = encLayerSizes[i];
                decoderInfo.put("Input Size", String.valueOf(inputSize));
                decoderInfo.put("Layer Size", String.valueOf(outputSize));
                decoderInfo.put("Num Parameters", String.valueOf((inputSize + 1) * outputSize));
                decoderInfo.put("Activation Function", va.getActivationFn().toString());
                layerInfo.add(decoderInfo);
            }
            vertexNames.add("x");
            originalVertexName.add("pXZ");
            layerTypes.add("VAE-Reconstruction");
            layerInputs.add(Collections.singletonList(layerIndex - 1));
            ++layerIndex;
            LinkedHashMap<String, String> reconstructionInfo = new LinkedHashMap<String, String>();
            inputSize = decLayerSizes[decLayerSizes.length - 1];
            outputSize = va.getNIn();
            reconstructionInfo.put("Input Size", String.valueOf(inputSize));
            reconstructionInfo.put("Layer Size", String.valueOf(outputSize));
            reconstructionInfo.put("Num Parameters", String.valueOf((inputSize + 1) * va.getOutputDistribution().distributionInputSize(va.getNIn())));
            reconstructionInfo.put("Distribution", va.getOutputDistribution().toString());
            layerInfo.add(reconstructionInfo);
        } else {
            Layer layer = config.getLayer();
            String layerName = layer.getLayerName();
            if (layerName == null) {
                layerName = "layer0";
            }
            vertexNames.add(layerName);
            originalVertexName.add(String.valueOf("0"));
            String layerType = config.getLayer().getClass().getSimpleName().replaceAll("Layer$", "");
            layerTypes.add(layerType);
            layerInputs.add(Collections.singletonList(0));
            Map<String, String> map = TrainModuleUtils.getLayerInfo(config, layer);
            layerInfo.add(map);
        }
        return new GraphInfo(vertexNames, layerTypes, layerInputs, layerInfo, originalVertexName);
    }

    private static Map<String, String> getLayerInfo(NeuralNetConfiguration c, Layer layer) {
        FeedForwardLayer layer1;
        LinkedHashMap<String, String> map = new LinkedHashMap<String, String>();
        if (layer instanceof FeedForwardLayer) {
            layer1 = (FeedForwardLayer)layer;
            map.put("Input size", String.valueOf(layer1.getNIn()));
            map.put("Output size", String.valueOf(layer1.getNOut()));
            map.put("Num Parameters", String.valueOf(layer1.initializer().numParams(c)));
            map.put("Activation Function", layer1.getActivationFn().toString());
        }
        if (layer instanceof ConvolutionLayer) {
            layer1 = (ConvolutionLayer)layer;
            map.put("Kernel size", Arrays.toString(layer1.getKernelSize()));
            map.put("Stride", Arrays.toString(layer1.getStride()));
            map.put("Padding", Arrays.toString(layer1.getPadding()));
        } else if (layer instanceof SubsamplingLayer) {
            layer1 = (SubsamplingLayer)layer;
            map.put("Kernel size", Arrays.toString(layer1.getKernelSize()));
            map.put("Stride", Arrays.toString(layer1.getStride()));
            map.put("Padding", Arrays.toString(layer1.getPadding()));
            map.put("Pooling Type", layer1.getPoolingType().toString());
        } else if (layer instanceof BaseOutputLayer) {
            BaseOutputLayer ol = (BaseOutputLayer)layer;
            map.put("Loss Function", ol.getLossFn().toString());
        }
        return map;
    }

    public static class GraphInfo {
        private List<String> vertexNames;
        private List<String> vertexTypes;
        private List<List<Integer>> vertexInputs;
        private List<Map<String, String>> vertexInfo;
        @JsonIgnore
        private List<String> originalVertexName;

        @ConstructorProperties(value={"vertexNames", "vertexTypes", "vertexInputs", "vertexInfo", "originalVertexName"})
        public GraphInfo(List<String> vertexNames, List<String> vertexTypes, List<List<Integer>> vertexInputs, List<Map<String, String>> vertexInfo, List<String> originalVertexName) {
            this.vertexNames = vertexNames;
            this.vertexTypes = vertexTypes;
            this.vertexInputs = vertexInputs;
            this.vertexInfo = vertexInfo;
            this.originalVertexName = originalVertexName;
        }

        public List<String> getVertexNames() {
            return this.vertexNames;
        }

        public List<String> getVertexTypes() {
            return this.vertexTypes;
        }

        public List<List<Integer>> getVertexInputs() {
            return this.vertexInputs;
        }

        public List<Map<String, String>> getVertexInfo() {
            return this.vertexInfo;
        }

        public List<String> getOriginalVertexName() {
            return this.originalVertexName;
        }

        public void setVertexNames(List<String> vertexNames) {
            this.vertexNames = vertexNames;
        }

        public void setVertexTypes(List<String> vertexTypes) {
            this.vertexTypes = vertexTypes;
        }

        public void setVertexInputs(List<List<Integer>> vertexInputs) {
            this.vertexInputs = vertexInputs;
        }

        public void setVertexInfo(List<Map<String, String>> vertexInfo) {
            this.vertexInfo = vertexInfo;
        }

        public void setOriginalVertexName(List<String> originalVertexName) {
            this.originalVertexName = originalVertexName;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof GraphInfo)) {
                return false;
            }
            GraphInfo other = (GraphInfo)o;
            if (!other.canEqual(this)) {
                return false;
            }
            List<String> this$vertexNames = this.getVertexNames();
            List<String> other$vertexNames = other.getVertexNames();
            if (this$vertexNames == null ? other$vertexNames != null : !((Object)this$vertexNames).equals(other$vertexNames)) {
                return false;
            }
            List<String> this$vertexTypes = this.getVertexTypes();
            List<String> other$vertexTypes = other.getVertexTypes();
            if (this$vertexTypes == null ? other$vertexTypes != null : !((Object)this$vertexTypes).equals(other$vertexTypes)) {
                return false;
            }
            List<List<Integer>> this$vertexInputs = this.getVertexInputs();
            List<List<Integer>> other$vertexInputs = other.getVertexInputs();
            if (this$vertexInputs == null ? other$vertexInputs != null : !((Object)this$vertexInputs).equals(other$vertexInputs)) {
                return false;
            }
            List<Map<String, String>> this$vertexInfo = this.getVertexInfo();
            List<Map<String, String>> other$vertexInfo = other.getVertexInfo();
            if (this$vertexInfo == null ? other$vertexInfo != null : !((Object)this$vertexInfo).equals(other$vertexInfo)) {
                return false;
            }
            List<String> this$originalVertexName = this.getOriginalVertexName();
            List<String> other$originalVertexName = other.getOriginalVertexName();
            return !(this$originalVertexName == null ? other$originalVertexName != null : !((Object)this$originalVertexName).equals(other$originalVertexName));
        }

        protected boolean canEqual(Object other) {
            return other instanceof GraphInfo;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            List<String> $vertexNames = this.getVertexNames();
            result = result * 59 + ($vertexNames == null ? 43 : ((Object)$vertexNames).hashCode());
            List<String> $vertexTypes = this.getVertexTypes();
            result = result * 59 + ($vertexTypes == null ? 43 : ((Object)$vertexTypes).hashCode());
            List<List<Integer>> $vertexInputs = this.getVertexInputs();
            result = result * 59 + ($vertexInputs == null ? 43 : ((Object)$vertexInputs).hashCode());
            List<Map<String, String>> $vertexInfo = this.getVertexInfo();
            result = result * 59 + ($vertexInfo == null ? 43 : ((Object)$vertexInfo).hashCode());
            List<String> $originalVertexName = this.getOriginalVertexName();
            result = result * 59 + ($originalVertexName == null ? 43 : ((Object)$originalVertexName).hashCode());
            return result;
        }

        public String toString() {
            return "TrainModuleUtils.GraphInfo(vertexNames=" + this.getVertexNames() + ", vertexTypes=" + this.getVertexTypes() + ", vertexInputs=" + this.getVertexInputs() + ", vertexInfo=" + this.getVertexInfo() + ", originalVertexName=" + this.getOriginalVertexName() + ")";
        }
    }
}

