/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.evaluator;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.evaluator.Evaluator;

public class IndexEvaluator
extends Evaluator {
    private Evaluator evaluator;
    private Integer predictionsIndex;
    private Integer labelsIndex;

    public IndexEvaluator(Evaluator evaluator, int index) {
        this(evaluator, index, index);
    }

    public IndexEvaluator(Evaluator evaluator, Integer predictionsIndex, Integer labelsIndex) {
        super(evaluator.getName());
        this.evaluator = evaluator;
        this.predictionsIndex = predictionsIndex;
        this.labelsIndex = labelsIndex;
    }

    @Override
    public NDArray evaluate(NDList labels, NDList predictions) {
        return this.evaluator.evaluate(this.getLabels(labels), this.getPredictions(predictions));
    }

    @Override
    public void addAccumulator(String key) {
        this.evaluator.addAccumulator(key);
    }

    @Override
    public void updateAccumulator(String key, NDList labels, NDList predictions) {
        this.evaluator.updateAccumulator(key, this.getLabels(labels), this.getPredictions(predictions));
    }

    @Override
    public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
        this.evaluator.updateAccumulators(keys, this.getLabels(labels), this.getPredictions(predictions));
    }

    @Override
    public void resetAccumulator(String key) {
        this.evaluator.resetAccumulator(key);
    }

    @Override
    public float getAccumulator(String key) {
        return this.evaluator.getAccumulator(key);
    }

    private NDList getPredictions(NDList predictions) {
        if (this.predictionsIndex == null) {
            return predictions;
        }
        return new NDList((NDArray)predictions.get(this.predictionsIndex));
    }

    private NDList getLabels(NDList labels) {
        if (this.labelsIndex == null) {
            return labels;
        }
        return new NDList((NDArray)labels.get(this.labelsIndex));
    }
}

