/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.ui.weights;

import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.Image;
import java.awt.image.BufferedImage;
import java.awt.image.RenderedImage;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.UUID;
import javax.imageio.ImageIO;
import lombok.NonNull;
import org.datavec.api.util.ClassPathResource;
import org.datavec.image.loader.ImageLoader;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.ui.UiConnectionInfo;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage;
import org.deeplearning4j.ui.weights.ConvolutionListenerPersistable;
import org.deeplearning4j.util.UIDProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ConvolutionalIterationListener
implements IterationListener {
    private int freq = 10;
    private static final Logger log = LoggerFactory.getLogger(ConvolutionalIterationListener.class);
    private int minibatchNum = 0;
    private boolean openBrowser = true;
    private String path;
    private boolean firstIteration = true;
    private Color borderColor = new Color(140, 140, 140);
    private Color bgColor = new Color(255, 255, 255);
    private final StatsStorageRouter ssr;
    private final String sessionID;
    private final String workerID;

    public ConvolutionalIterationListener(UiConnectionInfo connectionInfo, int visualizationFrequency) {
        this((StatsStorageRouter)new MapDBStatsStorage(), visualizationFrequency, true);
    }

    public ConvolutionalIterationListener(int visualizationFrequency) {
        this(visualizationFrequency, true);
    }

    public ConvolutionalIterationListener(int iterations, boolean openBrowser) {
        this((StatsStorageRouter)new MapDBStatsStorage(), iterations, openBrowser);
    }

    public ConvolutionalIterationListener(StatsStorageRouter ssr, int iterations, boolean openBrowser) {
        this(ssr, iterations, openBrowser, null, null);
    }

    public ConvolutionalIterationListener(StatsStorageRouter ssr, int iterations, boolean openBrowser, String sessionID, String workerID) {
        this.ssr = ssr;
        this.sessionID = sessionID == null ? UUID.randomUUID().toString() : sessionID;
        this.workerID = workerID == null ? UIDProvider.getJVMUID() + "_" + Thread.currentThread().getId() : workerID;
        String subPath = "activations";
        this.freq = iterations;
        this.openBrowser = openBrowser;
        this.path = "http://localhost:" + UIServer.getInstance().getPort() + "/" + subPath;
        if (openBrowser && ssr instanceof StatsStorage) {
            UIServer.getInstance().attach((StatsStorage)ssr);
        }
        System.out.println("ConvolutionIterationListener path: " + this.path);
    }

    public boolean invoked() {
        return false;
    }

    public void invoke() {
    }

    public void iterationDone(Model model, int iteration) {
        if (iteration % this.freq == 0) {
            ArrayList<INDArray> tensors = new ArrayList<INDArray>();
            int cnt = 0;
            Random rnd = new Random();
            MultiLayerNetwork l = (MultiLayerNetwork)model;
            BufferedImage sourceImage = null;
            for (Layer layer : l.getLayers()) {
                if (layer.type() != Layer.Type.CONVOLUTIONAL) continue;
                INDArray output = layer.activate();
                int sampleDim = rnd.nextInt(output.shape()[0] - 1) + 1;
                if (cnt == 0) {
                    INDArray inputs = ((ConvolutionLayer)layer).input();
                    try {
                        sourceImage = this.restoreRGBImage(inputs.tensorAlongDimension(sampleDim, new int[]{3, 2, 1}));
                    }
                    catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
                INDArray tad = output.tensorAlongDimension(sampleDim, new int[]{3, 2, 1});
                tensors.add(tad);
                ++cnt;
            }
            BufferedImage render = this.rasterizeConvoLayers(tensors, sourceImage);
            ConvolutionListenerPersistable p = new ConvolutionListenerPersistable(this.sessionID, this.workerID, System.currentTimeMillis(), render);
            this.ssr.putStaticInfo((Persistable)p);
            ++this.minibatchNum;
        }
    }

    private BufferedImage rasterizeConvoLayers(@NonNull List<INDArray> tensors3D, BufferedImage sourceImage) {
        if (tensors3D == null) {
            throw new NullPointerException("tensors3D");
        }
        int width = 0;
        int height = 0;
        int border = 1;
        int padding_row = 2;
        int padding_col = 80;
        int[] shape = tensors3D.get(0).shape();
        int numImages = shape[0];
        height = shape[2];
        width = shape[1];
        int maxHeight = 0;
        int totalWidth = 0;
        int iOffset = 1;
        Orientation orientation = Orientation.LANDSCAPE;
        if (tensors3D.size() > 3) {
            orientation = Orientation.PORTRAIT;
        }
        ArrayList<BufferedImage> images = new ArrayList<BufferedImage>();
        for (int layer = 0; layer < tensors3D.size(); ++layer) {
            INDArray tad = tensors3D.get(layer);
            boolean zoomed = false;
            BufferedImage image = null;
            if (orientation == Orientation.LANDSCAPE) {
                maxHeight = (height + border * 2 + padding_row) * numImages;
                image = this.renderMultipleImagesLandscape(tad, maxHeight, width, height);
                totalWidth += image.getWidth() + padding_col;
            } else if (orientation == Orientation.PORTRAIT) {
                totalWidth = (width + border * 2 + padding_row) * numImages;
                image = this.renderMultipleImagesPortrait(tad, totalWidth, width, height);
                maxHeight += image.getHeight() + padding_col;
            }
            images.add(image);
        }
        if (orientation == Orientation.LANDSCAPE) {
            totalWidth += padding_col * 2;
        } else if (orientation == Orientation.PORTRAIT) {
            maxHeight += padding_col * 2;
            maxHeight += sourceImage.getHeight() + padding_col * 2;
        }
        BufferedImage output = new BufferedImage(totalWidth, maxHeight, 1);
        Graphics2D graphics2D = output.createGraphics();
        graphics2D.setPaint(this.bgColor);
        graphics2D.fillRect(0, 0, output.getWidth(), output.getHeight());
        BufferedImage singleArrow = null;
        BufferedImage multipleArrows = null;
        try {
            ClassPathResource resource2;
            ClassPathResource resource3;
            if (orientation == Orientation.LANDSCAPE) {
                try {
                    resource3 = new ClassPathResource("arrow_sing.PNG");
                    resource2 = new ClassPathResource("arrow_mul.PNG");
                    singleArrow = ImageIO.read(resource3.getInputStream());
                    multipleArrows = ImageIO.read(resource2.getInputStream());
                }
                catch (Exception resource3) {
                    // empty catch block
                }
                graphics2D.drawImage((Image)sourceImage, padding_col / 2 - sourceImage.getWidth() / 2, maxHeight / 2 - sourceImage.getHeight() / 2, null);
                graphics2D.setPaint(this.borderColor);
                graphics2D.drawRect(padding_col / 2 - sourceImage.getWidth() / 2, maxHeight / 2 - sourceImage.getHeight() / 2, sourceImage.getWidth(), sourceImage.getHeight());
                iOffset += sourceImage.getWidth();
                if (singleArrow != null) {
                    graphics2D.drawImage((Image)singleArrow, iOffset + padding_col / 2 - singleArrow.getWidth() / 2, maxHeight / 2 - singleArrow.getHeight() / 2, null);
                }
            } else {
                try {
                    resource3 = new ClassPathResource("arrow_singi.PNG");
                    resource2 = new ClassPathResource("arrow_muli.PNG");
                    singleArrow = ImageIO.read(resource3.getInputStream());
                    multipleArrows = ImageIO.read(resource2.getInputStream());
                }
                catch (Exception resource4) {
                    // empty catch block
                }
                graphics2D.drawImage((Image)sourceImage, totalWidth / 2 - sourceImage.getWidth() / 2, padding_col / 2 - sourceImage.getHeight() / 2, null);
                graphics2D.setPaint(this.borderColor);
                graphics2D.drawRect(totalWidth / 2 - sourceImage.getWidth() / 2, padding_col / 2 - sourceImage.getHeight() / 2, sourceImage.getWidth(), sourceImage.getHeight());
                iOffset += sourceImage.getHeight();
                if (singleArrow != null) {
                    graphics2D.drawImage((Image)singleArrow, totalWidth / 2 - singleArrow.getWidth() / 2, iOffset + padding_col / 2 - singleArrow.getHeight() / 2, null);
                }
            }
            iOffset += padding_col;
        }
        catch (Exception resource4) {
            // empty catch block
        }
        for (int i = 0; i < images.size(); ++i) {
            BufferedImage curImage = (BufferedImage)images.get(i);
            if (orientation == Orientation.LANDSCAPE) {
                graphics2D.drawImage((Image)curImage, iOffset, 1, null);
                iOffset += curImage.getWidth() + padding_col;
                if (singleArrow == null || multipleArrows == null || i >= images.size() - 1 || multipleArrows == null) continue;
                graphics2D.drawImage((Image)multipleArrows, iOffset - padding_col / 2 - multipleArrows.getWidth() / 2, maxHeight / 2 - multipleArrows.getHeight() / 2, null);
                continue;
            }
            if (orientation != Orientation.PORTRAIT) continue;
            graphics2D.drawImage((Image)curImage, 1, iOffset, null);
            iOffset += curImage.getHeight() + padding_col;
            if (singleArrow == null || multipleArrows == null || i >= images.size() - 1 || multipleArrows == null) continue;
            graphics2D.drawImage((Image)multipleArrows, totalWidth / 2 - multipleArrows.getWidth() / 2, iOffset - padding_col / 2 - multipleArrows.getHeight() / 2, null);
        }
        return output;
    }

    private BufferedImage renderMultipleImagesPortrait(INDArray tensor3D, int maxWidth, int zoomWidth, int zoomHeight) {
        int border = 1;
        int padding_row = 2;
        int padding_col = 2;
        int zoomPadding = 20;
        int[] tShape = tensor3D.shape();
        int numRows = tShape[0] / tShape[2];
        int height = numRows * (tShape[1] + border + padding_col) + padding_col + zoomPadding + zoomWidth;
        BufferedImage outputImage = new BufferedImage(maxWidth, height, 10);
        Graphics2D graphics2D = outputImage.createGraphics();
        graphics2D.setPaint(this.bgColor);
        graphics2D.fillRect(0, 0, outputImage.getWidth(), outputImage.getHeight());
        int columnOffset = 0;
        int rowOffset = 0;
        int numZoomed = 0;
        int limZoomed = 5;
        int zoomSpan = maxWidth / limZoomed;
        for (int z = 0; z < tensor3D.shape()[0]; ++z) {
            INDArray tad2D = tensor3D.tensorAlongDimension(z, new int[]{2, 1});
            int rWidth = tad2D.shape()[0];
            int rHeight = tad2D.shape()[1];
            int loc_height = rHeight + border * 2 + padding_row;
            int loc_width = rWidth + border * 2 + padding_col;
            BufferedImage currentImage = this.renderImageGrayscale(tad2D);
            if (columnOffset + loc_width > maxWidth) {
                rowOffset += loc_height;
                columnOffset = 0;
            }
            graphics2D.drawImage((Image)currentImage, columnOffset + 1, rowOffset + 1, null);
            graphics2D.setPaint(this.borderColor);
            graphics2D.drawRect(columnOffset, rowOffset, tad2D.shape()[0], tad2D.shape()[1]);
            if (z % 7 == 0 && z != 0 && numZoomed < limZoomed && rHeight != zoomHeight && rWidth != zoomWidth) {
                int cY = zoomSpan * numZoomed + zoomHeight;
                int cX = zoomSpan * numZoomed + zoomWidth;
                graphics2D.drawImage(currentImage, cX - 1, height - zoomWidth - 1, zoomWidth, zoomHeight, null);
                graphics2D.drawRect(cX - 2, height - zoomWidth - 2, zoomWidth, zoomHeight);
                graphics2D.drawLine(columnOffset + rWidth, rowOffset + rHeight, cX - 2, height - zoomWidth - 2);
                ++numZoomed;
            }
            columnOffset += loc_width;
        }
        return outputImage;
    }

    private BufferedImage renderMultipleImagesLandscape(INDArray tensor3D, int maxHeight, int zoomWidth, int zoomHeight) {
        int border = 1;
        int padding_row = 2;
        int padding_col = 2;
        int zoomPadding = 20;
        int[] tShape = tensor3D.shape();
        int numColumns = tShape[0] / tShape[1];
        int width = numColumns * (tShape[1] + border + padding_col) + padding_col + zoomPadding + zoomWidth;
        BufferedImage outputImage = new BufferedImage(width, maxHeight, 10);
        Graphics2D graphics2D = outputImage.createGraphics();
        graphics2D.setPaint(this.bgColor);
        graphics2D.fillRect(0, 0, outputImage.getWidth(), outputImage.getHeight());
        int columnOffset = 0;
        int rowOffset = 0;
        int numZoomed = 0;
        int limZoomed = 5;
        int zoomSpan = maxHeight / limZoomed;
        for (int z = 0; z < tensor3D.shape()[0]; ++z) {
            INDArray tad2D = tensor3D.tensorAlongDimension(z, new int[]{2, 1});
            int rWidth = tad2D.shape()[0];
            int rHeight = tad2D.shape()[1];
            int loc_height = rHeight + border * 2 + padding_row;
            int loc_width = rWidth + border * 2 + padding_col;
            BufferedImage currentImage = this.renderImageGrayscale(tad2D);
            if (rowOffset + loc_height > maxHeight) {
                columnOffset += loc_width;
                rowOffset = 0;
            }
            graphics2D.drawImage((Image)currentImage, columnOffset + 1, rowOffset + 1, null);
            graphics2D.setPaint(this.borderColor);
            graphics2D.drawRect(columnOffset, rowOffset, tad2D.shape()[0], tad2D.shape()[1]);
            if (z % 5 == 0 && z != 0 && numZoomed < limZoomed && rHeight != zoomHeight && rWidth != zoomWidth) {
                int cY = zoomSpan * numZoomed + zoomHeight;
                graphics2D.drawImage(currentImage, width - zoomWidth - 1, cY - 1, zoomWidth, zoomHeight, null);
                graphics2D.drawRect(width - zoomWidth - 2, cY - 2, zoomWidth, zoomHeight);
                graphics2D.drawLine(columnOffset + rWidth, rowOffset + rHeight, width - zoomWidth - 2, cY - 2 + zoomHeight);
                ++numZoomed;
            }
            rowOffset += loc_height;
        }
        return outputImage;
    }

    private BufferedImage restoreRGBImage(INDArray tensor3D) {
        INDArray arrayR = null;
        INDArray arrayG = null;
        INDArray arrayB = null;
        if (tensor3D.shape()[0] == 3) {
            arrayR = tensor3D.tensorAlongDimension(2, new int[]{2, 1});
            arrayG = tensor3D.tensorAlongDimension(1, new int[]{2, 1});
            arrayB = tensor3D.tensorAlongDimension(0, new int[]{2, 1});
        } else {
            arrayG = arrayB = tensor3D.tensorAlongDimension(0, new int[]{2, 1});
            arrayR = arrayB;
        }
        BufferedImage imageToRender = new BufferedImage(arrayR.columns(), arrayR.rows(), 1);
        for (int x = 0; x < arrayR.columns(); ++x) {
            for (int y = 0; y < arrayR.rows(); ++y) {
                Color pix = new Color((int)(255.0 * arrayR.getRow(y).getDouble(x)), (int)(255.0 * arrayG.getRow(y).getDouble(x)), (int)(255.0 * arrayB.getRow(y).getDouble(x)));
                int rgb = pix.getRGB();
                imageToRender.setRGB(x, y, rgb);
            }
        }
        return imageToRender;
    }

    private BufferedImage renderImageGrayscale(INDArray array) {
        BufferedImage imageToRender = new BufferedImage(array.columns(), array.rows(), 10);
        for (int x = 0; x < array.columns(); ++x) {
            for (int y = 0; y < array.rows(); ++y) {
                imageToRender.getRaster().setSample(x, y, 0, (int)(255.0 * array.getRow(y).getDouble(x)));
            }
        }
        return imageToRender;
    }

    private void writeImageGrayscale(INDArray array, File file) {
        try {
            ImageIO.write((RenderedImage)this.renderImageGrayscale(array), "png", file);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void writeImage(INDArray array, File file) {
        BufferedImage image = ImageLoader.toImage((INDArray)array);
        try {
            ImageIO.write((RenderedImage)image, "png", file);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void writeRows(INDArray array, File file) {
        try {
            PrintWriter writer = new PrintWriter(file);
            for (int x = 0; x < array.rows(); ++x) {
                writer.println("Row [" + x + "]: " + array.getRow(x));
            }
            writer.flush();
            writer.close();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static enum Orientation {
        LANDSCAPE,
        PORTRAIT;

    }
}

