/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer.wekaevaluationpostprocessor;

import adams.core.QuickInfoHelper;
import adams.core.option.OptionHandler;
import adams.data.statistics.StatUtils;
import adams.flow.transformer.wekaevaluationpostprocessor.AbstractNumericClassPostProcessor;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.List;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.Prediction;

public class RemoveWorstStdDev
extends AbstractNumericClassPostProcessor {
    private static final long serialVersionUID = -8126062783012759418L;
    protected double m_Multiplier;

    public String globalInfo() {
        return "Removes the worst predictions, which are considered outliers that detract from the actual model performance. All errors that are larger than 'mean + multiplier*stdev' are considered outliers. Mean and stdev are calculated on the actual class values.\nOnly works on numeric predictions.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("multiplier", "multiplier", (Object)3.0, (Number)0.0, null);
    }

    public void setMultiplier(double value) {
        if (this.getOptionManager().isValid("multiplier", (Number)value)) {
            this.m_Multiplier = value;
            this.reset();
        }
    }

    public double getMultiplier() {
        return this.m_Multiplier;
    }

    public String multiplierTipText() {
        return "The multiplier for the standard deviation (mean + multiplier*stdev = threshold for outliers).";
    }

    @Override
    public String getQuickInfo() {
        return QuickInfoHelper.toString((OptionHandler)this, (String)"multiplier", (Object)this.m_Multiplier, (String)"multiplier: ");
    }

    @Override
    protected List<Evaluation> doPostProcess(Evaluation eval) {
        ArrayList<Evaluation> result = new ArrayList<Evaluation>();
        TDoubleArrayList errors = new TDoubleArrayList();
        for (Prediction pred : eval.predictions()) {
            errors.add(Math.abs(pred.actual() - pred.predicted()));
        }
        double mean = StatUtils.mean((double[])errors.toArray());
        double stdev = StatUtils.stddev((double[])errors.toArray(), (boolean)true);
        double threshold = mean + stdev * this.m_Multiplier;
        if (this.isLoggingEnabled()) {
            this.getLogger().info("mean: " + mean);
            this.getLogger().info("stdev: " + stdev);
            this.getLogger().info("threshold: " + threshold);
        }
        TIntArrayList indices = new TIntArrayList();
        for (int i = 0; i < errors.size(); ++i) {
            if (!(errors.get(i) < threshold)) continue;
            indices.add(i);
        }
        result.add(this.newEvaluation("-removed_worststdev_" + this.m_Multiplier, eval, (TIntList)indices));
        return result;
    }
}

