/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.ui.flow;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.ui.UiConnectionInfo;
import org.deeplearning4j.ui.UiUtils;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.flow.beans.Description;
import org.deeplearning4j.ui.flow.beans.LayerInfo;
import org.deeplearning4j.ui.flow.beans.LayerParams;
import org.deeplearning4j.ui.flow.beans.ModelInfo;
import org.deeplearning4j.ui.flow.beans.ModelState;
import org.deeplearning4j.ui.flow.data.FlowStaticPersistable;
import org.deeplearning4j.ui.flow.data.FlowUpdatePersistable;
import org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage;
import org.deeplearning4j.ui.weights.HistogramBin;
import org.deeplearning4j.util.UIDProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
public class FlowIterationListener
implements IterationListener {
    private static final Logger log = LoggerFactory.getLogger(FlowIterationListener.class);
    private static final String FORMAT = "%02d:%02d:%02d";
    public static final String INPUT = "INPUT";
    private int frequency = 1;
    private boolean firstIteration = true;
    private ModelState modelState = new ModelState();
    private AtomicLong iterationCount = new AtomicLong(0L);
    private long lastTime = System.currentTimeMillis();
    private long currTime;
    private long initTime = System.currentTimeMillis();
    private static final List<String> colors = Collections.unmodifiableList(Arrays.asList("#9966ff", "#ff9933", "#ffff99", "#3366ff", "#0099cc", "#669999", "#66ffff"));
    private final StatsStorageRouter ssr;
    private final String sessionID;
    private final String workerID;
    private boolean openBrowser;

    protected FlowIterationListener() {
        this(1);
    }

    public FlowIterationListener(int frequency) {
        this((StatsStorageRouter)new MapDBStatsStorage(), frequency, null, null, true);
    }

    @Deprecated
    public FlowIterationListener(@NonNull String address, int port, int frequency) {
        this(frequency);
        if (address == null) {
            throw new NullPointerException("address");
        }
    }

    public FlowIterationListener(StatsStorageRouter ssr, int frequency, String sessionID, String workerID, boolean openBrowser) {
        this.frequency = frequency;
        this.ssr = ssr;
        this.sessionID = sessionID == null ? UUID.randomUUID().toString() : sessionID;
        this.workerID = workerID == null ? UIDProvider.getJVMUID() + "_" + Thread.currentThread().getId() : workerID;
        this.openBrowser = openBrowser;
        if (ssr instanceof StatsStorage && openBrowser) {
            UIServer.getInstance().attach((StatsStorage)ssr);
        }
        System.out.println("FlowIterationListener path: http://localhost:" + UIServer.getInstance().getPort() + "/flow");
    }

    @Deprecated
    public FlowIterationListener(@NonNull UiConnectionInfo connectionInfo, int frequency) {
        this(frequency);
        if (connectionInfo == null) {
            throw new NullPointerException("connectionInfo");
        }
    }

    public boolean invoked() {
        return false;
    }

    public void invoke() {
    }

    public synchronized void iterationDone(Model model, int iteration) {
        if (this.iterationCount.incrementAndGet() % (long)this.frequency == 0L) {
            this.currTime = System.currentTimeMillis();
            if (this.firstIteration) {
                ModelInfo info = this.buildModelInfo(model);
                FlowStaticPersistable staticInfo = new FlowStaticPersistable(this.sessionID, this.workerID, System.currentTimeMillis(), info);
                this.ssr.putStaticInfo((Persistable)staticInfo);
            }
            this.buildModelState(model);
            FlowUpdatePersistable updateInfo = new FlowUpdatePersistable(this.sessionID, this.workerID, System.currentTimeMillis(), this.modelState);
            this.ssr.putUpdate((Persistable)updateInfo);
            if (this.firstIteration && this.openBrowser) {
                UIServer uiServer = UIServer.getInstance();
                String path = "http://localhost:" + uiServer.getPort() + "/flow?sid=" + this.sessionID;
                try {
                    UiUtils.tryOpenBrowser(path, log);
                }
                catch (Exception exception) {
                    // empty catch block
                }
                this.firstIteration = false;
            }
        }
        this.lastTime = System.currentTimeMillis();
    }

    protected List<LayerInfo> flattenToY(ModelInfo model, GraphVertex[] vertices, List<String> currentInput, int currentY) {
        ArrayList<LayerInfo> results = new ArrayList<LayerInfo>();
        int x = 0;
        for (int v = 0; v < vertices.length; ++v) {
            GraphVertex vertex = vertices[v];
            VertexIndices[] indices = vertex.getInputVertices();
            if (indices == null) continue;
            for (int i = 0; i < indices.length; ++i) {
                GraphVertex cv = vertices[indices[i].getVertexIndex()];
                String inputName = cv.getVertexName();
                for (String input : currentInput) {
                    if (!inputName.equals(input)) continue;
                    try {
                        LayerInfo connection;
                        LayerInfo info = model.getLayerInfoByName(vertex.getVertexName());
                        if (info == null) {
                            info = this.getLayerInfo(vertex.getLayer(), x, currentY, 121);
                        }
                        info.setName(vertex.getVertexName());
                        if (vertex.getLayer() == null) {
                            info.setLayerType(vertex.getClass().getSimpleName());
                        }
                        if (info.getName().endsWith("-merge")) {
                            info.setLayerType("MERGE");
                        }
                        if (model.getLayerInfoByName(vertex.getVertexName()) == null) {
                            ++x;
                            model.addLayer(info);
                            results.add(info);
                        }
                        if ((connection = model.getLayerInfoByName(input)) == null) continue;
                        connection.addConnection(info);
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
        }
        return results;
    }

    protected void buildModelState(Model model) {
        long timeSpent = this.currTime - this.lastTime;
        float timeSec = (float)timeSpent / 1000.0f;
        INDArray input = model.input();
        long tadLength = Shape.getTADLength((int[])input.shape(), (int[])ArrayUtil.range((int)1, (int)input.rank()));
        long numSamples = input.lengthLong() / tadLength;
        this.modelState.addPerformanceSamples((float)numSamples / timeSec);
        this.modelState.addPerformanceBatches(1.0f / timeSec);
        this.modelState.setIterationTime(timeSpent);
        this.modelState.addScore((float)model.score());
        this.modelState.setScore((float)model.score());
        this.modelState.setTrainingTime(this.parseTime(System.currentTimeMillis() - this.initTime));
        LinkedHashMap newGrad = new LinkedHashMap();
        LinkedHashMap newParams = new LinkedHashMap();
        Map params = model.paramTable();
        Layer[] layers = null;
        if (model instanceof MultiLayerNetwork) {
            layers = ((MultiLayerNetwork)model).getLayers();
        } else if (model instanceof ComputationGraph) {
            layers = ((ComputationGraph)model).getLayers();
        }
        ArrayList<Double> lrs = new ArrayList<Double>();
        if (layers != null) {
            for (Layer layer : layers) {
                lrs.add(layer.conf().getLayer().getLearningRate());
            }
            this.modelState.setLearningRates(lrs);
        }
        LinkedHashMap<Integer, LayerParams> layerParamsMap = new LinkedHashMap<Integer, LayerParams>();
        for (Map.Entry entry : params.entrySet()) {
            String param = (String)entry.getKey();
            if (!Character.isDigit(param.charAt(0))) continue;
            int layer = Integer.parseInt(param.replaceAll("\\_.*$", ""));
            String key = param.replaceAll("^.*?_", "").toLowerCase();
            if (!layerParamsMap.containsKey(layer)) {
                layerParamsMap.put(layer, new LayerParams());
            }
            HistogramBin histogram = new HistogramBin.Builder(((INDArray)entry.getValue()).dup()).setBinCount(14).setRounding(6).build();
            if (key.equalsIgnoreCase("w")) {
                ((LayerParams)layerParamsMap.get(layer)).setW(histogram.getData());
                continue;
            }
            if (key.equalsIgnoreCase("rw")) {
                ((LayerParams)layerParamsMap.get(layer)).setRW(histogram.getData());
                continue;
            }
            if (key.equalsIgnoreCase("rwf")) {
                ((LayerParams)layerParamsMap.get(layer)).setRWF(histogram.getData());
                continue;
            }
            if (!key.equalsIgnoreCase("b")) continue;
            ((LayerParams)layerParamsMap.get(layer)).setB(histogram.getData());
        }
        this.modelState.setLayerParams(layerParamsMap);
    }

    protected ModelInfo buildModelInfo(Model model) {
        ModelInfo modelInfo = new ModelInfo();
        if (model instanceof ComputationGraph) {
            ComputationGraph graph = (ComputationGraph)model;
            List inputs = graph.getConfiguration().getNetworkInputs();
            int x = 0;
            for (String input : inputs) {
                long tadLength;
                long numSamples;
                GraphVertex vertex = graph.getVertex(input);
                if (vertex.getInputs() == null || vertex.getInputs().length == 0) {
                    numSamples = 0L;
                    tadLength = 0L;
                } else {
                    INDArray gInput = vertex.getInputs()[0];
                    tadLength = Shape.getTADLength((int[])gInput.shape(), (int[])ArrayUtil.range((int)1, (int)gInput.rank()));
                    numSamples = gInput.lengthLong() / tadLength;
                }
                Layer[] builder = new StringBuilder();
                builder.append("Vertex name: ").append(input).append("<br/>");
                builder.append("Model input").append("<br/>");
                builder.append("Input size: ").append(tadLength).append("<br/>");
                builder.append("Batch size: ").append(numSamples).append("<br/>");
                LayerInfo info = new LayerInfo();
                info.setId(0L);
                info.setName(input);
                info.setY(0);
                info.setX(x);
                info.setLayerType(INPUT);
                info.setDescription(new Description());
                info.getDescription().setMainLine("Model input");
                info.getDescription().setText(builder.toString());
                modelInfo.addLayer(info);
                ++x;
            }
            GraphVertex[] vertices = graph.getVertices();
            ArrayList<String> needle = new ArrayList<String>();
            for (int y = 1; y < vertices.length; ++y) {
                if (needle.isEmpty()) {
                    needle.addAll(inputs);
                }
                List<LayerInfo> layersForGridY = this.flattenToY(modelInfo, vertices, needle, y);
                needle.clear();
                for (LayerInfo layerInfo : layersForGridY) {
                    needle.add(layerInfo.getName());
                }
                if (!needle.isEmpty()) {
                    continue;
                }
                break;
            }
        } else if (model instanceof MultiLayerNetwork) {
            MultiLayerNetwork network = (MultiLayerNetwork)model;
            INDArray input = model.input();
            long tadLength = Shape.getTADLength((int[])input.shape(), (int[])ArrayUtil.range((int)1, (int)input.rank()));
            long numSamples = input.lengthLong() / tadLength;
            StringBuilder builder = new StringBuilder();
            builder.append("Model input").append("<br/>");
            builder.append("Input size: ").append(tadLength).append("<br/>");
            builder.append("Batch size: ").append(numSamples).append("<br/>");
            LayerInfo info = new LayerInfo();
            info.setId(0L);
            info.setName("Input");
            info.setY(0);
            info.setX(0);
            info.setLayerType(INPUT);
            info.setDescription(new Description());
            info.getDescription().setMainLine("Model input");
            info.getDescription().setText(builder.toString());
            info.addConnection(0, 1);
            modelInfo.addLayer(info);
            int y = 1;
            boolean x = false;
            for (Layer layer : network.getLayers()) {
                LayerInfo layerInfo = this.getLayerInfo(layer, 0, y, y);
                layerInfo.addConnection(0, y + 1);
                modelInfo.addLayer(layerInfo);
                ++y;
            }
            LayerInfo layerInfo = modelInfo.getLayerInfoByCoords(0, y - 1);
            layerInfo.dropConnections();
        }
        for (LayerInfo layerInfo : modelInfo.getLayers()) {
            if (layerInfo.getConnections().size() != 0) continue;
            layerInfo.setLayerType("OUTPUT");
        }
        AtomicInteger cnt = new AtomicInteger(0);
        for (String layerType : modelInfo.getLayerTypes()) {
            String curColor = colors.get(cnt.getAndIncrement());
            if (cnt.get() >= colors.size()) {
                cnt.set(0);
            }
            for (LayerInfo layerInfo : modelInfo.getLayersByType(layerType)) {
                if (layerType.equals(INPUT)) {
                    layerInfo.setColor("#99ff66");
                    continue;
                }
                if (layerType.equals("OUTPUT")) {
                    layerInfo.setColor("#e6e6e6");
                    continue;
                }
                layerInfo.setColor(curColor);
            }
        }
        return modelInfo;
    }

    private LayerInfo getLayerInfo(Layer layer, int x, int y, int order) {
        LayerInfo info = new LayerInfo();
        info.setX(x);
        info.setY(y);
        try {
            info.setName(layer.conf().getLayer().getLayerName());
        }
        catch (Exception exception) {
            // empty catch block
        }
        if (info.getName() == null || info.getName().isEmpty()) {
            info.setName("unnamed");
        }
        info.setId((long)order);
        Description description = new Description();
        info.setDescription(description);
        try {
            info.setLayerType(layer.getClass().getSimpleName().replaceAll("Layer$", ""));
        }
        catch (Exception e) {
            info.setLayerType("n/a");
            return info;
        }
        StringBuilder mainLine = new StringBuilder();
        StringBuilder subLine = new StringBuilder();
        StringBuilder fullLine = new StringBuilder();
        if (layer.type().equals((Object)Layer.Type.CONVOLUTIONAL)) {
            ConvolutionLayer layer1 = (ConvolutionLayer)layer.conf().getLayer();
            mainLine.append("K: " + Arrays.toString(layer1.getKernelSize()) + " S: " + Arrays.toString(layer1.getStride()) + " P: " + Arrays.toString(layer1.getPadding()));
            subLine.append("nIn/nOut: [" + layer1.getNIn() + "/" + layer1.getNOut() + "]");
            fullLine.append("Kernel size: ").append(Arrays.toString(layer1.getKernelSize())).append("<br/>");
            fullLine.append("Stride: ").append(Arrays.toString(layer1.getStride())).append("<br/>");
            fullLine.append("Padding: ").append(Arrays.toString(layer1.getPadding())).append("<br/>");
            fullLine.append("Inputs number: ").append(layer1.getNIn()).append("<br/>");
            fullLine.append("Outputs number: ").append(layer1.getNOut()).append("<br/>");
        } else if (layer.conf().getLayer() instanceof SubsamplingLayer) {
            SubsamplingLayer layer1 = (SubsamplingLayer)layer.conf().getLayer();
            fullLine.append("Kernel size: ").append(Arrays.toString(layer1.getKernelSize())).append("<br/>");
            fullLine.append("Stride: ").append(Arrays.toString(layer1.getStride())).append("<br/>");
            fullLine.append("Padding: ").append(Arrays.toString(layer1.getPadding())).append("<br/>");
            fullLine.append("Pooling type: ").append(layer1.getPoolingType().toString()).append("<br/>");
        } else if (layer.conf().getLayer() instanceof FeedForwardLayer) {
            FeedForwardLayer layer1 = (FeedForwardLayer)layer.conf().getLayer();
            mainLine.append("nIn/nOut: [" + layer1.getNIn() + "/" + layer1.getNOut() + "]");
            subLine.append(info.getLayerType());
            fullLine.append("Inputs number: ").append(layer1.getNIn()).append("<br/>");
            fullLine.append("Outputs number: ").append(layer1.getNOut()).append("<br/>");
        } else if (layer instanceof BaseOutputLayer) {
            mainLine.append("Outputs: [" + ((BaseOutputLayer)layer.conf().getLayer()).getNOut() + "]");
            fullLine.append("Outputs number: ").append(((BaseOutputLayer)layer.conf().getLayer()).getNOut()).append("<br/>");
        }
        subLine.append(" A: [").append(layer.conf().getLayer().getActivationFunction()).append("]");
        fullLine.append("Activation function: ").append("<b>").append(layer.conf().getLayer().getActivationFunction()).append("</b>").append("<br/>");
        description.setMainLine(mainLine.toString());
        description.setSubLine(subLine.toString());
        description.setText(fullLine.toString());
        return info;
    }

    protected String parseTime(long milliseconds) {
        return String.format(FORMAT, TimeUnit.MILLISECONDS.toHours(milliseconds), TimeUnit.MILLISECONDS.toMinutes(milliseconds) - TimeUnit.HOURS.toMinutes(TimeUnit.MILLISECONDS.toHours(milliseconds)), TimeUnit.MILLISECONDS.toSeconds(milliseconds) - TimeUnit.MINUTES.toSeconds(TimeUnit.MILLISECONDS.toMinutes(milliseconds)));
    }
}

