package adams.ml.dl4j.trainstopcriterion;

import adams.core.Index;
import adams.core.MessageCollection;
import adams.core.QuickInfoHelper;
import adams.flow.container.DL4JModelContainer;
import adams.ml.dl4j.EvaluationHelper;
import adams.ml.dl4j.EvaluationStatistic;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;

/* loaded from: input_file:adams/ml/dl4j/trainstopcriterion/NoImprovement.class */
public class NoImprovement extends AbstractTrainStopCriterion {
    private static final long serialVersionUID = 6975594226423139162L;
    protected EvaluationStatistic[] m_Statistics;
    protected Index m_ClassIndex;
    protected Index m_RegressionColumns;
    protected Map<EvaluationStatistic, Double> m_History;

    public String globalInfo() {
        return "Monitors one or more statistics, whether they improve at all over time.";
    }

    @Override // adams.ml.dl4j.trainstopcriterion.AbstractTrainStopCriterion
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("statistic", "statistics", new EvaluationStatistic[]{EvaluationStatistic.ACCURACY});
        this.m_OptionManager.add("index", "classIndex", new Index("first"));
        this.m_OptionManager.add("regression-columns", "regressionColumns", new Index("last"));
    }

    public void setStatistics(EvaluationStatistic[] evaluationStatisticArr) {
        this.m_Statistics = evaluationStatisticArr;
        reset();
    }

    public EvaluationStatistic[] getStatistics() {
        return this.m_Statistics;
    }

    public String statisticsTipText() {
        return "The statistics to monitor.";
    }

    public void setClassIndex(Index index) {
        this.m_ClassIndex = index;
        reset();
    }

    public Index getClassIndex() {
        return this.m_ClassIndex;
    }

    public String classIndexTipText() {
        return "The range of class label indices (eg used for AUC).";
    }

    public void setRegressionColumns(Index index) {
        this.m_RegressionColumns = index;
        reset();
    }

    public Index getRegressionColumns() {
        return this.m_RegressionColumns;
    }

    public String regressionColumnsTipText() {
        return "The range of columns to get regression statistics for.";
    }

    @Override // adams.ml.dl4j.trainstopcriterion.AbstractTrainStopCriterion
    public String getQuickInfo() {
        return (QuickInfoHelper.toString(this, "statistics", this.m_Statistics.length + " statistic" + (this.m_Statistics.length == 1 ? "" : "s")) + QuickInfoHelper.toString(this, "classIndex", this.m_ClassIndex, ", class labels: ")) + QuickInfoHelper.toString(this, "regressionColumns", this.m_RegressionColumns, ", reg cols: ");
    }

    @Override // adams.ml.dl4j.trainstopcriterion.AbstractTrainStopCriterion
    public void start() {
        super.start();
        this.m_History = new HashMap();
    }

    @Override // adams.ml.dl4j.trainstopcriterion.AbstractTrainStopCriterion
    public boolean requiresFlowContext() {
        return false;
    }

    @Override // adams.ml.dl4j.trainstopcriterion.AbstractTrainStopCriterion
    protected boolean doCheckStopping(DL4JModelContainer dL4JModelContainer, MessageCollection messageCollection) {
        double value;
        if (!dL4JModelContainer.hasValue("Evaluation")) {
            return false;
        }
        Object value2 = dL4JModelContainer.getValue("Evaluation");
        boolean z = true;
        for (EvaluationStatistic evaluationStatistic : this.m_Statistics) {
            try {
                if (!(value2 instanceof Evaluation)) {
                    if (!(value2 instanceof RegressionEvaluation)) {
                        throw new IllegalStateException("Unhandled evaluation class: " + value2.getClass().getName());
                        break;
                    }
                    RegressionEvaluation regressionEvaluation = (RegressionEvaluation) value2;
                    this.m_RegressionColumns.setMax(regressionEvaluation.numColumns());
                    value = EvaluationHelper.getValue(regressionEvaluation, evaluationStatistic, this.m_RegressionColumns.getIntIndex());
                } else {
                    Evaluation evaluation = (Evaluation) value2;
                    this.m_ClassIndex.setMax(evaluation.falseNegatives().size());
                    value = EvaluationHelper.getValue(evaluation, evaluationStatistic, this.m_ClassIndex.getIntIndex());
                }
                if (this.m_History.isEmpty()) {
                    z = false;
                } else if (this.m_History.get(evaluationStatistic).doubleValue() != value) {
                    z = false;
                }
                if (!z) {
                    this.m_History.put(evaluationStatistic, Double.valueOf(value));
                }
            } catch (Exception e) {
                getLogger().log(Level.SEVERE, "Failed to obtain statistic " + evaluationStatistic + " from " + value2.getClass().getName() + " object!", e);
                z = false;
            }
        }
        if (isLoggingEnabled()) {
            if (z) {
                getLogger().info("No improvement: " + this.m_History);
            } else {
                getLogger().fine("Change: " + this.m_History);
            }
        }
        return z;
    }
}
