/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.forward.standard;

import deepboof.forward.SpatialBatchNorm;
import deepboof.impl.forward.standard.FunctionBatchNorm_F32;
import deepboof.tensors.Tensor_F32;

public class SpatialBatchNorm_F32
extends FunctionBatchNorm_F32
implements SpatialBatchNorm<Tensor_F32> {
    public SpatialBatchNorm_F32(boolean requiresGammaBeta) {
        super(requiresGammaBeta);
    }

    @Override
    public void _initialize() {
        if (this.shapeInput.length != 3) {
            throw new IllegalArgumentException("Expected 3 DOF in a spatial shape (C,W,H)");
        }
        this.shapeOutput = (int[])this.shapeInput.clone();
        int[] paramShape = new int[]{this.shapeInput[0], this.requiresGammaBeta ? 4 : 2};
        this.shapeParameters.add(paramShape);
        this.params.reshape(paramShape);
    }

    @Override
    public void _forward(Tensor_F32 input, Tensor_F32 output) {
        int C = input.length(1);
        int W = input.length(2);
        int H = input.length(3);
        int D = W * H;
        int indexIn = input.startIndex;
        int indexOut = output.startIndex;
        if (this.hasGammaBeta()) {
            for (int batch = 0; batch < this.miniBatchSize; ++batch) {
                int indexP = this.params.startIndex;
                for (int channel = 0; channel < C; ++channel) {
                    float mean = this.params.d[indexP++];
                    float inv_stdev_eps = this.params.d[indexP++];
                    float gamma = this.params.d[indexP++];
                    float beta = this.params.d[indexP++];
                    int end = indexIn + D;
                    while (indexIn < end) {
                        output.d[indexOut++] = (input.d[indexIn++] - mean) * (gamma * inv_stdev_eps) + beta;
                    }
                }
            }
        } else {
            for (int batch = 0; batch < this.miniBatchSize; ++batch) {
                int indexP = this.params.startIndex;
                for (int channel = 0; channel < C; ++channel) {
                    float mean = this.params.d[indexP++];
                    float inv_stdev_eps = this.params.d[indexP++];
                    int end = indexIn + D;
                    while (indexIn < end) {
                        output.d[indexOut++] = (input.d[indexIn++] - mean) * inv_stdev_eps;
                    }
                }
            }
        }
    }
}

