/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.concurrency;

import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.concurrency.BasicAffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaAffinityManager
extends BasicAffinityManager {
    private static final Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
    private static Logger logger = LoggerFactory.getLogger(CudaAffinityManager.class);
    private Map<Long, Integer> affinityMap = new ConcurrentHashMap<Long, Integer>();
    private AtomicInteger devPtr = new AtomicInteger(0);
    private ThreadLocal<AtomicBoolean> affiliated = new ThreadLocal();
    private AtomicInteger numberOfDevices = new AtomicInteger(-1);

    public Integer getDeviceForCurrentThread() {
        return this.getDeviceForThread(Thread.currentThread().getId());
    }

    public Integer getDeviceForThread(Thread thread) {
        return this.getDeviceForThread(thread.getId());
    }

    public Integer getDeviceForThread(long threadId) {
        if (!this.affinityMap.containsKey(threadId)) {
            Integer deviceId = this.getNextDevice(threadId);
            this.affinityMap.put(threadId, deviceId);
            this.affiliated.set(new AtomicBoolean(false));
            if (threadId == Thread.currentThread().getId()) {
                NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice((Pointer)new CudaPointer(deviceId.intValue()));
                this.affiliated.get().set(true);
            }
            return deviceId;
        }
        if (threadId == Thread.currentThread().getId()) {
            if (this.affiliated.get() == null) {
                this.affiliated.set(new AtomicBoolean(false));
            }
            if (!this.affiliated.get().get()) {
                int deviceId = this.affinityMap.get(threadId);
                NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice((Pointer)new CudaPointer(deviceId));
                this.affiliated.get().set(true);
                return deviceId;
            }
        }
        return this.affinityMap.get(threadId);
    }

    public void attachThreadToDevice(Thread thread, Integer deviceId) {
        this.attachThreadToDevice(thread.getId(), deviceId);
    }

    public void attachThreadToDevice(long threadId, Integer deviceId) {
        ArrayList<Integer> devices = new ArrayList<Integer>(configuration.getAvailableDevices());
        logger.debug("Manually mapping thread [{}] to device [{}], out of [{}] devices...", new Object[]{threadId, deviceId, devices.size()});
        this.affinityMap.put(threadId, deviceId);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected Integer getNextDevice(long threadId) {
        Integer device = null;
        if (!configuration.isForcedSingleGPU() && this.getNumberOfDevices() > 0) {
            CudaAffinityManager cudaAffinityManager = this;
            synchronized (cudaAffinityManager) {
                device = configuration.getAvailableDevices().get(this.devPtr.getAndIncrement());
                if (this.devPtr.get() >= configuration.getAvailableDevices().size()) {
                    this.devPtr.set(0);
                }
                logger.debug("Mapping thread [{}] to device [{}], out of [{}] devices...", new Object[]{threadId, device, configuration.getAvailableDevices().size()});
            }
        } else {
            device = configuration.getAvailableDevices().get(0);
            logger.debug("Single device is forced, mapping to device [{}]", (Object)device);
        }
        return device;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public int getNumberOfDevices() {
        if (this.numberOfDevices.get() < 0) {
            CudaAffinityManager cudaAffinityManager = this;
            synchronized (cudaAffinityManager) {
                if (this.numberOfDevices.get() < 1) {
                    this.numberOfDevices.set(NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices());
                }
            }
        }
        return this.numberOfDevices.get();
    }

    public void touch(INDArray array) {
        if (array == null) {
            return;
        }
        this.touch(array.data());
        this.touch(array.shapeInfoDataBuffer());
    }

    public void touch(DataBuffer buffer) {
        if (buffer == null) {
            return;
        }
        AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(buffer);
        if (point.isConstant()) {
            Nd4j.getConstantHandler().relocateConstantSpace(buffer);
        } else {
            AtomicAllocator.getInstance().getMemoryHandler().relocateObject(buffer);
        }
    }

    public synchronized INDArray replicateToDevice(Integer deviceId, INDArray array) {
        if (array == null) {
            return null;
        }
        if (array.isView()) {
            throw new UnsupportedOperationException("It's impossible to replicate View");
        }
        int[] shape = array.shape();
        int[] stride = array.stride();
        int elementWiseStride = array.elementWiseStride();
        char ordering = array.ordering();
        int length = array.length();
        AtomicAllocator.getInstance().getPointer(array, (CudaContext)AtomicAllocator.getInstance().getDeviceContext().getContext());
        int currentDeviceId = this.getDeviceForCurrentThread();
        NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice((Pointer)new CudaPointer(deviceId.intValue()));
        this.attachThreadToDevice(Thread.currentThread().getId(), deviceId);
        DataBuffer newDataBuffer = this.replicateToDevice(deviceId, array.data());
        DataBuffer newShapeBuffer = Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0, elementWiseStride, ordering);
        INDArray result = Nd4j.createArrayFromShapeBuffer((DataBuffer)newDataBuffer, (DataBuffer)newShapeBuffer);
        this.attachThreadToDevice(Thread.currentThread().getId(), (Integer)currentDeviceId);
        NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice((Pointer)new CudaPointer(currentDeviceId));
        return result;
    }

    public DataBuffer replicateToDevice(Integer deviceId, DataBuffer buffer) {
        if (buffer == null) {
            return null;
        }
        int currentDeviceId = AtomicAllocator.getInstance().getDeviceId();
        if (currentDeviceId != deviceId) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice((Pointer)new CudaPointer(deviceId.intValue()));
            Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread().getId(), deviceId);
        }
        DataBuffer dstBuffer = Nd4j.createBuffer((long)buffer.length(), (boolean)false);
        AtomicAllocator.getInstance().memcpy(dstBuffer, buffer);
        if (currentDeviceId != deviceId) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice((Pointer)new CudaPointer(currentDeviceId));
            Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread().getId(), Integer.valueOf(currentDeviceId));
        }
        return dstBuffer;
    }

    public void tagLocation(INDArray array, AffinityManager.Location location) {
        if (location == AffinityManager.Location.HOST) {
            AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite();
        } else if (location == AffinityManager.Location.DEVICE) {
            AtomicAllocator.getInstance().getAllocationPoint(array).tickDeviceWrite();
        } else if (location == AffinityManager.Location.EVERYWHERE) {
            AtomicAllocator.getInstance().getAllocationPoint(array).tickDeviceWrite();
            AtomicAllocator.getInstance().getAllocationPoint(array).tickHostRead();
        }
    }

    public void tagLocation(DataBuffer buffer, AffinityManager.Location location) {
        if (location == AffinityManager.Location.HOST) {
            AtomicAllocator.getInstance().getAllocationPoint(buffer).tickHostWrite();
        } else if (location == AffinityManager.Location.DEVICE) {
            AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite();
        } else if (location == AffinityManager.Location.EVERYWHERE) {
            AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite();
            AtomicAllocator.getInstance().getAllocationPoint(buffer).tickHostRead();
        }
    }
}

