/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.jcublas.blas;

import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.cuda;
import org.bytedeco.javacpp.cusolver;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
import org.nd4j.linalg.api.blas.impl.BaseLapack;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class JcublasLapack
extends BaseLapack {
    private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private Allocator allocator = AtomicAllocator.getInstance();
    private static Logger logger = LoggerFactory.getLogger(JcublasLapack.class);

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void sgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
        if (Nd4j.dataType() != DataBuffer.Type.FLOAT) {
            logger.warn("FLOAT getrf called in DOUBLE environment");
        }
        if (A.ordering() == 'c') {
            logger.warn("GPU requires arrays to be ordering ='f' (A)");
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        CudaContext ctx = (CudaContext)this.allocator.getDeviceContext().getContext();
        cusolverDnHandle_t handle = ctx.getSolverHandle();
        cusolver.cusolverDnContext solverDn = new cusolver.cusolverDnContext((Pointer)handle);
        cusolverDnHandle_t cusolverDnHandle_t2 = handle;
        synchronized (cusolverDnHandle_t2) {
            int result = cusolver.cusolverDnSetStream((cusolver.cusolverDnContext)new cusolver.cusolverDnContext((Pointer)handle), (cuda.CUstream_st)new cuda.CUstream_st((Pointer)ctx.getOldStream()));
            if (result != 0) {
                throw new IllegalStateException("solverSetStream failed");
            }
            CublasPointer xAPointer = new CublasPointer(A, ctx);
            DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1L);
            int stat = cusolver.cusolverDnSgetrf_bufferSize((cusolver.cusolverDnContext)solverDn, (int)M, (int)N, (FloatPointer)((FloatPointer)xAPointer.getDevicePointer()), (int)M, (IntPointer)((IntPointer)worksizeBuffer.addressPointer()));
            if (stat != 0) {
                throw new IllegalStateException("cusolverDnSgetrf_bufferSize failed with code: " + stat);
            }
            int worksize = worksizeBuffer.getInt(0L);
            Workspace workspace = new Workspace(worksize * Nd4j.sizeOfDataType());
            stat = cusolver.cusolverDnSgetrf((cusolver.cusolverDnContext)solverDn, (int)M, (int)N, (FloatPointer)((FloatPointer)xAPointer.getDevicePointer()), (int)M, (FloatPointer)new CudaPointer(workspace).asFloatPointer(), (IntPointer)new CudaPointer(this.allocator.getPointer(IPIV, ctx)).asIntPointer(), (IntPointer)new CudaPointer(this.allocator.getPointer(INFO, ctx)).asIntPointer());
            if (stat != 0) {
                throw new IllegalStateException("cusolverDnSgetrf failed with code: " + stat);
            }
        }
        this.allocator.registerAction(ctx, A, new INDArray[0]);
        this.allocator.registerAction(ctx, INFO, new INDArray[0]);
        this.allocator.registerAction(ctx, IPIV, new INDArray[0]);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void dgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
        if (Nd4j.dataType() != DataBuffer.Type.DOUBLE) {
            logger.warn("FLOAT getrf called in FLOAT environment");
        }
        if (A.ordering() == 'c') {
            logger.warn("GPU requires arrays to be ordering ='f' (A)");
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        CudaContext ctx = (CudaContext)this.allocator.getDeviceContext().getContext();
        cusolverDnHandle_t handle = ctx.getSolverHandle();
        cusolver.cusolverDnContext solverDn = new cusolver.cusolverDnContext((Pointer)handle);
        cusolverDnHandle_t cusolverDnHandle_t2 = handle;
        synchronized (cusolverDnHandle_t2) {
            int result = cusolver.cusolverDnSetStream((cusolver.cusolverDnContext)new cusolver.cusolverDnContext((Pointer)handle), (cuda.CUstream_st)new cuda.CUstream_st((Pointer)ctx.getOldStream()));
            if (result != 0) {
                throw new IllegalStateException("solverSetStream failed");
            }
            CublasPointer xAPointer = new CublasPointer(A, ctx);
            DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1L);
            int stat = cusolver.cusolverDnDgetrf_bufferSize((cusolver.cusolverDnContext)solverDn, (int)M, (int)N, (DoublePointer)((DoublePointer)xAPointer.getDevicePointer()), (int)M, (IntPointer)((IntPointer)worksizeBuffer.addressPointer()));
            if (stat != 0) {
                throw new IllegalStateException("cusolverDnDgetrf_bufferSize failed with code: " + stat);
            }
            int worksize = worksizeBuffer.getInt(0L);
            Workspace workspace = new Workspace(worksize * Nd4j.sizeOfDataType());
            stat = cusolver.cusolverDnDgetrf((cusolver.cusolverDnContext)solverDn, (int)M, (int)N, (DoublePointer)((DoublePointer)xAPointer.getDevicePointer()), (int)M, (DoublePointer)new CudaPointer(workspace).asDoublePointer(), (IntPointer)new CudaPointer(this.allocator.getPointer(IPIV, ctx)).asIntPointer(), (IntPointer)new CudaPointer(this.allocator.getPointer(INFO, ctx)).asIntPointer());
            if (stat != 0) {
                throw new IllegalStateException("cusolverDnSgetrf failed with code: " + stat);
            }
        }
        this.allocator.registerAction(ctx, A, new INDArray[0]);
        this.allocator.registerAction(ctx, INFO, new INDArray[0]);
        this.allocator.registerAction(ctx, IPIV, new INDArray[0]);
    }

    public void getri(int N, INDArray A, int lda, int[] IPIV, INDArray WORK, int lwork, int INFO) {
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void sgesvd(byte jobu, byte jobvt, int M, int N, INDArray A, INDArray S, INDArray U, INDArray VT, INDArray INFO) {
        if (Nd4j.dataType() != DataBuffer.Type.FLOAT) {
            logger.warn("FLOAT gesvd called in DOUBLE environment");
        }
        if (A.ordering() == 'c') {
            logger.warn("GPU requires arrays to be ordering ='f' (A)");
        }
        if (U != null && U.ordering() == 'c') {
            logger.warn("GPU requires arrays to be ordering ='f' (U)");
        }
        if (VT != null && VT.ordering() == 'c') {
            logger.warn("GPU requires arrays to be ordering ='f' (VT)");
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        CudaContext ctx = (CudaContext)this.allocator.getDeviceContext().getContext();
        cusolverDnHandle_t handle = ctx.getSolverHandle();
        cusolver.cusolverDnContext solverDn = new cusolver.cusolverDnContext((Pointer)handle);
        cusolverDnHandle_t cusolverDnHandle_t2 = handle;
        synchronized (cusolverDnHandle_t2) {
            int result = cusolver.cusolverDnSetStream((cusolver.cusolverDnContext)new cusolver.cusolverDnContext((Pointer)handle), (cuda.CUstream_st)new cuda.CUstream_st((Pointer)ctx.getOldStream()));
            if (result != 0) {
                throw new IllegalStateException("solverSetStream failed");
            }
            CublasPointer xAPointer = new CublasPointer(A, ctx);
            DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1L);
            int stat = cusolver.cusolverDnSgesvd_bufferSize((cusolver.cusolverDnContext)solverDn, (int)M, (int)N, (IntPointer)((IntPointer)worksizeBuffer.addressPointer()));
            if (stat != 0) {
                throw new IllegalStateException("cusolverDnSgesvd_bufferSize failed with code: " + stat);
            }
            int worksize = worksizeBuffer.getInt(0L);
            Workspace workspace = new Workspace(worksize * Nd4j.sizeOfDataType());
            DataBuffer rwork = Nd4j.getDataBufferFactory().createFloat((long)((M < N ? M : N) - 1));
            stat = cusolver.cusolverDnSgesvd((cusolver.cusolverDnContext)solverDn, (byte)jobu, (byte)jobvt, (int)M, (int)N, (FloatPointer)((FloatPointer)xAPointer.getDevicePointer()), (int)M, (FloatPointer)new CudaPointer(this.allocator.getPointer(S, ctx)).asFloatPointer(), (FloatPointer)(U == null ? null : new CudaPointer(this.allocator.getPointer(U, ctx)).asFloatPointer()), (int)M, (FloatPointer)(VT == null ? null : new CudaPointer(this.allocator.getPointer(VT, ctx)).asFloatPointer()), (int)N, (FloatPointer)new CudaPointer(workspace).asFloatPointer(), (int)worksize, (FloatPointer)new CudaPointer(this.allocator.getPointer(rwork, ctx)).asFloatPointer(), (IntPointer)new CudaPointer(this.allocator.getPointer(INFO, ctx)).asIntPointer());
            if (stat != 0) {
                throw new IllegalStateException("cusolverDnSgesvd failed with code: " + stat);
            }
        }
        this.allocator.registerAction(ctx, INFO, new INDArray[0]);
        this.allocator.registerAction(ctx, S, new INDArray[0]);
        if (U != null) {
            this.allocator.registerAction(ctx, U, new INDArray[0]);
        }
        if (VT != null) {
            this.allocator.registerAction(ctx, VT, new INDArray[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void dgesvd(byte jobu, byte jobvt, int M, int N, INDArray A, INDArray S, INDArray U, INDArray VT, INDArray INFO) {
        if (Nd4j.dataType() != DataBuffer.Type.DOUBLE) {
            logger.warn("DOUBLE gesvd called in FLOAT environment");
        }
        if (A.ordering() == 'c') {
            logger.warn("GPU requires arrays to be in fortran - ordering ='f' (A)");
        }
        if (U != null && U.ordering() == 'c') {
            logger.warn("GPU requires arrays to be in fortran - ordering ='f' (U)");
        }
        if (VT != null && VT.ordering() == 'c') {
            logger.warn("GPU requires arrays to be in fortran - ordering ='f' (VT)");
        }
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        CudaContext ctx = (CudaContext)this.allocator.getDeviceContext().getContext();
        cusolverDnHandle_t handle = ctx.getSolverHandle();
        cusolver.cusolverDnContext solverDn = new cusolver.cusolverDnContext((Pointer)handle);
        cusolverDnHandle_t cusolverDnHandle_t2 = handle;
        synchronized (cusolverDnHandle_t2) {
            int result = cusolver.cusolverDnSetStream((cusolver.cusolverDnContext)new cusolver.cusolverDnContext((Pointer)handle), (cuda.CUstream_st)new cuda.CUstream_st((Pointer)ctx.getOldStream()));
            if (result != 0) {
                throw new IllegalStateException("solverSetStream failed");
            }
            CublasPointer xAPointer = new CublasPointer(A, ctx);
            DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1L);
            int stat = cusolver.cusolverDnSgesvd_bufferSize((cusolver.cusolverDnContext)solverDn, (int)M, (int)N, (IntPointer)((IntPointer)worksizeBuffer.addressPointer()));
            if (stat != 0) {
                throw new IllegalStateException("cusolverDnSgesvd_bufferSize failed with code: " + stat);
            }
            int worksize = worksizeBuffer.getInt(0L);
            Workspace workspace = new Workspace(worksize * Nd4j.sizeOfDataType());
            DataBuffer rwork = Nd4j.getDataBufferFactory().createDouble((long)((M < N ? M : N) - 1));
            stat = cusolver.cusolverDnDgesvd((cusolver.cusolverDnContext)solverDn, (byte)jobu, (byte)jobvt, (int)M, (int)N, (DoublePointer)((DoublePointer)xAPointer.getDevicePointer()), (int)M, (DoublePointer)new CudaPointer(this.allocator.getPointer(S, ctx)).asDoublePointer(), (DoublePointer)(U == null ? null : new CudaPointer(this.allocator.getPointer(U, ctx)).asDoublePointer()), (int)M, (DoublePointer)(VT == null ? null : new CudaPointer(this.allocator.getPointer(VT, ctx)).asDoublePointer()), (int)N, (DoublePointer)new CudaPointer(workspace).asDoublePointer(), (int)worksize, (DoublePointer)new CudaPointer(this.allocator.getPointer(rwork, ctx)).asDoublePointer(), (IntPointer)new CudaPointer(this.allocator.getPointer(INFO, ctx)).asIntPointer());
            if (stat != 0) {
                throw new IllegalStateException("cusolverDnDgesvd failed with code: " + stat);
            }
        }
        this.allocator.registerAction(ctx, INFO, new INDArray[0]);
        this.allocator.registerAction(ctx, S, new INDArray[0]);
        this.allocator.registerAction(ctx, A, new INDArray[0]);
        if (U != null) {
            this.allocator.registerAction(ctx, U, new INDArray[0]);
        }
        if (VT != null) {
            this.allocator.registerAction(ctx, VT, new INDArray[0]);
        }
    }

    static class Workspace
    extends Pointer {
        public Workspace(long size) {
            super(NativeOpsHolder.getInstance().getDeviceNativeOps().mallocDevice(size, null, 0));
            this.deallocator(new Pointer.Deallocator(){

                public void deallocate() {
                    NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice((Pointer)Workspace.this, null);
                }
            });
        }
    }
}

