/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.memory.impl;

import java.util.ArrayList;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.memory.impl.CudaCachingZeroProvider;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaFullCachingProvider
extends CudaCachingZeroProvider {
    protected final long MAX_GPU_ALLOCATION;
    protected final long MAX_GPU_CACHE;
    protected volatile ConcurrentHashMap<Integer, ConcurrentHashMap<AllocationShape, CudaCachingZeroProvider.CacheHolder>> deviceCache;
    private static Logger log = LoggerFactory.getLogger(CudaFullCachingProvider.class);

    public CudaFullCachingProvider() {
        this.MAX_GPU_ALLOCATION = this.configuration.getMaximumSingleDeviceAllocation();
        this.MAX_GPU_CACHE = this.configuration.getMaximumDeviceCache();
        this.deviceCache = new ConcurrentHashMap();
        this.init();
    }

    public void init() {
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        this.deviceCachedAmount = new ArrayList();
        for (int i = 0; i < numDevices; ++i) {
            this.deviceCachedAmount.add(new AtomicLong());
        }
    }

    @Override
    public PointersPair malloc(AllocationShape shape, AllocationPoint point, AllocationStatus location) {
        long reqMemory = AllocationUtils.getRequiredMemory(shape);
        if (location == AllocationStatus.DEVICE && reqMemory < this.MAX_GPU_ALLOCATION) {
            Pointer pointer;
            int deviceId = AtomicAllocator.getInstance().getDeviceId();
            this.ensureDeviceCacheHolder(deviceId, shape);
            CudaCachingZeroProvider.CacheHolder cache = this.deviceCache.get(deviceId).get(shape);
            if (cache != null && (pointer = cache.poll()) != null) {
                this.cacheDeviceHit.incrementAndGet();
                ((AtomicLong)this.deviceCachedAmount.get(deviceId)).addAndGet(-1L * reqMemory);
                PointersPair pair = new PointersPair();
                pair.setDevicePointer(pointer);
                point.setAllocationStatus(AllocationStatus.DEVICE);
                point.setDeviceId(deviceId);
                return pair;
            }
            this.cacheDeviceMiss.incrementAndGet();
            return super.malloc(shape, point, location);
        }
        return super.malloc(shape, point, location);
    }

    @Override
    public void free(AllocationPoint point) {
        if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
            if (point.isConstant()) {
                return;
            }
            AllocationShape shape = point.getShape();
            int deviceId = point.getDeviceId();
            long address = point.getDevicePointer().address();
            long reqMemory = AllocationUtils.getRequiredMemory(shape);
            if (reqMemory > this.MAX_GPU_ALLOCATION || ((AtomicLong)this.deviceCachedAmount.get(deviceId)).get() >= this.MAX_GPU_CACHE) {
                super.free(point);
                return;
            }
            this.ensureDeviceCacheHolder(deviceId, shape);
            CudaCachingZeroProvider.CacheHolder cache = this.deviceCache.get(deviceId).get(shape);
            if (point.getDeviceId() != deviceId) {
                throw new RuntimeException("deviceId changed!");
            }
            if (reqMemory <= 96L) {
                cache.put(new CudaPointer(point.getDevicePointer().address()));
                return;
            }
            long cacheEntries = cache.size();
            long cacheHeight = this.deviceCache.get(deviceId).size();
            long cacheDepth = cacheEntries * reqMemory;
            cache.put(new CudaPointer(point.getDevicePointer().address()));
            return;
        }
        super.free(point);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void ensureDeviceCacheHolder(Integer deviceId, AllocationShape shape) {
        if (!this.deviceCache.containsKey(deviceId)) {
            try {
                this.singleLock.acquire();
                if (!this.deviceCache.containsKey(deviceId)) {
                    this.deviceCache.put(deviceId, new ConcurrentHashMap());
                }
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            finally {
                this.singleLock.release();
            }
        }
        if (!this.deviceCache.get(deviceId).containsKey(shape)) {
            try {
                this.singleLock.acquire();
                if (!this.deviceCache.get(deviceId).containsKey(shape)) {
                    this.deviceCache.get(deviceId).put(shape, new CudaCachingZeroProvider.CacheHolder(shape, (AtomicLong)this.deviceCachedAmount.get(deviceId)));
                }
            }
            catch (Exception exception) {
            }
            finally {
                this.singleLock.release();
            }
        }
    }

    @Override
    public synchronized void purgeCache() {
        for (Integer device : this.deviceCache.keySet()) {
            for (AllocationShape shape : this.deviceCache.get(device).keySet()) {
                Pointer ptr = null;
                while ((ptr = this.deviceCache.get(device).get(shape).poll()) != null) {
                    this.freeDevice(ptr, device);
                }
            }
            ((AtomicLong)this.deviceCachedAmount.get(device)).set(0L);
        }
        super.purgeCache();
    }
}

