/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.jcublas;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.apache.commons.math3.util.Pair;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.ShortPointer;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionDescriptor;
import org.nd4j.linalg.compression.CompressionType;
import org.nd4j.linalg.factory.BaseNDArrayFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.JCublasNDArray;
import org.nd4j.linalg.jcublas.blas.CudaBlas;
import org.nd4j.linalg.jcublas.blas.JcublasLapack;
import org.nd4j.linalg.jcublas.blas.JcublasLevel1;
import org.nd4j.linalg.jcublas.blas.JcublasLevel2;
import org.nd4j.linalg.jcublas.blas.JcublasLevel3;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer;
import org.nd4j.linalg.jcublas.complex.ComplexDouble;
import org.nd4j.linalg.jcublas.complex.ComplexFloat;
import org.nd4j.linalg.jcublas.complex.JCublasComplexNDArray;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JCublasNDArrayFactory
extends BaseNDArrayFactory {
    private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private static Logger log = LoggerFactory.getLogger(JCublasNDArrayFactory.class);

    public JCublasNDArrayFactory() {
    }

    public JCublasNDArrayFactory(DataBuffer.Type dtype, Character order) {
        super(dtype, order);
    }

    public JCublasNDArrayFactory(DataBuffer.Type dtype, char order) {
        super(dtype, order);
    }

    public void createBlas() {
        this.blas = new CudaBlas();
    }

    public void createLevel1() {
        this.level1 = new JcublasLevel1();
    }

    public void createLevel2() {
        this.level2 = new JcublasLevel2();
    }

    public void createLevel3() {
        this.level3 = new JcublasLevel3();
    }

    public void createLapack() {
        this.lapack = new JcublasLapack();
    }

    public INDArray create(int[] shape, DataBuffer buffer) {
        return new JCublasNDArray(shape, buffer);
    }

    public IComplexFloat createFloat(float real, float imag) {
        return new ComplexFloat(real, imag);
    }

    public IComplexDouble createDouble(double real, double imag) {
        return new ComplexDouble(real, imag);
    }

    public INDArray create(double[][] data) {
        return new JCublasNDArray(data);
    }

    public INDArray create(double[][] data, char ordering) {
        return new JCublasNDArray(data, ordering);
    }

    public IComplexNDArray createComplex(INDArray arr) {
        return new JCublasComplexNDArray(arr);
    }

    public IComplexNDArray createComplex(IComplexNumber[] data, int[] shape) {
        return new JCublasComplexNDArray(data, shape, Nd4j.getComplexStrides((int[])shape, (char)Nd4j.order().charValue()));
    }

    public IComplexNDArray createComplex(List<IComplexNDArray> arrs, int[] shape) {
        return new JCublasComplexNDArray(arrs, shape);
    }

    public INDArray create(DataBuffer data) {
        return new JCublasNDArray(data);
    }

    public IComplexNDArray createComplex(DataBuffer data) {
        return new JCublasComplexNDArray(data);
    }

    public IComplexNDArray createComplex(DataBuffer data, int rows, int columns, int[] stride, int offset) {
        return new JCublasComplexNDArray(data, new int[]{rows, columns}, stride, offset);
    }

    public INDArray create(DataBuffer data, int rows, int columns, int[] stride, int offset) {
        return new JCublasNDArray(data, new int[]{rows, columns}, stride, offset);
    }

    public IComplexNDArray createComplex(DataBuffer data, int[] shape, int[] stride, int offset) {
        return new JCublasComplexNDArray(data, shape, stride, offset);
    }

    public IComplexNDArray createComplex(float[] data, int[] shape, int[] stride, int offset) {
        return new JCublasComplexNDArray(data, shape, stride, offset);
    }

    public INDArray create(int[] shape, char ordering) {
        return new JCublasNDArray(shape, ordering);
    }

    public INDArray createUninitialized(int[] shape, char ordering) {
        return new JCublasNDArray(shape, Nd4j.getStrides((int[])shape, (char)ordering), 0, ordering, false);
    }

    public INDArray create(DataBuffer data, int[] newShape, int[] newStride, int offset, char ordering) {
        return new JCublasNDArray(data, newShape, newStride, offset, ordering);
    }

    public IComplexNDArray createComplex(DataBuffer data, int[] newDims, int[] newStrides, int offset, char ordering) {
        return new JCublasComplexNDArray(data, newDims, newStrides, offset, ordering);
    }

    public IComplexNDArray createComplex(float[] data, Character order) {
        return new JCublasComplexNDArray(data, order);
    }

    public INDArray create(float[] data, int[] shape, int offset, Character order) {
        return new JCublasNDArray(data, shape, offset, order.charValue());
    }

    public INDArray create(float[] data, int rows, int columns, int[] stride, int offset, char ordering) {
        return new JCublasNDArray(data, new int[]{rows, columns}, stride, offset, ordering);
    }

    public INDArray create(double[] data, int[] shape, char ordering) {
        return new JCublasNDArray(data, shape, ordering);
    }

    public INDArray create(List<INDArray> list, int[] shape, char ordering) {
        return new JCublasNDArray(list, shape, ordering);
    }

    public INDArray create(double[] data, int[] shape, int offset) {
        return new JCublasNDArray(data, shape, (char)offset);
    }

    public INDArray create(double[] data, int[] shape, int[] stride, int offset, char ordering) {
        return new JCublasNDArray(data, shape, stride, offset, ordering);
    }

    public IComplexNDArray createComplex(IComplexNumber[] data, int[] shape, int[] stride, int offset) {
        return new JCublasComplexNDArray(data, shape, stride, offset);
    }

    public IComplexNDArray createComplex(IComplexNumber[] data, int[] shape, int[] stride, int offset, char ordering) {
        return new JCublasComplexNDArray(data, shape, stride, offset, ordering);
    }

    public IComplexNDArray createComplex(IComplexNumber[] data, int[] shape, int[] stride, char ordering) {
        return new JCublasComplexNDArray(data, shape, stride, (int)ordering);
    }

    public IComplexNDArray createComplex(IComplexNumber[] data, int[] shape, int offset, char ordering) {
        return new JCublasComplexNDArray(data, shape, offset, ordering);
    }

    public IComplexNDArray createComplex(IComplexNumber[] data, int[] shape, char ordering) {
        return new JCublasComplexNDArray(data, shape, ordering);
    }

    public INDArray create(float[] data, int[] shape, int[] stride, int offset) {
        return new JCublasNDArray(data, shape, stride, offset);
    }

    public IComplexNDArray createComplex(double[] data, int[] shape, int[] stride, int offset) {
        return new JCublasComplexNDArray(ArrayUtil.floatCopyOf((double[])data), shape, stride, offset);
    }

    public INDArray create(double[] data, int[] shape, int[] stride, int offset) {
        return new JCublasNDArray(data, shape, stride, offset);
    }

    public INDArray create(DataBuffer data, int[] shape) {
        return new JCublasNDArray(data, shape);
    }

    public IComplexNDArray createComplex(DataBuffer data, int[] shape) {
        return new JCublasComplexNDArray(data, shape);
    }

    public IComplexNDArray createComplex(DataBuffer data, int[] shape, int[] stride) {
        return new JCublasComplexNDArray(data, shape, stride);
    }

    public INDArray create(DataBuffer data, int[] shape, int[] stride, int offset) {
        return new JCublasNDArray(data, shape, stride, offset);
    }

    public INDArray create(List<INDArray> list, int[] shape) {
        if (this.order == 'f') {
            return new JCublasNDArray(list, shape, ArrayUtil.calcStridesFortran((int[])shape));
        }
        return new JCublasNDArray(list, shape);
    }

    public IComplexNDArray createComplex(double[] data, int[] shape, int[] stride, int offset, char ordering) {
        return new JCublasComplexNDArray(ArrayUtil.floatCopyOf((double[])data), shape, stride, offset, ordering);
    }

    public IComplexNDArray createComplex(double[] data, int[] shape, int offset, char ordering) {
        return new JCublasComplexNDArray(ArrayUtil.floatCopyOf((double[])data), shape, offset, ordering);
    }

    public IComplexNDArray createComplex(DataBuffer buffer, int[] shape, int offset, char ordering) {
        return new JCublasComplexNDArray(buffer, shape, offset, ordering);
    }

    public IComplexNDArray createComplex(double[] data, int[] shape, int offset) {
        return new JCublasComplexNDArray(ArrayUtil.floatCopyOf((double[])data), shape, offset);
    }

    public IComplexNDArray createComplex(DataBuffer buffer, int[] shape, int offset) {
        return new JCublasComplexNDArray(buffer, shape, offset);
    }

    public INDArray create(float[] data, int[] shape, int offset) {
        return new JCublasNDArray(data, shape, offset);
    }

    public IComplexNDArray createComplex(float[] data, int[] shape, int offset, char ordering) {
        return new JCublasComplexNDArray(data, shape, Nd4j.getComplexStrides((int[])shape, (char)ordering), offset, ordering);
    }

    public IComplexNDArray createComplex(float[] data, int[] shape, int offset) {
        return new JCublasComplexNDArray(data, shape, offset);
    }

    public IComplexNDArray createComplex(float[] data, int[] shape, int[] stride, int offset, char ordering) {
        return new JCublasComplexNDArray(data, shape, stride, offset, ordering);
    }

    public INDArray create(float[][] floats) {
        return new JCublasNDArray(floats);
    }

    public INDArray create(float[][] data, char ordering) {
        return new JCublasNDArray(data, ordering);
    }

    public IComplexNDArray createComplex(float[] dim) {
        if (dim.length % 2 != 0) {
            throw new IllegalArgumentException("Complex nd array buffers must have an even number of elements");
        }
        IComplexNDArray ret = Nd4j.createComplex((int)(dim.length / 2));
        int count = 0;
        for (int i = 0; i < dim.length - 1; i += 2) {
            ret.putScalar(count++, (IComplexNumber)Nd4j.createDouble((double)dim[i], (double)dim[i + 1]));
        }
        return ret;
    }

    public INDArray create(float[] data, int[] shape, int[] stride, int offset, char ordering) {
        return new JCublasNDArray(data, shape, stride, offset, ordering);
    }

    public INDArray create(DataBuffer buffer, int[] shape, int offset) {
        return new JCublasNDArray(buffer, shape, offset);
    }

    public INDArray toFlattened(Collection<INDArray> matrices) {
        return this.toFlattened(this.order(), matrices);
    }

    public INDArray toFlattened(char order, Collection<INDArray> matrices) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        int length = 0;
        for (INDArray m : matrices) {
            length += m.length();
        }
        INDArray ret = Nd4j.create((int[])new int[]{1, length}, (char)order);
        int linearIndex = 0;
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        for (INDArray m : matrices) {
            CudaContext context = allocator.getFlowController().prepareAction(ret, m);
            if (m.ordering() == order && ret.elementWiseStride() == m.elementWiseStride() && ret.elementWiseStride() == 1) {
                allocator.memcpyAsync(ret.data(), new CudaPointer(allocator.getHostPointer(m).address()), AllocationUtils.getRequiredMemory(AllocationUtils.buildAllocationShape(m)), linearIndex * (m.data().dataType() == DataBuffer.Type.DOUBLE ? 8 : (m.data().dataType() == DataBuffer.Type.FLOAT ? 4 : 2)));
                linearIndex += m.length();
            } else {
                Pointer hostYShapeInfo = AddressRetriever.retrieveHostPointer(m.shapeInfoDataBuffer());
                PointerPointer extras = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()), context.getOldStream(), allocator.getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer())});
                if (m.data().dataType() == DataBuffer.Type.DOUBLE) {
                    this.nativeOps.flattenDouble(extras, linearIndex, order, (DoublePointer)allocator.getPointer(ret, context), (IntPointer)allocator.getPointer(ret.shapeInfoDataBuffer(), context), (DoublePointer)allocator.getPointer(m, context), (IntPointer)allocator.getPointer(m.shapeInfoDataBuffer(), context));
                } else if (m.data().dataType() == DataBuffer.Type.FLOAT) {
                    this.nativeOps.flattenFloat(extras, linearIndex, order, (FloatPointer)allocator.getPointer(ret, context), (IntPointer)allocator.getPointer(ret.shapeInfoDataBuffer(), context), (FloatPointer)allocator.getPointer(m, context), (IntPointer)allocator.getPointer(m.shapeInfoDataBuffer(), context));
                } else {
                    this.nativeOps.flattenHalf(extras, linearIndex, order, (ShortPointer)allocator.getPointer(ret, context), (IntPointer)allocator.getPointer(ret.shapeInfoDataBuffer(), context), (ShortPointer)allocator.getPointer(m, context), (IntPointer)allocator.getPointer(m.shapeInfoDataBuffer(), context));
                }
                linearIndex += m.length();
            }
            if (ret == null) continue;
            allocator.registerAction(context, ret, m);
        }
        return ret;
    }

    public INDArray concat(int dimension, INDArray ... toConcat) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        if (toConcat.length == 1) {
            return toConcat[0];
        }
        int sumAlongDim = 0;
        for (int i = 0; i < toConcat.length; ++i) {
            if (toConcat[i].isCompressed()) {
                Nd4j.getCompressor().decompressi(toConcat[i]);
            }
            sumAlongDim += toConcat[i].size(dimension);
        }
        int[] outputShape = ArrayUtil.copy((int[])toConcat[0].shape());
        outputShape[dimension] = sumAlongDim;
        INDArray ret = Nd4j.createUninitialized((int[])outputShape, (char)Nd4j.order().charValue());
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(ret, toConcat);
        long[] shapeInfoPointers = new long[toConcat.length];
        long[] dataPointers = new long[toConcat.length];
        long[] tadPointers = new long[toConcat.length];
        long[] offsetsPointers = new long[toConcat.length];
        long[] hostShapeInfoPointers = new long[toConcat.length];
        TADManager tadManager = Nd4j.getExecutioner().getTADManager();
        for (int i = 0; i < toConcat.length; ++i) {
            shapeInfoPointers[i] = AddressRetriever.retrieveDeviceAddress(toConcat[i].shapeInfoDataBuffer(), context);
            dataPointers[i] = AtomicAllocator.getInstance().getPointer(toConcat[i], context).address();
            hostShapeInfoPointers[i] = AtomicAllocator.getInstance().getHostPointer(toConcat[i].shapeInfoDataBuffer()).address();
            sumAlongDim += toConcat[i].size(dimension);
            for (int j = 0; j < toConcat[i].rank(); ++j) {
                if (j == dimension || toConcat[i].size(j) == outputShape[j]) continue;
                throw new IllegalArgumentException("Illegal concatneation at array " + i + " and shape element " + j);
            }
            Pair tadBuffers = tadManager.getTADOnlyShapeInfo(toConcat[i], new int[]{dimension});
            long devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context).address();
            DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
            long devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context).address();
            tadPointers[i] = devTadShapeInfo;
            offsetsPointers[i] = devTadOffsets;
        }
        Pair zBuffers = tadManager.getTADOnlyShapeInfo(ret, new int[]{dimension});
        Pointer dZ = AtomicAllocator.getInstance().getPointer(ret, context);
        Pointer dZShapeInfo = AddressRetriever.retrieveDevicePointer(ret.shapeInfoDataBuffer(), context);
        CudaDoubleDataBuffer tempData = new CudaDoubleDataBuffer(toConcat.length);
        CudaDoubleDataBuffer tempShapes = new CudaDoubleDataBuffer(toConcat.length);
        CudaDoubleDataBuffer tempTAD = new CudaDoubleDataBuffer(toConcat.length);
        CudaDoubleDataBuffer tempOffsets = new CudaDoubleDataBuffer(toConcat.length);
        AtomicAllocator.getInstance().memcpyBlocking(tempData, (Pointer)new LongPointer(dataPointers), dataPointers.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(tempShapes, (Pointer)new LongPointer(shapeInfoPointers), shapeInfoPointers.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(tempTAD, (Pointer)new LongPointer(tadPointers), tadPointers.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(tempOffsets, (Pointer)new LongPointer(offsetsPointers), offsetsPointers.length * 8, 0L);
        Pointer dataPointer = AtomicAllocator.getInstance().getPointer(tempData, context);
        Pointer shapesPointer = AtomicAllocator.getInstance().getPointer(tempShapes, context);
        Pointer tadPointer = AtomicAllocator.getInstance().getPointer(tempTAD, context);
        Pointer offsetPointer = AtomicAllocator.getInstance().getPointer(tempOffsets, context);
        PointerPointer extras = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()), context.getOldStream(), allocator.getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), AddressRetriever.retrieveHostPointer(toConcat[0].shapeInfoDataBuffer()), AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()), new LongPointer(hostShapeInfoPointers), AtomicAllocator.getInstance().getPointer((DataBuffer)zBuffers.getFirst(), context), AtomicAllocator.getInstance().getPointer((DataBuffer)zBuffers.getSecond(), context)});
        if (ret.data().dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.concatDouble(extras, dimension, toConcat.length, new PointerPointer(new Pointer[]{dataPointer}), new PointerPointer(new Pointer[]{shapesPointer}), (DoublePointer)dZ, (IntPointer)dZShapeInfo, new PointerPointer(new Pointer[]{tadPointer}), new PointerPointer(new Pointer[]{offsetPointer}));
        } else if (ret.data().dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.concatFloat(extras, dimension, toConcat.length, new PointerPointer(new Pointer[]{dataPointer}), new PointerPointer(new Pointer[]{shapesPointer}), (FloatPointer)dZ, (IntPointer)dZShapeInfo, new PointerPointer(new Pointer[]{tadPointer}), new PointerPointer(new Pointer[]{offsetPointer}));
        } else {
            this.nativeOps.concatHalf(extras, dimension, toConcat.length, new PointerPointer(new Pointer[]{dataPointer}), new PointerPointer(new Pointer[]{shapesPointer}), (ShortPointer)dZ, (IntPointer)dZShapeInfo, new PointerPointer(new Pointer[]{tadPointer}), new PointerPointer(new Pointer[]{offsetPointer}));
        }
        allocator.registerAction(context, ret, toConcat);
        return ret;
    }

    public INDArray pullRows(INDArray source, int sourceDimension, int[] indexes) {
        return this.pullRows(source, sourceDimension, indexes, Nd4j.order().charValue());
    }

    public INDArray pullRows(INDArray source, int sourceDimension, int[] indexes, char order) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        if (indexes == null || indexes.length < 1) {
            throw new IllegalStateException("Indexes can't be null or zero-length");
        }
        int[] shape = null;
        if (sourceDimension == 1) {
            shape = new int[]{indexes.length, source.shape()[sourceDimension]};
        } else if (sourceDimension == 0) {
            shape = new int[]{source.shape()[sourceDimension], indexes.length};
        } else {
            throw new UnsupportedOperationException("2D input is expected");
        }
        INDArray ret = Nd4j.createUninitialized((int[])shape, (char)order);
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(ret, source);
        Pointer x = AtomicAllocator.getInstance().getPointer(source, context);
        Pointer xShape = AtomicAllocator.getInstance().getPointer(source.shapeInfoDataBuffer(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(ret, context);
        Pointer zShape = AtomicAllocator.getInstance().getPointer(ret.shapeInfoDataBuffer(), context);
        PointerPointer extras = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()), context.getOldStream(), allocator.getDeviceIdPointer()});
        CudaIntDataBuffer tempIndexes = new CudaIntDataBuffer(indexes.length);
        AtomicAllocator.getInstance().memcpyBlocking(tempIndexes, (Pointer)new IntPointer(indexes), indexes.length * 4, 0L);
        Pointer pIndex = AtomicAllocator.getInstance().getPointer(tempIndexes, context);
        TADManager tadManager = Nd4j.getExecutioner().getTADManager();
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(source, new int[]{sourceDimension});
        Pair zTadBuffers = tadManager.getTADOnlyShapeInfo(ret, new int[]{sourceDimension});
        Pointer tadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        Pointer zTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)zTadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer tadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer zTadOffsets = AtomicAllocator.getInstance().getPointer((DataBuffer)zTadBuffers.getSecond(), context);
        if (ret.data().dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.pullRowsDouble(extras, (DoublePointer)x, (IntPointer)xShape, (DoublePointer)z, (IntPointer)zShape, indexes.length, (IntPointer)pIndex, (IntPointer)tadShapeInfo, (IntPointer)tadOffsets, (IntPointer)zTadShapeInfo, (IntPointer)zTadOffsets);
        } else if (ret.data().dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.pullRowsFloat(extras, (FloatPointer)x, (IntPointer)xShape, (FloatPointer)z, (IntPointer)zShape, indexes.length, (IntPointer)pIndex, (IntPointer)tadShapeInfo, (IntPointer)tadOffsets, (IntPointer)zTadShapeInfo, (IntPointer)zTadOffsets);
        } else {
            this.nativeOps.pullRowsHalf(extras, (ShortPointer)x, (IntPointer)xShape, (ShortPointer)z, (IntPointer)zShape, indexes.length, (IntPointer)pIndex, (IntPointer)tadShapeInfo, (IntPointer)tadOffsets, (IntPointer)zTadShapeInfo, (IntPointer)zTadOffsets);
        }
        allocator.registerAction(context, ret, source);
        return ret;
    }

    public INDArray average(INDArray target, INDArray[] arrays) {
        if (arrays == null || arrays.length == 0) {
            throw new RuntimeException("Input arrays are missing");
        }
        if (arrays.length == 1) {
            return target.assign(arrays[0]);
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        long len = target.lengthLong();
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = allocator.getFlowController().prepareAction(target, arrays);
        PointerPointer extras = new PointerPointer(new Pointer[]{null, context.getOldStream(), allocator.getDeviceIdPointer()});
        Pointer z = AtomicAllocator.getInstance().getPointer(target, context);
        long[] xPointers = new long[arrays.length];
        for (int i = 0; i < arrays.length; ++i) {
            if (arrays[i].lengthLong() != len) {
                throw new RuntimeException("All arrays should have equal length for averaging");
            }
            AllocationPoint point = allocator.getAllocationPoint(arrays[i]);
            xPointers[i] = point.getPointers().getDevicePointer().address();
            point.tickDeviceWrite();
        }
        CudaDoubleDataBuffer tempX = new CudaDoubleDataBuffer(arrays.length);
        allocator.memcpyBlocking(tempX, (Pointer)new LongPointer(xPointers), xPointers.length * 8, 0L);
        PointerPointer x = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context));
        if (target.data().dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.averageDouble(extras, x, (DoublePointer)z, arrays.length, len, true);
        } else if (target.data().dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.averageFloat(extras, x, (FloatPointer)z, arrays.length, len, true);
        } else {
            this.nativeOps.averageHalf(extras, x, (ShortPointer)z, arrays.length, len, true);
        }
        allocator.getFlowController().registerAction(context, target, arrays);
        tempX.address();
        return target;
    }

    public INDArray average(Collection<INDArray> arrays) {
        return this.average(arrays.toArray(new INDArray[0]));
    }

    public INDArray average(INDArray[] arrays) {
        if (arrays == null || arrays.length == 0) {
            throw new RuntimeException("Input arrays are missing");
        }
        INDArray ret = Nd4j.createUninitialized((int[])arrays[0].shape(), (char)arrays[0].ordering());
        return this.average(ret, arrays);
    }

    public INDArray average(INDArray target, Collection<INDArray> arrays) {
        return this.average(target, arrays.toArray(new INDArray[0]));
    }

    public void shuffle(INDArray array, Random rnd, int ... dimension) {
        this.shuffle(Collections.singletonList(array), rnd, dimension);
    }

    public void shuffle(List<INDArray> arrays, Random rnd, List<int[]> dimensions) {
        if (dimensions == null || dimensions.size() == 0) {
            throw new RuntimeException("Dimension can't be null or 0-length");
        }
        if (arrays == null || arrays.size() == 0) {
            throw new RuntimeException("No input arrays provided");
        }
        if (dimensions.size() > 1 && arrays.size() != dimensions.size()) {
            throw new IllegalStateException("Number of dimensions do not match number of arrays to shuffle");
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        CudaContext context = null;
        for (int x = 0; x < arrays.size(); ++x) {
            context = allocator.getFlowController().prepareAction(arrays.get(x), new INDArray[0]);
        }
        int tadLength = 1;
        for (int i = 0; i < dimensions.get(0).length; ++i) {
            tadLength *= arrays.get(0).shape()[dimensions.get(0)[i]];
        }
        int numTads = arrays.get(0).length() / tadLength;
        int[] map = ArrayUtil.buildInterleavedVector((Random)rnd, (int)numTads);
        CudaIntDataBuffer shuffle = new CudaIntDataBuffer(map);
        Pointer shuffleMap = allocator.getPointer(shuffle, context);
        PointerPointer extras = new PointerPointer(new Pointer[]{null, context.getOldStream(), allocator.getDeviceIdPointer()});
        long[] xPointers = new long[arrays.size()];
        long[] xShapes = new long[arrays.size()];
        long[] tadShapes = new long[arrays.size()];
        long[] tadOffsets = new long[arrays.size()];
        for (int i = 0; i < arrays.size(); ++i) {
            INDArray array = arrays.get(i);
            Pointer x = AtomicAllocator.getInstance().getPointer(array, context);
            Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer(), context);
            TADManager tadManager = Nd4j.getExecutioner().getTADManager();
            int[] dimension = dimensions.size() > 1 ? dimensions.get(i) : dimensions.get(0);
            Pair tadBuffers = tadManager.getTADOnlyShapeInfo(array, dimension);
            Pointer tadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
            DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
            Pointer tadOffset = AtomicAllocator.getInstance().getPointer(offsets, context);
            xPointers[i] = x.address();
            xShapes[i] = xShapeInfo.address();
            tadShapes[i] = tadShapeInfo.address();
            tadOffsets[i] = tadOffset.address();
        }
        CudaDoubleDataBuffer tempX = new CudaDoubleDataBuffer(arrays.size());
        CudaDoubleDataBuffer tempShapes = new CudaDoubleDataBuffer(arrays.size());
        CudaDoubleDataBuffer tempTAD = new CudaDoubleDataBuffer(arrays.size());
        CudaDoubleDataBuffer tempOffsets = new CudaDoubleDataBuffer(arrays.size());
        AtomicAllocator.getInstance().memcpyBlocking(tempX, (Pointer)new LongPointer(xPointers), xPointers.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(tempShapes, (Pointer)new LongPointer(xShapes), xPointers.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(tempTAD, (Pointer)new LongPointer(tadShapes), xPointers.length * 8, 0L);
        AtomicAllocator.getInstance().memcpyBlocking(tempOffsets, (Pointer)new LongPointer(tadOffsets), xPointers.length * 8, 0L);
        if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            this.nativeOps.shuffleDouble(extras, new PointerPointer(allocator.getPointer(tempX, context)), new PointerPointer(allocator.getPointer(tempShapes, context)), new PointerPointer(allocator.getPointer(tempX, context)), new PointerPointer(allocator.getPointer(tempShapes, context)), arrays.size(), (IntPointer)shuffleMap, new PointerPointer(allocator.getPointer(tempTAD, context)), new PointerPointer(allocator.getPointer(tempOffsets, context)));
        } else if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            this.nativeOps.shuffleFloat(extras, new PointerPointer(allocator.getPointer(tempX, context)), new PointerPointer(allocator.getPointer(tempShapes, context)), new PointerPointer(allocator.getPointer(tempX, context)), new PointerPointer(allocator.getPointer(tempShapes, context)), arrays.size(), (IntPointer)shuffleMap, new PointerPointer(allocator.getPointer(tempTAD, context)), new PointerPointer(allocator.getPointer(tempOffsets, context)));
        } else {
            this.nativeOps.shuffleHalf(extras, new PointerPointer(allocator.getPointer(tempX, context)), new PointerPointer(allocator.getPointer(tempShapes, context)), new PointerPointer(allocator.getPointer(tempX, context)), new PointerPointer(allocator.getPointer(tempShapes, context)), arrays.size(), (IntPointer)shuffleMap, new PointerPointer(allocator.getPointer(tempTAD, context)), new PointerPointer(allocator.getPointer(tempOffsets, context)));
        }
        for (int f = 0; f < arrays.size(); ++f) {
            allocator.getFlowController().registerAction(context, arrays.get(f), new INDArray[0]);
        }
        shuffle.address();
        tempX.dataType();
        tempShapes.dataType();
        tempOffsets.dataType();
        tempTAD.dataType();
    }

    public void shuffle(Collection<INDArray> sourceArrays, Random rnd, int ... dimension) {
        this.shuffle(new ArrayList<INDArray>(sourceArrays), rnd, Collections.singletonList(dimension));
    }

    public INDArray convertDataEx(DataBuffer.TypeEx typeSrc, INDArray source, DataBuffer.TypeEx typeDst) {
        if (source.isView()) {
            throw new UnsupportedOperationException("Impossible to compress View. Consider using dup() before. ");
        }
        DataBuffer buffer = this.convertDataEx(typeSrc, source.data(), typeDst);
        source.setData(buffer);
        if (buffer instanceof CompressedDataBuffer) {
            source.markAsCompressed(true);
        } else {
            source.markAsCompressed(false);
        }
        return source;
    }

    public void convertDataEx(DataBuffer.TypeEx typeSrc, Pointer source, DataBuffer.TypeEx typeDst, Pointer target, long length) {
        this.nativeOps.convertTypes(null, typeSrc.ordinal(), source, length, typeDst.ordinal(), target);
    }

    public void convertDataEx(DataBuffer.TypeEx typeSrc, DataBuffer source, DataBuffer.TypeEx typeDst, DataBuffer target) {
        this.convertDataEx(typeSrc, source.addressPointer(), typeDst, target.addressPointer(), target.length());
    }

    public DataBuffer convertDataEx(DataBuffer.TypeEx typeSrc, DataBuffer source, DataBuffer.TypeEx typeDst) {
        int elementSize = 0;
        if (typeDst.ordinal() <= 2) {
            elementSize = 1;
        } else if (typeDst.ordinal() <= 5) {
            elementSize = 2;
        } else if (typeDst.ordinal() == 6) {
            elementSize = 4;
        } else if (typeDst.ordinal() == 7) {
            elementSize = 8;
        } else {
            throw new UnsupportedOperationException("Unknown target TypeEx: " + typeDst.name());
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueueBlocking();
        }
        DataBuffer buffer = null;
        if (!(source instanceof CompressedDataBuffer)) {
            AtomicAllocator.getInstance().synchronizeHostData(source);
        }
        if (typeDst.ordinal() < 6) {
            BytePointer pointer = new BytePointer(source.length() * (long)elementSize);
            CompressionDescriptor descriptor = new CompressionDescriptor(source, typeDst.name());
            descriptor.setCompressionType(CompressionType.LOSSY);
            descriptor.setCompressedLength(source.length() * (long)elementSize);
            buffer = new CompressedDataBuffer((Pointer)pointer, descriptor);
        } else {
            buffer = Nd4j.createBuffer((long)source.length(), (boolean)false);
            AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(buffer);
            point.tickHostWrite();
        }
        this.convertDataEx(typeSrc, source, typeDst, buffer);
        return buffer;
    }
}

