/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.flow.impl;

import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.concurrency.EventsProvider;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.JCublasNDArray;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SynchronousFlowController
implements FlowController {
    private static Logger log = LoggerFactory.getLogger(SynchronousFlowController.class);
    private volatile Allocator allocator;
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
    protected EventsProvider eventsProvider = new EventsProvider();

    @Override
    public void init(Allocator allocator) {
        this.allocator = allocator;
    }

    @Override
    public void synchronizeToHost(AllocationPoint point) {
        if (!point.isActualOnHostSide()) {
            CudaContext context = (CudaContext)this.allocator.getDeviceContext().getContext();
            if (!point.isConstant()) {
                this.waitTillFinished(point);
            }
            if (point.getAllocationStatus() == AllocationStatus.DEVICE && !point.isActualOnHostSide()) {
                if (this.nativeOps.memcpyAsync(point.getHostPointer(), point.getDevicePointer(), AllocationUtils.getRequiredMemory(point.getShape()), CudaConstants.cudaMemcpyDeviceToHost, (Pointer)context.getSpecialStream()) == 0) {
                    throw new IllegalStateException("MemcpyAsync failed: " + point.getShape());
                }
                this.commitTransfer(context.getSpecialStream());
            }
            point.tickHostRead();
        }
    }

    @Override
    public void synchronizeToDevice(AllocationPoint point) {
        if (point.isConstant()) {
            return;
        }
        if (!point.isActualOnDeviceSide() && point.getAllocationStatus() == AllocationStatus.DEVICE) {
            CudaContext context = (CudaContext)this.allocator.getDeviceContext().getContext();
            if (this.nativeOps.memcpyAsync(point.getDevicePointer(), point.getHostPointer(), AllocationUtils.getRequiredMemory(point.getShape()), CudaConstants.cudaMemcpyHostToDevice, (Pointer)context.getSpecialStream()) == 0) {
                throw new IllegalStateException("MemcpyAsync failed: " + point.getShape());
            }
            this.commitTransfer(context.getSpecialStream());
            point.tickDeviceRead();
        }
    }

    @Override
    public void waitTillFinished(AllocationPoint point) {
        if (point.getLastWriteEvent() != null) {
            point.getLastWriteEvent().synchronize();
        }
    }

    @Override
    public CudaContext prepareActionAllWrite(INDArray ... operands) {
        CudaContext context = (CudaContext)this.allocator.getDeviceContext().getContext();
        int cId = this.allocator.getDeviceId();
        for (INDArray operand : operands) {
            if (operand == null) continue;
            Nd4j.getCompressor().autoDecompress(operand);
            AllocationPoint pointData = this.allocator.getAllocationPoint(operand);
            AllocationPoint pointShape = this.allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
            pointData.acquireLock();
            if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0) {
                DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() : operand.data().originalDataBuffer();
                this.allocator.getMemoryHandler().relocateObject(buffer);
            }
            if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
                ((JCublasNDArray)operand).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(operand.shapeInfoDataBuffer()));
            }
            this.prepareDelayedMemory(operand);
            this.allocator.getAllocationPoint(operand).setCurrentContext(context);
        }
        return context;
    }

    @Override
    public CudaContext prepareAction(INDArray result, INDArray ... operands) {
        CudaContext context = (CudaContext)this.allocator.getDeviceContext().getContext();
        int cId = this.allocator.getDeviceId();
        if (result != null) {
            Nd4j.getCompressor().autoDecompress(result);
            this.prepareDelayedMemory(result);
            AllocationPoint pointData = this.allocator.getAllocationPoint(result);
            AllocationPoint pointShape = this.allocator.getAllocationPoint(result.shapeInfoDataBuffer());
            pointData.acquireLock();
            if (!(pointData.getDeviceId() == cId || pointData.getDeviceId() < 0 || CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() && NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
                DataBuffer buffer = result.data().originalDataBuffer() == null ? result.data() : result.data().originalDataBuffer();
                this.allocator.getMemoryHandler().relocateObject(buffer);
            }
            if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
                ((JCublasNDArray)result).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(result.shapeInfoDataBuffer()));
            }
            this.allocator.getAllocationPoint(result).setCurrentContext(context);
        }
        for (INDArray operand : operands) {
            if (operand == null) continue;
            Nd4j.getCompressor().autoDecompress(operand);
            AllocationPoint pointData = this.allocator.getAllocationPoint(operand);
            AllocationPoint pointShape = this.allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
            pointData.acquireLock();
            if (!(pointData.getDeviceId() == cId || pointData.getDeviceId() < 0 || CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() && NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
                DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() : operand.data().originalDataBuffer();
                this.allocator.getMemoryHandler().relocateObject(buffer);
            }
            if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
                ((JCublasNDArray)operand).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(operand.shapeInfoDataBuffer()));
            }
            this.prepareDelayedMemory(operand);
            this.allocator.getAllocationPoint(operand).setCurrentContext(context);
        }
        return context;
    }

    @Override
    public void waitTillReleased(AllocationPoint point) {
        this.waitTillFinished(point);
        if (point.getLastReadEvent() != null) {
            point.getLastReadEvent().synchronize();
        }
    }

    @Override
    public void registerAction(CudaContext context, AllocationPoint result, AllocationPoint ... operands) {
        this.eventsProvider.storeEvent(result.getLastWriteEvent());
        result.setLastWriteEvent(this.eventsProvider.getEvent());
        result.getLastWriteEvent().register(context.getOldStream());
        result.releaseLock();
        for (AllocationPoint operand : operands) {
            this.eventsProvider.storeEvent(operand.getLastReadEvent());
            operand.setLastReadEvent(this.eventsProvider.getEvent());
            operand.getLastReadEvent().register(context.getOldStream());
            operand.releaseLock();
        }
    }

    @Override
    public void registerActionAllWrite(CudaContext context, INDArray ... operands) {
        for (INDArray operand : operands) {
            if (operand == null) continue;
            AllocationPoint pointOperand = this.allocator.getAllocationPoint(operand);
            pointOperand.tickDeviceWrite();
            this.eventsProvider.storeEvent(pointOperand.getLastWriteEvent());
            pointOperand.setLastWriteEvent(this.eventsProvider.getEvent());
            pointOperand.getLastWriteEvent().register(context.getOldStream());
            pointOperand.releaseLock();
        }
    }

    @Override
    public void registerAction(CudaContext context, INDArray result, INDArray ... operands) {
        if (result == null) {
            return;
        }
        AllocationPoint point = this.allocator.getAllocationPoint(result);
        point.tickDeviceWrite();
        this.eventsProvider.storeEvent(point.getLastWriteEvent());
        point.setLastWriteEvent(this.eventsProvider.getEvent());
        point.getLastWriteEvent().register(context.getOldStream());
        point.releaseLock();
        for (INDArray operand : operands) {
            if (operand == null) continue;
            AllocationPoint pointOperand = this.allocator.getAllocationPoint(operand);
            pointOperand.releaseLock();
            this.eventsProvider.storeEvent(pointOperand.getLastReadEvent());
            pointOperand.setLastReadEvent(this.eventsProvider.getEvent());
            pointOperand.getLastReadEvent().register(context.getOldStream());
        }
    }

    @Override
    public CudaContext prepareAction(AllocationPoint result, AllocationPoint ... operands) {
        CudaContext context = (CudaContext)this.allocator.getDeviceContext().getContext();
        if (result != null) {
            result.acquireLock();
            result.setCurrentContext(context);
        }
        for (AllocationPoint operand : operands) {
            if (operand == null) continue;
            operand.acquireLock();
            operand.setCurrentContext(context);
        }
        return context;
    }

    @Override
    public void commitTransfer(cudaStream_t streamUsed) {
        streamUsed.synchronize();
    }

    protected void prepareDelayedMemory(INDArray array) {
        if (this.configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
            AllocationPoint pointData = this.allocator.getAllocationPoint(array.shapeInfoDataBuffer());
            AllocationPoint pointShape = this.allocator.getAllocationPoint(array.shapeInfoDataBuffer());
            if (pointData.getAllocationStatus() != AllocationStatus.DEVICE) {
                this.prepareDelayedMemory(array.data());
            }
            if (pointShape.getAllocationStatus() == AllocationStatus.HOST) {
                DataBuffer oShape = array.shapeInfoDataBuffer();
                DataBuffer nShape = Nd4j.getConstantHandler().relocateConstantSpace(oShape);
                if (nShape == oShape) {
                    Nd4j.getConstantHandler().moveToConstantSpace(nShape);
                }
                ((JCublasNDArray)array).setShapeInfoDataBuffer(nShape);
            }
        }
    }

    protected void prepareDelayedMemory(DataBuffer buffer) {
        this.allocator.getMemoryHandler().promoteObject(buffer);
    }

    @Override
    public EventsProvider getEventsProvider() {
        return this.eventsProvider;
    }
}

