/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.training.loss.Loss;
import ai.djl.util.Pair;
import java.util.List;

public abstract class AbstractCompositeLoss
extends Loss {
    protected List<Loss> components;

    public AbstractCompositeLoss(String name) {
        super(name);
    }

    protected abstract Pair<NDList, NDList> inputForComponent(int var1, NDList var2, NDList var3);

    public List<Loss> getComponents() {
        return this.components;
    }

    @Override
    public NDArray evaluate(NDList labels, NDList predictions) {
        NDArray[] lossComponents = new NDArray[this.components.size()];
        for (int i = 0; i < this.components.size(); ++i) {
            Pair<NDList, NDList> inputs = this.inputForComponent(i, labels, predictions);
            lossComponents[i] = this.components.get(i).evaluate(inputs.getKey(), inputs.getValue());
        }
        return NDArrays.add(lossComponents);
    }

    @Override
    public void addAccumulator(String key) {
        for (Loss component : this.components) {
            component.addAccumulator(key);
        }
    }

    @Override
    public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
        for (int i = 0; i < this.components.size(); ++i) {
            Pair<NDList, NDList> inputs = this.inputForComponent(i, labels, predictions);
            this.components.get(i).updateAccumulators(keys, inputs.getKey(), inputs.getValue());
        }
    }

    @Override
    public void resetAccumulator(String key) {
        for (Loss component : this.components) {
            component.resetAccumulator(key);
        }
    }

    @Override
    public float getAccumulator(String key) {
        return (float)this.components.stream().mapToDouble(component -> component.getAccumulator(key)).sum();
    }
}

