/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.loss.Loss;

public class L2Loss
extends Loss {
    private float weight;

    public L2Loss() {
        this("L2Loss");
    }

    public L2Loss(String name) {
        this(name, 0.5f);
    }

    public L2Loss(String name, float weight) {
        super(name);
        this.weight = weight;
    }

    @Override
    public NDArray evaluate(NDList label, NDList prediction) {
        NDArray pred = prediction.singletonOrThrow();
        NDArray labelReshaped = label.singletonOrThrow().reshape(pred.getShape());
        NDArray loss = labelReshaped.sub(pred).square().mul(Float.valueOf(this.weight));
        return loss.mean();
    }
}

