/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.listener;

import ai.djl.Device;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDList;
import ai.djl.training.Trainer;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.listener.TrainingListenerAdapter;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class EvaluatorTrainingListener
extends TrainingListenerAdapter {
    public static final String TRAIN_EPOCH = "train/epoch";
    public static final String TRAIN_PROGRESS = "train/progress";
    public static final String TRAIN_ALL = "train/all";
    public static final String VALIDATE_EPOCH = "validate/epoch";
    private int progressUpdateFrequency;
    private int progressCounter;
    private Map<String, Float> latestEvaluations;

    public EvaluatorTrainingListener() {
        this(5);
    }

    public EvaluatorTrainingListener(int progressUpdateFrequency) {
        this.progressUpdateFrequency = progressUpdateFrequency;
        this.progressCounter = 0;
        this.latestEvaluations = new ConcurrentHashMap<String, Float>();
    }

    @Override
    public void onEpoch(Trainer trainer) {
        Metrics metrics = trainer.getMetrics();
        for (Evaluator evaluator : trainer.getEvaluators()) {
            float trainValue = evaluator.getAccumulator(TRAIN_EPOCH);
            float validateValue = evaluator.getAccumulator(VALIDATE_EPOCH);
            if (metrics != null) {
                String key = EvaluatorTrainingListener.metricName(evaluator, TRAIN_EPOCH);
                metrics.addMetric(key, Float.valueOf(trainValue));
                String validateKey = EvaluatorTrainingListener.metricName(evaluator, VALIDATE_EPOCH);
                metrics.addMetric(validateKey, Float.valueOf(validateValue));
            }
            this.latestEvaluations.put("train_" + evaluator.getName(), Float.valueOf(trainValue));
            this.latestEvaluations.put("validate_" + evaluator.getName(), Float.valueOf(validateValue));
            if (evaluator != trainer.getLoss()) continue;
            this.latestEvaluations.put("train_loss", Float.valueOf(trainValue));
            this.latestEvaluations.put("validate_loss", Float.valueOf(validateValue));
        }
        for (Evaluator evaluator : trainer.getEvaluators()) {
            evaluator.resetAccumulator(TRAIN_EPOCH);
            evaluator.resetAccumulator(TRAIN_PROGRESS);
            evaluator.resetAccumulator(TRAIN_ALL);
            evaluator.resetAccumulator(VALIDATE_EPOCH);
        }
        this.progressCounter = 0;
    }

    @Override
    public void onTrainingBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        for (Evaluator evaluator : trainer.getEvaluators()) {
            evaluator.resetAccumulator(TRAIN_ALL);
        }
        this.updateEvaluators(trainer, batchData, new String[]{TRAIN_EPOCH, TRAIN_PROGRESS, TRAIN_ALL});
        Metrics metrics = trainer.getMetrics();
        if (metrics != null) {
            float value;
            String key;
            for (Evaluator evaluator : trainer.getEvaluators()) {
                key = EvaluatorTrainingListener.metricName(evaluator, TRAIN_ALL);
                value = evaluator.getAccumulator(TRAIN_ALL);
                metrics.addMetric(key, Float.valueOf(value));
            }
            ++this.progressCounter;
            if (this.progressCounter == this.progressUpdateFrequency) {
                for (Evaluator evaluator : trainer.getEvaluators()) {
                    key = EvaluatorTrainingListener.metricName(evaluator, TRAIN_PROGRESS);
                    value = evaluator.getAccumulator(TRAIN_PROGRESS);
                    metrics.addMetric(key, Float.valueOf(value));
                }
                this.progressCounter = 0;
            }
        }
    }

    @Override
    public void onValidationBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        this.updateEvaluators(trainer, batchData, new String[]{VALIDATE_EPOCH});
    }

    private void updateEvaluators(Trainer trainer, TrainingListener.BatchData batchData, String[] accumulators) {
        for (Evaluator evaluator : trainer.getEvaluators()) {
            for (Device device : batchData.getLabels().keySet()) {
                NDList labels = batchData.getLabels().get(device);
                NDList predictions = batchData.getPredictions().get(device);
                evaluator.updateAccumulators(accumulators, labels, predictions);
            }
        }
    }

    @Override
    public void onTrainingBegin(Trainer trainer) {
        for (Evaluator evaluator : trainer.getEvaluators()) {
            evaluator.addAccumulator(TRAIN_EPOCH);
            evaluator.addAccumulator(TRAIN_PROGRESS);
            evaluator.addAccumulator(TRAIN_ALL);
            evaluator.addAccumulator(VALIDATE_EPOCH);
        }
    }

    public static String metricName(Evaluator evaluator, String stage) {
        switch (stage) {
            case "train/epoch": {
                return "train_epoch_" + evaluator.getName();
            }
            case "train/progress": {
                return "train_progress_" + evaluator.getName();
            }
            case "train/all": {
                return "train_all_" + evaluator.getName();
            }
            case "validate/epoch": {
                return "validate_epoch_" + evaluator.getName();
            }
        }
        throw new IllegalArgumentException("Invalid metric stage");
    }

    public Map<String, Float> getLatestEvaluations() {
        return this.latestEvaluations;
    }
}

