/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.concurrency;

import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.concurrency.BasicAffinityManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CudaAffinityManager
extends BasicAffinityManager {
    private static final Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
    private static Logger logger = LoggerFactory.getLogger(CudaAffinityManager.class);
    private Map<Long, Integer> affinityMap = new ConcurrentHashMap<Long, Integer>();
    private AtomicInteger devPtr = new AtomicInteger(0);

    public Integer getDeviceForCurrentThread() {
        return this.getDeviceForThread(Thread.currentThread().getId());
    }

    public Integer getDeviceForThread(Thread thread) {
        return this.getDeviceForThread(thread.getId());
    }

    public Integer getDeviceForThread(long threadId) {
        if (!this.affinityMap.containsKey(threadId)) {
            Integer deviceId = this.getNextDevice(threadId);
            this.affinityMap.put(threadId, deviceId);
            return deviceId;
        }
        return this.affinityMap.get(threadId);
    }

    public void attachThreadToDevice(Thread thread, Integer deviceId) {
        this.attachThreadToDevice(thread.getId(), deviceId);
    }

    public void attachThreadToDevice(long threadId, Integer deviceId) {
        ArrayList<Integer> devices = new ArrayList<Integer>(configuration.getAvailableDevices());
        logger.debug("Manually mapping thread [{}] to device [{}], out of [{}] devices...", new Object[]{threadId, deviceId, devices.size()});
        this.affinityMap.put(threadId, deviceId);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected Integer getNextDevice(long threadId) {
        ArrayList<Integer> devices = new ArrayList<Integer>(configuration.getAvailableDevices());
        Integer device = null;
        if (!configuration.isForcedSingleGPU()) {
            CudaAffinityManager cudaAffinityManager = this;
            synchronized (cudaAffinityManager) {
                device = (Integer)devices.get(this.devPtr.getAndIncrement());
                if (this.devPtr.get() >= devices.size()) {
                    this.devPtr.set(0);
                }
                logger.debug("Mapping thread [{}] to device [{}], out of [{}] devices...", new Object[]{threadId, device, devices.size()});
            }
        } else {
            device = configuration.getAvailableDevices().get(0);
            logger.debug("Single device is forced, mapping to device [{}]", (Object)device);
        }
        return device;
    }
}

