/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.flow.impl;

import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.JCublasNDArray;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SynchronousFlowController
implements FlowController {
    private static Logger log = LoggerFactory.getLogger(SynchronousFlowController.class);
    private volatile Allocator allocator;
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected Configuration configuration = CudaEnvironment.getInstance().getConfiguration();

    @Override
    public void init(Allocator allocator) {
        this.allocator = allocator;
    }

    @Override
    public void synchronizeToHost(AllocationPoint point) {
        CudaContext context = (CudaContext)this.allocator.getDeviceContext().getContext();
        if (!point.isActualOnHostSide()) {
            if (!point.isConstant()) {
                this.waitTillFinished(point);
            }
            if (point.getAllocationStatus() == AllocationStatus.DEVICE && !point.isActualOnHostSide()) {
                if (this.nativeOps.memcpyAsync(point.getHostPointer(), point.getDevicePointer(), AllocationUtils.getRequiredMemory(point.getShape()), CudaConstants.cudaMemcpyDeviceToHost, (Pointer)context.getSpecialStream()) == 0) {
                    throw new IllegalStateException("MemcpyAsync failed");
                }
                this.commitTransfer(context.getSpecialStream());
            }
            point.tickHostRead();
        }
    }

    @Override
    public void waitTillFinished(AllocationPoint point) {
        CudaContext context = (CudaContext)this.allocator.getDeviceContext().getContext();
        context.syncOldStream();
    }

    @Override
    public void registerAction(CudaContext context, INDArray result, INDArray ... operands) {
        if (result == null) {
            return;
        }
        AllocationPoint point = this.allocator.getAllocationPoint(result);
        point.tickDeviceWrite();
    }

    @Override
    public CudaContext prepareAction(INDArray result, INDArray ... operands) {
        CudaContext context = (CudaContext)this.allocator.getDeviceContext().getContext();
        int cId = this.allocator.getDeviceId();
        if (result != null) {
            this.prepareDelayedMemory(result);
            AllocationPoint pointData = this.allocator.getAllocationPoint(result.data());
            AllocationPoint pointShape = this.allocator.getAllocationPoint(result.shapeInfoDataBuffer());
            if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0) {
                this.allocator.getMemoryHandler().relocateObject(result.data());
            }
            if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
                ((JCublasNDArray)result).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(result.shapeInfoDataBuffer()));
            }
            this.allocator.getAllocationPoint(result).setCurrentContext(context);
        }
        for (INDArray operand : operands) {
            if (operand == null) continue;
            AllocationPoint pointData = this.allocator.getAllocationPoint(operand.data());
            AllocationPoint pointShape = this.allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
            if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0) {
                this.allocator.getMemoryHandler().relocateObject(operand.data());
            }
            if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
                ((JCublasNDArray)operand).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(operand.shapeInfoDataBuffer()));
            }
            this.prepareDelayedMemory(operand);
            this.allocator.getAllocationPoint(operand).setCurrentContext(context);
        }
        return context;
    }

    @Override
    public void waitTillReleased(AllocationPoint point) {
        this.waitTillFinished(point);
    }

    @Override
    public void registerAction(CudaContext context, AllocationPoint result, AllocationPoint ... operands) {
        context.syncOldStream();
    }

    @Override
    public CudaContext prepareAction(AllocationPoint result, AllocationPoint ... operands) {
        CudaContext context = (CudaContext)this.allocator.getDeviceContext().getContext();
        return context;
    }

    @Override
    public void commitTransfer(cudaStream_t streamUsed) {
        streamUsed.synchronize();
    }

    protected void prepareDelayedMemory(INDArray array) {
        if (this.configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
            AllocationPoint pointData = this.allocator.getAllocationPoint(array.shapeInfoDataBuffer());
            AllocationPoint pointShape = this.allocator.getAllocationPoint(array.shapeInfoDataBuffer());
            if (pointData.getAllocationStatus() != AllocationStatus.DEVICE) {
                this.prepareDelayedMemory(array.data());
            }
            if (pointShape.getAllocationStatus() == AllocationStatus.HOST) {
                DataBuffer oShape = array.shapeInfoDataBuffer();
                DataBuffer nShape = Nd4j.getConstantHandler().relocateConstantSpace(oShape);
                if (nShape == oShape) {
                    Nd4j.getConstantHandler().moveToConstantSpace(nShape);
                }
                ((JCublasNDArray)array).setShapeInfoDataBuffer(nShape);
            }
        }
    }

    protected void prepareDelayedMemory(DataBuffer buffer) {
        this.allocator.getMemoryHandler().promoteObject(buffer);
    }
}

