/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.flow.impl;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.context.ContextPack;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.pointers.cuda.cudaEvent_t;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.allocator.time.TimeProvider;
import org.nd4j.jita.allocator.time.providers.OperativeProvider;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AsynchronousFlowController
implements FlowController {
    private volatile Allocator allocator;
    private static final Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
    private static Logger log = LoggerFactory.getLogger(AsynchronousFlowController.class);
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private transient TimeProvider timeProvider = new OperativeProvider();
    protected AtomicLong asyncHit = new AtomicLong(0L);
    protected AtomicLong asyncMiss = new AtomicLong(0L);
    protected Map<Integer, AtomicLong> lanesCounter = new ConcurrentHashMap<Integer, AtomicLong>();
    private AtomicLong totalHits = new AtomicLong(0L);
    protected static final int MAX_EXECUTION_QUEUE = configuration.getCommandQueueLength();
    protected static final AtomicLong eventCounts = new AtomicLong(0L);
    protected ArrayList<ArrayList<Queue<cudaEvent_t>>> eventsBarrier = new ArrayList();
    protected ArrayList<ArrayList<AtomicLong>> laneClocks = new ArrayList();
    protected ArrayList<AtomicLong> deviceClocks = new ArrayList();

    public AsynchronousFlowController() {
        int numLanes = configuration.getCommandLanesNumber();
        int numDevices = configuration.getAvailableDevices().size();
        for (int d = 0; d < numDevices; ++d) {
            this.eventsBarrier.add(d, new ArrayList());
            this.laneClocks.add(d, new ArrayList());
            this.deviceClocks.add(d, new AtomicLong(0L));
            for (int l = 0; l < numLanes; ++l) {
                this.eventsBarrier.get(d).add(l, new ConcurrentLinkedQueue());
                this.laneClocks.get(d).add(l, new AtomicLong(0L));
            }
        }
    }

    @Override
    public void init(Allocator allocator) {
        this.allocator = allocator;
    }

    @Override
    public void synchronizeToHost(AllocationPoint point) {
        if (!point.isActualOnHostSide()) {
            if (!point.isConstant()) {
                this.waitTillFinished(point);
            }
            if (point.getAllocationStatus() == AllocationStatus.DEVICE && !point.isActualOnHostSide()) {
                CudaContext context = (CudaContext)this.allocator.getDeviceContext().getContext();
                if (this.nativeOps.memcpyAsync(point.getHostPointer(), point.getDevicePointer(), AllocationUtils.getRequiredMemory(point.getShape()), CudaConstants.cudaMemcpyDeviceToHost, (Pointer)context.getSpecialStream()) == 0) {
                    throw new IllegalStateException("MemcpyAsync D2H failed: [" + point.getDevicePointer().address() + "] -> [" + point.getHostPointer().address() + "]");
                }
                this.commitTransfer(context.getSpecialStream());
            }
            point.tickHostRead();
        }
    }

    @Override
    public void waitTillFinished(AllocationPoint point) {
        cudaEvent_t event = point.getWriteLane();
        if (event != null) {
            event.synchronize();
            event.destroy();
        }
    }

    @Override
    public void waitTillReleased(AllocationPoint point) {
        cudaEvent_t event;
        this.waitTillFinished(point);
        while ((event = point.getReadLane().poll()) != null) {
            event.synchronize();
            event.destroy();
        }
    }

    @Override
    public void registerAction(CudaContext context, INDArray result, INDArray ... operands) {
        if (this.totalHits.incrementAndGet() % 25000L == 0L) {
            log.debug("AsyncHit ratio: [{}]", (Object)Float.valueOf(this.getAsyncHitRatio()));
        }
        cudaEvent_t event = new cudaEvent_t(this.nativeOps.createEvent());
        event.setLaneId(context.getLaneId());
        this.nativeOps.registerEvent((Pointer)event, (Pointer)context.getOldStream());
        if (result != null) {
            this.setWriteLane(result, event);
            this.allocator.tickDeviceWrite(result);
        }
        for (INDArray operand : operands) {
            if (operand == null) continue;
            this.setReadLane(operand, event);
        }
        Integer deviceId = this.allocator.getDeviceId();
        this.fillTail(deviceId, event.getLaneId(), event);
    }

    protected void setWriteLane(INDArray array, cudaEvent_t event) {
        AllocationPoint point = this.allocator.getAllocationPoint(array);
        point.setWriteLane(event);
    }

    protected void setReadLane(INDArray array, cudaEvent_t event) {
        AllocationPoint point = this.allocator.getAllocationPoint(array);
        point.addReadLane(event);
    }

    protected Queue<cudaEvent_t> getReadLanes(INDArray array) {
        AllocationPoint point = this.allocator.getAllocationPoint(array);
        return point.getReadLane();
    }

    protected cudaEvent_t getWriteLane(INDArray array) {
        AllocationPoint point = this.allocator.getAllocationPoint(array);
        return point.getWriteLane();
    }

    protected int hasActiveWrite(INDArray array) {
        if (array == null) {
            return -1;
        }
        cudaEvent_t event = this.getWriteLane(array);
        if (event == null || event.isDestroyed()) {
            return -1;
        }
        return event.getLaneId();
    }

    protected int hasActiveWrite(AllocationPoint point) {
        cudaEvent_t event = point.getWriteLane();
        if (event == null || event.isDestroyed()) {
            return -1;
        }
        return event.getLaneId();
    }

    protected boolean hasActiveReads(AllocationPoint point) {
        Queue<cudaEvent_t> events = point.getReadLane();
        if (events.size() == 0) {
            return false;
        }
        AtomicBoolean result = new AtomicBoolean(false);
        ArrayList<cudaEvent_t> asList = new ArrayList<cudaEvent_t>(events);
        for (cudaEvent_t event : asList) {
            if (event == null) continue;
            result.compareAndSet(false, !event.isDestroyed());
        }
        return result.get();
    }

    protected boolean hasActiveReads(INDArray array) {
        if (array == null) {
            return false;
        }
        AllocationPoint point = this.allocator.getAllocationPoint(array);
        return this.hasActiveReads(point);
    }

    protected boolean isMatchingLanes(int[] lanes) {
        return lanes[0] == lanes[1] || lanes[1] == -1 || lanes[0] == -1;
    }

    protected boolean isMatchingLanes(int zLane, int[] lanes) {
        return (zLane == lanes[0] || zLane == lanes[1]) && this.isMatchingLanes(lanes);
    }

    protected void synchronizeReadLanes(AllocationPoint point) {
        cudaEvent_t event;
        int cnt = 0;
        while ((event = point.getReadLane().poll()) != null) {
            event.synchronize();
            event.destroy();
            ++cnt;
        }
    }

    protected void synchronizeReadLanes(INDArray array) {
        if (array == null) {
            return;
        }
        AllocationPoint point = this.allocator.getAllocationPoint(array);
        this.synchronizeReadLanes(point);
    }

    @Override
    public void registerAction(CudaContext context, AllocationPoint result, AllocationPoint ... operands) {
        cudaEvent_t event = new cudaEvent_t(this.nativeOps.createEvent());
        event.setLaneId(context.getLaneId());
        this.nativeOps.registerEvent((Pointer)event, (Pointer)context.getOldStream());
        result.setWriteLane(event);
        Integer deviceId = this.allocator.getDeviceId();
        this.fillTail(deviceId, event.getLaneId(), event);
    }

    @Override
    public CudaContext prepareAction(AllocationPoint result, AllocationPoint ... operands) {
        if (this.hasActiveReads(result)) {
            this.synchronizeReadLanes(result);
        }
        ContextPack pack = this.allocator.getContextPool().acquireContextPackForDevice(this.allocator.getDeviceId());
        return pack.getContextForLane(pack.nextRandomLane());
    }

    protected int pickFirstLane(int[] lanes) {
        if (lanes[0] >= 0) {
            return lanes[0];
        }
        if (lanes[1] >= 0) {
            return lanes[1];
        }
        return 0;
    }

    @Override
    public CudaContext prepareAction(INDArray result, INDArray ... operands) {
        int lane;
        int[] pendingLanes;
        int lastLane;
        AtomicInteger holdersCount;
        AtomicInteger cnt;
        ContextPack pack = this.allocator.getContextPool().acquireContextPackForDevice(this.allocator.getDeviceId());
        int newLane = 0;
        int zLane = this.hasActiveWrite(result);
        boolean zReads = this.hasActiveReads(result);
        if (result != null && (zReads || zLane >= 0)) {
            cnt = new AtomicInteger(0);
            holdersCount = new AtomicInteger(0);
            lastLane = -1;
            pendingLanes = new int[operands.length + 1];
            Arrays.fill(pendingLanes, -1);
            for (INDArray operand : operands) {
                if (operand == null) continue;
                lane = this.hasActiveWrite(operand);
                if (lane >= 0) {
                    pendingLanes[cnt.get()] = lane;
                    holdersCount.incrementAndGet();
                    lastLane = lane;
                }
                cnt.incrementAndGet();
            }
            if (zReads) {
                this.synchronizeReadLanes(result);
            }
            if (holdersCount.get() > 0) {
                this.asyncMiss.incrementAndGet();
                if (this.isMatchingLanes(zLane, pendingLanes)) {
                    newLane = zLane >= 0 ? zLane : this.pickFirstLane(pendingLanes);
                } else {
                    newLane = zLane >= 0 ? zLane : this.pickFirstLane(pendingLanes);
                    for (INDArray operand : operands) {
                        if (operand == null) continue;
                        this.waitTillFinished(this.allocator.getAllocationPoint(operand));
                    }
                }
            } else {
                this.asyncHit.incrementAndGet();
                if (zLane < 0) {
                    zLane = pack.nextRandomLane();
                }
                newLane = zLane;
            }
        } else {
            cnt = new AtomicInteger(0);
            holdersCount = new AtomicInteger(0);
            lastLane = -1;
            pendingLanes = new int[operands.length + 1];
            Arrays.fill(pendingLanes, -1);
            for (INDArray operand : operands) {
                if (operand == null) continue;
                lane = this.hasActiveWrite(operand);
                if (lane >= 0) {
                    pendingLanes[cnt.get()] = lane;
                    holdersCount.incrementAndGet();
                    lastLane = lane;
                }
                cnt.incrementAndGet();
            }
            if (holdersCount.get() > 0) {
                this.asyncMiss.incrementAndGet();
                if (this.isMatchingLanes(pendingLanes)) {
                    newLane = lastLane;
                } else if (pendingLanes[0] >= 0) {
                    this.waitTillFinished(this.allocator.getAllocationPoint(operands[0]));
                    newLane = pendingLanes[1];
                } else if (pendingLanes[1] >= 0) {
                    this.waitTillFinished(this.allocator.getAllocationPoint(operands[1]));
                    newLane = pendingLanes[0];
                }
            } else {
                this.asyncHit.incrementAndGet();
                newLane = pack.nextRandomLane();
            }
        }
        CudaContext context = pack.getContextForLane(newLane);
        if (result != null) {
            this.allocator.getAllocationPoint(result).setCurrentContext(context);
        }
        for (INDArray operand : operands) {
            if (operand == null) continue;
            this.allocator.getAllocationPoint(operand).setCurrentContext(context);
        }
        if (!this.lanesCounter.containsKey(newLane)) {
            this.lanesCounter.put(newLane, new AtomicLong(0L));
        }
        this.lanesCounter.get(newLane).incrementAndGet();
        if (context == null) {
            throw new IllegalStateException("Context shouldn't be null: " + newLane);
        }
        return context;
    }

    private float getAsyncHitRatio() {
        long totalHits = this.asyncHit.get() + this.asyncMiss.get();
        float cacheRatio = (float)(this.asyncHit.get() * 100L) / (float)totalHits;
        return cacheRatio;
    }

    protected void fillTail(int deviceId, int lane, cudaEvent_t event) {
        this.eventsBarrier.get(deviceId).get(lane).add(event);
        long tick = this.deviceClocks.get(deviceId).incrementAndGet();
        this.laneClocks.get(deviceId).get(lane).set(tick);
    }

    protected void sweepTail() {
        Integer deviceId = this.allocator.getDeviceId();
        int cnt = 0;
        long lastCommandId = this.deviceClocks.get(deviceId).get();
        for (int l = 0; l < configuration.getCommandLanesNumber(); ++l) {
            cudaEvent_t event;
            Queue<cudaEvent_t> queue = this.eventsBarrier.get(deviceId).get(l);
            if (queue.size() < MAX_EXECUTION_QUEUE && this.laneClocks.get(deviceId).get(l).get() >= lastCommandId - (long)MAX_EXECUTION_QUEUE || (event = queue.poll()) == null || event.isDestroyed()) continue;
            event.synchronize();
            event.destroy();
            ++cnt;
        }
        this.deviceClocks.get(deviceId).incrementAndGet();
    }

    protected void cutTail() {
        Integer deviceId = this.allocator.getDeviceId();
        for (int l = 0; l < configuration.getCommandLanesNumber(); ++l) {
            cudaEvent_t event;
            Queue<cudaEvent_t> queue = this.eventsBarrier.get(deviceId).get(l);
            while ((event = queue.poll()) != null) {
                event.synchronize();
                event.destroy();
            }
        }
    }

    @Override
    public void commitTransfer(cudaStream_t streamUsed) {
        this.sweepTail();
        streamUsed.synchronize();
    }
}

