package adams.flow.transformer.wekaevaluationpostprocessor;

import adams.core.QuickInfoHelper;
import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.List;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.Prediction;

/* loaded from: input_file:adams/flow/transformer/wekaevaluationpostprocessor/RemoveWorst.class */
public class RemoveWorst extends AbstractNumericClassPostProcessor {
    private static final long serialVersionUID = -8126062783012759418L;
    protected double m_Percent;

    public String globalInfo() {
        return "Removes the worst predictions, which are considered outliers that detract from the actual model performance.\nOnly works on numeric predictions.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("percent", "percent", Double.valueOf(0.01d), Double.valueOf(0.0d), Double.valueOf(1.0d));
    }

    public void setPercent(double d) {
        if (getOptionManager().isValid("percent", Double.valueOf(d))) {
            this.m_Percent = d;
            reset();
        }
    }

    public double getPercent() {
        return this.m_Percent;
    }

    public String percentTipText() {
        return "The percentage of worst predictions to remove (0-1).";
    }

    @Override // adams.flow.transformer.wekaevaluationpostprocessor.AbstractWekaEvaluationPostProcessor
    public String getQuickInfo() {
        return QuickInfoHelper.toString(this, "percent", Double.valueOf(this.m_Percent), "percent: ");
    }

    @Override // adams.flow.transformer.wekaevaluationpostprocessor.AbstractWekaEvaluationPostProcessor
    protected List<Evaluation> doPostProcess(Evaluation evaluation) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList(evaluation.predictions());
        arrayList2.sort(new AbsolutePredictionErrorComparator());
        int round = (int) Math.round((1.0d - this.m_Percent) * arrayList2.size());
        if (round >= arrayList2.size()) {
            round = arrayList2.size() - 1;
        }
        double abs = Math.abs(((Prediction) arrayList2.get(round)).actual() - ((Prediction) arrayList2.get(round)).predicted());
        if (isLoggingEnabled()) {
            getLogger().info("threshold: " + abs);
        }
        TIntArrayList tIntArrayList = new TIntArrayList();
        for (int i = 0; i < evaluation.predictions().size(); i++) {
            Prediction prediction = (Prediction) evaluation.predictions().get(i);
            if (Math.abs(prediction.actual() - prediction.predicted()) < abs) {
                tIntArrayList.add(i);
            }
        }
        arrayList.add(newEvaluation("-removed_worst_" + this.m_Percent, evaluation, tIntArrayList));
        return arrayList;
    }
}
