/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.shape;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.StridePermutation;
import org.nd4j.linalg.api.shape.loop.coordinatefunction.CoordinateFunction;
import org.nd4j.linalg.api.shape.loop.one.RawArrayIterationInformation1;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.ShapeOffsetResolution;
import org.nd4j.linalg.util.ArrayUtil;

public class Shape {
    private Shape() {
    }

    public static int[] resolveNegativeShapeIfNeccessary(int[] newShape) {
        int numberNegativesOnes = 0;
        int[] shape = ArrayUtil.copy((int[])newShape);
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] >= 0) continue;
            if (numberNegativesOnes >= 1) {
                throw new IllegalArgumentException("Only one dimension can be negative ones");
            }
            ++numberNegativesOnes;
            int shapeLength = 1;
            for (int j = 0; j < shape.length; ++j) {
                if (shape[j] < 1) continue;
                shapeLength *= shape[j];
            }
            int realShape = Math.abs(ArrayUtil.prod((int[])newShape) / shapeLength);
            int[] thisNewShape = new int[shape.length];
            for (int j = 0; j < shape.length; ++j) {
                thisNewShape[j] = i != j ? shape[j] : realShape;
            }
            shape = thisNewShape;
            break;
        }
        return shape;
    }

    public static boolean isWholeArray(int[] shape, int ... dimension) {
        return dimension == null || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE || dimension.length == shape.length;
    }

    public static int[] getReducedShape(int[] wholeShape, int[] dimensions) {
        if (Shape.isWholeArray(wholeShape, dimensions)) {
            return new int[]{1, 1};
        }
        if (dimensions.length == 1 && wholeShape.length == 2) {
            int[] ret = new int[2];
            if (dimensions[0] == 0) {
                ret[0] = wholeShape[0];
                ret[1] = 1;
            } else if (dimensions[0] == 1) {
                ret[0] = 1;
                ret[1] = wholeShape[1];
            }
            return ret;
        }
        return ArrayUtil.removeIndex((int[])wholeShape, (int[])dimensions);
    }

    public static int[] getMatrixMultiplyShape(int[] left, int[] right) {
        if (left.length != 2 && right.length != 2) {
            throw new IllegalArgumentException("Illegal shapes for matrix multiply. Must be of length 2");
        }
        if (left[1] != right[0]) {
            throw new IllegalArgumentException("Columns of left not equal to rows of right");
        }
        int[] shape = new int[]{left[0], right[1]};
        return shape;
    }

    public static INDArray toOffsetZero(INDArray arr) {
        if ((arr.offset() < 1L && arr.data().length() == (long)arr.length() || arr instanceof IComplexNDArray && (long)(arr.length() * 2) == arr.data().length()) && (arr.ordering() == 'f' && arr.stride(-1) != arr.elementStride() || arr.ordering() == 'c' && arr.stride(0) != arr.elementStride())) {
            return arr;
        }
        if (arr.isRowVector()) {
            if (arr instanceof IComplexNDArray) {
                IComplexNDArray ret = Nd4j.createComplex(arr.shape());
                for (int i = 0; i < ret.length(); ++i) {
                    ret.putScalar(i, ((IComplexNDArray)arr).getComplex(i));
                }
                return ret;
            }
            INDArray ret = Nd4j.create(arr.shape());
            for (int i = 0; i < ret.length(); ++i) {
                ret.putScalar(i, arr.getDouble(i));
            }
            return ret;
        }
        if (arr instanceof IComplexNDArray) {
            IComplexNDArray ret = Nd4j.createComplex(arr.shape());
            for (int i = 0; i < ret.slices(); ++i) {
                ret.putSlice(i, arr.slice(i));
            }
            return ret;
        }
        INDArray ret = Nd4j.create(arr.shape(), arr.ordering());
        ret.assign(arr);
        return ret;
    }

    public static INDArray toOffsetZeroCopy(INDArray arr) {
        return Shape.toOffsetZeroCopyHelper(arr, Nd4j.order().charValue(), false);
    }

    public static INDArray toOffsetZeroCopy(INDArray arr, char order) {
        return Shape.toOffsetZeroCopyHelper(arr, order, false);
    }

    public static INDArray toOffsetZeroCopyAnyOrder(INDArray arr) {
        return Shape.toOffsetZeroCopyHelper(arr, Nd4j.order().charValue(), true);
    }

    private static INDArray toOffsetZeroCopyHelper(INDArray arr, char order, boolean anyOrder) {
        char outOrder;
        if (arr instanceof IComplexNDArray) {
            if (arr.isRowVector()) {
                IComplexNDArray ret = Nd4j.createComplex(arr.shape(), order);
                for (int i = 0; i < ret.length(); ++i) {
                    ret.putScalar(i, ((IComplexNDArray)arr).getComplex(i));
                }
                return ret;
            }
            IComplexNDArray ret = Nd4j.createComplex(arr.shape(), order);
            for (int i = 0; i < ret.slices(); ++i) {
                ret.putSlice(i, arr.slice(i));
            }
            return ret;
        }
        char c = outOrder = anyOrder ? arr.ordering() : order;
        if (outOrder == 'a') {
            outOrder = Nd4j.order().charValue();
        }
        INDArray z = Nd4j.createUninitialized(arr.shape(), outOrder);
        z.assign(arr);
        return z;
    }

    public static double getDouble(INDArray arr, int ... indices) {
        long offset = Shape.getOffset(arr.shapeInfo(), indices);
        return arr.data().getDouble(offset);
    }

    public static void iterate(INDArray arr, CoordinateFunction coordinateFunction) {
        Shape.iterate(0, arr.rank(), arr.shape(), new int[arr.rank()], coordinateFunction);
    }

    public static void iterate(INDArray arr, INDArray arr2, CoordinateFunction coordinateFunction) {
        Shape.iterate(0, arr.rank(), arr.shape(), new int[arr.rank()], 0, arr2.rank(), arr2.shape(), new int[arr2.rank()], coordinateFunction);
    }

    public static void iterate(int dimension, int n, int[] size, int[] res, int dimension2, int n2, int[] size2, int[] res2, CoordinateFunction func) {
        if (dimension >= n || dimension2 >= n2) {
            func.process(res, res2);
            return;
        }
        if (size2.length != size.length) {
            if (dimension >= size.length) {
                return;
            }
            for (int i = 0; i < size[dimension] && dimension2 < size2.length; ++i) {
                int j = 0;
                while (j < size2[dimension2]) {
                    res[dimension] = i;
                    res2[dimension2] = j++;
                    Shape.iterate(dimension + 1, n, size, res, dimension2 + 1, n2, size2, res2, func);
                }
            }
        } else {
            if (dimension >= size.length) {
                return;
            }
            for (int i = 0; i < size[dimension]; ++i) {
                int j = 0;
                while (j < size2[dimension2] && dimension2 < size2.length) {
                    res[dimension] = i;
                    res2[dimension2] = j++;
                    Shape.iterate(dimension + 1, n, size, res, dimension2 + 1, n2, size2, res2, func);
                }
            }
        }
    }

    public static void iterate(int dimension, int n, int[] size, int[] res, CoordinateFunction func) {
        if (dimension >= n) {
            func.process(new int[][]{res});
            return;
        }
        int i = 0;
        while (i < size[dimension]) {
            res[dimension] = i++;
            Shape.iterate(dimension + 1, n, size, res, func);
        }
    }

    public static long getOffset(long baseOffset, int[] shape, int[] stride, int ... indices) {
        if (shape.length != stride.length || indices.length != shape.length) {
            throw new IllegalArgumentException("Indexes, shape, and stride must be the same length");
        }
        long offset = baseOffset;
        for (int i = 0; i < shape.length; ++i) {
            if (indices[i] >= shape[i]) {
                throw new IllegalArgumentException(String.format("Index [%d] must not be >= shape[%d]=%d.", i, i, shape[i]));
            }
            if (shape[i] == 1) continue;
            offset += (long)(indices[i] * stride[i]);
        }
        return offset;
    }

    public static long getOffset(IntBuffer shapeInformation, int ... indices) {
        int rank = Shape.rank(shapeInformation);
        if (indices.length != rank) {
            throw new IllegalArgumentException("Indexes must be same length as array rank");
        }
        long offset = 0L;
        for (int i = 0; i < rank; ++i) {
            int size_dimi = Shape.size(shapeInformation, i);
            if (indices[i] >= size_dimi) {
                throw new IllegalArgumentException(String.format("Index [%d] must not be >= shape[%d]=%d.", i, i, size_dimi));
            }
            if (size_dimi == 1) continue;
            offset += (long)(indices[i] * Shape.stride(shapeInformation, i));
        }
        return offset;
    }

    public static long getOffset(DataBuffer shapeInformation, int ... indices) {
        int rank = Shape.rank(shapeInformation);
        if (indices.length != rank) {
            throw new IllegalArgumentException("Indexes must be same length as array rank");
        }
        long offset = 0L;
        for (int i = 0; i < rank; ++i) {
            int size_dimi = Shape.size(shapeInformation, i);
            if (indices[i] > size_dimi) {
                throw new IllegalArgumentException(String.format("Index [%d] must not be >= shape[%d]=%d.", i, i, size_dimi));
            }
            if (size_dimi == 1) continue;
            offset += (long)(indices[i] * Shape.stride(shapeInformation, i));
        }
        return offset;
    }

    public static long getOffset(int[] shapeInformation, int ... indices) {
        int rank = Shape.rank(shapeInformation);
        if (indices.length != rank) {
            throw new IllegalArgumentException("Indexes must be same length as array rank");
        }
        long offset = 0L;
        for (int i = 0; i < rank; ++i) {
            int size_dimi = Shape.size(shapeInformation, i);
            if (indices[i] > size_dimi) {
                throw new IllegalArgumentException(String.format("Index [%d] must not be >= shape[%d]=%d.", i, i, size_dimi));
            }
            if (size_dimi == 1) continue;
            offset += (long)(indices[i] * Shape.stride(shapeInformation, i));
        }
        return offset;
    }

    public static long getOffset(DataBuffer shapeInformation, int row, int col) {
        int rank = Shape.rank(shapeInformation);
        if (rank != 2) {
            throw new IllegalArgumentException("Cannot use this getOffset method on arrays of rank != 2 (rank is: " + rank + ")");
        }
        return Shape.getOffsetUnsafe(shapeInformation, row, col);
    }

    public static long getOffsetUnsafe(DataBuffer shapeInformation, int row, int col) {
        long offset = 0L;
        int size_0 = Shape.sizeUnsafe(shapeInformation, 0);
        int size_1 = Shape.sizeUnsafe(shapeInformation, 1);
        if (row >= size_0 || col >= size_1) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + row + "," + col + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(row * Shape.strideUnsafe(shapeInformation, 0, 2));
        }
        if (size_1 != 1) {
            offset += (long)(col * Shape.strideUnsafe(shapeInformation, 1, 2));
        }
        return offset;
    }

    public static long getOffsetUnsafe(int[] shapeInformation, int row, int col) {
        long offset = 0L;
        int size_0 = Shape.sizeUnsafe(shapeInformation, 0);
        int size_1 = Shape.sizeUnsafe(shapeInformation, 1);
        if (row >= size_0 || col >= size_1) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + row + "," + col + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(row * Shape.strideUnsafe(shapeInformation, 0, 2));
        }
        if (size_1 != 1) {
            offset += (long)(col * Shape.strideUnsafe(shapeInformation, 1, 2));
        }
        return offset;
    }

    public static long getOffset(IntBuffer shapeInformation, int row, int col) {
        int rank = Shape.rank(shapeInformation);
        if (rank != 2) {
            throw new IllegalArgumentException("Cannot use this getOffset method on arrays of rank != 2 (rank is: " + rank + ")");
        }
        long offset = 0L;
        int size_0 = Shape.size(shapeInformation, 0);
        int size_1 = Shape.size(shapeInformation, 1);
        if (row >= size_0 || col >= size_1) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + row + "," + col + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(row * Shape.stride(shapeInformation, 0));
        }
        if (size_1 != 1) {
            offset += (long)(col * Shape.stride(shapeInformation, 1));
        }
        return offset;
    }

    public static long getOffset(IntBuffer shapeInformation, int dim0, int dim1, int dim2) {
        int rank = Shape.rank(shapeInformation);
        if (rank != 3) {
            throw new IllegalArgumentException("Cannot use this getOffset method on arrays of rank != 3 (rank is: " + rank + ")");
        }
        long offset = 0L;
        int size_0 = Shape.size(shapeInformation, 0);
        int size_1 = Shape.size(shapeInformation, 1);
        int size_2 = Shape.size(shapeInformation, 2);
        if (dim0 >= size_0 || dim1 >= size_1 || dim2 >= size_2) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + dim0 + "," + dim1 + "," + dim2 + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(dim0 * Shape.stride(shapeInformation, 0));
        }
        if (size_1 != 1) {
            offset += (long)(dim1 * Shape.stride(shapeInformation, 1));
        }
        if (size_2 != 1) {
            offset += (long)(dim2 * Shape.stride(shapeInformation, 2));
        }
        return offset;
    }

    public static long getOffset(DataBuffer shapeInformation, int dim0, int dim1, int dim2) {
        int rank = Shape.rank(shapeInformation);
        if (rank != 3) {
            throw new IllegalArgumentException("Cannot use this getOffset method on arrays of rank != 3 (rank is: " + rank + ")");
        }
        return Shape.getOffsetUnsafe(shapeInformation, dim0, dim1, dim2);
    }

    public static long getOffsetUnsafe(DataBuffer shapeInformation, int dim0, int dim1, int dim2) {
        long offset = 0L;
        int size_0 = Shape.sizeUnsafe(shapeInformation, 0);
        int size_1 = Shape.sizeUnsafe(shapeInformation, 1);
        int size_2 = Shape.sizeUnsafe(shapeInformation, 2);
        if (dim0 >= size_0 || dim1 >= size_1 || dim2 >= size_2) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + dim0 + "," + dim1 + "," + dim2 + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(dim0 * Shape.strideUnsafe(shapeInformation, 0, 3));
        }
        if (size_1 != 1) {
            offset += (long)(dim1 * Shape.strideUnsafe(shapeInformation, 1, 3));
        }
        if (size_2 != 1) {
            offset += (long)(dim2 * Shape.strideUnsafe(shapeInformation, 2, 3));
        }
        return offset;
    }

    public static long getOffsetUnsafe(int[] shapeInformation, int dim0, int dim1, int dim2) {
        int offset = 0;
        int size_0 = Shape.sizeUnsafe(shapeInformation, 0);
        int size_1 = Shape.sizeUnsafe(shapeInformation, 1);
        int size_2 = Shape.sizeUnsafe(shapeInformation, 2);
        if (dim0 >= size_0 || dim1 >= size_1 || dim2 >= size_2) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + dim0 + "," + dim1 + "," + dim2 + "] from a " + Arrays.toString(shapeInformation) + " NDArray");
        }
        if (size_0 != 1) {
            offset += dim0 * Shape.strideUnsafe(shapeInformation, 0, 3);
        }
        if (size_1 != 1) {
            offset += dim1 * Shape.strideUnsafe(shapeInformation, 1, 3);
        }
        if (size_2 != 1) {
            offset += dim2 * Shape.strideUnsafe(shapeInformation, 2, 3);
        }
        return offset;
    }

    public static long getOffset(IntBuffer shapeInformation, int dim0, int dim1, int dim2, int dim3) {
        int rank = Shape.rank(shapeInformation);
        if (rank != 4) {
            throw new IllegalArgumentException("Cannot use this getOffset method on arrays of rank != 4 (rank is: " + rank + ")");
        }
        long offset = 0L;
        int size_0 = Shape.size(shapeInformation, 0);
        int size_1 = Shape.size(shapeInformation, 1);
        int size_2 = Shape.size(shapeInformation, 2);
        int size_3 = Shape.size(shapeInformation, 3);
        if (dim0 >= size_0 || dim1 >= size_1 || dim2 >= size_2 || dim3 >= size_3) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + dim0 + "," + dim1 + "," + dim2 + "," + dim3 + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(dim0 * Shape.stride(shapeInformation, 0));
        }
        if (size_1 != 1) {
            offset += (long)(dim1 * Shape.stride(shapeInformation, 1));
        }
        if (size_2 != 1) {
            offset += (long)(dim2 * Shape.stride(shapeInformation, 2));
        }
        if (size_3 != 1) {
            offset += (long)(dim3 * Shape.stride(shapeInformation, 3));
        }
        return offset;
    }

    public static long getOffset(DataBuffer shapeInformation, int dim0, int dim1, int dim2, int dim3) {
        int rank = Shape.rank(shapeInformation);
        if (rank != 4) {
            throw new IllegalArgumentException("Cannot use this getOffset method on arrays of rank != 4 (rank is: " + rank + ")");
        }
        return Shape.getOffsetUnsafe(shapeInformation, dim0, dim1, dim2, dim3);
    }

    public static long getOffsetUnsafe(DataBuffer shapeInformation, int dim0, int dim1, int dim2, int dim3) {
        long offset = 0L;
        int size_0 = Shape.sizeUnsafe(shapeInformation, 0);
        int size_1 = Shape.sizeUnsafe(shapeInformation, 1);
        int size_2 = Shape.sizeUnsafe(shapeInformation, 2);
        int size_3 = Shape.sizeUnsafe(shapeInformation, 3);
        if (dim0 >= size_0 || dim1 >= size_1 || dim2 >= size_2 || dim3 >= size_3) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + dim0 + "," + dim1 + "," + dim2 + "," + dim3 + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(dim0 * Shape.strideUnsafe(shapeInformation, 0, 4));
        }
        if (size_1 != 1) {
            offset += (long)(dim1 * Shape.strideUnsafe(shapeInformation, 1, 4));
        }
        if (size_2 != 1) {
            offset += (long)(dim2 * Shape.strideUnsafe(shapeInformation, 2, 4));
        }
        if (size_3 != 1) {
            offset += (long)(dim3 * Shape.strideUnsafe(shapeInformation, 3, 4));
        }
        return offset;
    }

    public static long getOffsetUnsafe(int[] shapeInformation, int dim0, int dim1, int dim2, int dim3) {
        long offset = 0L;
        int size_0 = Shape.sizeUnsafe(shapeInformation, 0);
        int size_1 = Shape.sizeUnsafe(shapeInformation, 1);
        int size_2 = Shape.sizeUnsafe(shapeInformation, 2);
        int size_3 = Shape.sizeUnsafe(shapeInformation, 3);
        if (dim0 >= size_0 || dim1 >= size_1 || dim2 >= size_2 || dim3 >= size_3) {
            throw new IllegalArgumentException("Invalid indices: cannot get [" + dim0 + "," + dim1 + "," + dim2 + "," + dim3 + "] from a " + Arrays.toString(Shape.shape(shapeInformation)) + " NDArray");
        }
        if (size_0 != 1) {
            offset += (long)(dim0 * Shape.strideUnsafe(shapeInformation, 0, 4));
        }
        if (size_1 != 1) {
            offset += (long)(dim1 * Shape.strideUnsafe(shapeInformation, 1, 4));
        }
        if (size_2 != 1) {
            offset += (long)(dim2 * Shape.strideUnsafe(shapeInformation, 2, 4));
        }
        if (size_3 != 1) {
            offset += (long)(dim3 * Shape.strideUnsafe(shapeInformation, 3, 4));
        }
        return offset;
    }

    public static int[] sizeForAxes(int[] axes, int[] shape) {
        int[] ret = new int[shape.length];
        for (int i = 0; i < axes.length; ++i) {
            ret[i] = shape[axes[i]];
        }
        return ret;
    }

    public static boolean isVector(IntBuffer shapeInfo) {
        int rank = Shape.rank(shapeInfo);
        if (rank > 2 || rank < 1) {
            return false;
        }
        int len = Shape.length(shapeInfo);
        IntBuffer shape = Shape.shapeOf(shapeInfo);
        return shape.get(0) == len || shape.get(1) == len;
    }

    public static boolean isVector(DataBuffer shapeInfo) {
        int rank = Shape.rank(shapeInfo);
        if (rank > 2 || rank < 1) {
            return false;
        }
        int len = Shape.length(shapeInfo);
        DataBuffer shape = Shape.shapeOf(shapeInfo);
        return shape.getInt(0L) == len || shape.getInt(1L) == len;
    }

    public static boolean isVector(int[] shape) {
        if (shape.length > 2 || shape.length < 1) {
            return false;
        }
        long len = ArrayUtil.prodLong((int[])shape);
        return (long)shape[0] == len || (long)shape[1] == len;
    }

    public static boolean isMatrix(IntBuffer shapeInfo) {
        int rank = Shape.rank(shapeInfo);
        if (rank != 2) {
            return false;
        }
        return !Shape.isVector(shapeInfo);
    }

    public static boolean isMatrix(DataBuffer shapeInfo) {
        int rank = Shape.rank(shapeInfo);
        if (rank != 2) {
            return false;
        }
        return !Shape.isVector(shapeInfo);
    }

    public static boolean isMatrix(int[] shape) {
        if (shape.length != 2) {
            return false;
        }
        return !Shape.isVector(shape);
    }

    public static int[] squeeze(int[] shape) {
        if (Shape.isColumnVectorShape(shape)) {
            return shape;
        }
        ArrayList<Integer> ret = new ArrayList<Integer>();
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] == 1) continue;
            ret.add(shape[i]);
        }
        return ArrayUtil.toArray(ret);
    }

    public static boolean shapeEquals(int[] shape1, int[] shape2) {
        if (Shape.isColumnVectorShape(shape1) && Shape.isColumnVectorShape(shape2)) {
            return Arrays.equals(shape1, shape2);
        }
        if (Shape.isRowVectorShape(shape1) && Shape.isRowVectorShape(shape2)) {
            int[] shape1Comp = Shape.squeeze(shape1);
            int[] shape2Comp = Shape.squeeze(shape2);
            return Arrays.equals(shape1Comp, shape2Comp);
        }
        return Shape.scalarEquals(shape1 = Shape.squeeze(shape1), shape2 = Shape.squeeze(shape2)) || Arrays.equals(shape1, shape2);
    }

    public static boolean scalarEquals(int[] shape1, int[] shape2) {
        if (shape1.length == 0 && shape2.length == 1 && shape2[0] == 1) {
            return true;
        }
        return shape2.length == 0 && shape1.length == 1 && shape1[0] == 1;
    }

    public static boolean isRowVectorShape(DataBuffer shapeInfo) {
        int rank = Shape.rank(shapeInfo);
        DataBuffer shape = Shape.shapeOf(shapeInfo);
        return rank == 2 && shape.getInt(0L) == 1 || rank == 1;
    }

    public static boolean isRowVectorShape(IntBuffer shapeInfo) {
        int rank = Shape.rank(shapeInfo);
        IntBuffer shape = Shape.shapeOf(shapeInfo);
        return rank == 2 && shape.get(0) == 1 || rank == 1;
    }

    public static boolean isRowVectorShape(int[] shape) {
        return shape.length == 2 && shape[0] == 1 || shape.length == 1;
    }

    public static boolean isColumnVectorShape(int[] shape) {
        return shape.length == 2 && shape[1] == 1;
    }

    public static boolean isColumnVectorShape(long[] shape) {
        return shape.length == 2 && shape[1] == 1L;
    }

    public static RawArrayIterationInformation1 prepareRawArrayIter(INDArray dst) {
        return RawArrayIterationInformation1.builder().aOffset(dst.offset()).a(dst.data()).aStrides(dst.stride()).nDim(dst.rank()).shape(dst.shape()).build().computeOut();
    }

    public static StridePermutation[] createSortedStrides(int[] strides) {
        Object[] perm = StridePermutation.create(strides);
        Arrays.sort(perm);
        return perm;
    }

    public static int[] ensureAtMinRowVector(int ... shape) {
        if (shape.length >= 2) {
            return shape;
        }
        return new int[]{1, shape[0]};
    }

    public static long getTADLength(int[] shape, int ... dimensions) {
        int tadLength = 1;
        for (int i = 0; i < dimensions.length; ++i) {
            tadLength *= shape[dimensions[i]];
        }
        return tadLength;
    }

    public static int elementWiseStride(int[] shape, int[] stride, boolean isFOrder) {
        int nk;
        int ni;
        int oi;
        int[] olddims = ArrayUtil.copy((int[])shape);
        int[] oldstrides = ArrayUtil.copy((int[])stride);
        long[] newStrides = new long[stride.length];
        int oldnd = 0;
        int newShapeRank = 2;
        long[] newShape = new long[shape.length];
        newShape[0] = 1L;
        newShape[1] = ArrayUtil.prodLong((int[])shape);
        for (oi = 0; oi < shape.length; ++oi) {
            if (shape[oi] == 1) continue;
            olddims[oldnd] = shape[oi];
            oldstrides[oldnd] = stride[oi];
            ++oldnd;
        }
        long np = 1L;
        for (ni = 0; ni < newShapeRank; ++ni) {
            np *= newShape[ni];
        }
        long op = 1L;
        for (oi = 0; oi < oldnd; ++oi) {
            op *= (long)olddims[oi];
        }
        if (np != op) {
            return -1;
        }
        if (np == 0L) {
            return -1;
        }
        oi = 0;
        int oj = 1;
        ni = 0;
        int nj = 1;
        while (ni < newShapeRank && oi < oldnd) {
            np = newShape[ni];
            op = olddims[oi];
            while (np != op) {
                if (np < op) {
                    np *= newShape[nj++];
                    continue;
                }
                op *= (long)olddims[oj++];
            }
            for (int ok = oi; ok < oj - 1; ++ok) {
                if (!(isFOrder ? oldstrides[ok + 1] != olddims[ok] * oldstrides[ok] : oldstrides[ok] != olddims[ok + 1] * oldstrides[ok + 1])) continue;
                return -1;
            }
            if (isFOrder) {
                newStrides[ni] = oldstrides[oi];
                for (nk = ni + 1; nk < nj; ++nk) {
                    newStrides[nk] = newStrides[nk - 1] * newShape[nk - 1];
                }
            } else {
                newStrides[nj - 1] = oldstrides[oj - 1];
                for (nk = nj - 1; nk > ni; --nk) {
                    newStrides[nk - 1] = newStrides[nk] * newShape[nk];
                }
            }
            ni = nj++;
            oi = oj++;
        }
        long last_stride = ni >= 1 ? newStrides[ni - 1] : (long)stride[shape.length - 1];
        if (isFOrder && ni >= 1) {
            last_stride *= newShape[ni - 1];
        }
        for (nk = ni; nk < newShapeRank; ++nk) {
            newStrides[nk] = last_stride;
        }
        if (newStrides[newShapeRank - 1] >= Integer.MAX_VALUE) {
            throw new IllegalArgumentException("Element size can not be >= Integer.MAX_VALUE");
        }
        return (int)newStrides[newShapeRank - 1];
    }

    public static INDArray newShapeNoCopy(INDArray arr, int[] newShape, boolean isFOrder) {
        int nk;
        int ni;
        int oi;
        int[] olddims = ArrayUtil.copy((int[])arr.shape());
        int[] oldstrides = ArrayUtil.copy((int[])arr.stride());
        int[] newStrides = new int[newShape.length];
        int oldnd = 0;
        for (oi = 0; oi < arr.rank(); ++oi) {
            if (arr.size(oi) == 1) continue;
            olddims[oldnd] = arr.size(oi);
            oldstrides[oldnd] = arr.stride(oi);
            ++oldnd;
        }
        int np = 1;
        for (ni = 0; ni < newShape.length; ++ni) {
            np *= newShape[ni];
        }
        int op = 1;
        for (oi = 0; oi < oldnd; ++oi) {
            op *= olddims[oi];
        }
        if (np != op) {
            return null;
        }
        if (np == 0) {
            return null;
        }
        oi = 0;
        int oj = 1;
        ni = 0;
        int nj = 1;
        while (ni < newShape.length && oi < oldnd) {
            np = newShape[ni];
            op = olddims[oi];
            while (np != op) {
                if (np < op) {
                    np *= newShape[nj++];
                    continue;
                }
                op *= olddims[oj++];
            }
            for (int ok = oi; ok < oj - 1; ++ok) {
                if (!(isFOrder ? oldstrides[ok + 1] != olddims[ok] * oldstrides[ok] : oldstrides[ok] != olddims[ok + 1] * oldstrides[ok + 1])) continue;
                return null;
            }
            if (isFOrder) {
                newStrides[ni] = oldstrides[oi];
                for (nk = ni + 1; nk < nj; ++nk) {
                    newStrides[nk] = newStrides[nk - 1] * newShape[nk - 1];
                }
            } else {
                newStrides[nj - 1] = oldstrides[oj - 1];
                for (nk = nj - 1; nk > ni; --nk) {
                    newStrides[nk - 1] = newStrides[nk] * newShape[nk];
                }
            }
            ni = nj++;
            oi = oj++;
        }
        int last_stride = ni >= 1 ? newStrides[ni - 1] : arr.elementStride();
        if (isFOrder && ni >= 1) {
            last_stride *= newShape[ni - 1];
        }
        for (nk = ni; nk < newShape.length; ++nk) {
            newStrides[nk] = last_stride;
        }
        if (arr instanceof IComplexNDArray) {
            return Nd4j.createComplex(arr.data(), newShape, newStrides, arr.offset());
        }
        INDArray ret = Nd4j.create(arr.data(), newShape, newStrides, arr.offset(), isFOrder ? (char)'f' : 'c');
        return ret;
    }

    public static boolean cOrFortranOrder(int[] shape, int[] stride, int elementStride) {
        int dim;
        int i;
        boolean cContiguous = true;
        boolean isFortran = true;
        int sd = 1;
        for (i = shape.length - 1; i >= 0; --i) {
            dim = shape[i];
            if (stride[i] != sd) {
                cContiguous = false;
                break;
            }
            if (dim == 0) break;
            sd *= dim;
        }
        sd = elementStride;
        for (i = 0; i < shape.length; ++i) {
            dim = shape[i];
            if (stride[i] != sd) {
                isFortran = false;
            }
            if (dim == 0) break;
            sd *= dim;
        }
        return cContiguous || isFortran;
    }

    public static char getOrder(int[] shape, int[] stride, int elementStride) {
        int dim;
        int i;
        boolean cContiguous = true;
        boolean isFortran = true;
        int sd = 1;
        for (i = shape.length - 1; i >= 0; --i) {
            dim = shape[i];
            if (stride[i] != sd) {
                cContiguous = false;
                break;
            }
            if (dim == 0) break;
            sd *= dim;
        }
        sd = elementStride;
        for (i = 0; i < shape.length; ++i) {
            dim = shape[i];
            if (stride[i] != sd) {
                isFortran = false;
            }
            if (dim == 0) break;
            sd *= dim;
        }
        if (isFortran && cContiguous) {
            return 'a';
        }
        if (isFortran && !cContiguous) {
            return 'f';
        }
        if (!isFortran && !cContiguous) {
            return 'c';
        }
        return 'c';
    }

    public static char getOrder(INDArray arr) {
        return Shape.getOrder(arr.shape(), arr.stride(), arr.elementStride());
    }

    public static long sub2Ind(int[] shape, int[] indices) {
        long index = 0L;
        int shift = 1;
        for (int i = 0; i < shape.length; ++i) {
            index += (long)(shift * indices[i]);
            shift *= shape[i];
        }
        return index;
    }

    public static int[] ind2sub(int[] shape, long index, long numIndices) {
        long denom = numIndices;
        int[] ret = new int[shape.length];
        for (int i = ret.length - 1; i >= 0; --i) {
            if (index / (denom /= (long)shape[i]) >= Integer.MAX_VALUE) {
                throw new IllegalArgumentException("Dimension can not be >= Integer.MAX_VALUE");
            }
            ret[i] = (int)(index / denom);
            index %= denom;
        }
        return ret;
    }

    public static int[] ind2sub(int[] shape, long index) {
        return Shape.ind2sub(shape, index, ArrayUtil.prodLong((int[])shape));
    }

    public static int[] ind2sub(INDArray arr, long index) {
        return Shape.ind2sub(arr.shape(), index, ArrayUtil.prodLong((int[])arr.shape()));
    }

    public static int[] ind2subC(int[] shape, long index, long numIndices) {
        long denom = numIndices;
        int[] ret = new int[shape.length];
        for (int i = 0; i < shape.length; ++i) {
            if (index / (denom /= (long)shape[i]) >= Integer.MAX_VALUE) {
                throw new IllegalArgumentException("Dimension can not be >= Integer.MAX_VALUE");
            }
            ret[i] = (int)(index / denom);
            index %= denom;
        }
        return ret;
    }

    public static int[] ind2subC(int[] shape, long index) {
        return Shape.ind2subC(shape, index, ArrayUtil.prodLong((int[])shape));
    }

    public static int[] ind2subC(INDArray arr, long index) {
        return Shape.ind2subC(arr.shape(), index, ArrayUtil.prodLong((int[])arr.shape()));
    }

    public static long offsetFor(INDArray arr, int[] indexes) {
        ShapeOffsetResolution resolution = new ShapeOffsetResolution(arr);
        resolution.exec(Shape.toIndexes(indexes));
        return resolution.getOffset();
    }

    public static void assertShapeLessThan(int[] shape, int[] lessThan) {
        if (shape.length != lessThan.length) {
            throw new IllegalArgumentException("Shape length must be == less than length");
        }
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] < lessThan[i]) continue;
            throw new IllegalStateException("Shape[" + i + "] should be less than lessThan[" + i + "]");
        }
    }

    public static INDArrayIndex[] toIndexes(int[] indices) {
        INDArrayIndex[] ret = new INDArrayIndex[indices.length];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = new NDArrayIndex(indices[i]);
        }
        return ret;
    }

    public static int[] newStrides(int[] strides, int newLength, INDArrayIndex[] indexes) {
        if (strides.length > newLength) {
            int[] newStrides = new int[strides.length - 1];
            for (int i = 0; i < newStrides.length; ++i) {
                newStrides[i] = strides[i + 1];
            }
            strides = newStrides;
        }
        return strides;
    }

    public static boolean strideDescendingCAscendingF(INDArray array) {
        int[] strides = array.stride();
        if (array.isVector() && strides[0] == 1 && strides[1] == 1) {
            return true;
        }
        char order = array.ordering();
        if (order == 'c') {
            for (int i = 1; i < strides.length; ++i) {
                if (strides[i - 1] > strides[i]) continue;
                return false;
            }
            return true;
        }
        if (order == 'f') {
            for (int i = 1; i < strides.length; ++i) {
                if (strides[i - 1] < strides[i]) continue;
                return false;
            }
            return true;
        }
        if (order == 'a') {
            return true;
        }
        throw new RuntimeException("Invalid order: not c or f (is: " + order + ")");
    }

    public static int length(IntBuffer buffer) {
        int ret = 1;
        IntBuffer shape = Shape.shapeOf(buffer);
        int rank = Shape.rank(buffer);
        for (int i = 0; i < rank; ++i) {
            ret *= shape.get(i);
        }
        return ret;
    }

    public static int length(DataBuffer buffer) {
        int ret = 1;
        DataBuffer shape = Shape.shapeOf(buffer);
        int rank = Shape.rank(buffer);
        for (int i = 0; i < rank; ++i) {
            ret *= shape.getInt((long)i);
        }
        return ret;
    }

    public static int rank(DataBuffer buffer) {
        return buffer.getInt(0L);
    }

    public static int rank(IntBuffer buffer) {
        IntBuffer ret = (IntBuffer)buffer.position(0);
        return ret.get(0);
    }

    public static int rank(int[] buffer) {
        return buffer[0];
    }

    public static int size(IntBuffer buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer.get(1 + dimension);
    }

    public static int size(DataBuffer buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer.getInt((long)(1 + dimension));
    }

    public static int size(int[] buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer[1 + dimension];
    }

    public static int sizeUnsafe(DataBuffer buffer, int dimension) {
        return buffer.getInt((long)(1 + dimension));
    }

    public static int sizeUnsafe(int[] buffer, int dimension) {
        return buffer[1 + dimension];
    }

    public static int[] shape(IntBuffer buffer) {
        int[] ret = new int[Shape.rank(buffer)];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = buffer.get(1 + i);
        }
        return ret;
    }

    public static int[] shape(DataBuffer buffer) {
        int[] ret = new int[Shape.rank(buffer)];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = buffer.getInt((long)(1 + i));
        }
        return ret;
    }

    public static int[] shape(int[] buffer) {
        int[] ret = new int[Shape.rank(buffer)];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = buffer[1 + i];
        }
        return ret;
    }

    public static int stride(IntBuffer buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer.get(1 + rank + dimension);
    }

    public static int stride(DataBuffer buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer.getInt((long)(1 + rank + dimension));
    }

    public static int stride(int[] buffer, int dimension) {
        int rank = Shape.rank(buffer);
        if (dimension >= rank) {
            throw new IllegalArgumentException("Invalid dimension " + dimension + " for rank " + rank + " array");
        }
        return buffer[1 + rank + dimension];
    }

    public static int[] strideArr(DataBuffer buffer) {
        int[] ret = new int[Shape.rank(buffer)];
        DataBuffer stride = Shape.stride(buffer);
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = stride.getInt((long)i);
        }
        return ret;
    }

    public static int strideUnsafe(DataBuffer buffer, int dimension, int rank) {
        return buffer.getInt((long)(1 + rank + dimension));
    }

    public static int strideUnsafe(int[] buffer, int dimension, int rank) {
        return buffer[1 + rank + dimension];
    }

    public static int shapeInfoLength(int rank) {
        return rank * 2 + 4;
    }

    public static IntBuffer stride(IntBuffer buffer) {
        int rank = Shape.rank(buffer);
        IntBuffer ret = (IntBuffer)buffer.position(1 + rank);
        return ret.slice();
    }

    public static DataBuffer stride(DataBuffer buffer) {
        int rank = Shape.rank(buffer);
        return Nd4j.createBuffer(buffer, (long)(1 + rank), (long)rank);
    }

    public static int[] stride(int[] buffer) {
        int rank = Shape.rank(buffer);
        int[] ret = new int[rank];
        for (int i = 0; i < rank; ++i) {
            ret[i] = buffer[1 + rank + i];
        }
        return ret;
    }

    public static DataBuffer shapeOf(DataBuffer buffer) {
        int rank = buffer.getInt(0L);
        return Nd4j.createBuffer(buffer, 1L, (long)rank);
    }

    public static IntBuffer shapeOf(IntBuffer buffer) {
        IntBuffer ret = (IntBuffer)buffer.position(1);
        return ret.slice();
    }

    public static String shapeToString(INDArray arr) {
        return Shape.shapeToString(arr.shapeInfo());
    }

    public static String shapeToString(IntBuffer buffer) {
        int i;
        IntBuffer shapeBuff = Shape.shapeOf(buffer);
        int rank = Shape.rank(buffer);
        IntBuffer strideBuff = Shape.stride(buffer);
        StringBuilder sb = new StringBuilder();
        sb.append("Rank: " + rank + ",");
        sb.append("Offset: " + Shape.offset(buffer) + "\n");
        sb.append(" Order: " + Shape.order(buffer));
        sb.append(" Shape: [");
        for (i = 0; i < rank; ++i) {
            sb.append(shapeBuff.get(i));
            if (i >= rank - 1) continue;
            sb.append(",");
        }
        sb.append("], ");
        sb.append(" stride: [");
        for (i = 0; i < rank; ++i) {
            sb.append(strideBuff.get(i));
            if (i >= rank - 1) continue;
            sb.append(",");
        }
        sb.append("]");
        return sb.toString();
    }

    @Deprecated
    public static int offset(DataBuffer buffer) {
        int length = Shape.shapeInfoLength(Shape.rank(buffer));
        int ret = buffer.getInt((long)(length - 3));
        return ret;
    }

    @Deprecated
    public static int offset(int[] buffer) {
        int length = Shape.shapeInfoLength(Shape.rank(buffer));
        int ret = buffer[length - 3];
        return ret;
    }

    public static int offset(IntBuffer buffer) {
        int length = Shape.shapeInfoLength(Shape.rank(buffer));
        int ret = buffer.get(length - 3);
        return ret;
    }

    public static int elementWiseStride(DataBuffer buffer) {
        int length2 = Shape.shapeInfoLength(buffer.getInt(0L));
        return buffer.getInt((long)(length2 - 2));
    }

    public static int elementWiseStride(IntBuffer buffer) {
        int length2 = Shape.shapeInfoLength(buffer.get(0));
        return buffer.get(length2 - 2);
    }

    public static void setElementWiseStride(IntBuffer buffer, int elementWiseStride) {
        int length2 = Shape.shapeInfoLength(buffer.get(0));
        buffer.put(length2 - 2, elementWiseStride);
    }

    public static void setElementWiseStride(DataBuffer buffer, int elementWiseStride) {
        int length2 = Shape.shapeInfoLength(Shape.rank(buffer));
        buffer.put((long)(length2 - 2), elementWiseStride);
    }

    public static String bufferToString(IntBuffer buffer) {
        StringBuilder builder = new StringBuilder();
        int rank = buffer.get(0);
        builder.append("[ ").append(rank).append(", ");
        for (int p = 1; p < rank * 2 + 4; ++p) {
            builder.append(buffer.get(p));
            if (p >= rank * 2 + 4 - 1) continue;
            builder.append(", ");
        }
        builder.append("]");
        return builder.toString();
    }

    public static char order(IntBuffer buffer) {
        int length = Shape.shapeInfoLength(Shape.rank(buffer));
        return (char)buffer.get(length - 1);
    }

    public static char order(DataBuffer buffer) {
        int length = Shape.shapeInfoLength(Shape.rank(buffer));
        return (char)buffer.getInt((long)(length - 1));
    }

    public static char order(int[] buffer) {
        int length = Shape.shapeInfoLength(Shape.rank(buffer));
        return (char)buffer[length - 1];
    }

    @Deprecated
    public static void setOrder(IntBuffer buffer, char order) {
        int length = Shape.shapeInfoLength(Shape.rank(buffer));
        buffer.put(length - 1, order);
        throw new RuntimeException("setOrder called");
    }

    public static DataBuffer createShapeInformation(int[] shape, int[] stride, long offset, int elementWiseStride, char order) {
        int e;
        if (shape.length != stride.length) {
            throw new IllegalStateException("Shape and stride must be the same length");
        }
        int rank = shape.length;
        int[] shapeBuffer = new int[rank * 2 + 4];
        shapeBuffer[0] = rank;
        int count = 1;
        for (e = 0; e < shape.length; ++e) {
            shapeBuffer[count++] = shape[e];
        }
        for (e = 0; e < stride.length; ++e) {
            shapeBuffer[count++] = stride[e];
        }
        shapeBuffer[count++] = (int)offset;
        shapeBuffer[count++] = elementWiseStride;
        shapeBuffer[count] = order;
        DataBuffer ret = Nd4j.createBufferDetached(shapeBuffer);
        ret.setConstant(true);
        return ret;
    }

    public static IntBuffer toBuffer(int ... arr) {
        ByteBuffer directBuffer = ByteBuffer.allocateDirect(arr.length * 4).order(ByteOrder.nativeOrder());
        IntBuffer buffer = directBuffer.asIntBuffer();
        for (int i = 0; i < arr.length; ++i) {
            buffer.put(i, arr[i]);
        }
        return buffer;
    }

    public static String toString(IntBuffer buffer) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < buffer.capacity(); ++i) {
            sb.append(buffer.get(i));
            if (i >= buffer.capacity() - 1) continue;
            sb.append(",");
        }
        return sb.toString();
    }

    public static String toString(DataBuffer buffer) {
        return buffer.toString();
    }

    public static boolean wholeArrayDimension(int ... arr) {
        return arr.length == 1 && arr[0] == Integer.MAX_VALUE;
    }

    public static boolean contentEquals(int[] arr, DataBuffer other) {
        for (int i = 0; i < arr.length; ++i) {
            if (other.getInt((long)i) == arr[i]) continue;
            return false;
        }
        return true;
    }

    public static boolean contentEquals(int[] arr, IntBuffer other) {
        for (int i = 0; i < arr.length; ++i) {
            other.position(i);
            if (arr[i] == other.get()) continue;
            return false;
        }
        return true;
    }

    public static boolean isContiguousInBuffer(INDArray in) {
        int[] stridesIfContiguous;
        long dLength;
        int length = in.length();
        if ((long)length == (dLength = in.data().length())) {
            return true;
        }
        char order = in.ordering();
        int[] shape = in.shape();
        if (order == 'f') {
            stridesIfContiguous = ArrayUtil.calcStridesFortran((int[])shape);
        } else if (order == 'c') {
            stridesIfContiguous = ArrayUtil.calcStrides((int[])shape);
        } else if (order == 'a') {
            stridesIfContiguous = new int[]{1, 1};
        } else {
            throw new RuntimeException("Invalid order: not c or f (is: " + order + ")");
        }
        return Arrays.equals(in.stride(), stridesIfContiguous);
    }

    public static INDArray toMmulCompatible(INDArray input) {
        if (input.rank() != 2) {
            throw new IllegalArgumentException("Input must be rank 2 (matrix)");
        }
        boolean doCopy = false;
        if (input.ordering() == 'c' && (input.stride(0) != input.size(1) || input.stride(1) != 1)) {
            doCopy = true;
        } else if (input.ordering() == 'f' && (input.stride(0) != 1 || input.stride(1) != input.size(0))) {
            doCopy = true;
        }
        if (doCopy) {
            return Shape.toOffsetZeroCopyAnyOrder(input);
        }
        return input;
    }
}

