/*
 * Decompiled with CFR 0.152.
 */
package weka.filters.supervised.instance;

import adams.core.Performance;
import adams.core.option.OptionHandler;
import adams.data.spreadsheet.SpreadSheet;
import adams.data.spreadsheet.SpreadSheetColumnIndex;
import adams.flow.container.WekaTrainTestSetContainer;
import adams.flow.control.removeoutliers.AbstractOutlierDetector;
import adams.flow.control.removeoutliers.Null;
import adams.flow.core.Token;
import adams.flow.transformer.WekaPredictionsToSpreadSheet;
import adams.multiprocess.Job;
import adams.multiprocess.JobList;
import adams.multiprocess.LocalJobRunner;
import adams.multiprocess.WekaCrossValidationJob;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Set;
import java.util.Vector;
import weka.classifiers.AggregateEvaluations;
import weka.classifiers.Classifier;
import weka.classifiers.CrossValidationHelper;
import weka.classifiers.DefaultCrossValidationFoldGenerator;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.LinearRegressionJ;
import weka.core.Capabilities;
import weka.core.CapabilitiesHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Randomizable;
import weka.core.WekaOptionUtils;
import weka.filters.SimpleBatchFilter;

public class RemoveOutliers
extends SimpleBatchFilter
implements Randomizable {
    private static final long serialVersionUID = -8292965351930853084L;
    public static final String CLASSIFIER = "classifier";
    public static final String NUM_FOLDS = "num-folds";
    public static final String NUM_THREADS = "num-threads";
    public static final String DETECTOR = "detector";
    protected Classifier m_Classifier = this.getDefaultClassifier();
    protected int m_Seed = this.getDefaultSeed();
    protected int m_NumFolds = this.getDefaultNumFolds();
    protected AbstractOutlierDetector m_Detector = this.getDefaultDetector();
    protected int m_NumThreads = this.getDefaultNumThreads();

    public String globalInfo() {
        return "Cross-validates the specified classifier on the incoming data and applies the outlier detector to the actual vs predicted data to remove the outliers.\nNB: only works on full dataset, not instance by instance.";
    }

    protected Classifier getDefaultClassifier() {
        return new LinearRegressionJ();
    }

    public void setClassifier(Classifier value) {
        this.m_Classifier = value;
        this.reset();
    }

    public Classifier getClassifier() {
        return this.m_Classifier;
    }

    public String classifierTipText() {
        return "The classifier to use for generating the actual vs predicted data.";
    }

    protected int getDefaultSeed() {
        return 1;
    }

    public void setSeed(int value) {
        this.m_Seed = value;
        this.reset();
    }

    public int getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value for the cross-validation.";
    }

    protected int getDefaultNumFolds() {
        return 10;
    }

    public void setNumFolds(int value) {
        this.m_NumFolds = value;
        this.reset();
    }

    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public String numFoldsTipText() {
        return "The number of folds to use in the cross-validation.";
    }

    protected int getDefaultNumThreads() {
        return 1;
    }

    public void setNumThreads(int value) {
        this.m_NumThreads = value;
        this.reset();
    }

    public int getNumThreads() {
        return this.m_NumThreads;
    }

    public String numThreadsTipText() {
        return "The number of threads to use for cross-validation; -1 = number of CPUs/cores; 0 or 1 = sequential execution.";
    }

    protected AbstractOutlierDetector getDefaultDetector() {
        return new Null();
    }

    public void setDetector(AbstractOutlierDetector value) {
        this.m_Detector = value;
        this.reset();
    }

    public AbstractOutlierDetector getDetector() {
        return this.m_Detector;
    }

    public String detectorTipText() {
        return "The outlier detector to use.";
    }

    public Enumeration listOptions() {
        Vector result = new Vector();
        WekaOptionUtils.addOption(result, this.classifierTipText(), "" + this.getDefaultClassifier(), CLASSIFIER);
        WekaOptionUtils.addOption(result, this.numFoldsTipText(), "" + this.getDefaultNumFolds(), NUM_FOLDS);
        WekaOptionUtils.addOption(result, this.numThreadsTipText(), "" + this.getDefaultNumThreads(), NUM_THREADS);
        WekaOptionUtils.addFlag(result, this.detectorTipText(), DETECTOR);
        WekaOptionUtils.add(result, super.listOptions());
        return WekaOptionUtils.toEnumeration(result);
    }

    public void setOptions(String[] options) throws Exception {
        this.setClassifier((Classifier)WekaOptionUtils.parse(options, CLASSIFIER, (weka.core.OptionHandler)this.getDefaultClassifier()));
        this.setNumFolds(WekaOptionUtils.parse(options, NUM_FOLDS, this.getDefaultNumFolds()));
        this.setNumThreads(WekaOptionUtils.parse(options, NUM_THREADS, this.getDefaultNumThreads()));
        this.setDetector((AbstractOutlierDetector)WekaOptionUtils.parse(options, DETECTOR, (OptionHandler)this.getDefaultDetector()));
        super.setOptions(options);
    }

    public String[] getOptions() {
        ArrayList<String> result = new ArrayList<String>();
        WekaOptionUtils.add(result, CLASSIFIER, (weka.core.OptionHandler)this.getClassifier());
        WekaOptionUtils.add(result, NUM_FOLDS, this.getNumFolds());
        WekaOptionUtils.add(result, NUM_THREADS, this.getNumThreads());
        WekaOptionUtils.add(result, DETECTOR, (OptionHandler)this.getDetector());
        WekaOptionUtils.add(result, super.getOptions());
        return WekaOptionUtils.toArray(result);
    }

    public Capabilities getCapabilities() {
        Capabilities result = this.getClassifier().getCapabilities();
        result.setOwner((CapabilitiesHandler)this);
        return super.getCapabilities();
    }

    protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
        return new Instances(inputFormat, 0);
    }

    protected Evaluation crossValidate(Instances data, int folds) throws Exception {
        WekaCrossValidationJob job;
        int numThreads = Performance.determineNumThreads((int)this.m_NumThreads);
        if (numThreads == 1) {
            Evaluation eval = new Evaluation(data);
            eval.setDiscardPredictions(false);
            eval.crossValidateModel(this.m_Classifier, data, folds, new Random(this.m_Seed), new Object[0]);
            return eval;
        }
        DefaultCrossValidationFoldGenerator generator = new DefaultCrossValidationFoldGenerator(data, folds, this.m_Seed, true);
        LocalJobRunner jobRunner = new LocalJobRunner();
        jobRunner.setNumThreads(this.m_NumThreads);
        JobList list = new JobList();
        while (generator.hasNext()) {
            WekaTrainTestSetContainer cont = generator.next();
            job = new WekaCrossValidationJob(this.m_Classifier, (Instances)cont.getValue("Train"), (Instances)cont.getValue("Test"), (Integer)cont.getValue("FoldNumber"), false);
            list.add((Job)job);
        }
        jobRunner.add(list);
        jobRunner.start();
        jobRunner.stop();
        String msg = null;
        AggregateEvaluations evalAgg = new AggregateEvaluations();
        for (int i = 0; i < jobRunner.getJobs().size(); ++i) {
            job = (WekaCrossValidationJob)((Object)jobRunner.getJobs().get(i));
            if (job.getEvaluation() == null) {
                msg = "Fold #" + (i + 1) + " failed to evaluate";
                if (!job.hasExecutionError()) {
                    msg = msg + "?";
                    break;
                }
                msg = msg + ":\n" + job.getExecutionError();
                break;
            }
            evalAgg.add(job.getEvaluation());
            job.cleanUp();
        }
        if (msg != null) {
            throw new Exception(msg);
        }
        list.cleanUp();
        jobRunner.cleanUp();
        Evaluation result = evalAgg.aggregated();
        if (result != null) {
            throw new IllegalStateException(evalAgg.hasLastError() ? evalAgg.getLastError() : "Failed to aggregate evaluations!");
        }
        return result;
    }

    protected SpreadSheet evaluationToSpreadSheet(Evaluation eval) throws Exception {
        WekaPredictionsToSpreadSheet conv = new WekaPredictionsToSpreadSheet();
        String msg = conv.setUp();
        if (msg != null) {
            throw new Exception("Failed to convert predictions to spreadsheet (setUp): " + msg);
        }
        conv.input(new Token((Object)eval));
        msg = conv.execute();
        if (msg != null) {
            throw new Exception("Failed to convert predictions to spreadsheet (execute): " + msg);
        }
        if (!conv.hasPendingOutput()) {
            throw new Exception("No output data generated from predictions!");
        }
        Token token = conv.output();
        SpreadSheet result = (SpreadSheet)token.getPayload();
        return result;
    }

    protected Instances process(Instances data) throws Exception {
        Evaluation eval;
        int folds = this.m_NumFolds;
        if (folds == -1) {
            folds = data.numInstances();
        }
        try {
            eval = this.crossValidate(data, folds);
        }
        catch (Exception e) {
            throw new Exception("Failed to cross-validate!", e);
        }
        SpreadSheet sheet = this.evaluationToSpreadSheet(eval);
        if (sheet == null) {
            return null;
        }
        Set outliers = this.m_Detector.detect(sheet, new SpreadSheetColumnIndex("Actual"), new SpreadSheetColumnIndex("Predicted"));
        if (outliers == null) {
            throw new Exception("Failed to detect outliers!");
        }
        if (this.getDebug()) {
            ArrayList sorted = new ArrayList(outliers);
            Collections.sort(sorted);
            System.err.println(((Object)((Object)this)).getClass().getName() + ": Outliers (0-based index): " + sorted);
        }
        int[] indices = CrossValidationHelper.crossValidationIndices(data, folds, new Random(this.m_Seed));
        Instances result = new Instances(data, data.numInstances() - outliers.size());
        for (int i = 0; i < indices.length; ++i) {
            if (outliers.contains(i)) continue;
            result.add((Instance)data.instance(indices[i]).copy());
        }
        return result;
    }
}

