/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.convolutional;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.convolutional.Convolution;
import ai.djl.util.Preconditions;

public class Conv3d
extends Convolution {
    private static final LayoutType[] EXPECTED_LAYOUT = new LayoutType[]{LayoutType.BATCH, LayoutType.CHANNEL, LayoutType.DEPTH, LayoutType.HEIGHT, LayoutType.WIDTH};
    private static final String STRING_LAYOUT = "NCDHW";
    private static final int NUM_DIMENSIONS = 5;

    Conv3d(Builder builder) {
        super(builder);
    }

    @Override
    protected LayoutType[] getExpectedLayout() {
        return EXPECTED_LAYOUT;
    }

    @Override
    protected String getStringLayout() {
        return STRING_LAYOUT;
    }

    @Override
    protected int numDimensions() {
        return 5;
    }

    public static NDList conv3d(NDArray input, NDArray weight) {
        return Conv3d.conv3d(input, weight, null, new Shape(1L, 1L, 1L), new Shape(0L, 0L, 0L), new Shape(1L, 1L, 1L));
    }

    public static NDList conv3d(NDArray input, NDArray weight, NDArray bias) {
        return Conv3d.conv3d(input, weight, bias, new Shape(1L, 1L, 1L), new Shape(0L, 0L, 0L), new Shape(1L, 1L, 1L));
    }

    public static NDList conv3d(NDArray input, NDArray weight, NDArray bias, Shape stride) {
        return Conv3d.conv3d(input, weight, bias, stride, new Shape(0L, 0L, 0L), new Shape(1L, 1L, 1L));
    }

    public static NDList conv3d(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding) {
        return Conv3d.conv3d(input, weight, bias, stride, padding, new Shape(1L, 1L, 1L));
    }

    public static NDList conv3d(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape dilation) {
        return Conv3d.conv3d(input, weight, bias, stride, padding, dilation, 1);
    }

    public static NDList conv3d(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape dilation, int groups) {
        Preconditions.checkArgument(input.getShape().dimension() == 5 && weight.getShape().dimension() == 5, "the shape of input or weight doesn't match the conv2d");
        Preconditions.checkArgument(stride.dimension() == 3 && padding.dimension() == 3 && dilation.dimension() == 3, "the shape of stride or padding or dilation doesn't match the conv2d");
        return Convolution.convolution(input, weight, bias, stride, padding, dilation, groups);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder
    extends Convolution.ConvolutionBuilder<Builder> {
        Builder() {
            this.stride = new Shape(1L, 1L, 1L);
            this.padding = new Shape(0L, 0L, 0L);
            this.dilation = new Shape(1L, 1L, 1L);
        }

        @Override
        protected Builder self() {
            return this;
        }

        public Conv3d build() {
            this.validate();
            return new Conv3d(this);
        }
    }
}

