/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer.wekaevaluationpostprocessor;

import adams.core.QuickInfoHelper;
import adams.core.option.OptionHandler;
import adams.flow.transformer.wekaevaluationpostprocessor.AbsolutePredictionErrorComparator;
import adams.flow.transformer.wekaevaluationpostprocessor.AbstractNumericClassPostProcessor;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.List;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.Prediction;

public class RemoveWorst
extends AbstractNumericClassPostProcessor {
    private static final long serialVersionUID = -8126062783012759418L;
    protected double m_Percent;

    public String globalInfo() {
        return "Removes the worst predictions, which are considered outliers that detract from the actual model performance.\nOnly works on numeric predictions.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("percent", "percent", (Object)0.01, (Number)0.0, (Number)1.0);
    }

    public void setPercent(double value) {
        if (this.getOptionManager().isValid("percent", (Number)value)) {
            this.m_Percent = value;
            this.reset();
        }
    }

    public double getPercent() {
        return this.m_Percent;
    }

    public String percentTipText() {
        return "The percentage of worst predictions to remove (0-1).";
    }

    @Override
    public String getQuickInfo() {
        return QuickInfoHelper.toString((OptionHandler)this, (String)"percent", (Object)this.m_Percent, (String)"percent: ");
    }

    @Override
    protected List<Evaluation> doPostProcess(Evaluation eval) {
        ArrayList<Evaluation> result = new ArrayList<Evaluation>();
        ArrayList<Prediction> sorted = new ArrayList<Prediction>(eval.predictions());
        sorted.sort(new AbsolutePredictionErrorComparator());
        int index = (int)Math.round((1.0 - this.m_Percent) * (double)sorted.size());
        if (index >= sorted.size()) {
            index = sorted.size() - 1;
        }
        double threshold = Math.abs(((Prediction)sorted.get(index)).actual() - ((Prediction)sorted.get(index)).predicted());
        if (this.isLoggingEnabled()) {
            this.getLogger().info("threshold: " + threshold);
        }
        TIntArrayList indices = new TIntArrayList();
        for (int i = 0; i < eval.predictions().size(); ++i) {
            Prediction pred = (Prediction)eval.predictions().get(i);
            if (!(Math.abs(pred.actual() - pred.predicted()) < threshold)) continue;
            indices.add(i);
        }
        result.add(this.newEvaluation("-removed_worst_" + this.m_Percent, eval, (TIntList)indices));
        return result;
    }
}

