package adams.ml.dl4j.trainstopcriterion;

import adams.core.Index;
import adams.core.MessageCollection;
import adams.core.QuickInfoHelper;
import adams.flow.container.DL4JModelContainer;
import adams.ml.dl4j.EvaluationHelper;
import adams.ml.dl4j.EvaluationStatistic;
import java.util.logging.Level;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;

/* loaded from: input_file:adams/ml/dl4j/trainstopcriterion/Statistic.class */
public class Statistic extends AbstractTrainStopCriterion {
    private static final long serialVersionUID = 6975594226423139162L;
    protected EvaluationStatistic m_Statistic;
    protected Index m_ClassIndex;
    protected Index m_RegressionColumns;
    protected double m_Threshold;
    protected ThresholdCheck m_ThresholdCheck;

    /* loaded from: input_file:adams/ml/dl4j/trainstopcriterion/Statistic$ThresholdCheck.class */
    public enum ThresholdCheck {
        GREATER_OR_EQUAL,
        GREATER,
        LESS,
        LESS_OR_EQUAL
    }

    public String globalInfo() {
        return "Monitors a statistic, whether it goes below or above a threshold.";
    }

    @Override // adams.ml.dl4j.trainstopcriterion.AbstractTrainStopCriterion
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("statistic", "statistic", EvaluationStatistic.ACCURACY);
        this.m_OptionManager.add("index", "classIndex", new Index("first"));
        this.m_OptionManager.add("regression-columns", "regressionColumns", new Index("last"));
        this.m_OptionManager.add("threshold", "threshold", Double.valueOf(0.9d));
        this.m_OptionManager.add("threshold-check", "thresholdCheck", ThresholdCheck.GREATER_OR_EQUAL);
    }

    public void setStatistic(EvaluationStatistic evaluationStatistic) {
        this.m_Statistic = evaluationStatistic;
        reset();
    }

    public EvaluationStatistic getStatistic() {
        return this.m_Statistic;
    }

    public String statisticTipText() {
        return "The name of the variable to monitor.";
    }

    public void setClassIndex(Index index) {
        this.m_ClassIndex = index;
        reset();
    }

    public Index getClassIndex() {
        return this.m_ClassIndex;
    }

    public String classIndexTipText() {
        return "The range of class label indices (eg used for AUC).";
    }

    public void setRegressionColumns(Index index) {
        this.m_RegressionColumns = index;
        reset();
    }

    public Index getRegressionColumns() {
        return this.m_RegressionColumns;
    }

    public String regressionColumnsTipText() {
        return "The range of columns to get regression statistics for.";
    }

    public void setThreshold(double d) {
        this.m_Threshold = d;
        reset();
    }

    public double getThreshold() {
        return this.m_Threshold;
    }

    public String thresholdTipText() {
        return "The threshold for the statistic.";
    }

    public void setThresholdCheck(ThresholdCheck thresholdCheck) {
        this.m_ThresholdCheck = thresholdCheck;
        reset();
    }

    public ThresholdCheck getThresholdCheck() {
        return this.m_ThresholdCheck;
    }

    public String thresholdCheckTipText() {
        return "Determines how to interpret the threshold.";
    }

    @Override // adams.ml.dl4j.trainstopcriterion.AbstractTrainStopCriterion
    public String getQuickInfo() {
        return (((QuickInfoHelper.toString(this, "statistic", this.m_Statistic) + QuickInfoHelper.toString(this, "thresholdCheck", this.m_ThresholdCheck, " ")) + QuickInfoHelper.toString(this, "threshold", Double.valueOf(this.m_Threshold), " ")) + QuickInfoHelper.toString(this, "classIndex", this.m_ClassIndex, ", class labels: ")) + QuickInfoHelper.toString(this, "regressionColumns", this.m_RegressionColumns, ", reg cols: ");
    }

    @Override // adams.ml.dl4j.trainstopcriterion.AbstractTrainStopCriterion
    public boolean requiresFlowContext() {
        return false;
    }

    @Override // adams.ml.dl4j.trainstopcriterion.AbstractTrainStopCriterion
    protected boolean doCheckStopping(DL4JModelContainer dL4JModelContainer, MessageCollection messageCollection) {
        boolean z;
        if (!dL4JModelContainer.hasValue("Evaluation")) {
            return false;
        }
        Object value = dL4JModelContainer.getValue("Evaluation");
        double d = Double.NaN;
        try {
            if (value instanceof Evaluation) {
                Evaluation evaluation = (Evaluation) value;
                this.m_ClassIndex.setMax(evaluation.falseNegatives().size());
                d = EvaluationHelper.getValue(evaluation, this.m_Statistic, this.m_ClassIndex.getIntIndex());
            } else {
                if (!(value instanceof RegressionEvaluation)) {
                    throw new IllegalStateException("Unhandled evaluation class: " + value.getClass().getName());
                }
                RegressionEvaluation regressionEvaluation = (RegressionEvaluation) value;
                this.m_RegressionColumns.setMax(regressionEvaluation.numColumns());
                d = EvaluationHelper.getValue(regressionEvaluation, this.m_Statistic, this.m_RegressionColumns.getIntIndex());
            }
            switch (this.m_ThresholdCheck) {
                case GREATER:
                    z = d > this.m_Threshold;
                    break;
                case GREATER_OR_EQUAL:
                    z = d >= this.m_Threshold;
                    break;
                case LESS:
                    z = d < this.m_Threshold;
                    break;
                case LESS_OR_EQUAL:
                    z = d <= this.m_Threshold;
                    break;
                default:
                    throw new IllegalStateException("Unhandled threshold check: " + this.m_ThresholdCheck);
            }
        } catch (Exception e) {
            getLogger().log(Level.SEVERE, "Failed to obtain statistic " + this.m_Statistic + " from " + value.getClass().getName() + " object!", e);
            z = false;
        }
        if (z && isLoggingEnabled()) {
            getLogger().info(this.m_Statistic + ": " + d + " " + this.m_ThresholdCheck + " " + this.m_Threshold);
        }
        return z;
    }
}
