/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.image.loader;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteOrder;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacpp.indexer.UShortIndexer;
import org.bytedeco.javacpp.lept;
import org.bytedeco.javacpp.opencv_core;
import org.bytedeco.javacpp.opencv_imgcodecs;
import org.bytedeco.javacpp.opencv_imgproc;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.transform.ImageTransform;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

public class NativeImageLoader
extends BaseImageLoader {
    public static final String[] ALLOWED_FORMATS = new String[]{"bmp", "gif", "jpg", "jpeg", "jp2", "pbm", "pgm", "ppm", "pnm", "png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM", "PNG", "TIF", "TIFF", "EXR", "WEBP"};
    OpenCVFrameConverter.ToMat converter = null;

    public NativeImageLoader() {
    }

    public NativeImageLoader(int height, int width) {
        this.height = height;
        this.width = width;
    }

    public NativeImageLoader(int height, int width, int channels) {
        this.height = height;
        this.width = width;
        this.channels = channels;
    }

    public NativeImageLoader(int height, int width, int channels, boolean centerCropIfNeeded) {
        this(height, width, channels);
        this.centerCropIfNeeded = centerCropIfNeeded;
    }

    public NativeImageLoader(int height, int width, int channels, ImageTransform imageTransform) {
        this(height, width, channels);
        this.imageTransform = imageTransform;
        this.converter = new OpenCVFrameConverter.ToMat();
    }

    @Override
    public String[] getAllowedFormats() {
        return ALLOWED_FORMATS;
    }

    @Override
    public INDArray asRowVector(File f) throws IOException {
        return this.asMatrix(f).ravel();
    }

    @Override
    public INDArray asRowVector(InputStream is) throws IOException {
        return this.asMatrix(is).ravel();
    }

    public INDArray asRowVector(opencv_core.Mat image) throws IOException {
        return this.asMatrix(image).ravel();
    }

    static opencv_core.Mat convert(lept.PIX pix) {
        lept.PIX pix2;
        lept.PIX tempPix = null;
        if (pix.colormap() != null) {
            tempPix = pix = (pix2 = lept.pixRemoveColormap((lept.PIX)pix, (int)2));
        } else if (pix.d() < 8) {
            pix2 = null;
            switch (pix.d()) {
                case 1: {
                    pix2 = lept.pixConvert1To8(null, (lept.PIX)pix, (byte)0, (byte)-1);
                    break;
                }
                case 2: {
                    pix2 = lept.pixConvert2To8((lept.PIX)pix, (byte)0, (byte)85, (byte)-86, (byte)-1, (int)0);
                    break;
                }
                case 4: {
                    pix2 = lept.pixConvert4To8((lept.PIX)pix, (int)0);
                    break;
                }
                default: {
                    assert (false);
                    break;
                }
            }
            tempPix = pix = pix2;
        }
        int height = pix.h();
        int width = pix.w();
        int channels = pix.d() / 8;
        opencv_core.Mat mat = new opencv_core.Mat(height, width, opencv_core.CV_8UC((int)channels), (Pointer)pix.data(), (long)(4 * pix.wpl()));
        opencv_core.Mat mat2 = new opencv_core.Mat(height, width, opencv_core.CV_8UC((int)channels));
        int[] swap = new int[]{0, 3, 1, 2, 2, 1, 3, 0};
        int[] copy = new int[]{0, 0, 1, 1, 2, 2, 3, 3};
        int[] fromTo = channels > 1 && ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN) ? swap : copy;
        opencv_core.mixChannels((opencv_core.Mat)mat, (long)1L, (opencv_core.Mat)mat2, (long)1L, (int[])fromTo, (long)(fromTo.length / 2));
        if (tempPix != null) {
            lept.pixDestroy((lept.PIX)tempPix);
        }
        return mat2;
    }

    @Override
    public INDArray asMatrix(File f) throws IOException {
        try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));){
            INDArray iNDArray = this.asMatrix(bis);
            return iNDArray;
        }
    }

    @Override
    public INDArray asMatrix(InputStream is) throws IOException {
        byte[] bytes = IOUtils.toByteArray((InputStream)is);
        opencv_core.Mat image = opencv_imgcodecs.imdecode((opencv_core.Mat)new opencv_core.Mat(bytes), (int)6);
        if (image == null || image.empty()) {
            lept.PIX pix = lept.pixReadMem((byte[])bytes, (long)bytes.length);
            if (pix == null) {
                throw new IOException("Could not decode image from input stream");
            }
            image = NativeImageLoader.convert(pix);
            lept.pixDestroy((lept.PIX)pix);
        }
        return this.asMatrix(image);
    }

    public INDArray asMatrix(opencv_core.Mat image) throws IOException {
        FloatIndexer floatidx;
        IntIndexer intidx;
        UShortIndexer ushortidx;
        int j;
        int i;
        int k;
        UByteIndexer ubyteidx;
        FloatIndexer retidx;
        if (this.imageTransform != null && this.converter != null) {
            ImageWritable writable = new ImageWritable(this.converter.convert(image));
            writable = this.imageTransform.transform(writable);
            image = this.converter.convert(writable.getFrame());
        }
        if (this.channels > 0 && image.channels() != this.channels) {
            int code = -1;
            block0 : switch (image.channels()) {
                case 1: {
                    switch (this.channels) {
                        case 3: {
                            code = 8;
                            break;
                        }
                        case 4: {
                            code = 9;
                        }
                    }
                    break;
                }
                case 3: {
                    switch (this.channels) {
                        case 1: {
                            code = 6;
                            break;
                        }
                        case 4: {
                            code = 2;
                        }
                    }
                    break;
                }
                case 4: {
                    switch (this.channels) {
                        case 1: {
                            code = 11;
                            break block0;
                        }
                        case 3: {
                            code = 3;
                        }
                    }
                }
            }
            if (code < 0) {
                throw new IOException("Cannot convert from " + image.channels() + " to " + this.channels + " channels.");
            }
            opencv_core.Mat newimage = new opencv_core.Mat();
            opencv_imgproc.cvtColor((opencv_core.Mat)image, (opencv_core.Mat)newimage, (int)code);
            image = newimage;
        }
        if (this.centerCropIfNeeded) {
            image = this.centerCropIfNeeded(image);
        }
        image = this.scalingIfNeed(image);
        int rows = image.rows();
        int cols = image.cols();
        int channels = image.channels();
        Indexer idx = image.createIndexer();
        INDArray ret = Nd4j.create((int[])new int[]{channels, rows, cols});
        Pointer pointer = ret.data().pointer();
        int[] stride = ret.stride();
        boolean done = false;
        if (pointer instanceof FloatPointer) {
            retidx = FloatIndexer.create((FloatPointer)((FloatPointer)pointer), (long[])new long[]{channels, rows, cols}, (long[])new long[]{stride[0], stride[1], stride[2]});
            if (idx instanceof UByteIndexer) {
                ubyteidx = (UByteIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, (float)ubyteidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            } else if (idx instanceof UShortIndexer) {
                ushortidx = (UShortIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, (float)ushortidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            } else if (idx instanceof IntIndexer) {
                intidx = (IntIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, (float)intidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            } else if (idx instanceof FloatIndexer) {
                floatidx = (FloatIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, floatidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            }
        } else if (pointer instanceof DoublePointer) {
            retidx = DoubleIndexer.create((DoublePointer)((DoublePointer)pointer), (long[])new long[]{channels, rows, cols}, (long[])new long[]{stride[0], stride[1], stride[2]});
            if (idx instanceof UByteIndexer) {
                ubyteidx = (UByteIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, (double)ubyteidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            } else if (idx instanceof UShortIndexer) {
                ushortidx = (UShortIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, (double)ushortidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            } else if (idx instanceof IntIndexer) {
                intidx = (IntIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, (double)intidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            } else if (idx instanceof FloatIndexer) {
                floatidx = (FloatIndexer)idx;
                for (k = 0; k < channels; ++k) {
                    for (i = 0; i < rows; ++i) {
                        for (j = 0; j < cols; ++j) {
                            retidx.put((long)k, (long)i, (long)j, (double)floatidx.get((long)i, (long)j, (long)k));
                        }
                    }
                }
                done = true;
            }
        }
        if (!done) {
            for (int k2 = 0; k2 < channels; ++k2) {
                for (int i2 = 0; i2 < rows; ++i2) {
                    for (int j2 = 0; j2 < cols; ++j2) {
                        if (channels > 1) {
                            ret.putScalar(k2, i2, j2, idx.getDouble(new long[]{i2, j2, k2}));
                            continue;
                        }
                        ret.putScalar(i2, j2, idx.getDouble(new long[]{i2, j2}));
                    }
                }
            }
        }
        image.data();
        Nd4j.getAffinityManager().tagLocation(ret, AffinityManager.Location.HOST);
        return ret.reshape(ArrayUtil.combine((int[][])new int[][]{{1}, ret.shape()}));
    }

    protected opencv_core.Mat centerCropIfNeeded(opencv_core.Mat img) {
        int x = 0;
        int y = 0;
        int height = img.rows();
        int width = img.cols();
        int diff = Math.abs(width - height) / 2;
        if (width > height) {
            x = diff;
            width -= diff;
        } else if (height > width) {
            y = diff;
            height -= diff;
        }
        return img.apply(new opencv_core.Rect(x, y, width, height));
    }

    protected opencv_core.Mat scalingIfNeed(opencv_core.Mat image) {
        return this.scalingIfNeed(image, this.height, this.width);
    }

    protected opencv_core.Mat scalingIfNeed(opencv_core.Mat image, int dstHeight, int dstWidth) {
        opencv_core.Mat scaled = image;
        if (dstHeight > 0 && dstWidth > 0 && (image.rows() != dstHeight || image.cols() != dstWidth)) {
            scaled = new opencv_core.Mat();
            opencv_imgproc.resize((opencv_core.Mat)image, (opencv_core.Mat)scaled, (opencv_core.Size)new opencv_core.Size(dstWidth, dstHeight));
        }
        return scaled;
    }
}

