/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.allocator.tad;

import java.util.Arrays;
import org.apache.commons.math3.util.Pair;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BasicTADManager
implements TADManager {
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private static Logger logger = LoggerFactory.getLogger(BasicTADManager.class);

    public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray array, int[] dimension) {
        if (dimension == null || dimension.length == 0 || dimension[0] == Integer.MAX_VALUE) {
            return new Pair((Object)array.shapeInfoDataBuffer(), null);
        }
        Arrays.sort(dimension);
        int dimensionLength = dimension.length;
        int targetRank = array.rank();
        int offsetLength = 0;
        int tadLength = 1;
        for (int i = 0; i < dimensionLength; ++i) {
            tadLength *= array.shape()[dimension[i]];
        }
        offsetLength = array.length() / tadLength;
        CudaIntDataBuffer outputBuffer = new CudaIntDataBuffer(targetRank * 2 + 4);
        CudaIntDataBuffer offsetsBuffer = new CudaIntDataBuffer(offsetLength);
        DataBuffer dimensionBuffer = AtomicAllocator.getInstance().getConstantBuffer(dimension);
        Pointer dimensionPointer = AtomicAllocator.getInstance().getHostPointer(dimensionBuffer);
        Pointer xShapeInfo = AddressRetriever.retrieveHostPointer(array.shapeInfoDataBuffer());
        Pointer targetPointer = AddressRetriever.retrieveHostPointer(outputBuffer);
        Pointer offsetsPointer = AddressRetriever.retrieveHostPointer(offsetsBuffer);
        this.nativeOps.tadOnlyShapeInfo(xShapeInfo, dimensionPointer, dimensionLength, targetPointer, offsetsPointer);
        AtomicAllocator.getInstance().getAllocationPoint(outputBuffer).tickHostWrite();
        AtomicAllocator.getInstance().getAllocationPoint(offsetsBuffer).tickHostWrite();
        return new Pair((Object)outputBuffer, (Object)offsetsBuffer);
    }
}

