/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.jcublas.buffer;

import java.io.DataInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.util.Collection;
import lombok.NonNull;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseCudaDataBuffer
extends BaseDataBuffer
implements JCudaBuffer {
    protected transient AllocationPoint allocationPoint;
    private static AtomicAllocator allocator = AtomicAllocator.getInstance();
    private static Logger log = LoggerFactory.getLogger(BaseCudaDataBuffer.class);
    protected DataBuffer.Type globalType = DataTypeUtil.getDtypeFromContext();

    public BaseCudaDataBuffer() {
    }

    public BaseCudaDataBuffer(Pointer pointer, Indexer indexer, long length) {
        super(pointer, indexer, length);
        if (!(pointer instanceof CudaPointer)) {
            this.pointer = new CudaPointer(pointer, length * (long)this.getElementSize(), 0L);
        }
        this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, this.elementSize), false);
        this.trackingPoint = this.allocationPoint.getObjectId();
    }

    public BaseCudaDataBuffer(float[] data, boolean copy) {
        this(data, copy, 0L);
    }

    public BaseCudaDataBuffer(float[] data, boolean copy, long offset) {
        this(data.length, 4);
        this.offset = offset;
        this.originalOffset = offset;
        this.length = (long)data.length - offset;
        this.underlyingLength = data.length;
        this.set(data, this.length, offset, offset);
    }

    public BaseCudaDataBuffer(double[] data, boolean copy) {
        this(data, copy, 0L);
    }

    public BaseCudaDataBuffer(double[] data, boolean copy, long offset) {
        this(data.length, 8);
        this.offset = offset;
        this.originalOffset = offset;
        this.length = (long)data.length - offset;
        this.underlyingLength = data.length;
        this.set(data, this.length, offset, offset);
    }

    public BaseCudaDataBuffer(int[] data, boolean copy) {
        this(data, copy, 0L);
    }

    public BaseCudaDataBuffer(int[] data, boolean copy, long offset) {
        this(data.length, 4);
        this.offset = offset;
        this.originalOffset = offset;
        this.length = (long)data.length - offset;
        this.underlyingLength = data.length;
        this.set(data, this.length, offset, offset);
    }

    public BaseCudaDataBuffer(long length, int elementSize, boolean initialize) {
        this.allocationMode = DataBuffer.AllocationMode.JAVACPP;
        this.initTypeAndSize();
        this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize), initialize);
        this.length = length;
        this.elementSize = elementSize;
        this.trackingPoint = this.allocationPoint.getObjectId();
        this.offset = 0L;
        this.originalOffset = 0L;
        if (this.dataType() == DataBuffer.Type.DOUBLE) {
            this.pointer = new CudaPointer(this.allocationPoint.getPointers().getHostPointer(), length, 0L).asDoublePointer();
            this.indexer = DoubleIndexer.create((DoublePointer)((DoublePointer)this.pointer));
        } else if (this.dataType() == DataBuffer.Type.FLOAT) {
            this.pointer = new CudaPointer(this.allocationPoint.getPointers().getHostPointer(), length, 0L).asFloatPointer();
            this.indexer = FloatIndexer.create((FloatPointer)((FloatPointer)this.pointer));
        } else if (this.dataType() == DataBuffer.Type.INT) {
            this.pointer = new CudaPointer(this.allocationPoint.getPointers().getHostPointer(), length, 0L).asIntPointer();
            this.indexer = IntIndexer.create((IntPointer)((IntPointer)this.pointer));
        }
        this.wrappedBuffer = this.pointer.asByteBuffer();
        if (this.wrappedBuffer == null) {
            throw new IllegalStateException("WrappedBuffer is NULL");
        }
    }

    public BaseCudaDataBuffer(long length, int elementSize) {
        this(length, elementSize, true);
    }

    public BaseCudaDataBuffer(long length, int elementSize, long offset) {
        this(length, elementSize);
        this.offset = offset;
        this.originalOffset = offset;
    }

    public BaseCudaDataBuffer(@NonNull DataBuffer underlyingBuffer, long length, long offset) {
        if (underlyingBuffer == null) {
            throw new NullPointerException("underlyingBuffer");
        }
        this.allocationMode = DataBuffer.AllocationMode.JAVACPP;
        this.initTypeAndSize();
        this.wrappedDataBuffer = underlyingBuffer;
        this.originalBuffer = underlyingBuffer.originalDataBuffer() == null ? underlyingBuffer : underlyingBuffer.originalDataBuffer();
        this.length = length;
        this.offset = offset;
        this.originalOffset = offset;
        this.trackingPoint = underlyingBuffer.getTrackingPoint();
        this.elementSize = underlyingBuffer.getElementSize();
        this.allocationPoint = ((BaseCudaDataBuffer)underlyingBuffer).allocationPoint;
        if (underlyingBuffer.dataType() == DataBuffer.Type.DOUBLE) {
            this.pointer = new CudaPointer(this.allocationPoint.getPointers().getHostPointer(), this.originalBuffer.length()).asDoublePointer();
            this.indexer = DoubleIndexer.create((DoublePointer)((DoublePointer)this.pointer));
        } else if (underlyingBuffer.dataType() == DataBuffer.Type.FLOAT) {
            this.pointer = new CudaPointer(this.allocationPoint.getPointers().getHostPointer(), this.originalBuffer.length()).asFloatPointer();
            this.indexer = FloatIndexer.create((FloatPointer)((FloatPointer)this.pointer));
        } else if (underlyingBuffer.dataType() == DataBuffer.Type.INT) {
            this.pointer = new CudaPointer(this.allocationPoint.getPointers().getHostPointer(), this.originalBuffer.length()).asIntPointer();
            this.indexer = IntIndexer.create((IntPointer)((IntPointer)this.pointer));
        }
        this.wrappedBuffer = this.pointer.asByteBuffer();
    }

    public BaseCudaDataBuffer(long length) {
        this(length, Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 8 : 4);
    }

    public BaseCudaDataBuffer(float[] data) {
        this((long)data.length, Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 8 : 4, false);
        this.set(data, (long)data.length, 0L, 0L);
    }

    public BaseCudaDataBuffer(int[] data) {
        this((long)data.length, Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 8 : 4, false);
        this.set(data, (long)data.length, 0L, 0L);
    }

    public BaseCudaDataBuffer(double[] data) {
        this((long)data.length, Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 8 : 4, false);
        this.set(data, (long)data.length, 0L, 0L);
    }

    public BaseCudaDataBuffer(byte[] data, long length) {
        this(ByteBuffer.wrap(data), length);
    }

    public BaseCudaDataBuffer(ByteBuffer buffer, long length) {
        this(buffer, length, 0L);
    }

    public BaseCudaDataBuffer(ByteBuffer buffer, long length, long offset) {
        this(length, Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 8 : 4, offset);
        CudaPointer srcPtr = new CudaPointer(new Pointer((Buffer)buffer.order(ByteOrder.nativeOrder())));
        allocator.memcpyAsync(this, srcPtr, length * (long)this.elementSize, offset * (long)this.elementSize);
    }

    public long address() {
        return this.allocationPoint.getPointers().getHostPointer().address();
    }

    public void set(int[] data, long length, long srcOffset, long dstOffset) {
        if (this.dataType() == DataBuffer.Type.DOUBLE) {
            CudaPointer srcPtr = new CudaPointer(new DoublePointer(ArrayUtil.toDoubles((int[])data)).address() + dstOffset * (long)this.elementSize);
            allocator.memcpyAsync(this, srcPtr, length * (long)this.elementSize, dstOffset * (long)this.elementSize);
        } else if (this.dataType() == DataBuffer.Type.FLOAT) {
            CudaPointer srcPtr = new CudaPointer(new FloatPointer(ArrayUtil.toFloats((int[])data)).address() + dstOffset * (long)this.elementSize);
            allocator.memcpyAsync(this, srcPtr, length * (long)this.elementSize, dstOffset * (long)this.elementSize);
        } else if (this.dataType() == DataBuffer.Type.INT) {
            CudaPointer srcPtr = new CudaPointer(new IntPointer(data).address() + dstOffset * (long)this.elementSize);
            allocator.memcpyAsync(this, srcPtr, length * (long)this.elementSize, dstOffset * (long)this.elementSize);
        }
    }

    public void set(float[] data, long length, long srcOffset, long dstOffset) {
        if (this.dataType() == DataBuffer.Type.DOUBLE) {
            CudaPointer srcPtr = new CudaPointer(new DoublePointer(ArrayUtil.toDoubles((float[])data)).address() + dstOffset * (long)this.elementSize);
            allocator.memcpyAsync(this, srcPtr, length * (long)this.elementSize, dstOffset * (long)this.elementSize);
        } else if (this.dataType() == DataBuffer.Type.FLOAT) {
            CudaPointer srcPtr = new CudaPointer(new FloatPointer(data).address() + dstOffset * (long)this.elementSize);
            allocator.memcpyAsync(this, srcPtr, length * (long)this.elementSize, dstOffset * (long)this.elementSize);
        } else if (this.dataType() == DataBuffer.Type.INT) {
            CudaPointer srcPtr = new CudaPointer(new IntPointer(ArrayUtil.toInts((float[])data)).address() + dstOffset * (long)this.elementSize);
            allocator.memcpyAsync(this, srcPtr, length * (long)this.elementSize, dstOffset * (long)this.elementSize);
        }
    }

    public void set(double[] data, long length, long srcOffset, long dstOffset) {
        if (this.dataType() == DataBuffer.Type.DOUBLE) {
            CudaPointer srcPtr = new CudaPointer(new DoublePointer(data).address() + dstOffset * (long)this.elementSize);
            allocator.memcpyAsync(this, srcPtr, length * (long)this.elementSize, dstOffset * (long)this.elementSize);
        } else if (this.dataType() == DataBuffer.Type.FLOAT) {
            CudaPointer srcPtr = new CudaPointer(new FloatPointer(ArrayUtil.toFloats((double[])data)).address() + dstOffset * (long)this.elementSize);
            allocator.memcpyAsync(this, srcPtr, length * (long)this.elementSize, dstOffset * (long)this.elementSize);
        } else if (this.dataType() == DataBuffer.Type.INT) {
            CudaPointer srcPtr = new CudaPointer(new IntPointer(ArrayUtil.toInts((double[])data)).address() + dstOffset * (long)this.elementSize);
            allocator.memcpyAsync(this, srcPtr, length * (long)this.elementSize, dstOffset * (long)this.elementSize);
        }
    }

    public void setData(int[] data) {
        this.set(data, (long)data.length, 0L, 0L);
    }

    public void setData(float[] data) {
        this.set(data, (long)data.length, 0L, 0L);
    }

    public void setData(double[] data) {
        this.set(data, (long)data.length, 0L, 0L);
    }

    protected void setNioBuffer() {
        throw new UnsupportedOperationException("setNioBuffer() is not supported for CUDA backend");
    }

    public void copyAtStride(DataBuffer buf, long n, long stride, long yStride, long offset, long yOffset) {
        allocator.synchronizeHostData(this);
        allocator.synchronizeHostData(buf);
        super.copyAtStride(buf, n, stride, yStride, offset, yOffset);
    }

    public DataBuffer.AllocationMode allocationMode() {
        return this.allocationMode;
    }

    @Override
    public ByteBuffer getHostBuffer() {
        return this.wrappedBuffer;
    }

    @Override
    public Pointer getHostPointer() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Pointer getHostPointer(int offset) {
        throw new UnsupportedOperationException();
    }

    public void removeReferencing(String id) {
        this.referencing.remove(id);
    }

    public Collection<String> references() {
        return this.referencing;
    }

    public int getElementSize() {
        return this.elementSize;
    }

    public void addReferencing(String id) {
        this.referencing.add(id);
    }

    public void put(long i, IComplexNumber result) {
        throw new UnsupportedOperationException("ComplexNumbers are not supported yet");
    }

    @Deprecated
    public Pointer getHostPointer(INDArray arr, int stride, int offset, int length) {
        throw new UnsupportedOperationException("This method is deprecated");
    }

    @Deprecated
    public void set(Pointer pointer) {
        throw new UnsupportedOperationException("set(Pointer) is not supported");
    }

    public void put(long i, float element) {
        allocator.synchronizeHostData(this);
        allocator.tickHostWrite(this);
        super.put(i, element);
    }

    public void put(long i, double element) {
        allocator.synchronizeHostData(this);
        allocator.tickHostWrite(this);
        super.put(i, element);
    }

    public void put(long i, int element) {
        allocator.synchronizeHostData(this);
        allocator.tickHostWrite(this);
        super.put(i, element);
    }

    public IComplexFloat getComplexFloat(long i) {
        return Nd4j.createFloat((float)this.getFloat(i), (float)this.getFloat(i + 1L));
    }

    public IComplexDouble getComplexDouble(long i) {
        return Nd4j.createDouble((double)this.getDouble(i), (double)this.getDouble(i + 1L));
    }

    public IComplexNumber getComplex(long i) {
        return this.dataType() == DataBuffer.Type.FLOAT ? this.getComplexFloat(i) : this.getComplexDouble(i);
    }

    @Deprecated
    protected void set(long index, long length, Pointer from, long inc) {
        long offset = (long)this.getElementSize() * index;
        if (offset >= this.length() * (long)this.getElementSize()) {
            throw new IllegalArgumentException("Illegal offset " + offset + " with index of " + index + " and length " + this.length());
        }
        throw new UnsupportedOperationException("Deprecated set() call");
    }

    @Deprecated
    protected void set(long index, long length, Pointer from) {
        this.set(index, length, from, 1L);
    }

    public void assign(DataBuffer data) {
        allocator.memcpy(this, data);
    }

    @Deprecated
    protected void set(long index, Pointer from) {
        this.set(index, 1L, from);
    }

    public void flush() {
    }

    public void destroy() {
    }

    private void writeObject(ObjectOutputStream stream) throws IOException {
        stream.defaultWriteObject();
        this.write(stream);
    }

    private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
        this.doReadObject(stream);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("[");
        int i = 0;
        while ((long)i < this.length()) {
            sb.append(this.getDouble(i));
            if ((long)i < this.length() - 1L) {
                sb.append(",");
            }
            ++i;
        }
        sb.append("]");
        return sb.toString();
    }

    public boolean sameUnderlyingData(DataBuffer buffer) {
        return buffer.getTrackingPoint() == this.getTrackingPoint();
    }

    public boolean equals(Object o) {
        if (o == null) {
            return false;
        }
        return this == o;
    }

    public void read(DataInputStream s) {
        try {
            s.readUTF();
            this.allocationMode = DataBuffer.AllocationMode.JAVACPP;
            this.length = s.readInt();
            DataBuffer.Type t = DataBuffer.Type.valueOf((String)s.readUTF());
            if (this.globalType == null && Nd4j.dataType() != null) {
                this.globalType = Nd4j.dataType();
            }
            if (t != this.globalType && t != DataBuffer.Type.INT) {
                log.warn("Loading a data stream with type different from what is set globally. Expect precision loss");
                if (this.globalType == DataBuffer.Type.INT) {
                    log.warn("Int to float/double widening UNSUPPORTED!!!");
                }
            }
            if (t == DataBuffer.Type.INT || this.globalType == DataBuffer.Type.INT) {
                this.elementSize = 4;
                this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(this.length, this.elementSize), false);
                this.trackingPoint = this.allocationPoint.getObjectId();
                this.pointer = new CudaPointer(this.allocationPoint.getPointers().getHostPointer(), this.length).asIntPointer();
                this.indexer = IntIndexer.create((IntPointer)((IntPointer)this.pointer));
                int[] array = new int[(int)this.length];
                int i = 0;
                while ((long)i < this.length()) {
                    if (t == DataBuffer.Type.INT) {
                        array[i] = s.readInt();
                    } else if (t == DataBuffer.Type.DOUBLE) {
                        array[i] = (int)s.readDouble();
                    } else if (t == DataBuffer.Type.FLOAT) {
                        array[i] = (int)s.readFloat();
                    }
                    ++i;
                }
                this.setData(array);
            } else if (this.globalType == DataBuffer.Type.DOUBLE) {
                this.elementSize = 8;
                this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(this.length, this.elementSize), false);
                this.trackingPoint = this.allocationPoint.getObjectId();
                this.pointer = new CudaPointer(this.allocationPoint.getPointers().getHostPointer(), this.length).asDoublePointer();
                this.indexer = DoubleIndexer.create((DoublePointer)((DoublePointer)this.pointer));
                double[] array = new double[(int)this.length];
                int i = 0;
                while ((long)i < this.length()) {
                    if (t == DataBuffer.Type.INT) {
                        array[i] = s.readInt();
                    } else if (t == DataBuffer.Type.DOUBLE) {
                        array[i] = s.readDouble();
                    } else if (t == DataBuffer.Type.FLOAT) {
                        array[i] = s.readFloat();
                    }
                    ++i;
                }
                this.setData(array);
            } else if (this.globalType == DataBuffer.Type.FLOAT) {
                this.elementSize = 4;
                this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(this.length, this.elementSize), false);
                this.trackingPoint = this.allocationPoint.getObjectId();
                this.pointer = new CudaPointer(this.allocationPoint.getPointers().getHostPointer(), this.length).asFloatPointer();
                this.indexer = FloatIndexer.create((FloatPointer)((FloatPointer)this.pointer));
                float[] array = new float[(int)this.length];
                int i = 0;
                while ((long)i < this.length()) {
                    if (t == DataBuffer.Type.INT) {
                        array[i] = s.readInt();
                    } else if (t == DataBuffer.Type.DOUBLE) {
                        array[i] = (float)s.readDouble();
                    } else if (t == DataBuffer.Type.FLOAT) {
                        array[i] = s.readFloat();
                    }
                    ++i;
                }
                this.setData(array);
            } else {
                throw new IllegalStateException("Unknown dataType: [" + t.toString() + "]");
            }
            this.wrappedBuffer = this.pointer.asByteBuffer();
            this.wrappedBuffer.order(ByteOrder.nativeOrder());
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        allocator.synchronizeHostData(this);
    }

    public byte[] asBytes() {
        allocator.synchronizeHostData(this);
        return super.asBytes();
    }

    public double[] asDouble() {
        allocator.synchronizeHostData(this);
        return super.asDouble();
    }

    public float[] asFloat() {
        allocator.synchronizeHostData(this);
        return super.asFloat();
    }

    public int[] asInt() {
        allocator.synchronizeHostData(this);
        return super.asInt();
    }

    public ByteBuffer asNio() {
        allocator.synchronizeHostData(this);
        return super.asNio();
    }

    public DoubleBuffer asNioDouble() {
        allocator.synchronizeHostData(this);
        return super.asNioDouble();
    }

    public FloatBuffer asNioFloat() {
        allocator.synchronizeHostData(this);
        return super.asNioFloat();
    }

    public IntBuffer asNioInt() {
        allocator.synchronizeHostData(this);
        return super.asNioInt();
    }

    public DataBuffer dup() {
        allocator.synchronizeHostData(this);
        DataBuffer buffer = this.create(this.length);
        allocator.memcpyBlocking(buffer, new CudaPointer(allocator.getHostPointer(this).address()), this.length * (long)this.elementSize, 0L);
        return buffer;
    }

    public Number getNumber(long i) {
        allocator.synchronizeHostData(this);
        return super.getNumber(i);
    }

    public double getDouble(long i) {
        allocator.synchronizeHostData(this);
        return super.getDouble(i);
    }

    public float getFloat(long i) {
        allocator.synchronizeHostData(this);
        return super.getFloat(i);
    }

    public int getInt(long ix) {
        allocator.synchronizeHostData(this);
        return super.getInt(ix);
    }

    public AllocationPoint getAllocationPoint() {
        return this.allocationPoint;
    }
}

