/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.jcublas.ops.executioner;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.ShortPointer;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.tad.DeviceTADManager;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.blas.Blas;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.CopyOp;
import org.nd4j.linalg.api.ops.impl.transforms.convolution.Pooling2D;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaExecutioner
extends DefaultOpExecutioner {
    private static final Logger log = LoggerFactory.getLogger(CudaExecutioner.class);
    protected static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected static TADManager tadManager = new DeviceTADManager();
    protected ThreadLocal<PointerPointer> extraz = new ThreadLocal();
    protected volatile transient Properties properties;
    protected ThreadLocal<String> lastOp = new ThreadLocal();

    public NativeOps getNativeOps() {
        return nativeOps;
    }

    public String getLastOp() {
        return this.lastOp.get();
    }

    public INDArray exec(BroadcastOp op, int ... dimension) {
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        CudaExecutioner.validateDataType((DataBuffer.Type)Nd4j.dataType(), (Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        Arrays.sort(dimension);
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.name());
        }
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;
        Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), dimension);
        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getSecond(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, devTadShapeInfoZ, devTadOffsetsZ});
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execBroadcastDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)y, (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (DoublePointer)z, (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execBroadcastFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)y, (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (FloatPointer)z, (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
        } else {
            nativeOps.execBroadcastHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)y, (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (ShortPointer)z, (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return op.z();
    }

    protected INDArray naiveExec(Accumulation op, int ... dimension) {
        long st = this.profilingHookIn((Op)op);
        INDArray ret = op.z();
        CudaExecutioner.validateDataType((DataBuffer.Type)Nd4j.dataType(), (Op)op);
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.name());
        }
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        Pointer yDevTadOffsets = null;
        Pointer yDevTadShapeInfo = null;
        if (op.y() != null) {
            if (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE || op.x().tensorAlongDimension(0, dimension).lengthLong() != op.y().lengthLong()) {
                if (!op.isComplexAccumulation() && op.x().lengthLong() != op.y().lengthLong()) {
                    throw new ND4JIllegalStateException("Op.X [" + op.x().lengthLong() + "] and Op.Y [" + op.y().lengthLong() + "] lengths should match");
                }
                Pair yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension);
                yDevTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)yTadBuffers.getFirst(), context);
                DataBuffer yOffsets = (DataBuffer)yTadBuffers.getSecond();
                yDevTadOffsets = yOffsets == null ? null : AtomicAllocator.getInstance().getPointer(yOffsets, context);
                xShapeInfoHostPointer.put(12L, yDevTadShapeInfo);
                xShapeInfoHostPointer.put(13L, yDevTadOffsets);
            } else {
                xShapeInfoHostPointer.put(12L, AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context));
                xShapeInfoHostPointer.put(13L, null);
            }
        }
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (op instanceof Variance) {
                if (ret.isScalar()) {
                    double res = nativeOps.execSummaryStatsScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, ((Variance)op).isBiasCorrected());
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                    ret.assign((Number)res);
                    op.setFinalResult((Number)res);
                } else {
                    nativeOps.execSummaryStatsDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                }
            } else if (op.y() != null) {
                if (op.isComplexAccumulation()) {
                    nativeOps.execReduce3AllDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length, (IntPointer)devTadShapeInfo, (LongPointer)new LongPointerWrapper(devTadOffsets), (IntPointer)yDevTadShapeInfo, (LongPointer)new LongPointerWrapper(yDevTadOffsets));
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                } else if (ret.isScalar()) {
                    double res = nativeOps.execReduce3ScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context));
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                    ret.assign((Number)res);
                    op.setFinalResult((Number)res);
                } else {
                    nativeOps.execReduce3Double(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                }
            } else if (ret.isScalar()) {
                double res = nativeOps.execReduceScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs);
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                ret.assign((Number)res);
                op.setFinalResult((Number)res);
            } else {
                nativeOps.execReduceDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            }
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            if (op instanceof Variance) {
                if (ret.isScalar()) {
                    float res = nativeOps.execSummaryStatsScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, ((Variance)op).isBiasCorrected());
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                    ret.assign((Number)Float.valueOf(res));
                    op.setFinalResult((Number)Float.valueOf(res));
                } else {
                    nativeOps.execSummaryStatsFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                }
            } else if (op.y() != null) {
                if (op.isComplexAccumulation()) {
                    nativeOps.execReduce3AllFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length, (IntPointer)devTadShapeInfo, (LongPointer)new LongPointerWrapper(devTadOffsets), (IntPointer)yDevTadShapeInfo, (LongPointer)new LongPointerWrapper(yDevTadOffsets));
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                } else if (ret.isScalar()) {
                    float res = nativeOps.execReduce3ScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context));
                    ret.assign((Number)Float.valueOf(res));
                    op.setFinalResult((Number)Float.valueOf(res));
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                } else {
                    nativeOps.execReduce3Float(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
                    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                }
            } else if (ret.isScalar()) {
                float res = nativeOps.execReduceScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs);
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                ret.assign((Number)Float.valueOf(res));
                op.setFinalResult((Number)Float.valueOf(res));
            } else {
                nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            }
        } else if (op instanceof Variance) {
            if (ret.isScalar()) {
                float res = nativeOps.execSummaryStatsScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, ((Variance)op).isBiasCorrected());
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                ret.assign((Number)Float.valueOf(res));
                op.setFinalResult((Number)Float.valueOf(res));
            } else {
                nativeOps.execSummaryStatsHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            }
        } else if (op.y() != null) {
            if (op.isComplexAccumulation()) {
                nativeOps.execReduce3AllHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length, (IntPointer)devTadShapeInfo, (LongPointer)new LongPointerWrapper(devTadOffsets), (IntPointer)yDevTadShapeInfo, (LongPointer)new LongPointerWrapper(yDevTadOffsets));
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            } else if (ret.isScalar()) {
                float res = nativeOps.execReduce3ScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context));
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
                ret.assign((Number)Float.valueOf(res));
                op.setFinalResult((Number)Float.valueOf(res));
            } else {
                nativeOps.execReduce3Half(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
                AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            }
        } else if (ret.isScalar()) {
            float res = nativeOps.execReduceScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs);
            AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
            ret.assign((Number)Float.valueOf(res));
            op.setFinalResult((Number)Float.valueOf(res));
        } else {
            nativeOps.execReduceHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer)dimensionPointer, dimension.length);
            AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        }
        this.profilingHookOut((Op)op, st);
        return op.z();
    }

    public INDArray exec(Accumulation op, int ... dimension) {
        int[] retShape;
        int[] nArray;
        int i;
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        CudaExecutioner.validateDataType((DataBuffer.Type)Nd4j.dataType(), (Op)op);
        Arrays.sort(dimension);
        for (i = 0; i < dimension.length; ++i) {
            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
        }
        for (i = 0; i < dimension.length; ++i) {
            if (dimension[i] >= 0) continue;
            int n = i;
            dimension[n] = dimension[n] + op.x().rank();
        }
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (Shape.wholeArrayDimension((int[])dimension)) {
            int[] nArray2 = new int[2];
            nArray2[0] = 1;
            nArray = nArray2;
            nArray2[1] = 1;
        } else {
            nArray = retShape = ArrayUtil.removeIndex((int[])op.x().shape(), (int[])dimension);
        }
        if (retShape.length == 1) {
            retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
        } else if (retShape.length == 0) {
            retShape = new int[]{1, 1};
        }
        if (op.x().isVector() && op.x().length() == ArrayUtil.prod((int[])retShape) && ArrayUtil.prodLong((int[])retShape) > 1L && op.y() == null) {
            return op.noOp();
        }
        INDArray ret = null;
        if (op.z() == null || op.z() == op.x()) {
            if (op.isComplexAccumulation()) {
                int xT = op.x().tensorssAlongDimension(dimension);
                int yT = op.y().tensorssAlongDimension(dimension);
                ret = Nd4j.create((int)xT, (int)yT);
            } else if (0.0 + Math.abs(op.zeroDouble()) <= Nd4j.EPS_THRESHOLD) {
                ret = Nd4j.zeros((int[])retShape);
            } else if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                ret = Nd4j.valueArrayOf((int[])retShape, (double)op.zeroDouble());
            } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
                ret = Nd4j.valueArrayOf((int[])retShape, (double)op.zeroFloat());
            } else if (op.x().data().dataType() == DataBuffer.Type.HALF) {
                ret = Nd4j.valueArrayOf((int[])retShape, (double)op.zeroHalf());
            }
            op.setZ(ret);
        } else {
            if (op.z().lengthLong() != ArrayUtil.prodLong((int[])retShape)) {
                throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]");
            }
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                op.z().assign((Number)op.zeroDouble());
            } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
                op.z().assign((Number)Float.valueOf(op.zeroFloat()));
            } else if (op.x().data().dataType() == DataBuffer.Type.HALF) {
                op.z().assign((Number)Float.valueOf(op.zeroHalf()));
            }
            ret = op.z();
        }
        this.naiveExec(op, dimension);
        this.profilingHookOut((Op)op, st);
        return op.z();
    }

    public INDArray exec(IndexAccumulation op, int ... dimension) {
        int[] retShape;
        int[] nArray;
        int i;
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        CudaExecutioner.validateDataType((DataBuffer.Type)Nd4j.dataType(), (Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        Arrays.sort(dimension);
        for (i = 0; i < dimension.length; ++i) {
            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
        }
        for (i = 0; i < dimension.length; ++i) {
            if (dimension[i] >= 0) continue;
            int n = i;
            dimension[n] = dimension[n] + op.x().rank();
        }
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (Shape.wholeArrayDimension((int[])dimension)) {
            int[] nArray2 = new int[2];
            nArray2[0] = 1;
            nArray = nArray2;
            nArray2[1] = 1;
        } else {
            nArray = retShape = ArrayUtil.removeIndex((int[])op.x().shape(), (int[])dimension);
        }
        if (op.x().isVector() && op.x().length() == ArrayUtil.prod((int[])retShape)) {
            return op.x();
        }
        if (retShape.length == 1) {
            retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
        } else if (retShape.length == 0) {
            retShape = new int[]{1, 1};
        }
        INDArray ret = null;
        if (0.0 + Math.abs(op.zeroDouble()) <= Nd4j.EPS_THRESHOLD) {
            ret = Nd4j.zeros((int[])retShape);
        } else if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            ret = Nd4j.valueArrayOf((int[])retShape, (double)op.zeroDouble());
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            ret = Nd4j.valueArrayOf((int[])retShape, (double)op.zeroFloat());
        } else if (op.x().data().dataType() == DataBuffer.Type.HALF) {
            ret = Nd4j.valueArrayOf((int[])retShape, (double)op.zeroHalf());
        }
        op.setZ(ret);
        if (dimension.length == op.x().rank()) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.name());
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execIndexReduceDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, dimension.length);
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execIndexReduceFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, dimension.length);
        } else {
            nativeOps.execIndexReduceHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, dimension.length);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return op.z();
    }

    public Op exec(Op op, int ... dimension) {
        this.checkForCompression(op);
        Arrays.sort(dimension);
        return super.exec(op, dimension);
    }

    public Op exec(Op op) {
        this.checkForCompression(op);
        if (op.x() instanceof IComplexNDArray || this.executionMode() == OpExecutioner.ExecutionMode.JAVA || op instanceof CopyOp) {
            if (op.x() != null) {
                AtomicAllocator.getInstance().synchronizeHostData(op.x());
            }
            if (op.y() != null) {
                AtomicAllocator.getInstance().synchronizeHostData(op.y());
            }
            super.exec(op);
            if (op.z() != null) {
                AtomicAllocator.getInstance().tickHostWrite(op.z());
            }
            return null;
        }
        if (op instanceof TransformOp) {
            TransformOp t = (TransformOp)op;
            this.invoke(t);
        } else if (op instanceof Accumulation) {
            Accumulation acc = (Accumulation)op;
            this.invoke(acc, null);
        } else if (op instanceof ScalarOp) {
            ScalarOp sc = (ScalarOp)op;
            this.invoke(sc);
        } else if (op instanceof BroadcastOp) {
            BroadcastOp broadcastOp = (BroadcastOp)op;
            this.invoke(broadcastOp);
        } else if (op instanceof IndexAccumulation) {
            IndexAccumulation indexAccumulation = (IndexAccumulation)op;
            this.invoke(indexAccumulation, null);
        }
        return op;
    }

    public INDArray execAndReturn(TransformOp op) {
        this.checkForCompression((Op)op);
        this.invoke(op);
        return op.z();
    }

    protected CudaContext invoke(BroadcastOp op) {
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        CudaExecutioner.validateDataType((DataBuffer.Type)Nd4j.dataType(), (Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.name());
        }
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), op.getDimension());
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;
        Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), op.getDimension());
        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getSecond(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, devTadShapeInfoZ, devTadOffsetsZ});
        Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
        Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(op.getDimension()), context);
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execBroadcastDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)y, (IntPointer)yShapeInfo, (DoublePointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, op.getDimension().length);
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execBroadcastFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)y, (IntPointer)yShapeInfo, (FloatPointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, op.getDimension().length);
        } else {
            nativeOps.execBroadcastHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)y, (IntPointer)yShapeInfo, (ShortPointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, op.getDimension().length);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return null;
    }

    protected CudaContext invoke(IndexAccumulation op, int[] dimension) {
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        CudaExecutioner.validateDataType((DataBuffer.Type)Nd4j.dataType(), (Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.name());
        }
        CudaEnvironment.getInstance().getConfiguration().enableDebug(true);
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z().isScalar() ? null : op.z(), op.x(), op.y());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        int[] fdimension = dimension;
        if (fdimension == null) {
            fdimension = new int[]{0};
        }
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), fdimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        if (op.z().isScalar() || dimension == null || dimension[0] == Integer.MAX_VALUE) {
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                double result = nativeOps.execIndexReduceScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs);
                op.setFinalResult((int)result);
                op.z().assign((Number)result);
            } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
                float result = nativeOps.execIndexReduceScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs);
                op.setFinalResult((int)result);
                op.z().assign((Number)Float.valueOf(result));
            } else {
                float result = nativeOps.execIndexReduceScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs);
                op.setFinalResult((int)result);
                op.z().assign((Number)Float.valueOf(result));
            }
        } else {
            Arrays.sort(dimension);
            Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
            Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
            Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                nativeOps.execIndexReduceDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, dimension.length);
            } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
                nativeOps.execIndexReduceFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, dimension.length);
            } else {
                nativeOps.execIndexReduceHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)z, (IntPointer)zShapeInfo, (IntPointer)dimensionPointer, dimension.length);
            }
        }
        AtomicAllocator.getInstance().registerAction(context, null, op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return null;
    }

    protected CudaContext invoke(Accumulation op, int[] dimension) {
        int[] retShape;
        int[] nArray;
        Pointer extraArgs;
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        CudaExecutioner.validateDataType((DataBuffer.Type)Nd4j.dataType(), (Op)op);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (dimension == null) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        Arrays.sort(dimension);
        for (int i = 0; i < dimension.length; ++i) {
            if (dimension[i] < op.x().rank() || dimension[i] == Integer.MAX_VALUE) continue;
            throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.name());
        }
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets});
        if (op.y() != null) {
            Pair yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension);
            Pointer yDevTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)yTadBuffers.getFirst(), context);
            DataBuffer yOffsets = (DataBuffer)yTadBuffers.getSecond();
            Pointer yDevTadOffsets = yOffsets == null ? null : AtomicAllocator.getInstance().getPointer(yOffsets, context);
            xShapeInfoHostPointer.put(12L, yDevTadShapeInfo);
            xShapeInfoHostPointer.put(13L, yDevTadOffsets);
        }
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer pointer = extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        if (Shape.wholeArrayDimension((int[])dimension)) {
            int[] nArray2 = new int[2];
            nArray2[0] = 1;
            nArray = nArray2;
            nArray2[1] = 1;
        } else {
            nArray = retShape = ArrayUtil.removeIndex((int[])op.x().shape(), (int[])dimension);
        }
        if (retShape.length == 1) {
            retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
        } else if (retShape.length == 0) {
            retShape = new int[]{1, 1};
        }
        if (op.x().isVector() && op.x().length() == ArrayUtil.prod((int[])retShape)) {
            return null;
        }
        INDArray ret = null;
        if (0.0 + Math.abs(op.zeroDouble()) <= Nd4j.EPS_THRESHOLD) {
            ret = Nd4j.zeros((int[])retShape);
        } else if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            ret = Nd4j.valueArrayOf((int[])retShape, (double)op.zeroDouble());
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            ret = Nd4j.valueArrayOf((int[])retShape, (double)op.zeroFloat());
        } else if (op.x().data().dataType() == DataBuffer.Type.HALF) {
            ret = Nd4j.valueArrayOf((int[])retShape, (double)op.zeroHalf());
        }
        op.setZ(ret);
        if (op.z().isScalar()) {
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (op instanceof Variance) {
                    double result = nativeOps.execSummaryStatsScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, ((Variance)op).isBiasCorrected());
                    op.setFinalResult((Number)result);
                } else if (op.y() != null) {
                    Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                    Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                    double result = nativeOps.execReduce3ScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)y, (IntPointer)yShapeInfo);
                    op.setFinalResult((Number)result);
                } else {
                    double result = nativeOps.execReduceScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs);
                    op.setFinalResult((Number)result);
                }
            } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
                if (op instanceof Variance) {
                    float result = nativeOps.execSummaryStatsScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, ((Variance)op).isBiasCorrected());
                    op.setFinalResult((Number)Float.valueOf(result));
                } else if (op.y() != null) {
                    Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                    Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                    float result = nativeOps.execReduce3ScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)y, (IntPointer)yShapeInfo);
                    op.setFinalResult((Number)Float.valueOf(result));
                } else {
                    float result = nativeOps.execReduceScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs);
                    op.setFinalResult((Number)Float.valueOf(result));
                }
            } else if (op instanceof Variance) {
                float result = nativeOps.execSummaryStatsScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, ((Variance)op).isBiasCorrected());
                op.setFinalResult((Number)Float.valueOf(result));
            } else if (op.y() != null) {
                Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                float result = nativeOps.execReduce3ScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)y, (IntPointer)yShapeInfo);
                op.setFinalResult((Number)Float.valueOf(result));
            } else {
                float result = nativeOps.execReduceScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs);
                op.setFinalResult((Number)Float.valueOf(result));
            }
        } else {
            Pointer result = AtomicAllocator.getInstance().getPointer(op.z(), context);
            Pointer resultShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
            Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (op.y() != null) {
                    Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                    Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                    nativeOps.execReduce3Double(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)y, (IntPointer)yShapeInfo, (DoublePointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length);
                } else if (op instanceof Variance) {
                    nativeOps.execSummaryStatsDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
                } else {
                    nativeOps.execReduceDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)extraArgs, (DoublePointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length);
                }
            } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
                if (op.y() != null) {
                    Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                    Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                    nativeOps.execReduce3Float(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)y, (IntPointer)yShapeInfo, (FloatPointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length);
                } else if (op instanceof Variance) {
                    nativeOps.execSummaryStatsFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
                } else {
                    nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)extraArgs, (FloatPointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length);
                }
            } else if (op.y() != null) {
                Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
                Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
                nativeOps.execReduce3Half(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)y, (IntPointer)yShapeInfo, (ShortPointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length);
            } else if (op instanceof Variance) {
                nativeOps.execSummaryStatsHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length, ((Variance)op).isBiasCorrected());
            } else {
                nativeOps.execReduceHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)extraArgs, (ShortPointer)result, (IntPointer)resultShapeInfo, (IntPointer)dimensionPointer, dimension.length);
            }
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return context;
    }

    protected CudaContext intercept(ScalarOp op, int[] dimension) {
        long st = this.profilingHookIn((Op)op);
        Arrays.sort(dimension);
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.name());
        }
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
        DataBuffer offsets = (DataBuffer)tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;
        Pair tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), dimension);
        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffersZ.getSecond(), context);
        PointerPointer extraPointers = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, devTadShapeInfoZ, devTadOffsetsZ});
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
        if (op.x().data().dataType() == DataBuffer.Type.HALF) {
            nativeOps.execScalarHalf(extraPointers, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)z, (IntPointer)zShapeInfo, (ShortPointer)y, (ShortPointer)extraArgs, (IntPointer)dimensionPointer, dimension.length);
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execScalarFloat(extraPointers, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)z, (IntPointer)zShapeInfo, (FloatPointer)y, (FloatPointer)extraArgs, (IntPointer)dimensionPointer, dimension.length);
        } else if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execScalarDouble(extraPointers, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)z, (IntPointer)zShapeInfo, (DoublePointer)y, (DoublePointer)extraArgs, (IntPointer)dimensionPointer, dimension.length);
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return null;
    }

    protected CudaContext invoke(ScalarOp op) {
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        CudaExecutioner.validateDataType((DataBuffer.Type)Nd4j.dataType(), (Op)op);
        if (op.x().length() != op.z().length()) {
            throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: [" + Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "] != [" + Arrays.toString(op.z().shapeInfoDataBuffer().asInt()) + "]");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.name());
        }
        if (op.getDimension() != null) {
            this.intercept(op, op.getDimension());
            return null;
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, null, null});
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (op.x().elementWiseStride() >= 1 && op.z().ordering() == op.x().ordering()) {
                nativeOps.execScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, op.x().elementWiseStride(), (DoublePointer)z, op.z().elementWiseStride(), op.scalar().doubleValue(), (DoublePointer)extraArgs, op.n());
            } else {
                nativeOps.execScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)z, (IntPointer)zShapeInfo, op.scalar().doubleValue(), (DoublePointer)extraArgs);
            }
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            if (op.x().elementWiseStride() >= 1 && op.z().ordering() == op.x().ordering()) {
                nativeOps.execScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, op.x().elementWiseStride(), (FloatPointer)z, op.z().elementWiseStride(), op.scalar().floatValue(), (FloatPointer)extraArgs, op.n());
            } else {
                nativeOps.execScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)z, (IntPointer)zShapeInfo, op.scalar().floatValue(), (FloatPointer)extraArgs);
            }
        } else if (op.x().elementWiseStride() >= 1 && op.z().ordering() == op.x().ordering()) {
            nativeOps.execScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, op.x().elementWiseStride(), (ShortPointer)z, op.z().elementWiseStride(), op.scalar().floatValue(), (ShortPointer)extraArgs, op.n());
        } else {
            nativeOps.execScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)z, (IntPointer)zShapeInfo, op.scalar().floatValue(), (ShortPointer)extraArgs);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return null;
    }

    protected CudaContext invoke(TransformOp op) {
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        CudaExecutioner.validateDataType((DataBuffer.Type)Nd4j.dataType(), (Op)op);
        AtomicAllocator allocator = AtomicAllocator.getInstance();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        CudaContext context = allocator.getFlowController().prepareAction(op.z(), op.x(), op.y());
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.name());
        }
        INDArray ret = null;
        Pointer x = allocator.getPointer(op.x(), context);
        Pointer xShapeInfo = allocator.getPointer(op.x().shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? allocator.getPointer(op.extraArgsDataBuff(), context) : null;
        Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
        Pointer dimensionDevPointer = null;
        Pointer dimensionHostPointer = null;
        Pointer retPointer = null;
        int[] dimension = null;
        if (op.opNum() == 41 && op.extraArgs() != null) {
            int[] retShape;
            int[] nArray;
            int i;
            dimension = new int[((Integer)op.extraArgs()[0]).intValue()];
            for (i = 0; i < dimension.length; ++i) {
                dimension[i] = (Integer)op.extraArgs()[i + 1];
            }
            for (i = 0; i < dimension.length; ++i) {
                if (dimension[i] >= 0) continue;
                int n = i;
                dimension[n] = dimension[n] + op.x().rank();
            }
            if (dimension.length == op.x().rank()) {
                dimension = new int[]{Integer.MAX_VALUE};
            }
            if (Shape.wholeArrayDimension((int[])dimension)) {
                int[] nArray2 = new int[2];
                nArray2[0] = 1;
                nArray = nArray2;
                nArray2[1] = 1;
            } else {
                nArray = retShape = ArrayUtil.removeIndex((int[])op.x().shape(), (int[])dimension);
            }
            if (retShape.length == 1) {
                retShape = dimension[0] == 0 ? new int[]{1, retShape[0]} : new int[]{retShape[0], 1};
            } else if (retShape.length == 0) {
                retShape = new int[]{1, 1};
            }
            ret = Nd4j.zeros((int[])retShape);
            hostYShapeInfo = allocator.getPointer(ret.shapeInfoDataBuffer(), context);
            DataBuffer dimensionBuffer = allocator.getConstantBuffer(dimension);
            dimensionDevPointer = allocator.getPointer(dimensionBuffer, context);
            dimensionHostPointer = allocator.getHostPointer(dimensionBuffer);
            retPointer = allocator.getPointer(ret, context);
        }
        Pointer hostTadShapeInfo = null;
        Pointer devTadShapeInfo = null;
        Pointer hostMaxTadShapeInfo = null;
        Pointer devMaxTadShapeInfo = null;
        Pointer devTadOffsets = null;
        Pointer devMaxTadOffsets = null;
        if (op.opNum() >= 38 && op.opNum() <= 41) {
            DataBuffer offsets;
            if (op.opNum() != 41) {
                Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), new int[]{0});
                Pair tadMaxBuffers = tadManager.getTADOnlyShapeInfo(op.x(), new int[]{1});
                hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
                devTadShapeInfo = allocator.getPointer((DataBuffer)tadBuffers.getFirst(), context);
                hostMaxTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadMaxBuffers.getFirst());
                devMaxTadShapeInfo = allocator.getPointer((DataBuffer)tadMaxBuffers.getFirst(), context);
                offsets = (DataBuffer)tadBuffers.getSecond();
                devTadOffsets = offsets == null ? null : allocator.getPointer(offsets, context);
                DataBuffer maxOffsets = (DataBuffer)tadMaxBuffers.getSecond();
                devMaxTadOffsets = maxOffsets == null ? null : allocator.getPointer(maxOffsets, context);
            } else {
                Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.z(), dimension);
                hostTadShapeInfo = AddressRetriever.retrieveHostPointer((DataBuffer)tadBuffers.getFirst());
                devTadShapeInfo = AtomicAllocator.getInstance().getPointer((DataBuffer)tadBuffers.getFirst(), context);
                offsets = (DataBuffer)tadBuffers.getSecond();
                devTadOffsets = offsets == null ? null : allocator.getPointer(offsets, context);
            }
        }
        Pointer z = allocator.getPointer(op.z(), context);
        Pointer zShapeInfo = allocator.getPointer(op.z().shapeInfoDataBuffer(), context);
        PointerPointer xShapeInfoHostPointer = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), allocator.getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, hostMaxTadShapeInfo, devMaxTadShapeInfo, devMaxTadOffsets, dimensionDevPointer, dimensionHostPointer, retPointer, new CudaPointer(dimension == null ? 0L : (long)dimension.length)});
        if (op.opNum() == 71) {
            this.extraz.get().put(10L, ((Pooling2D)op).getIm2colShape().addressPointer());
        }
        if (op.y() != null) {
            Pointer y = allocator.getPointer(op.y(), context);
            Pointer yShapeInfo = allocator.getPointer(op.y().shapeInfoDataBuffer(), context);
            int xEWS = op.x().elementWiseStride();
            int yEWS = op.y().elementWiseStride();
            int zEWS = op.z().elementWiseStride();
            boolean xRow = op.x().isRowVector();
            boolean yRow = op.y().isRowVector();
            boolean zRow = op.z().isRowVector();
            if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (xEWS >= 1 && yEWS >= 1 && zEWS >= 1 && !op.isExecSpecial() && op.x().ordering() == op.y().ordering() && op.x().ordering() == op.z().ordering() || xEWS >= 1 && yEWS == xEWS && zEWS == xEWS && xRow && yRow && zRow) {
                    nativeOps.execPairwiseTransformDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, xEWS, (DoublePointer)y, yEWS, (DoublePointer)z, zEWS, (DoublePointer)extraArgs, op.n());
                } else {
                    nativeOps.execPairwiseTransformDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)y, (IntPointer)yShapeInfo, (DoublePointer)z, (IntPointer)zShapeInfo, (DoublePointer)extraArgs);
                }
            } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
                if (xEWS >= 1 && yEWS >= 1 && xEWS == yEWS && !op.isExecSpecial() && op.x().ordering() == op.y().ordering() && op.x().ordering() == op.z().ordering() || xEWS >= 1 && yEWS == xEWS && zEWS == xEWS && xRow && yRow && zRow) {
                    nativeOps.execPairwiseTransformFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, xEWS, (FloatPointer)y, yEWS, (FloatPointer)z, zEWS, (FloatPointer)extraArgs, op.n());
                } else {
                    nativeOps.execPairwiseTransformFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)y, (IntPointer)yShapeInfo, (FloatPointer)z, (IntPointer)zShapeInfo, (FloatPointer)extraArgs);
                }
            } else if (xEWS >= 1 && yEWS >= 1 && xEWS == op.y().elementWiseStride() && !op.isExecSpecial() && op.x().ordering() == op.y().ordering() && op.x().ordering() == op.z().ordering() || xEWS >= 1 && yEWS == xEWS && zEWS == xEWS && xRow && yRow && zRow) {
                nativeOps.execPairwiseTransformHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, xEWS, (ShortPointer)y, yEWS, (ShortPointer)z, zEWS, (ShortPointer)extraArgs, op.n());
            } else {
                nativeOps.execPairwiseTransformHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)y, (IntPointer)yShapeInfo, (ShortPointer)z, (IntPointer)zShapeInfo, (ShortPointer)extraArgs);
            }
        } else if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.z().ordering() == op.x().ordering()) {
                nativeOps.execTransformDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, op.x().elementWiseStride(), (DoublePointer)z, op.z().elementWiseStride(), (DoublePointer)extraArgs, op.n());
            } else {
                nativeOps.execTransformDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer)x, (IntPointer)xShapeInfo, (DoublePointer)z, (IntPointer)zShapeInfo, (DoublePointer)extraArgs);
            }
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.z().ordering() == op.x().ordering()) {
                nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, op.x().elementWiseStride(), (FloatPointer)z, op.z().elementWiseStride(), (FloatPointer)extraArgs, op.n());
            } else {
                nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer)x, (IntPointer)xShapeInfo, (FloatPointer)z, (IntPointer)zShapeInfo, (FloatPointer)extraArgs);
            }
        } else if (op.x().elementWiseStride() >= 1 && !op.isExecSpecial() && op.z().ordering() == op.x().ordering()) {
            nativeOps.execTransformHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, op.x().elementWiseStride(), (ShortPointer)z, op.z().elementWiseStride(), (ShortPointer)extraArgs, op.n());
        } else {
            nativeOps.execTransformHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer)x, (IntPointer)xShapeInfo, (ShortPointer)z, (IntPointer)zShapeInfo, (ShortPointer)extraArgs);
        }
        AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
        if (extraArgs != null) {
            extraArgs.address();
        }
        if (ret != null) {
            ret.elementWiseStride();
        }
        this.profilingHookOut((Op)op, st);
        return null;
    }

    protected <T extends Aggregate> DataBuffer getBuffer(Batch<T> batch) {
        DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(batch.getSample().getRequiredBatchMemorySize() * 4L, false);
        batch.setParamsSurface(buffer);
        return buffer;
    }

    public <T extends Aggregate> void exec(Batch<T> batch) {
        DataBuffer surfaceBuffer = this.getBuffer(batch);
        CudaContext context = (CudaContext)AtomicAllocator.getInstance().getDeviceContext().getContext();
        IntPointer pointer = new CudaPointer(AtomicAllocator.getInstance().getHostPointer(surfaceBuffer)).asIntPointer();
        AllocationPoint surfacePoint = AtomicAllocator.getInstance().getAllocationPoint(surfaceBuffer);
        int maxTypes = 5;
        int maxIntArrays = batch.getSample().maxIntArrays();
        int maxArraySize = batch.getSample().maxIntArraySize();
        int indexPos = maxTypes * (Batch.getBatchLimit() * 16);
        int intArraysPos = indexPos + batch.getSample().maxIndexArguments() * (Batch.getBatchLimit() * 16);
        int realPos = (intArraysPos + maxIntArrays * maxArraySize * (Batch.getBatchLimit() * 16)) / (Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 2 : 1);
        if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            realPos *= 2;
        }
        int argsPos = (realPos + batch.getSample().maxRealArguments() * (Batch.getBatchLimit() * 16)) / (Nd4j.dataType() == DataBuffer.Type.FLOAT ? 2 : 1);
        if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            argsPos /= 4;
        }
        int shapesPos = argsPos + batch.getSample().maxArguments() * (Batch.getBatchLimit() * 16);
        for (int i = 0; i < batch.getNumAggregates(); ++i) {
            int e;
            Aggregate op = (Aggregate)batch.getAggregates().get(i);
            int idx = i * maxTypes;
            pointer.put((long)idx, op.getArguments().size());
            pointer.put((long)(idx + 1), op.getShapes().size());
            pointer.put((long)(idx + 2), op.getIndexingArguments().size());
            pointer.put((long)(idx + 3), op.getRealArguments().size());
            pointer.put((long)(idx + 4), op.getIntArrayArguments().size());
            for (int e2 = 0; e2 < op.getIndexingArguments().size(); ++e2) {
                idx = indexPos + i * batch.getSample().maxIndexArguments();
                pointer.put((long)(idx + e2), ((Integer)op.getIndexingArguments().get(e2)).intValue());
            }
            int bsize = maxIntArrays * maxArraySize;
            for (int e3 = 0; e3 < op.getIntArrayArguments().size(); ++e3) {
                int step = i * bsize + e3 * maxArraySize;
                if (op.getIntArrayArguments().get(e3) == null) continue;
                for (int x = 0; x < ((int[])op.getIntArrayArguments().get(e3)).length; ++x) {
                    idx = intArraysPos + step + x;
                    pointer.put((long)idx, ((int[])op.getIntArrayArguments().get(e3))[x]);
                }
            }
            if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
                FloatPointer realPtr = new FloatPointer((Pointer)pointer);
                for (e = 0; e < op.getRealArguments().size(); ++e) {
                    idx = realPos + i * op.maxRealArguments();
                    realPtr.put((long)(idx + e), ((Number)op.getRealArguments().get(e)).floatValue());
                }
            } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
                DoublePointer dPtr = new DoublePointer((Pointer)pointer);
                for (e = 0; e < op.getRealArguments().size(); ++e) {
                    idx = realPos + i * op.maxRealArguments();
                    dPtr.put((long)(idx + e), ((Number)op.getRealArguments().get(e)).doubleValue());
                }
            } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
                ShortPointer sPtr = new ShortPointer((Pointer)pointer);
                for (e = 0; e < op.getRealArguments().size(); ++e) {
                    idx = realPos + i * op.maxRealArguments();
                    sPtr.put((long)(idx + e), BaseDataBuffer.fromFloat((float)((Number)op.getRealArguments().get(e)).floatValue()));
                }
            }
            PointerPointer ptrPtr = new PointerPointer((Pointer)pointer);
            for (e = 0; e < op.getArguments().size(); ++e) {
                idx = argsPos + i * batch.getSample().maxArguments();
                if (op.getArguments().get(e) == null) continue;
                ptrPtr.put((long)(idx + e), AtomicAllocator.getInstance().getPointer((INDArray)op.getArguments().get(e), context));
                AtomicAllocator.getInstance().getAllocationPoint((INDArray)op.getArguments().get(e)).tickDeviceWrite();
            }
            for (e = 0; e < op.getShapes().size(); ++e) {
                idx = shapesPos + i * batch.getSample().maxShapes();
                if (op.getShapes().get(e) == null) continue;
                ptrPtr.put((long)(idx + e), AtomicAllocator.getInstance().getPointer((DataBuffer)op.getShapes().get(e), context));
                AtomicAllocator.getInstance().getAllocationPoint((DataBuffer)op.getShapes().get(e)).tickDeviceWrite();
            }
        }
        surfacePoint.tickHostWrite();
        PointerPointer extraArgs = new PointerPointer(32L);
        extraArgs.put(0L, null);
        extraArgs.put(1L, (Pointer)context.getOldStream());
        extraArgs.put(2L, (Pointer)new CudaPointer(Math.min(batch.getNumAggregates(), CudaEnvironment.getInstance().getConfiguration().getMaximumGridSize())));
        extraArgs.put(3L, (Pointer)new CudaPointer(batch.getSample().getThreadsPerInstance()));
        extraArgs.put(4L, (Pointer)new CudaPointer(batch.getSample().getSharedMemorySize()));
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execAggregateBatchFloat(extraArgs, batch.getNumAggregates(), batch.opNum(), batch.getSample().maxArguments(), batch.getSample().maxShapes(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), AtomicAllocator.getInstance().getPointer(surfaceBuffer, context));
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execAggregateBatchDouble(extraArgs, batch.getNumAggregates(), batch.opNum(), batch.getSample().maxArguments(), batch.getSample().maxShapes(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), AtomicAllocator.getInstance().getPointer(surfaceBuffer, context));
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            nativeOps.execAggregateBatchHalf(extraArgs, batch.getNumAggregates(), batch.opNum(), batch.getSample().maxArguments(), batch.getSample().maxShapes(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), AtomicAllocator.getInstance().getPointer(surfaceBuffer, context));
        }
        surfacePoint.tickHostWrite();
    }

    public void exec(List<Aggregate> batch) {
        if (batch.size() == 0) {
            return;
        }
        List batches = Batch.getBatches(batch, (int)8192);
        for (Batch single : batches) {
            this.exec(single);
        }
        CudaContext context = (CudaContext)AtomicAllocator.getInstance().getDeviceContext().getContext();
        context.syncOldStream();
    }

    public void exec(Aggregate op) {
        int numArguments = op.getArguments().size();
        int numShapeArguments = op.getShapes().size();
        int numIndexArguments = op.getIndexingArguments().size();
        int numIntArrays = op.getIntArrayArguments().size();
        int numRealArguments = op.getRealArguments().size();
        CudaContext context = (CudaContext)AtomicAllocator.getInstance().getDeviceContext().getContext();
        PointerPointer extraArgs = new PointerPointer(32L);
        extraArgs.put(0L, null);
        extraArgs.put(1L, (Pointer)context.getOldStream());
        extraArgs.put(2L, (Pointer)new CudaPointer(1L));
        extraArgs.put(3L, (Pointer)new CudaPointer(op.getThreadsPerInstance()));
        extraArgs.put(4L, (Pointer)new CudaPointer(op.getSharedMemorySize()));
        long[] arguments = new long[numArguments];
        for (int x = 0; x < numArguments; ++x) {
            long l = arguments[x] = op.getArguments().get(x) == null ? 0L : AtomicAllocator.getInstance().getPointer((INDArray)op.getArguments().get(x), context).address();
            if (op.getArguments().get(x) == null) continue;
            AtomicAllocator.getInstance().getAllocationPoint((INDArray)op.getArguments().get(x)).tickDeviceWrite();
        }
        DataBuffer tempX = AllocationUtils.getPointersBuffer(arguments);
        PointerPointer xPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context));
        long[] shapes = new long[numShapeArguments];
        for (int x = 0; x < numShapeArguments; ++x) {
            long l = shapes[x] = op.getShapes().get(x) == null ? 0L : AtomicAllocator.getInstance().getPointer((DataBuffer)op.getShapes().get(x), context).address();
            if (op.getShapes().get(x) == null) continue;
            AtomicAllocator.getInstance().getAllocationPoint((DataBuffer)op.getShapes().get(x)).tickDeviceWrite();
        }
        DataBuffer tempS = AllocationUtils.getPointersBuffer(shapes);
        PointerPointer sPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempS, context));
        long[] ints = new long[numIntArrays];
        for (int x = 0; x < numIntArrays; ++x) {
            if (op.getIntArrayArguments().get(x) == null) continue;
            DataBuffer intBuf = Nd4j.getDataBufferFactory().createInt((int[])op.getIntArrayArguments().get(x));
            ints[x] = AtomicAllocator.getInstance().getPointer(intBuf, context).address();
        }
        DataBuffer tempI = AllocationUtils.getPointersBuffer(ints);
        PointerPointer iPtr = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempI, context));
        int[] indexes = new int[numIndexArguments];
        for (int x = 0; x < numIndexArguments; ++x) {
            indexes[x] = (Integer)op.getIndexingArguments().get(x);
        }
        DataBuffer intBuffer = Nd4j.getDataBufferFactory().createInt(indexes);
        double[] reals = new double[numRealArguments];
        for (int x = 0; x < numRealArguments; ++x) {
            reals[x] = ((Number)op.getRealArguments().get(x)).doubleValue();
        }
        INDArray realsBuffer = Nd4j.create((double[])reals);
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execAggregateFloat(extraArgs, op.opNum(), xPtr, numArguments, sPtr, numShapeArguments, (IntPointer)AtomicAllocator.getInstance().getPointer(intBuffer, context), numIndexArguments, iPtr, numIntArrays, (FloatPointer)AtomicAllocator.getInstance().getPointer(realsBuffer.data(), context), numRealArguments);
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execAggregateDouble(extraArgs, op.opNum(), xPtr, numArguments, sPtr, numShapeArguments, (IntPointer)AtomicAllocator.getInstance().getPointer(intBuffer, context), numIndexArguments, iPtr, numIntArrays, (DoublePointer)AtomicAllocator.getInstance().getPointer(realsBuffer.data(), context), numRealArguments);
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            nativeOps.execAggregateHalf(extraArgs, op.opNum(), xPtr, numArguments, sPtr, numShapeArguments, (IntPointer)AtomicAllocator.getInstance().getPointer(intBuffer, context), numIndexArguments, iPtr, numIntArrays, (ShortPointer)AtomicAllocator.getInstance().getPointer(realsBuffer.data(), context), numRealArguments);
        }
    }

    public INDArray exec(RandomOp op) {
        return this.exec(op, Nd4j.getRandom());
    }

    public INDArray exec(RandomOp op, Random rng) {
        long st = this.profilingHookIn((Op)op);
        this.checkForCompression((Op)op);
        CudaExecutioner.validateDataType((DataBuffer.Type)Nd4j.dataType(), (Op)op);
        if (rng.getStateBuffer() == null) {
            throw new IllegalStateException("You should use one of NativeRandom classes for NativeOperations execution");
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            this.lastOp.set(op.name());
        }
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
        PointerPointer extraZZ = this.extraz.get().put(new Pointer[]{AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()});
        if (op.x() != null && op.y() != null && op.z() != null) {
            if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
                nativeOps.execRandomFloat(extraZZ, op.opNum(), rng.getStatePointer(), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.x(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
            } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
                nativeOps.execRandomDouble(extraZZ, op.opNum(), rng.getStatePointer(), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.x(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
            } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
                nativeOps.execRandomHalf(extraZZ, op.opNum(), rng.getStatePointer(), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.x(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.y(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
            }
        } else if (op.x() != null && op.z() != null) {
            if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
                nativeOps.execRandomFloat(extraZZ, op.opNum(), rng.getStatePointer(), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.x(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
            } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
                nativeOps.execRandomDouble(extraZZ, op.opNum(), rng.getStatePointer(), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.x(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
            } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
                nativeOps.execRandomHalf(extraZZ, op.opNum(), rng.getStatePointer(), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.x(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
            }
        } else if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.execRandomFloat(extraZZ, op.opNum(), rng.getStatePointer(), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (FloatPointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execRandomDouble(extraZZ, op.opNum(), rng.getStatePointer(), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (DoublePointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            nativeOps.execRandomHalf(extraZZ, op.opNum(), rng.getStatePointer(), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.z(), context), (IntPointer)AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (ShortPointer)AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context));
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y());
        this.profilingHookOut((Op)op, st);
        return op.z();
    }

    public synchronized Properties getEnvironmentInformation() {
        if (this.properties == null) {
            Properties props = super.getEnvironmentInformation();
            ArrayList devicesList = new ArrayList();
            for (int i = 0; i < nativeOps.getAvailableDevices(); ++i) {
                HashMap<String, Object> deviceProps = new HashMap<String, Object>();
                CudaPointer devPtr = new CudaPointer(i);
                deviceProps.put("cuda.deviceName", nativeOps.getDeviceName((Pointer)devPtr));
                deviceProps.put("cuda.freeMemory", nativeOps.getDeviceFreeMemory((Pointer)devPtr));
                deviceProps.put("cuda.totalMemory", nativeOps.getDeviceTotalMemory((Pointer)devPtr));
                deviceProps.put("cuda.deviceMajor", Long.valueOf(nativeOps.getDeviceMajor((Pointer)devPtr)));
                deviceProps.put("cuda.deviceMinor", Long.valueOf(nativeOps.getDeviceMinor((Pointer)devPtr)));
                devicesList.add(i, deviceProps);
            }
            props.put("backend", "CUDA");
            props.put("cuda.availableDevices", (Object)nativeOps.getAvailableDevices());
            props.put("cuda.devicesInformation", devicesList);
            props.put("blas.vendor", Blas.Vendor.CUBLAS.toString());
            props.put("memory.free", (Object)(Pointer.maxBytes() - Pointer.totalBytes()));
            this.properties = props;
        } else {
            List devicesList = (List)this.properties.get("cuda.devicesInformation");
            for (int i = 0; i < nativeOps.getAvailableDevices(); ++i) {
                Map dev = (Map)devicesList.get(i);
                CudaPointer devPtr = new CudaPointer(i);
                dev.put("cuda.freeMemory", nativeOps.getDeviceFreeMemory((Pointer)devPtr));
                dev.put("cuda.totalMemory", nativeOps.getDeviceTotalMemory((Pointer)devPtr));
            }
            this.properties.put("cuda.devicesInformation", devicesList);
            this.properties.put("memory.free", (Object)(Pointer.maxBytes() - Pointer.totalBytes()));
        }
        return this.properties;
    }

    public TADManager getTADManager() {
        return tadManager;
    }

    public void printEnvironmentInformation() {
        super.printEnvironmentInformation();
        Properties env = this.getEnvironmentInformation();
        List devicesList = (List)env.get("cuda.devicesInformation");
        for (Map dev : devicesList) {
            log.info("Device name: [{}]; CC: [{}.{}]; Total/free memory: [{}]", new Object[]{dev.get("cuda.deviceName"), dev.get("cuda.deviceMajor"), dev.get("cuda.deviceMinor"), dev.get("cuda.totalMemory")});
        }
    }

    public void commit() {
        ((CudaContext)AtomicAllocator.getInstance().getDeviceContext().getContext()).syncOldStream();
    }

    public INDArray thresholdEncode(INDArray input, double threshold, Integer boundary) {
        int numPrefixBlocks;
        DataBuffer tempX;
        int numPrefixBlocks2;
        DataBuffer blocksBuffer;
        DataBuffer buffer = input.data();
        int numThreads = 1024;
        int numBlocks = (int)(buffer.length() / (long)numThreads + (long)(buffer.length() % (long)numThreads == 0L ? 0 : 1));
        CudaContext context = (CudaContext)AtomicAllocator.getInstance().getDeviceContext().getContext();
        DataBuffer dataBuffer = blocksBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt((long)(numBlocks + 1), true) : Nd4j.getDataBufferFactory().createInt((long)(numBlocks + 1), true, Nd4j.getMemoryManager().getCurrentWorkspace());
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer extras = this.extraz.get().put(1L, (Pointer)context.getOldStream());
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP1Float(extras, (FloatPointer)AtomicAllocator.getInstance().getPointer(buffer), buffer.length(), (IntPointer)AtomicAllocator.getInstance().getPointer(blocksBuffer), (float)threshold);
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP1Double(extras, (DoublePointer)AtomicAllocator.getInstance().getPointer(buffer), buffer.length(), (IntPointer)AtomicAllocator.getInstance().getPointer(blocksBuffer), (float)threshold);
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP1Half(extras, (ShortPointer)AtomicAllocator.getInstance().getPointer(buffer), buffer.length(), (IntPointer)AtomicAllocator.getInstance().getPointer(blocksBuffer), (float)threshold);
        }
        AtomicAllocator.getInstance().getAllocationPoint(blocksBuffer).tickDeviceWrite();
        int numMatches = blocksBuffer.getInt(0L);
        if (numMatches < 2) {
            return null;
        }
        if (boundary != null && numMatches > boundary) {
            numMatches = boundary;
            blocksBuffer.put(0L, numMatches);
        }
        DataBuffer encodedBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt((long)(4 + numMatches), false) : Nd4j.getDataBufferFactory().createInt((long)(4 + numMatches), false, Nd4j.getMemoryManager().getCurrentWorkspace());
        AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickHostWrite();
        encodedBuffer.put(0L, numMatches);
        encodedBuffer.put(1L, (int)buffer.length());
        encodedBuffer.put(2L, Float.floatToIntBits((float)threshold));
        AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickHostWrite();
        encodedBuffer.put(3L, 0);
        int prefixThreads = 512;
        int numElts = numBlocks;
        int level = 0;
        ArrayList<DataBuffer> buffers = new ArrayList<DataBuffer>();
        do {
            numPrefixBlocks2 = Math.max(1, (int)Math.ceil((float)numElts / (2.0f * (float)prefixThreads)));
            if (numBlocks <= 1) continue;
            ++level;
        } while ((numElts = numPrefixBlocks2) > 1);
        long[] pointers = new long[level];
        level = 0;
        numElts = numBlocks;
        DataBuffer dataBuffer2 = tempX = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createDouble((long)pointers.length, false) : Nd4j.getDataBufferFactory().createDouble((long)pointers.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
        do {
            if ((numPrefixBlocks = Math.max(1, (int)Math.ceil((float)numElts / (2.0f * (float)prefixThreads)))) <= 1) continue;
            DataBuffer bf = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt((long)numPrefixBlocks, false) : Nd4j.getDataBufferFactory().createInt((long)numPrefixBlocks, false, Nd4j.getMemoryManager().getCurrentWorkspace());
            buffers.add(bf);
            pointers[level++] = AtomicAllocator.getInstance().getPointer(bf).address();
        } while ((numElts = numPrefixBlocks) > 1);
        AtomicAllocator.getInstance().memcpyBlocking(tempX, (Pointer)new LongPointer(pointers), pointers.length * 8, 0L);
        extras.put(2L, AtomicAllocator.getInstance().getPointer(tempX));
        DataBuffer offsetsBuffer = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt((long)numBlocks, true) : Nd4j.getDataBufferFactory().createInt((long)numBlocks, true, Nd4j.getMemoryManager().getCurrentWorkspace());
        NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP2Int(extras, (IntPointer)AtomicAllocator.getInstance().getPointer(blocksBuffer), (long)numBlocks, (IntPointer)AtomicAllocator.getInstance().getPointer(offsetsBuffer));
        AtomicAllocator.getInstance().getAllocationPoint(offsetsBuffer).tickDeviceWrite();
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP3Float(extras, (FloatPointer)AtomicAllocator.getInstance().getPointer(buffer), (IntPointer)AtomicAllocator.getInstance().getPointer(offsetsBuffer), buffer.length(), (IntPointer)AtomicAllocator.getInstance().getPointer(encodedBuffer));
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP3Double(extras, (DoublePointer)AtomicAllocator.getInstance().getPointer(buffer), (IntPointer)AtomicAllocator.getInstance().getPointer(offsetsBuffer), buffer.length(), (IntPointer)AtomicAllocator.getInstance().getPointer(encodedBuffer));
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP3Half(extras, (ShortPointer)AtomicAllocator.getInstance().getPointer(buffer), (IntPointer)AtomicAllocator.getInstance().getPointer(offsetsBuffer), buffer.length(), (IntPointer)AtomicAllocator.getInstance().getPointer(encodedBuffer));
        }
        AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickDeviceWrite();
        AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite();
        extras.address();
        tempX.address();
        buffers.getClass();
        return Nd4j.createArrayFromShapeBuffer((DataBuffer)encodedBuffer, (DataBuffer)input.shapeInfoDataBuffer());
    }

    public INDArray thresholdEncode(INDArray input, double threshold) {
        return this.thresholdEncode(input, threshold, null);
    }

    public INDArray thresholdDecode(INDArray encoded, INDArray target) {
        DataBuffer buffer = encoded.data();
        if (buffer.dataType() != DataBuffer.Type.INT) {
            throw new UnsupportedOperationException();
        }
        long compressedLength = buffer.getInt(0L);
        long originalLength = buffer.getInt(1L);
        if (target.lengthLong() != originalLength) {
            throw new ND4JIllegalStateException("originalLength [" + originalLength + "] stored in encoded array doesn't match target length [" + target.lengthLong() + "]");
        }
        DataBuffer result = target.data();
        CudaContext context = (CudaContext)AtomicAllocator.getInstance().getDeviceContext().getContext();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer extras = this.extraz.get().put(1L, (Pointer)context.getOldStream());
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.decodeThresholdFloat(extras, AtomicAllocator.getInstance().getPointer(buffer), compressedLength, (FloatPointer)AtomicAllocator.getInstance().getPointer(result));
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.decodeThresholdDouble(extras, AtomicAllocator.getInstance().getPointer(buffer), compressedLength, (DoublePointer)AtomicAllocator.getInstance().getPointer(result));
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            nativeOps.decodeThresholdHalf(extras, AtomicAllocator.getInstance().getPointer(buffer), compressedLength, (ShortPointer)AtomicAllocator.getInstance().getPointer(result));
        }
        AtomicAllocator.getInstance().getAllocationPoint(result).tickDeviceWrite();
        return target;
    }

    public long bitmapEncode(INDArray indArray, INDArray target, double threshold) {
        long length = indArray.lengthLong();
        long tLen = target.data().length();
        if (tLen != length / 16L + 5L) {
            throw new ND4JIllegalStateException("Length of target array should be " + (length / 16L + 5L));
        }
        if (target.data().dataType() != DataBuffer.Type.INT) {
            throw new ND4JIllegalStateException("Target array should have INT dataType");
        }
        DataBuffer buffer = target.data();
        buffer.put(0L, (int)length);
        buffer.put(1L, (int)length);
        buffer.put(2L, Float.floatToIntBits((float)threshold));
        buffer.put(3L, 1);
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(indArray, new INDArray[0]);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer extras = this.extraz.get().put(new Pointer[]{AtomicAllocator.getInstance().getHostPointer(indArray), context.getOldStream(), context.getBufferScalar(), context.getBufferReduction()});
        long val = 0L;
        if (indArray.data().dataType() == DataBuffer.Type.FLOAT) {
            val = nativeOps.encodeBitmapFloat(extras, (FloatPointer)AtomicAllocator.getInstance().getPointer(indArray, context), length, (IntPointer)AtomicAllocator.getInstance().getPointer(buffer, context), (float)threshold);
        } else if (indArray.data().dataType() == DataBuffer.Type.DOUBLE) {
            val = nativeOps.encodeBitmapDouble(extras, (DoublePointer)AtomicAllocator.getInstance().getPointer(indArray, context), length, (IntPointer)AtomicAllocator.getInstance().getPointer(buffer, context), (float)threshold);
        } else if (indArray.data().dataType() == DataBuffer.Type.HALF) {
            val = nativeOps.encodeBitmapHalf(extras, (ShortPointer)AtomicAllocator.getInstance().getPointer(indArray, context), length, (IntPointer)AtomicAllocator.getInstance().getPointer(buffer, context), (float)threshold);
        } else {
            throw new ND4JIllegalStateException("Unknown dataType " + indArray.data().dataType());
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(context, indArray, new INDArray[0]);
        AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite();
        return val;
    }

    public INDArray bitmapDecode(INDArray encoded, INDArray target) {
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(target, new INDArray[0]);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer extras = this.extraz.get().put(new Pointer[]{AtomicAllocator.getInstance().getHostPointer(target), context.getOldStream(), context.getBufferScalar(), context.getBufferReduction()});
        if (target.data().dataType() == DataBuffer.Type.FLOAT) {
            nativeOps.decodeBitmapFloat(extras, AtomicAllocator.getInstance().getPointer(encoded.data(), context), target.lengthLong(), (FloatPointer)AtomicAllocator.getInstance().getPointer(target, context));
        } else if (target.data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.decodeBitmapDouble(extras, AtomicAllocator.getInstance().getPointer(encoded.data(), context), target.lengthLong(), (DoublePointer)AtomicAllocator.getInstance().getPointer(target, context));
        } else if (target.data().dataType() == DataBuffer.Type.HALF) {
            nativeOps.decodeBitmapHalf(extras, AtomicAllocator.getInstance().getPointer(encoded.data(), context), target.lengthLong(), (ShortPointer)AtomicAllocator.getInstance().getPointer(target, context));
        } else {
            throw new ND4JIllegalStateException("Unknown dataType " + target.data().dataType());
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(context, target, new INDArray[0]);
        return target;
    }
}

