/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.allocator.context.impl;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import org.apache.commons.lang3.RandomUtils;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.cublas;
import org.bytedeco.javacpp.cusolver;
import org.nd4j.jita.allocator.context.ContextPack;
import org.nd4j.jita.allocator.context.ContextPool;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.cuda.CUcontext;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BasicContextPool
implements ContextPool {
    private static final Logger log = LoggerFactory.getLogger(BasicContextPool.class);
    protected static final int MAX_STREAMS_PER_DEVICE = 0x7FFFFFFE;
    protected volatile Map<Integer, CUcontext> cuPool = new ConcurrentHashMap<Integer, CUcontext>();
    protected volatile Map<Integer, cublasHandle_t> cublasPool = new ConcurrentHashMap<Integer, cublasHandle_t>();
    protected volatile Map<Integer, cusolverDnHandle_t> solverPool = new ConcurrentHashMap<Integer, cusolverDnHandle_t>();
    protected volatile Map<Long, CudaContext> contextsPool = new ConcurrentHashMap<Long, CudaContext>();
    protected volatile Map<Integer, Map<Integer, CudaContext>> contextsForDevices = new ConcurrentHashMap<Integer, Map<Integer, CudaContext>>();
    protected Semaphore lock = new Semaphore(1);
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();

    public boolean containsContextForThread(long threadId) {
        return this.contextsPool.containsKey(threadId);
    }

    public CudaContext getContextForDevice(Integer deviceId) {
        return this.acquireContextForDevice(deviceId);
    }

    @Override
    public CudaContext acquireContextForDevice(Integer deviceId) {
        Long threadId = Thread.currentThread().getId();
        if (!this.contextsPool.containsKey(threadId)) {
            try {
                this.lock.acquire();
                if (!this.contextsForDevices.containsKey(deviceId)) {
                    this.contextsForDevices.put(deviceId, new ConcurrentHashMap());
                }
                if (this.contextsForDevices.get(deviceId).size() < 0x7FFFFFFE) {
                    Object handle;
                    log.debug("Creating new context...");
                    CudaContext context = this.createNewStream(deviceId);
                    this.getDeviceBuffers(context, deviceId);
                    if (this.contextsForDevices.get(deviceId).size() == 0) {
                        log.debug("Creating new cuBLAS handle for device [{}]...", (Object)deviceId);
                        cudaStream_t cublasStream = this.createNewStream(deviceId).getOldStream();
                        cublasHandle_t handle2 = this.createNewCublasHandle(cublasStream);
                        context.setHandle(handle2);
                        context.setCublasStream(cublasStream);
                        this.cublasPool.put(deviceId, handle2);
                        log.debug("Creating new cuSolver handle for device [{}]...", (Object)deviceId);
                        cudaStream_t solverStream = this.createNewStream(deviceId).getOldStream();
                        cusolverDnHandle_t solverhandle = this.createNewSolverHandle(solverStream);
                        context.setSolverHandle(solverhandle);
                        context.setSolverStream(solverStream);
                        this.solverPool.put(deviceId, solverhandle);
                    } else {
                        log.debug("Reusing blas here...");
                        handle = this.cublasPool.get(deviceId);
                        context.setHandle((cublasHandle_t)((Object)handle));
                        log.debug("Reusing solver here...");
                        cusolverDnHandle_t solverHandle = this.solverPool.get(deviceId);
                        context.setSolverHandle(solverHandle);
                    }
                    context.syncOldStream();
                    this.contextsPool.put(threadId, context);
                    this.contextsForDevices.get(deviceId).put(this.contextsForDevices.get(deviceId).size(), context);
                    handle = context;
                    return handle;
                }
                Integer rand = RandomUtils.nextInt((int)0, (int)0x7FFFFFFE);
                log.debug("Reusing context: " + rand);
                this.nativeOps.setDevice((Pointer)new CudaPointer(deviceId.intValue()));
                CudaContext context = this.contextsForDevices.get(deviceId).get(rand);
                this.contextsPool.put(threadId, context);
                CudaContext cudaContext = context;
                return cudaContext;
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            finally {
                this.lock.release();
            }
        }
        return this.contextsPool.get(threadId);
    }

    protected CudaContext createNewStream(Integer deviceId) {
        log.debug("Creating new stream for thread: [{}], device: [{}]...", (Object)Thread.currentThread().getId(), (Object)deviceId);
        this.nativeOps.setDevice((Pointer)new CudaPointer(deviceId.intValue()));
        CudaContext context = new CudaContext();
        context.initOldStream();
        return context;
    }

    protected cublasHandle_t createNewCublasHandle() {
        cublas.cublasContext pointer = new cublas.cublasContext();
        int result = cublas.cublasCreate_v2((cublas.cublasContext)pointer);
        if (result != 0) {
            throw new IllegalStateException("Can't create new cuBLAS handle! cuBLAS errorCode: [" + result + "]");
        }
        cublasHandle_t handle = new cublasHandle_t((Pointer)pointer);
        return handle;
    }

    protected cublasHandle_t createNewCublasHandle(cudaStream_t stream) {
        return this.createNewCublasHandle();
    }

    protected cusolverDnHandle_t createNewSolverHandle() {
        cusolver.cusolverDnContext pointer = new cusolver.cusolverDnContext();
        int result = cusolver.cusolverDnCreate((cusolver.cusolverDnContext)pointer);
        if (result != 0) {
            throw new IllegalStateException("Can't create new cuBLAS handle! cusolverDn errorCode: [" + result + "] from cusolverDnCreate()");
        }
        cusolverDnHandle_t handle = new cusolverDnHandle_t((Pointer)pointer);
        return handle;
    }

    protected cusolverDnHandle_t createNewSolverHandle(cudaStream_t stream) {
        return this.createNewSolverHandle();
    }

    protected CUcontext createNewContext(Integer deviceId) {
        return null;
    }

    public synchronized void resetPool(int deviceId) {
    }

    public CUcontext getCuContextForDevice(Integer deviceId) {
        return this.cuPool.get(deviceId);
    }

    protected void getDeviceBuffers(CudaContext context, int deviceId) {
        int sizeOf;
        NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
        Pointer reductionPointer = nativeOps.mallocDevice((long)(16385 * (sizeOf = 8) * 2), (Pointer)new CudaPointer(deviceId), 0);
        if (reductionPointer == null) {
            throw new IllegalStateException("Can't allocate [DEVICE] reduction buffer memory!");
        }
        nativeOps.memsetAsync(reductionPointer, 0, (long)(16385 * sizeOf * 2), 0, (Pointer)context.getOldStream());
        context.syncOldStream();
        Pointer allocationPointer = nativeOps.mallocDevice(0x100000L, (Pointer)new CudaPointer(deviceId), 0);
        if (allocationPointer == null) {
            throw new IllegalStateException("Can't allocate [DEVICE] allocation buffer memory!");
        }
        Pointer scalarPointer = nativeOps.mallocHost((long)(1 * sizeOf), 0);
        if (scalarPointer == null) {
            throw new IllegalStateException("Can't allocate [HOST] scalar buffer memory!");
        }
        context.setBufferScalar(scalarPointer);
        context.setBufferAllocation(allocationPointer);
        context.setBufferReduction(reductionPointer);
        Pointer specialPointer = nativeOps.mallocDevice((long)(0x100000 * sizeOf), (Pointer)new CudaPointer(deviceId), 0);
        if (specialPointer == null) {
            throw new IllegalStateException("Can't allocate [DEVICE] special buffer memory!");
        }
        nativeOps.memsetAsync(specialPointer, 0, (long)(65536 * sizeOf), 0, (Pointer)context.getOldStream());
        context.setBufferSpecial(specialPointer);
    }

    @Override
    public ContextPack acquireContextPackForDevice(Integer deviceId) {
        return new ContextPack(this.acquireContextForDevice(deviceId));
    }
}

