/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.jcublas.ops.executioner;

import java.util.Arrays;
import org.apache.commons.math3.util.Pair;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.tad.DeviceTADManager;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.CopyOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JCudaExecutioner
extends DefaultOpExecutioner {
    private static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private static Logger log = LoggerFactory.getLogger(JCudaExecutioner.class);
    private static TADManager tadManager = new DeviceTADManager();

    public NativeOps getNativeOps() {
        return nativeOps;
    }

    public INDArray exec(BroadcastOp op, int ... dimension) {
        Arrays.sort(dimension);
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        PointerPointer xShapeInfoHostPointer = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execBroadcastDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, y, AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), z, AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), dimensionPointer, dimension.length);
        } else {
            nativeOps.execBroadcastFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, y, AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), z, AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), dimensionPointer, dimension.length);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        return op.z();
    }

    public INDArray exec(Accumulation op, int ... dimension) {
        int[] retShape;
        int[] nArray;
        Arrays.sort(dimension);
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] >= 0) continue;
            int n = i;
            dimension[n] = dimension[n] + op.x().rank();
        }
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (Shape.wholeArrayDimension((int[])dimension)) {
            int[] nArray2 = new int[2];
            nArray2[0] = 1;
            nArray = nArray2;
            nArray2[1] = 1;
        } else {
            nArray = retShape = ArrayUtil.removeIndex((int[])op.x().shape(), (int[])dimension);
        }
        if (retShape.length == 1) {
            retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
        } else if (retShape.length == 0) {
            retShape = new int[]{1, 1};
        }
        if (op.x().isVector() && op.x().length() == ArrayUtil.prod((int[])retShape)) {
            return op.noOp();
        }
        INDArray ret = null;
        ret = op.zeroDouble() > (double)-0.01f && op.zeroDouble() < (double)0.01f ? Nd4j.zeros((int[])retShape) : Nd4j.valueArrayOf((int[])retShape, (double)op.zeroDouble());
        op.setZ(ret);
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        PointerPointer xShapeInfoHostPointer = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        Pointer extraArgs = op.extraArgs() != null && op instanceof Variance ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (op instanceof Variance) {
                if (ret.isScalar()) {
                    AtomicAllocator.getInstance().tickHostWrite(ret);
                    ret.putScalar(0, nativeOps.execSummaryStatsScalarDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, true));
                    op.setFinalResult((Number)ret.getDouble(0));
                } else {
                    nativeOps.execSummaryStatsDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, AtomicAllocator.getInstance().getPointer(op.z(), context), AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                }
            } else if (op.y() != null) {
                if (ret.isScalar()) {
                    AtomicAllocator.getInstance().tickHostWrite(ret);
                    ret.putScalar(0, nativeOps.execReduce3ScalarDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, AtomicAllocator.getInstance().getPointer(op.y(), context), AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context)));
                    op.setFinalResult((Number)ret.getDouble(0));
                } else {
                    nativeOps.execReduce3Double(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, AtomicAllocator.getInstance().getPointer(op.y(), context), AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), AtomicAllocator.getInstance().getPointer(op.z(), context), AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), dimensionPointer, dimension.length);
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                }
            } else if (ret.isScalar()) {
                AtomicAllocator.getInstance().tickHostWrite(ret);
                ret.putScalar(0, nativeOps.execReduceScalarDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs));
                op.setFinalResult((Number)ret.getDouble(0));
            } else {
                nativeOps.execReduceDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, AtomicAllocator.getInstance().getPointer(op.z(), context), AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), dimensionPointer, dimension.length);
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            }
        } else if (op instanceof Variance) {
            if (ret.isScalar()) {
                AtomicAllocator.getInstance().tickHostWrite(ret);
                ret.putScalar(0, nativeOps.execSummaryStatsScalarFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, true));
                op.setFinalResult((Number)Float.valueOf(ret.getFloat(0)));
            } else {
                nativeOps.execSummaryStatsFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, AtomicAllocator.getInstance().getPointer(op.z(), context), AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            }
        } else if (op.y() != null) {
            if (ret.isScalar()) {
                AtomicAllocator.getInstance().tickHostWrite(ret);
                ret.putScalar(0, nativeOps.execReduce3ScalarFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, AtomicAllocator.getInstance().getPointer(op.y(), context), AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context)));
                op.setFinalResult((Number)Float.valueOf(ret.getFloat(0)));
            } else {
                nativeOps.execReduce3Float(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, AtomicAllocator.getInstance().getPointer(op.y(), context), AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), AtomicAllocator.getInstance().getPointer(op.z(), context), AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), dimensionPointer, dimension.length);
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            }
        } else if (ret.isScalar()) {
            AtomicAllocator.getInstance().tickHostWrite(ret);
            ret.putScalar(0, nativeOps.execReduceScalarFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs));
            op.setFinalResult((Number)Float.valueOf(ret.getFloat(0)));
        } else {
            nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, AtomicAllocator.getInstance().getPointer(op.z(), context), AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), dimensionPointer, dimension.length);
            AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        }
        return ret;
    }

    public INDArray exec(IndexAccumulation op, int ... dimension) {
        int[] retShape;
        int[] nArray;
        Arrays.sort(dimension);
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] >= 0) continue;
            int n = i;
            dimension[n] = dimension[n] + op.x().rank();
        }
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (Shape.wholeArrayDimension((int[])dimension)) {
            int[] nArray2 = new int[2];
            nArray2[0] = 1;
            nArray = nArray2;
            nArray2[1] = 1;
        } else {
            nArray = retShape = ArrayUtil.removeIndex((int[])op.x().shape(), (int[])dimension);
        }
        if (op.x().isVector() && op.x().length() == ArrayUtil.prod((int[])retShape)) {
            return op.x();
        }
        if (retShape.length == 1) {
            retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
        } else if (retShape.length == 0) {
            retShape = new int[]{1, 1};
        }
        INDArray ret = null;
        ret = op.zeroDouble() > (double)-0.01f && op.zeroDouble() < (double)0.01f ? Nd4j.zeros((int[])retShape) : Nd4j.valueArrayOf((int[])retShape, (double)op.zeroDouble());
        op.setZ(ret);
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        PointerPointer xShapeInfoHostPointer = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execIndexReduceDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, z, zShapeInfo, dimensionPointer, dimension.length);
        } else {
            nativeOps.execIndexReduceFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, z, zShapeInfo, dimensionPointer, dimension.length);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        return op.z();
    }

    public Op exec(Op op, int ... dimension) {
        Arrays.sort(dimension);
        return super.exec(op, dimension);
    }

    public Op exec(Op op) {
        if (op.x() instanceof IComplexNDArray || this.executionMode() == OpExecutioner.ExecutionMode.JAVA || op instanceof CopyOp) {
            if (op.x() != null) {
                AtomicAllocator.getInstance().synchronizeHostData(op.x());
            }
            if (op.y() != null) {
                AtomicAllocator.getInstance().synchronizeHostData(op.y());
            }
            super.exec(op);
            if (op.z() != null) {
                AtomicAllocator.getInstance().tickHostWrite(op.z());
            }
            return null;
        }
        if (op instanceof TransformOp) {
            TransformOp t = (TransformOp)op;
            this.invoke(t);
        } else if (op instanceof Accumulation) {
            Accumulation acc = (Accumulation)op;
            this.invoke(acc, null);
        } else if (op instanceof ScalarOp) {
            ScalarOp sc = (ScalarOp)op;
            this.invoke(sc);
        } else if (op instanceof BroadcastOp) {
            BroadcastOp broadcastOp = (BroadcastOp)op;
            this.invoke(broadcastOp);
        } else if (op instanceof IndexAccumulation) {
            IndexAccumulation indexAccumulation = (IndexAccumulation)op;
            this.invoke(indexAccumulation, null);
        }
        return op;
    }

    public INDArray execAndReturn(TransformOp op) {
        this.invoke(op);
        return op.z();
    }

    private CudaContext invoke(BroadcastOp op) {
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), op.getDimension());
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        PointerPointer xShapeInfoHostPointer = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
        Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(op.getDimension()), context);
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execBroadcastDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimensionPointer, op.getDimension().length);
        } else {
            nativeOps.execBroadcastFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimensionPointer, op.getDimension().length);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        return null;
    }

    private CudaContext invoke(IndexAccumulation op, int[] dimension) {
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        int[] fdimension = dimension;
        if (fdimension == null) {
            fdimension = new int[]{0};
        }
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), fdimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        PointerPointer xShapeInfoHostPointer = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        if (op.z().isScalar() || dimension == null || dimension[0] == Integer.MAX_VALUE) {
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                double result = nativeOps.execIndexReduceScalarDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs);
                op.setFinalResult((int)result);
            } else {
                float result = nativeOps.execIndexReduceScalarFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs);
                op.setFinalResult((int)result);
            }
        } else {
            Arrays.sort(dimension);
            Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
            Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
            Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                nativeOps.execIndexReduceDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, z, zShapeInfo, dimensionPointer, dimension.length);
            } else {
                nativeOps.execIndexReduceFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, z, zShapeInfo, dimensionPointer, dimension.length);
            }
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        return null;
    }

    private CudaContext invoke(Accumulation op, int[] dimension) {
        int[] retShape;
        int[] nArray;
        Pointer extraArgs;
        if (dimension == null) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        Arrays.sort(dimension);
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        PointerPointer xShapeInfoHostPointer = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer pointer = extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        if (Shape.wholeArrayDimension((int[])dimension)) {
            int[] nArray2 = new int[2];
            nArray2[0] = 1;
            nArray = nArray2;
            nArray2[1] = 1;
        } else {
            nArray = retShape = ArrayUtil.removeIndex((int[])op.x().shape(), (int[])dimension);
        }
        if (retShape.length == 1) {
            retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
        } else if (retShape.length == 0) {
            retShape = new int[]{1, 1};
        }
        if (op.x().isVector() && op.x().length() == ArrayUtil.prod((int[])retShape)) {
            return null;
        }
        INDArray ret = null;
        ret = op.zeroDouble() > (double)-0.01f && op.zeroDouble() < (double)0.01f ? Nd4j.zeros((int[])retShape) : Nd4j.valueArrayOf((int[])retShape, (double)op.zeroDouble());
        op.setZ(ret);
        if (op.z().isScalar()) {
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (op instanceof Variance) {
                    double result = nativeOps.execSummaryStatsScalarDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, true);
                    op.setFinalResult((Number)result);
                } else if (op.y() != null) {
                    Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                    Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                    double result = nativeOps.execReduce3ScalarDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, y, yShapeInfo);
                    op.setFinalResult((Number)result);
                } else {
                    double result = nativeOps.execReduceScalarDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs);
                    op.setFinalResult((Number)result);
                }
            } else if (op instanceof Variance) {
                float result = nativeOps.execSummaryStatsScalarFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, true);
                op.setFinalResult((Number)Float.valueOf(result));
            } else if (op.y() != null) {
                Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                float result = nativeOps.execReduce3ScalarFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, y, yShapeInfo);
                op.setFinalResult((Number)Float.valueOf(result));
            } else {
                float result = nativeOps.execReduceScalarFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs);
                op.setFinalResult((Number)Float.valueOf(result));
            }
        } else {
            Pointer result = AtomicAllocator.getInstance().getPointer(op.z(), context);
            Pointer resultShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
            Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (op.y() != null) {
                    Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                    Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                    nativeOps.execReduce3Double(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, y, yShapeInfo, result, resultShapeInfo, dimensionPointer, dimension.length);
                } else if (op instanceof Variance) {
                    nativeOps.execSummaryStatsDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, result, resultShapeInfo, dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
                } else {
                    nativeOps.execReduceDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, result, resultShapeInfo, dimensionPointer, dimension.length);
                }
            } else if (op.y() != null) {
                Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                nativeOps.execReduce3Float(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, y, yShapeInfo, result, resultShapeInfo, dimensionPointer, dimension.length);
            } else if (op instanceof Variance) {
                nativeOps.execSummaryStatsFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, result, resultShapeInfo, dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
            } else {
                nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, extraArgs, result, resultShapeInfo, dimensionPointer, dimension.length);
            }
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        return context;
    }

    private CudaContext invoke(ScalarOp op) {
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        PointerPointer xShapeInfoHostPointer = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, null, null});
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execScalarDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, z, zShapeInfo, op.scalar().doubleValue(), extraArgs);
        } else {
            nativeOps.execScalarFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, z, zShapeInfo, op.scalar().floatValue(), extraArgs);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        return null;
    }

    private CudaContext invoke(TransformOp op) {
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        INDArray ret = null;
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer dimensionDevPointer = null;
        Pointer dimensionHostPointer = null;
        int[] dimension = null;
        if (op.opNum() == 41 && op.extraArgs() != null) {
            int[] retShape;
            int[] nArray;
            dimension = new int[]{(Integer)op.extraArgs()[1]};
            for (int i = 0; i < dimension.length; ++i) {
                if (dimension[i] >= 0) continue;
                int n = i;
                dimension[n] = dimension[n] + op.x().rank();
            }
            if (dimension.length == op.x().rank()) {
                dimension = new int[]{Integer.MAX_VALUE};
            }
            if (Shape.wholeArrayDimension((int[])dimension)) {
                int[] nArray2 = new int[2];
                nArray2[0] = 1;
                nArray = nArray2;
                nArray2[1] = 1;
            } else {
                nArray = retShape = ArrayUtil.removeIndex((int[])op.x().shape(), (int[])dimension);
            }
            if (retShape.length == 1) {
                retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
            } else if (retShape.length == 0) {
                retShape = new int[]{1, 1};
            }
            ret = Nd4j.zeros((int[])retShape);
            hostYShapeInfo = AtomicAllocator.getInstance().getPointer(ret.shapeInfoDataBuffer(), context);
            DataBuffer dimensionBuffer = AtomicAllocator.getInstance().getConstantBuffer(dimension);
            dimensionDevPointer = AtomicAllocator.getInstance().getPointer(dimensionBuffer, context);
            dimensionHostPointer = AtomicAllocator.getInstance().getHostPointer(dimensionBuffer);
        }
        Pointer hostTadShapeInfo = null;
        Pointer devTadShapeInfo = null;
        Pointer hostMaxTadShapeInfo = null;
        Pointer devMaxTadShapeInfo = null;
        Pointer devTadOffsets = null;
        Pointer devMaxTadOffsets = null;
        if (op.opNum() >= 38 && op.opNum() <= 41) {
            DataBuffer offsets;
            if (op.opNum() != 41) {
                Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), new int[]{0});
                Pair tadMaxBuffers = tadManager.getTADOnlyShapeInfo(op.x(), new int[]{1});
                hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
                devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
                hostMaxTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadMaxBuffers.getFirst());
                devMaxTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadMaxBuffers.getFirst(), context);
                offsets = (DataBuffer)tadBuffers.getSecond();
                devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
                DataBuffer maxOffsets = (DataBuffer)tadMaxBuffers.getSecond();
                devMaxTadOffsets = maxOffsets == null ? null : AtomicAllocator.getInstance().getPointer(maxOffsets, context);
            } else {
                Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.z(), dimension);
                hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
                devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
                offsets = (DataBuffer)tadBuffers.getSecond();
                devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
            }
        }
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        PointerPointer xShapeInfoHostPointer = new PointerPointer(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, hostMaxTadShapeInfo, devMaxTadShapeInfo, devMaxTadOffsets, dimensionDevPointer, dimensionHostPointer});
        if (op.y() != null) {
            Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
            Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (op.x().elementWiseStride() >= 1 && op.y().elementWiseStride() >= 1 && !op.isExecSpecial() && op.x().ordering() == op.y().ordering() && op.x().ordering() == op.z().ordering()) {
                    nativeOps.execPairwiseTransformDouble(xShapeInfoHostPointer, op.opNum(), x, op.x().elementWiseStride(), y, op.y().elementWiseStride(), z, op.z().elementWiseStride(), extraArgs, op.n());
                } else {
                    nativeOps.execPairwiseTransformDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraArgs);
                }
            } else if (op.x().elementWiseStride() >= 1 && op.y().elementWiseStride() >= 1 && op.x().elementWiseStride() == op.y().elementWiseStride() && !op.isExecSpecial() && op.x().ordering() == op.y().ordering() && op.x().ordering() == op.z().ordering()) {
                nativeOps.execPairwiseTransformFloat(xShapeInfoHostPointer, op.opNum(), x, op.x().elementWiseStride(), y, op.y().elementWiseStride(), z, op.z().elementWiseStride(), extraArgs, op.n());
            } else {
                nativeOps.execPairwiseTransformFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraArgs);
            }
        } else if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.z().ordering() == op.x().ordering()) {
                nativeOps.execTransformDouble(xShapeInfoHostPointer, op.opNum(), x, op.x().elementWiseStride(), z, op.z().elementWiseStride(), extraArgs, op.n());
            } else {
                nativeOps.execTransformDouble(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, z, zShapeInfo, extraArgs);
            }
        } else if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.z().ordering() == op.x().ordering()) {
            nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(), x, op.x().elementWiseStride(), z, op.z().elementWiseStride(), extraArgs, op.n());
        } else {
            nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(), x, xShapeInfo, z, zShapeInfo, extraArgs);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        return null;
    }

    public static TADManager getTadManager() {
        return tadManager;
    }
}

