/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.allocator.tad;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import org.apache.commons.math3.util.Pair;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.tad.BasicTADManager;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.TadDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DeviceTADManager
extends BasicTADManager {
    protected List<Map<TadDescriptor, Pair<DataBuffer, DataBuffer>>> tadCache = new ArrayList<Map<TadDescriptor, Pair<DataBuffer, DataBuffer>>>();
    private Semaphore lock = new Semaphore(1);
    private static Logger logger = LoggerFactory.getLogger(DeviceTADManager.class);
    private Configuration configuration = CudaEnvironment.getInstance().getConfiguration();

    public DeviceTADManager() {
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        for (int i = 0; i < numDevices; ++i) {
            this.tadCache.add(i, new ConcurrentHashMap());
        }
    }

    @Override
    public void purgeBuffers() {
        logger.info("Purging TAD buffers...");
        this.tadCache = new ArrayList<Map<TadDescriptor, Pair<DataBuffer, DataBuffer>>>();
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        for (int i = 0; i < numDevices; ++i) {
            logger.info("Resetting device: [{}]", (Object)i);
            this.tadCache.add(i, new ConcurrentHashMap());
        }
        super.purgeBuffers();
    }

    @Override
    public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray array, int[] dimension) {
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        TadDescriptor descriptor = new TadDescriptor(array, dimension);
        if (!this.tadCache.get(deviceId).containsKey(descriptor)) {
            Pair<DataBuffer, DataBuffer> buffers = super.getTADOnlyShapeInfo(array, dimension);
            if (buffers.getFirst() != array.shapeInfoDataBuffer()) {
                AtomicAllocator.getInstance().moveToConstant((DataBuffer)buffers.getFirst());
            }
            if (buffers.getSecond() != null) {
                AtomicAllocator.getInstance().moveToConstant((DataBuffer)buffers.getSecond());
            }
            this.tadCache.get(deviceId).put(descriptor, buffers);
        }
        return this.tadCache.get(deviceId).get(descriptor);
    }
}

