/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.constant;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicLong;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.constant.ConstantProtector;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.cache.ArrayDescriptor;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaFloatDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaHalfDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ProtectedCudaConstantHandler
implements ConstantHandler {
    private static ProtectedCudaConstantHandler ourInstance = new ProtectedCudaConstantHandler();
    protected Map<Integer, AtomicLong> constantOffsets = new HashMap<Integer, AtomicLong>();
    protected Map<Integer, Semaphore> deviceLocks = new ConcurrentHashMap<Integer, Semaphore>();
    protected Map<Integer, Map<ArrayDescriptor, DataBuffer>> buffersCache = new HashMap<Integer, Map<ArrayDescriptor, DataBuffer>>();
    protected Map<Integer, Pointer> deviceAddresses = new HashMap<Integer, Pointer>();
    private Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected FlowController flowController;
    protected static final ConstantProtector protector = ConstantProtector.getInstance();
    private static Logger logger = LoggerFactory.getLogger(ProtectedCudaConstantHandler.class);
    private static final int MAX_CONSTANT_LENGTH = 49152;
    private static final int MAX_BUFFER_LENGTH = 272;
    protected Semaphore lock = new Semaphore(1);
    private boolean resetHappened = false;

    public static ProtectedCudaConstantHandler getInstance() {
        return ourInstance;
    }

    private ProtectedCudaConstantHandler() {
    }

    public void purgeConstants() {
        this.buffersCache = new HashMap<Integer, Map<ArrayDescriptor, DataBuffer>>();
        protector.purgeProtector();
        this.resetHappened = true;
        logger.info("Resetting Constants...");
        for (Integer device : this.constantOffsets.keySet()) {
            this.constantOffsets.get(device).set(0L);
            this.buffersCache.put(device, new ConcurrentHashMap());
        }
    }

    protected int amountOfEntries(int deviceId) {
        this.ensureMaps(deviceId);
        return this.buffersCache.get(0).size();
    }

    public synchronized long moveToConstantSpace(DataBuffer dataBuffer) {
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        this.ensureMaps(deviceId);
        AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer);
        long requiredMemoryBytes = AllocationUtils.getRequiredMemory(point.getShape());
        long currentOffset = this.constantOffsets.get(deviceId).get();
        CudaContext context = (CudaContext)AtomicAllocator.getInstance().getDeviceContext().getContext();
        if (currentOffset + requiredMemoryBytes >= 49152L || requiredMemoryBytes > 272L) {
            if (point.getAllocationStatus() == AllocationStatus.HOST && this.configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
                AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false);
            }
            this.nativeOps.memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(), requiredMemoryBytes, 1, (Pointer)context.getSpecialStream());
            this.flowController.commitTransfer(context.getSpecialStream());
            point.setConstant(true);
            point.tickDeviceWrite();
            point.tickHostRead();
            point.setDeviceId(deviceId);
            protector.persistDataBuffer(dataBuffer);
            return 0L;
        }
        long bytes = requiredMemoryBytes;
        if (dataBuffer.dataType() == DataBuffer.Type.HALF) {
            if (bytes % 4L != 0L) {
                bytes += 2L;
            }
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            long div = bytes / 4L;
            if (div % 2L != 0L) {
                bytes += 4L;
            }
            div = currentOffset / 4L;
            while (div % 2L != 0L) {
                currentOffset = this.constantOffsets.get(deviceId).addAndGet(4L);
                div = currentOffset / 4L;
                if (currentOffset <= 49152L) continue;
            }
        }
        if ((currentOffset = this.constantOffsets.get(deviceId).getAndAdd(bytes)) >= 49152L) {
            if (point.getAllocationStatus() == AllocationStatus.HOST && this.configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
                AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false);
            }
            this.nativeOps.memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(), requiredMemoryBytes, 1, (Pointer)context.getSpecialStream());
            this.flowController.commitTransfer(context.getSpecialStream());
            point.setConstant(true);
            point.tickDeviceWrite();
            point.tickHostRead();
            point.setDeviceId(deviceId);
            protector.persistDataBuffer(dataBuffer);
            return 0L;
        }
        this.nativeOps.memcpyConstantAsync(currentOffset, point.getPointers().getHostPointer(), requiredMemoryBytes, 1, (Pointer)context.getSpecialStream());
        this.flowController.commitTransfer(context.getSpecialStream());
        long cAddr = this.deviceAddresses.get(deviceId).address() + currentOffset;
        point.setAllocationStatus(AllocationStatus.CONSTANT);
        point.getPointers().setDevicePointer(new CudaPointer(cAddr));
        point.setConstant(true);
        point.tickDeviceWrite();
        point.setDeviceId(deviceId);
        point.tickHostRead();
        protector.persistDataBuffer(dataBuffer);
        return cAddr;
    }

    public DataBuffer relocateConstantSpace(DataBuffer dataBuffer) {
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        this.ensureMaps(deviceId);
        if (dataBuffer instanceof CudaIntDataBuffer) {
            int[] data = dataBuffer.asInt();
            return this.getConstantBuffer(data);
        }
        if (dataBuffer instanceof CudaFloatDataBuffer) {
            float[] data = dataBuffer.asFloat();
            return this.getConstantBuffer(data);
        }
        if (dataBuffer instanceof CudaDoubleDataBuffer) {
            double[] data = dataBuffer.asDouble();
            return this.getConstantBuffer(data);
        }
        if (dataBuffer instanceof CudaHalfDataBuffer) {
            float[] data = dataBuffer.asFloat();
            return this.getConstantBuffer(data);
        }
        throw new IllegalStateException("Unknown CudaDataBuffer type");
    }

    private void ensureMaps(Integer deviceId) {
        if (!this.buffersCache.containsKey(deviceId)) {
            if (this.flowController == null) {
                this.flowController = AtomicAllocator.getInstance().getFlowController();
            }
            try {
                this.lock.acquire();
                if (!this.buffersCache.containsKey(deviceId)) {
                    this.buffersCache.put(deviceId, new ConcurrentHashMap());
                    this.constantOffsets.put(deviceId, new AtomicLong(0L));
                    this.deviceLocks.put(deviceId, new Semaphore(1));
                    Pointer cAddr = this.nativeOps.getConstantSpace();
                    this.deviceAddresses.put(deviceId, cAddr);
                }
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            finally {
                this.lock.release();
            }
        }
    }

    public DataBuffer getConstantBuffer(int[] array) {
        ArrayDescriptor descriptor = new ArrayDescriptor(array);
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        this.ensureMaps(deviceId);
        if (!this.buffersCache.get(deviceId).containsKey(descriptor)) {
            DataBuffer buffer = Nd4j.createBuffer((int[])array);
            buffer.setConstant(true);
            this.moveToConstantSpace(buffer);
            this.buffersCache.get(deviceId).put(descriptor, buffer);
            return buffer;
        }
        return this.buffersCache.get(deviceId).get(descriptor);
    }

    public DataBuffer getConstantBuffer(float[] array) {
        ArrayDescriptor descriptor = new ArrayDescriptor(array);
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        this.ensureMaps(deviceId);
        if (!this.buffersCache.get(deviceId).containsKey(descriptor)) {
            DataBuffer buffer = Nd4j.createBuffer((float[])array);
            buffer.setConstant(true);
            this.moveToConstantSpace(buffer);
            this.buffersCache.get(deviceId).put(descriptor, buffer);
            return buffer;
        }
        return this.buffersCache.get(deviceId).get(descriptor);
    }

    public DataBuffer getConstantBuffer(double[] array) {
        ArrayDescriptor descriptor = new ArrayDescriptor(array);
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        this.ensureMaps(deviceId);
        if (!this.buffersCache.get(deviceId).containsKey(descriptor)) {
            DataBuffer buffer = Nd4j.createBuffer((double[])array);
            buffer.setConstant(true);
            this.moveToConstantSpace(buffer);
            this.buffersCache.get(deviceId).put(descriptor, buffer);
            return buffer;
        }
        return this.buffersCache.get(deviceId).get(descriptor);
    }
}

