/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.jcublas.ops.executioner;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Properties;
import org.apache.commons.math3.util.Pair;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.ShortPointer;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.tad.DeviceTADManager;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.CopyOp;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.Nd4jBlas;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaExecutioner
extends DefaultOpExecutioner {
    protected static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private static Logger log = LoggerFactory.getLogger(CudaExecutioner.class);
    protected static TADManager tadManager = new DeviceTADManager();
    protected ThreadLocal<PointerPointer> extraz = new ThreadLocal();
    protected volatile transient Properties properties;

    public NativeOps getNativeOps() {
        return nativeOps;
    }

    public INDArray exec(BroadcastOp op, int ... dimension) {
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        Arrays.sort(dimension);
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;
        Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), dimension);
        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getSecond(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, devTadShapeInfoZ, devTadOffsetsZ});
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execBroadcastDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)y, (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (DoublePointer)z, (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execBroadcastFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)y, (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (FloatPointer)z, (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
        } else {
            nativeOps.execBroadcastHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)y, (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (ShortPointer)z, (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return op.z();
    }

    protected INDArray naiveExec(Accumulation op, int ... dimension) {
        long st = this.profilingHookIn((Op)op);
        INDArray ret = op.z();
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        if (op.y() != null) {
            Pair yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension);
            Pointer yDevTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)yTadBuffers.getFirst(), context);
            DataBuffer yOffsets = (DataBuffer)yTadBuffers.getSecond();
            Pointer yDevTadOffsets = yOffsets == null ? null : AtomicAllocator.getInstance().getPointer(yOffsets, context);
            xShapeInfoHostPointer.put(12L, yDevTadShapeInfo);
            xShapeInfoHostPointer.put(13L, yDevTadOffsets);
        }
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (op instanceof Variance) {
                if (ret.isScalar()) {
                    AtomicAllocator.getInstance().tickHostWrite(ret);
                    ret.putScalar(0, nativeOps.execSummaryStatsScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, ((Variance)op).isBiasCorrected()));
                    op.setFinalResult((Number)ret.getDouble(0));
                } else {
                    nativeOps.execSummaryStatsDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                }
            } else if (op.y() != null) {
                if (ret.isScalar()) {
                    AtomicAllocator.getInstance().tickHostWrite(ret);
                    ret.putScalar(0, nativeOps.execReduce3ScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context)));
                    op.setFinalResult((Number)ret.getDouble(0));
                } else {
                    nativeOps.execReduce3Double(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                }
            } else if (ret.isScalar()) {
                AtomicAllocator.getInstance().tickHostWrite(ret);
                ret.putScalar(0, nativeOps.execReduceScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs));
                op.setFinalResult((Number)ret.getDouble(0));
            } else {
                nativeOps.execReduceDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            }
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            if (op instanceof Variance) {
                if (ret.isScalar()) {
                    AtomicAllocator.getInstance().tickHostWrite(ret);
                    ret.putScalar(0, nativeOps.execSummaryStatsScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, ((Variance)op).isBiasCorrected()));
                    op.setFinalResult((Number)Float.valueOf(ret.getFloat(0)));
                } else {
                    nativeOps.execSummaryStatsFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                }
            } else if (op.y() != null) {
                if (ret.isScalar()) {
                    AtomicAllocator.getInstance().tickHostWrite(ret);
                    ret.putScalar(0, nativeOps.execReduce3ScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context)));
                    op.setFinalResult((Number)Float.valueOf(ret.getFloat(0)));
                } else {
                    nativeOps.execReduce3Float(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                }
            } else if (ret.isScalar()) {
                AtomicAllocator.getInstance().tickHostWrite(ret);
                float resx = nativeOps.execReduceScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs);
                ret.putScalar(0, resx);
                op.setFinalResult((Number)Float.valueOf(ret.getFloat(0)));
            } else {
                nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            }
        } else if (op instanceof Variance) {
            if (ret.isScalar()) {
                AtomicAllocator.getInstance().tickHostWrite(ret);
                ret.putScalar(0, nativeOps.execSummaryStatsScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, ((Variance)op).isBiasCorrected()));
                op.setFinalResult((Number)Float.valueOf(ret.getFloat(0)));
            } else {
                nativeOps.execSummaryStatsHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            }
        } else if (op.y() != null) {
            if (ret.isScalar()) {
                AtomicAllocator.getInstance().tickHostWrite(ret);
                ret.putScalar(0, nativeOps.execReduce3ScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context)));
                op.setFinalResult((Number)Float.valueOf(ret.getFloat(0)));
            } else {
                nativeOps.execReduce3Half(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            }
        } else if (ret.isScalar()) {
            AtomicAllocator.getInstance().tickHostWrite(ret);
            ret.putScalar(0, nativeOps.execReduceScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs));
            op.setFinalResult((Number)Float.valueOf(ret.getFloat(0)));
        } else {
            nativeOps.execReduceHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
            AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        }
        this.profilingHookOut((Op)op, st);
        return op.z();
    }

    public INDArray exec(Accumulation op, int ... dimension) {
        int[] retShape;
        int[] nArray;
        int i;
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        Arrays.sort(dimension);
        for (i = 0; i < dimension.length; ++i) {
            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
        }
        for (i = 0; i < dimension.length; ++i) {
            if (dimension[i] >= 0) continue;
            int n = i;
            dimension[n] = dimension[n] + op.x().rank();
        }
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (Shape.wholeArrayDimension((int[])dimension)) {
            int[] nArray2 = new int[2];
            nArray2[0] = 1;
            nArray = nArray2;
            nArray2[1] = 1;
        } else {
            nArray = retShape = ArrayUtil.removeIndex((int[])op.x().shape(), (int[])dimension);
        }
        if (retShape.length == 1) {
            retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
        } else if (retShape.length == 0) {
            retShape = new int[]{1, 1};
        }
        if (op.x().isVector() && op.x().length() == ArrayUtil.prod((int[])retShape)) {
            return op.noOp();
        }
        INDArray ret = null;
        ret = op.zeroDouble() > (double)-0.01f && op.zeroDouble() < (double)0.01f ? Nd4j.zeros((int[])retShape) : Nd4j.valueArrayOf((int[])retShape, (double)op.zeroDouble());
        op.setZ(ret);
        this.naiveExec(op, dimension);
        this.profilingHookOut((Op)op, st);
        return op.z();
    }

    public INDArray exec(IndexAccumulation op, int ... dimension) {
        int[] retShape;
        int[] nArray;
        int i;
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        Arrays.sort(dimension);
        for (i = 0; i < dimension.length; ++i) {
            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
        }
        for (i = 0; i < dimension.length; ++i) {
            if (dimension[i] >= 0) continue;
            int n = i;
            dimension[n] = dimension[n] + op.x().rank();
        }
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (Shape.wholeArrayDimension((int[])dimension)) {
            int[] nArray2 = new int[2];
            nArray2[0] = 1;
            nArray = nArray2;
            nArray2[1] = 1;
        } else {
            nArray = retShape = ArrayUtil.removeIndex((int[])op.x().shape(), (int[])dimension);
        }
        if (op.x().isVector() && op.x().length() == ArrayUtil.prod((int[])retShape)) {
            return op.x();
        }
        if (retShape.length == 1) {
            retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
        } else if (retShape.length == 0) {
            retShape = new int[]{1, 1};
        }
        INDArray ret = null;
        ret = op.zeroDouble() > (double)-0.01f && op.zeroDouble() < (double)0.01f ? Nd4j.zeros((int[])retShape) : Nd4j.valueArrayOf((int[])retShape, (double)op.zeroDouble());
        op.setZ(ret);
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execIndexReduceDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, dimension.length);
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execIndexReduceFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, dimension.length);
        } else {
            nativeOps.execIndexReduceHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, dimension.length);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return op.z();
    }

    public Op exec(Op op, int ... dimension) {
        this.checkForCompression(op);
        Arrays.sort(dimension);
        return super.exec(op, dimension);
    }

    public Op exec(Op op) {
        this.checkForCompression(op);
        if (op.x() instanceof IComplexNDArray || this.executionMode() == OpExecutioner.ExecutionMode.JAVA || op instanceof CopyOp) {
            if (op.x() != null) {
                AtomicAllocator.getInstance().synchronizeHostData(op.x());
            }
            if (op.y() != null) {
                AtomicAllocator.getInstance().synchronizeHostData(op.y());
            }
            super.exec(op);
            if (op.z() != null) {
                AtomicAllocator.getInstance().tickHostWrite(op.z());
            }
            return null;
        }
        if (op instanceof TransformOp) {
            TransformOp t = (TransformOp)op;
            this.invoke(t);
        } else if (op instanceof Accumulation) {
            Accumulation acc = (Accumulation)op;
            this.invoke(acc, null);
        } else if (op instanceof ScalarOp) {
            ScalarOp sc = (ScalarOp)op;
            this.invoke(sc);
        } else if (op instanceof BroadcastOp) {
            BroadcastOp broadcastOp = (BroadcastOp)op;
            this.invoke(broadcastOp);
        } else if (op instanceof IndexAccumulation) {
            IndexAccumulation indexAccumulation = (IndexAccumulation)op;
            this.invoke(indexAccumulation, null);
        }
        return op;
    }

    public INDArray execAndReturn(TransformOp op) {
        this.checkForCompression((Op)op);
        this.invoke(op);
        return op.z();
    }

    protected CudaContext invoke(BroadcastOp op) {
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), op.getDimension());
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;
        Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), op.getDimension());
        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getSecond(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, devTadShapeInfoZ, devTadOffsetsZ});
        Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
        Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(op.getDimension()), context);
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execBroadcastDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)y, (IntPointer)yShapeInfo, (DoublePointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, op.getDimension().length);
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execBroadcastFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)y, (IntPointer)yShapeInfo, (FloatPointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, op.getDimension().length);
        } else {
            nativeOps.execBroadcastHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)y, (IntPointer)yShapeInfo, (ShortPointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, op.getDimension().length);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return null;
    }

    protected CudaContext invoke(IndexAccumulation op, int[] dimension) {
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        int[] fdimension = dimension;
        if (fdimension == null) {
            fdimension = new int[]{0};
        }
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), fdimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        if (op.z().isScalar() || dimension == null || dimension[0] == Integer.MAX_VALUE) {
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                double result = nativeOps.execIndexReduceScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs);
                op.setFinalResult((int)result);
            } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
                float result = nativeOps.execIndexReduceScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs);
                op.setFinalResult((int)result);
            } else {
                float result = nativeOps.execIndexReduceScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs);
                op.setFinalResult((int)result);
            }
        } else {
            Arrays.sort(dimension);
            Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
            Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
            Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                nativeOps.execIndexReduceDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, dimension.length);
            } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
                nativeOps.execIndexReduceFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, dimension.length);
            } else {
                nativeOps.execIndexReduceHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, dimension.length);
            }
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return null;
    }

    protected CudaContext invoke(Accumulation op, int[] dimension) {
        int[] retShape;
        int[] nArray;
        Pointer extraArgs;
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (dimension == null) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        Arrays.sort(dimension);
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        if (op.y() != null) {
            Pair yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension);
            Pointer yDevTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)yTadBuffers.getFirst(), context);
            DataBuffer yOffsets = (DataBuffer)yTadBuffers.getSecond();
            Pointer yDevTadOffsets = yOffsets == null ? null : AtomicAllocator.getInstance().getPointer(yOffsets, context);
            xShapeInfoHostPointer.put(12L, yDevTadShapeInfo);
            xShapeInfoHostPointer.put(13L, yDevTadOffsets);
        }
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer pointer = extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        if (Shape.wholeArrayDimension((int[])dimension)) {
            int[] nArray2 = new int[2];
            nArray2[0] = 1;
            nArray = nArray2;
            nArray2[1] = 1;
        } else {
            nArray = retShape = ArrayUtil.removeIndex((int[])op.x().shape(), (int[])dimension);
        }
        if (retShape.length == 1) {
            retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
        } else if (retShape.length == 0) {
            retShape = new int[]{1, 1};
        }
        if (op.x().isVector() && op.x().length() == ArrayUtil.prod((int[])retShape)) {
            return null;
        }
        INDArray ret = null;
        ret = op.zeroDouble() > (double)-0.01f && op.zeroDouble() < (double)0.01f ? Nd4j.zeros((int[])retShape) : Nd4j.valueArrayOf((int[])retShape, (double)op.zeroDouble());
        op.setZ(ret);
        if (op.z().isScalar()) {
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (op instanceof Variance) {
                    double result = nativeOps.execSummaryStatsScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, ((Variance)op).isBiasCorrected());
                    op.setFinalResult((Number)result);
                } else if (op.y() != null) {
                    Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                    Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                    double result = nativeOps.execReduce3ScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)y, (IntPointer)yShapeInfo);
                    op.setFinalResult((Number)result);
                } else {
                    double result = nativeOps.execReduceScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs);
                    op.setFinalResult((Number)result);
                }
            } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
                if (op instanceof Variance) {
                    float result = nativeOps.execSummaryStatsScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, ((Variance)op).isBiasCorrected());
                    op.setFinalResult((Number)Float.valueOf(result));
                } else if (op.y() != null) {
                    Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                    Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                    float result = nativeOps.execReduce3ScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)y, (IntPointer)yShapeInfo);
                    op.setFinalResult((Number)Float.valueOf(result));
                } else {
                    float result = nativeOps.execReduceScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs);
                    op.setFinalResult((Number)Float.valueOf(result));
                }
            } else if (op instanceof Variance) {
                float result = nativeOps.execSummaryStatsScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, ((Variance)op).isBiasCorrected());
                op.setFinalResult((Number)Float.valueOf(result));
            } else if (op.y() != null) {
                Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                float result = nativeOps.execReduce3ScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)y, (IntPointer)yShapeInfo);
                op.setFinalResult((Number)Float.valueOf(result));
            } else {
                float result = nativeOps.execReduceScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs);
                op.setFinalResult((Number)Float.valueOf(result));
            }
        } else {
            Pointer result = AtomicAllocator.getInstance().getPointer(op.z(), context);
            Pointer resultShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
            Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (op.y() != null) {
                    Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                    Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                    nativeOps.execReduce3Double(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)y, (IntPointer)yShapeInfo, (DoublePointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length);
                } else if (op instanceof Variance) {
                    nativeOps.execSummaryStatsDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
                } else {
                    nativeOps.execReduceDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length);
                }
            } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
                if (op.y() != null) {
                    Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                    Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                    nativeOps.execReduce3Float(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)y, (IntPointer)yShapeInfo, (FloatPointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length);
                } else if (op instanceof Variance) {
                    nativeOps.execSummaryStatsFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
                } else {
                    nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length);
                }
            } else if (op.y() != null) {
                Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                nativeOps.execReduce3Half(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)y, (IntPointer)yShapeInfo, (ShortPointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length);
            } else if (op instanceof Variance) {
                nativeOps.execSummaryStatsHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
            } else {
                nativeOps.execReduceHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length);
            }
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return context;
    }

    protected CudaContext intercept(ScalarOp op, int[] dimension) {
        long st = this.profilingHookIn((Op)op);
        Arrays.sort(dimension);
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;
        Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), dimension);
        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getSecond(), context);
        PointerPointer extraPointers = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, devTadShapeInfoZ, devTadOffsetsZ});
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        if (op.x().data().dataType() == DataBuffer.Type.HALF) {
            nativeOps.execScalarHalf(extraPointers, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)z, (IntPointer)zShapeInfo, (ShortPointer)y, (ShortPointer)extraArgs, (IntPointer)dimensionPointer, dimension.length);
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execScalarFloat(extraPointers, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)z, (IntPointer)zShapeInfo, (FloatPointer)y, (FloatPointer)extraArgs, (IntPointer)dimensionPointer, dimension.length);
        } else if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execScalarDouble(extraPointers, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)z, (IntPointer)zShapeInfo, (DoublePointer)y, (DoublePointer)extraArgs, (IntPointer)dimensionPointer, dimension.length);
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return null;
    }

    protected CudaContext invoke(ScalarOp op) {
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (op.getDimension() != null) {
            this.intercept(op, op.getDimension());
            return null;
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, null, null});
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (op.x().elementWiseStride() >= 1 && op.z().ordering() == op.x().ordering()) {
                nativeOps.execScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, op.x().elementWiseStride(), (DoublePointer)z, op.z().elementWiseStride(), op.scalar().doubleValue(), (DoublePointer)extraArgs, op.n());
            } else {
                nativeOps.execScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)z, (IntPointer)zShapeInfo, op.scalar().doubleValue(), (DoublePointer)extraArgs);
            }
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            if (op.x().elementWiseStride() >= 1 && op.z().ordering() == op.x().ordering()) {
                nativeOps.execScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, op.x().elementWiseStride(), (FloatPointer)z, op.z().elementWiseStride(), op.scalar().floatValue(), (FloatPointer)extraArgs, op.n());
            } else {
                nativeOps.execScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)z, (IntPointer)zShapeInfo, op.scalar().floatValue(), (FloatPointer)extraArgs);
            }
        } else if (op.x().elementWiseStride() >= 1 && op.z().ordering() == op.x().ordering()) {
            nativeOps.execScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, op.x().elementWiseStride(), (ShortPointer)z, op.z().elementWiseStride(), op.scalar().floatValue(), (ShortPointer)extraArgs, op.n());
        } else {
            nativeOps.execScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)z, (IntPointer)zShapeInfo, op.scalar().floatValue(), (ShortPointer)extraArgs);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return null;
    }

    protected CudaContext invoke(TransformOp op) {
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        INDArray ret = null;
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer dimensionDevPointer = null;
        Pointer dimensionHostPointer = null;
        Pointer retPointer = null;
        int[] dimension = null;
        if (op.opNum() == 41 && op.extraArgs() != null) {
            int[] retShape;
            int[] nArray;
            dimension = new int[]{(Integer)op.extraArgs()[1]};
            for (int i = 0; i < dimension.length; ++i) {
                if (dimension[i] >= 0) continue;
                int n = i;
                dimension[n] = dimension[n] + op.x().rank();
            }
            if (dimension.length == op.x().rank()) {
                dimension = new int[]{Integer.MAX_VALUE};
            }
            if (Shape.wholeArrayDimension((int[])dimension)) {
                int[] nArray2 = new int[2];
                nArray2[0] = 1;
                nArray = nArray2;
                nArray2[1] = 1;
            } else {
                nArray = retShape = ArrayUtil.removeIndex((int[])op.x().shape(), (int[])dimension);
            }
            if (retShape.length == 1) {
                retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
            } else if (retShape.length == 0) {
                retShape = new int[]{1, 1};
            }
            ret = Nd4j.zeros((int[])retShape);
            hostYShapeInfo = AtomicAllocator.getInstance().getPointer(ret.shapeInfoDataBuffer(), context);
            DataBuffer dimensionBuffer = AtomicAllocator.getInstance().getConstantBuffer(dimension);
            dimensionDevPointer = AtomicAllocator.getInstance().getPointer(dimensionBuffer, context);
            dimensionHostPointer = AtomicAllocator.getInstance().getHostPointer(dimensionBuffer);
            retPointer = AtomicAllocator.getInstance().getPointer(ret, context);
        }
        Pointer hostTadShapeInfo = null;
        Pointer devTadShapeInfo = null;
        Pointer hostMaxTadShapeInfo = null;
        Pointer devMaxTadShapeInfo = null;
        Pointer devTadOffsets = null;
        Pointer devMaxTadOffsets = null;
        if (op.opNum() >= 38 && op.opNum() <= 41) {
            DataBuffer offsets;
            if (op.opNum() != 41) {
                Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), new int[]{0});
                Pair tadMaxBuffers = tadManager.getTADOnlyShapeInfo(op.x(), new int[]{1});
                hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
                devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
                hostMaxTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadMaxBuffers.getFirst());
                devMaxTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadMaxBuffers.getFirst(), context);
                offsets = (DataBuffer)tadBuffers.getSecond();
                devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
                DataBuffer maxOffsets = (DataBuffer)tadMaxBuffers.getSecond();
                devMaxTadOffsets = maxOffsets == null ? null : AtomicAllocator.getInstance().getPointer(maxOffsets, context);
            } else {
                Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.z(), dimension);
                hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
                devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
                offsets = (DataBuffer)tadBuffers.getSecond();
                devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
            }
        }
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, hostMaxTadShapeInfo, devMaxTadShapeInfo, devMaxTadOffsets, dimensionDevPointer, dimensionHostPointer, retPointer});
        if (op.y() != null) {
            Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
            Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (op.x().elementWiseStride() >= 1 && op.y().elementWiseStride() >= 1 && !op.isExecSpecial() && op.x().ordering() == op.y().ordering() && op.x().ordering() == op.z().ordering()) {
                    nativeOps.execPairwiseTransformDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, op.x().elementWiseStride(), (DoublePointer)y, op.y().elementWiseStride(), (DoublePointer)z, op.z().elementWiseStride(), (DoublePointer)extraArgs, op.n());
                } else {
                    nativeOps.execPairwiseTransformDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)y, (IntPointer)yShapeInfo, (DoublePointer)z, (IntPointer)zShapeInfo, (DoublePointer)extraArgs);
                }
            } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
                if (op.x().elementWiseStride() >= 1 && op.y().elementWiseStride() >= 1 && op.x().elementWiseStride() == op.y().elementWiseStride() && !op.isExecSpecial() && op.x().ordering() == op.y().ordering() && op.x().ordering() == op.z().ordering()) {
                    nativeOps.execPairwiseTransformFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, op.x().elementWiseStride(), (FloatPointer)y, op.y().elementWiseStride(), (FloatPointer)z, op.z().elementWiseStride(), (FloatPointer)extraArgs, op.n());
                } else {
                    nativeOps.execPairwiseTransformFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)y, (IntPointer)yShapeInfo, (FloatPointer)z, (IntPointer)zShapeInfo, (FloatPointer)extraArgs);
                }
            } else if (op.x().elementWiseStride() >= 1 && op.y().elementWiseStride() >= 1 && op.x().elementWiseStride() == op.y().elementWiseStride() && !op.isExecSpecial() && op.x().ordering() == op.y().ordering() && op.x().ordering() == op.z().ordering()) {
                nativeOps.execPairwiseTransformHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, op.x().elementWiseStride(), (ShortPointer)y, op.y().elementWiseStride(), (ShortPointer)z, op.z().elementWiseStride(), (ShortPointer)extraArgs, op.n());
            } else {
                nativeOps.execPairwiseTransformHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)y, (IntPointer)yShapeInfo, (ShortPointer)z, (IntPointer)zShapeInfo, (ShortPointer)extraArgs);
            }
        } else if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.z().ordering() == op.x().ordering()) {
                nativeOps.execTransformDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, op.x().elementWiseStride(), (DoublePointer)z, op.z().elementWiseStride(), (DoublePointer)extraArgs, op.n());
            } else {
                nativeOps.execTransformDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)z, (IntPointer)zShapeInfo, (DoublePointer)extraArgs);
            }
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.z().ordering() == op.x().ordering()) {
                nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, op.x().elementWiseStride(), (FloatPointer)z, op.z().elementWiseStride(), (FloatPointer)extraArgs, op.n());
            } else {
                nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)z, (IntPointer)zShapeInfo, (FloatPointer)extraArgs);
            }
        } else if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.z().ordering() == op.x().ordering()) {
            nativeOps.execTransformHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, op.x().elementWiseStride(), (ShortPointer)z, op.z().elementWiseStride(), (ShortPointer)extraArgs, op.n());
        } else {
            nativeOps.execTransformHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)z, (IntPointer)zShapeInfo, (ShortPointer)extraArgs);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        if (extraArgs != null) {
            extraArgs.address();
        }
        if (ret != null) {
            ret.elementWiseStride();
        }
        this.profilingHookOut((Op)op, st);
        return null;
    }

    protected <T extends Aggregate> DataBuffer getBuffer(Batch<T> batch) {
        DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(batch.getSample().getRequiredBatchMemorySize() * 4L, false);
        batch.setParamsSurface(buffer);
        return buffer;
    }

    public <T extends Aggregate> void exec(Batch<T> batch) {
        DataBuffer surfaceBuffer = this.getBuffer(batch);
        CudaContext context = (CudaContext)AtomicAllocator.getInstance().getDeviceContext().getContext();
        IntPointer pointer = new CudaPointer(AtomicAllocator.getInstance().getHostPointer(surfaceBuffer)).asIntPointer();
        AllocationPoint surfacePoint = AtomicAllocator.getInstance().getAllocationPoint(surfaceBuffer);
        int maxTypes = 5;
        int maxIntArrays = batch.getSample().maxIntArrays();
        int maxArraySize = batch.getSample().maxIntArraySize();
        int indexPos = maxTypes * (Batch.getBatchLimit() * 16);
        int intArraysPos = indexPos + batch.getSample().maxIndexArguments() * (Batch.getBatchLimit() * 16);
        int realPos = (intArraysPos + maxIntArrays * maxArraySize * (Batch.getBatchLimit() * 16)) / (Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 2 : 1);
        if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            realPos *= 2;
        }
        int argsPos = (realPos + batch.getSample().maxRealArguments() * (Batch.getBatchLimit() * 16)) / (Nd4j.dataType() == DataBuffer.Type.FLOAT ? 2 : 1);
        if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            argsPos /= 4;
        }
        int shapesPos = argsPos + batch.getSample().maxArguments() * (Batch.getBatchLimit() * 16);
        for (int i = 0; i < batch.getNumAggregates(); ++i) {
            int e;
            Aggregate op = (Aggregate)batch.getAggregates().get(i);
            int idx = i * maxTypes;
            pointer.put((long)idx, op.getArguments().size());
            pointer.put((long)(idx + 1), op.getShapes().size());
            pointer.put((long)(idx + 2), op.getIndexingArguments().size());
            pointer.put((long)(idx + 3), op.getRealArguments().size());
            pointer.put((long)(idx + 4), op.getIntArrayArguments().size());
            for (int e2 = 0; e2 < op.getIndexingArguments().size(); ++e2) {
                idx = indexPos + i * batch.getSample().maxIndexArguments();
                pointer.put((long)(idx + e2), ((Integer)op.getIndexingArguments().get(e2)).intValue());
            }
            int bsize = maxIntArrays * maxArraySize;
            for (int e3 = 0; e3 < op.getIntArrayArguments().size(); ++e3) {
                int step = i * bsize + e3 * maxArraySize;
                if (op.getIntArrayArguments().get(e3) == null) continue;
                for (int x = 0; x < ((int[])op.getIntArrayArguments().get(e3)).length; ++x) {
                    idx = intArraysPos + step + x;
                    pointer.put((long)idx, ((int[])op.getIntArrayArguments().get(e3))[x]);
                }
            }
            if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
                FloatPointer realPtr = new FloatPointer((Pointer)pointer);
                for (e = 0; e < op.getRealArguments().size(); ++e) {
                    idx = realPos + i * op.maxRealArguments();
                    realPtr.put((long)(idx + e), ((Number)op.getRealArguments().get(e)).floatValue());
                }
            } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
                DoublePointer dPtr = new DoublePointer((Pointer)pointer);
                for (e = 0; e < op.getRealArguments().size(); ++e) {
                    idx = realPos + i * op.maxRealArguments();
                    dPtr.put((long)(idx + e), ((Number)op.getRealArguments().get(e)).doubleValue());
                }
            } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
                ShortPointer sPtr = new ShortPointer((Pointer)pointer);
                for (e = 0; e < op.getRealArguments().size(); ++e) {
                    idx = realPos + i * op.maxRealArguments();
                    sPtr.put((long)(idx + e), BaseDataBuffer.fromFloat((float)((Number)op.getRealArguments().get(e)).floatValue()));
                }
            }
            PointerPointer ptrPtr = new PointerPointer((Pointer)pointer);
            for (e = 0; e < op.getArguments().size(); ++e) {
                idx = argsPos + i * batch.getSample().maxArguments();
                if (op.getArguments().get(e) == null) continue;
                ptrPtr.put((long)(idx + e), AtomicAllocator.getInstance().getPointer((INDArray)op.getArguments().get(e), context));
                AtomicAllocator.getInstance().getAllocationPoint((INDArray)op.getArguments().get(e)).tickDeviceWrite();
            }
            for (e = 0; e < op.getShapes().size(); ++e) {
                idx = shapesPos + i * batch.getSample().maxShapes();
                if (op.getShapes().get(e) == null) continue;
                ptrPtr.put((long)(idx + e), AtomicAllocator.getInstance().getPointer((DataBuffer)op.getShapes().get(e), context));
                AtomicAllocator.getInstance().getAllocationPoint((DataBuffer)op.getShapes().get(e)).tickDeviceWrite();
            }
        }
        surfacePoint.tickHostWrite();
        PointerPointer extraArgs = new PointerPointer(32L);
        extraArgs.put(0L, null);
        extraArgs.put(1L, (Pointer)context.getOldStream());
        extraArgs.put(2L, (Pointer)new CudaPointer(Math.min(batch.getNumAggregates(), CudaEnvironment.getInstance().getConfiguration().getMaximumGridSize())));
        extraArgs.put(3L, (Pointer)new CudaPointer(batch.getSample().getThreadsPerInstance()));
        extraArgs.put(4L, (Pointer)new CudaPointer(batch.getSample().getSharedMemorySize()));
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execAggregateBatchFloat(extraArgs, batch.getNumAggregates(), batch.opNum(), batch.getSample().maxArguments(), batch.getSample().maxShapes(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), AtomicAllocator.getInstance().getPointer(surfaceBuffer, context));
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execAggregateBatchDouble(extraArgs, batch.getNumAggregates(), batch.opNum(), batch.getSample().maxArguments(), batch.getSample().maxShapes(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), AtomicAllocator.getInstance().getPointer(surfaceBuffer, context));
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            nativeOps.execAggregateBatchHalf(extraArgs, batch.getNumAggregates(), batch.opNum(), batch.getSample().maxArguments(), batch.getSample().maxShapes(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), AtomicAllocator.getInstance().getPointer(surfaceBuffer, context));
        }
        surfacePoint.tickHostWrite();
    }

    public void exec(List<Aggregate> batch) {
        if (batch.size() == 0) {
            return;
        }
        List batches = Batch.getBatches(batch, (int)8192);
        for (Batch single : batches) {
            this.exec(single);
        }
        CudaContext context = (CudaContext)AtomicAllocator.getInstance().getDeviceContext().getContext();
        context.syncOldStream();
    }

    public void exec(Aggregate op) {
        int numArguments = op.getArguments().size();
        int numShapeArguments = op.getShapes().size();
        int numIndexArguments = op.getIndexingArguments().size();
        int numIntArrays = op.getIntArrayArguments().size();
        int numRealArguments = op.getRealArguments().size();
        CudaContext context = (CudaContext)AtomicAllocator.getInstance().getDeviceContext().getContext();
        PointerPointer extraArgs = new PointerPointer(32L);
        extraArgs.put(0L, null);
        extraArgs.put(1L, (Pointer)context.getOldStream());
        extraArgs.put(2L, (Pointer)new CudaPointer(1L));
        extraArgs.put(3L, (Pointer)new CudaPointer(op.getThreadsPerInstance()));
        extraArgs.put(4L, (Pointer)new CudaPointer(op.getSharedMemorySize()));
        long[] arguments = new long[numArguments];
        for (int x = 0; x < numArguments; ++x) {
            long l = arguments[x] = op.getArguments().get(x) == null ? 0L : AtomicAllocator.getInstance().getPointer((INDArray)op.getArguments().get(x), context).address();
            if (op.getArguments().get(x) == null) continue;
            AtomicAllocator.getInstance().getAllocationPoint((INDArray)op.getArguments().get(x)).tickDeviceWrite();
        }
        DataBuffer tempX = AllocationUtils.getPointersBuffer(arguments);
        PointerPointer xPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context));
        long[] shapes = new long[numShapeArguments];
        for (int x = 0; x < numShapeArguments; ++x) {
            long l = shapes[x] = op.getShapes().get(x) == null ? 0L : AtomicAllocator.getInstance().getPointer((DataBuffer)op.getShapes().get(x), context).address();
            if (op.getShapes().get(x) == null) continue;
            AtomicAllocator.getInstance().getAllocationPoint((DataBuffer)op.getShapes().get(x)).tickDeviceWrite();
        }
        DataBuffer tempS = AllocationUtils.getPointersBuffer(shapes);
        PointerPointer sPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempS, context));
        long[] ints = new long[numIntArrays];
        for (int x = 0; x < numIntArrays; ++x) {
            if (op.getIntArrayArguments().get(x) == null) continue;
            DataBuffer intBuf = Nd4j.getDataBufferFactory().createInt((int[])op.getIntArrayArguments().get(x));
            ints[x] = AtomicAllocator.getInstance().getPointer(intBuf, context).address();
        }
        DataBuffer tempI = AllocationUtils.getPointersBuffer(ints);
        PointerPointer iPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempI, context));
        int[] indexes = new int[numIndexArguments];
        for (int x = 0; x < numIndexArguments; ++x) {
            indexes[x] = (Integer)op.getIndexingArguments().get(x);
        }
        DataBuffer intBuffer = Nd4j.getDataBufferFactory().createInt(indexes);
        double[] reals = new double[numRealArguments];
        for (int x = 0; x < numRealArguments; ++x) {
            reals[x] = ((Number)op.getRealArguments().get(x)).doubleValue();
        }
        INDArray realsBuffer = Nd4j.create((double[])reals);
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execAggregateFloat(extraArgs, op.opNum(), xPtr, numArguments, sPtr, numShapeArguments, (IntPointer)AtomicAllocator.getInstance().getPointer(intBuffer, context), numIndexArguments, iPtr, numIntArrays, (FloatPointer)AtomicAllocator.getInstance().getPointer(realsBuffer.data(), context), numRealArguments);
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execAggregateDouble(extraArgs, op.opNum(), xPtr, numArguments, sPtr, numShapeArguments, (IntPointer)AtomicAllocator.getInstance().getPointer(intBuffer, context), numIndexArguments, iPtr, numIntArrays, (DoublePointer)AtomicAllocator.getInstance().getPointer(realsBuffer.data(), context), numRealArguments);
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            nativeOps.execAggregateHalf(extraArgs, op.opNum(), xPtr, numArguments, sPtr, numShapeArguments, (IntPointer)AtomicAllocator.getInstance().getPointer(intBuffer, context), numIndexArguments, iPtr, numIntArrays, (ShortPointer)AtomicAllocator.getInstance().getPointer(realsBuffer.data(), context), numRealArguments);
        }
    }

    public INDArray exec(RandomOp op) {
        return this.exec(op, Nd4j.getRandom());
    }

    public INDArray exec(RandomOp op, Random rng) {
        long st = this.profilingHookIn((Op)op);
        if (rng.getStateBuffer() == null) {
            throw new IllegalStateException("You should use one of NativeRandom classes for NativeOperations execution");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        PointerPointer extraZZ = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()});
        if (op.x() != null && op.y() != null && op.z() != null) {
            if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
                nativeOps.execRandomFloat(extraZZ, op.opNum(), rng.getStatePointer(), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.x(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
            } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
                nativeOps.execRandomDouble(extraZZ, op.opNum(), rng.getStatePointer(), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.x(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
            } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
                nativeOps.execRandomHalf(extraZZ, op.opNum(), rng.getStatePointer(), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.x(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
            }
        } else if (op.x() != null && op.z() != null) {
            if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
                nativeOps.execRandomFloat(extraZZ, op.opNum(), rng.getStatePointer(), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.x(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
            } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
                nativeOps.execRandomDouble(extraZZ, op.opNum(), rng.getStatePointer(), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.x(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
            } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
                nativeOps.execRandomHalf(extraZZ, op.opNum(), rng.getStatePointer(), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.x(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
            }
        } else if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execRandomFloat(extraZZ, op.opNum(), rng.getStatePointer(), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execRandomDouble(extraZZ, op.opNum(), rng.getStatePointer(), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            nativeOps.execRandomHalf(extraZZ, op.opNum(), rng.getStatePointer(), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return op.z();
    }

    public synchronized Properties getEnvironmentInformation() {
        if (this.properties == null) {
            Properties props = super.getEnvironmentInformation();
            ArrayList devicesList = new ArrayList();
            for (int i = 0; i < nativeOps.getAvailableDevices(); ++i) {
                HashMap<String, Object> deviceProps = new HashMap<String, Object>();
                CudaPointer devPtr = new CudaPointer(i);
                deviceProps.put("cuda.deviceName", nativeOps.getDeviceName((Pointer)devPtr));
                deviceProps.put("cuda.freeMemory", nativeOps.getDeviceFreeMemory((Pointer)devPtr));
                deviceProps.put("cuda.totalMemory", nativeOps.getDeviceTotalMemory((Pointer)devPtr));
                deviceProps.put("cuda.deviceMajor", Long.valueOf(nativeOps.getDeviceMajor((Pointer)devPtr)));
                deviceProps.put("cuda.deviceMinor", Long.valueOf(nativeOps.getDeviceMinor((Pointer)devPtr)));
                devicesList.add(i, deviceProps);
            }
            props.put("backend", "CUDA");
            props.put("cuda.availableDevices", (Object)nativeOps.getAvailableDevices());
            props.put("cuda.devicesInformation", devicesList);
            props.put("blas.vendor", Nd4jBlas.Vendor.CUBLAS.toString());
            this.properties = props;
        }
        return this.properties;
    }

    public TADManager getTADManager() {
        return tadManager;
    }
}

