package weka.filters.supervised.instance;

import adams.core.Performance;
import adams.core.option.OptionHandler;
import adams.data.spreadsheet.SpreadSheet;
import adams.data.spreadsheet.SpreadSheetColumnIndex;
import adams.flow.container.WekaTrainTestSetContainer;
import adams.flow.control.removeoutliers.AbstractOutlierDetector;
import adams.flow.control.removeoutliers.Null;
import adams.flow.core.Token;
import adams.flow.transformer.WekaPredictionsToSpreadSheet;
import adams.multiprocess.JobList;
import adams.multiprocess.LocalJobRunner;
import adams.multiprocess.WekaCrossValidationJob;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.Vector;
import weka.classifiers.AggregateEvaluations;
import weka.classifiers.Classifier;
import weka.classifiers.CrossValidationHelper;
import weka.classifiers.DefaultCrossValidationFoldGenerator;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.LinearRegressionJ;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Randomizable;
import weka.core.WekaOptionUtils;
import weka.filters.SimpleBatchFilter;

/* loaded from: input_file:weka/filters/supervised/instance/RemoveOutliers.class */
public class RemoveOutliers extends SimpleBatchFilter implements Randomizable {
    private static final long serialVersionUID = -8292965351930853084L;
    public static final String CLASSIFIER = "classifier";
    public static final String NUM_FOLDS = "num-folds";
    public static final String NUM_THREADS = "num-threads";
    public static final String DETECTOR = "detector";
    protected Classifier m_Classifier = getDefaultClassifier();
    protected int m_Seed = getDefaultSeed();
    protected int m_NumFolds = getDefaultNumFolds();
    protected AbstractOutlierDetector m_Detector = getDefaultDetector();
    protected int m_NumThreads = getDefaultNumThreads();

    public String globalInfo() {
        return "Cross-validates the specified classifier on the incoming data and applies the outlier detector to the actual vs predicted data to remove the outliers.\nNB: only works on full dataset, not instance by instance.";
    }

    protected Classifier getDefaultClassifier() {
        return new LinearRegressionJ();
    }

    public void setClassifier(Classifier classifier) {
        this.m_Classifier = classifier;
        reset();
    }

    public Classifier getClassifier() {
        return this.m_Classifier;
    }

    public String classifierTipText() {
        return "The classifier to use for generating the actual vs predicted data.";
    }

    protected int getDefaultSeed() {
        return 1;
    }

    public void setSeed(int i) {
        this.m_Seed = i;
        reset();
    }

    public int getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value for the cross-validation.";
    }

    protected int getDefaultNumFolds() {
        return 10;
    }

    public void setNumFolds(int i) {
        this.m_NumFolds = i;
        reset();
    }

    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public String numFoldsTipText() {
        return "The number of folds to use in the cross-validation.";
    }

    protected int getDefaultNumThreads() {
        return 1;
    }

    public void setNumThreads(int i) {
        this.m_NumThreads = i;
        reset();
    }

    public int getNumThreads() {
        return this.m_NumThreads;
    }

    public String numThreadsTipText() {
        return "The number of threads to use for cross-validation; -1 = number of CPUs/cores; 0 or 1 = sequential execution.";
    }

    protected AbstractOutlierDetector getDefaultDetector() {
        return new Null();
    }

    public void setDetector(AbstractOutlierDetector abstractOutlierDetector) {
        this.m_Detector = abstractOutlierDetector;
        reset();
    }

    public AbstractOutlierDetector getDetector() {
        return this.m_Detector;
    }

    public String detectorTipText() {
        return "The outlier detector to use.";
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        WekaOptionUtils.addOption(vector, classifierTipText(), "" + getDefaultClassifier(), "classifier");
        WekaOptionUtils.addOption(vector, numFoldsTipText(), "" + getDefaultNumFolds(), NUM_FOLDS);
        WekaOptionUtils.addOption(vector, numThreadsTipText(), "" + getDefaultNumThreads(), NUM_THREADS);
        WekaOptionUtils.addFlag(vector, detectorTipText(), DETECTOR);
        WekaOptionUtils.add(vector, super.listOptions());
        return WekaOptionUtils.toEnumeration(vector);
    }

    public void setOptions(String[] strArr) throws Exception {
        setClassifier((Classifier) WekaOptionUtils.parse(strArr, "classifier", getDefaultClassifier()));
        setNumFolds(WekaOptionUtils.parse(strArr, NUM_FOLDS, getDefaultNumFolds()));
        setNumThreads(WekaOptionUtils.parse(strArr, NUM_THREADS, getDefaultNumThreads()));
        setDetector((AbstractOutlierDetector) WekaOptionUtils.parse(strArr, DETECTOR, (OptionHandler) getDefaultDetector()));
        super.setOptions(strArr);
    }

    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        WekaOptionUtils.add((List<String>) arrayList, "classifier", getClassifier());
        WekaOptionUtils.add((List<String>) arrayList, NUM_FOLDS, getNumFolds());
        WekaOptionUtils.add((List<String>) arrayList, NUM_THREADS, getNumThreads());
        WekaOptionUtils.add((List<String>) arrayList, DETECTOR, (OptionHandler) getDetector());
        WekaOptionUtils.add(arrayList, super.getOptions());
        return WekaOptionUtils.toArray(arrayList);
    }

    public Capabilities getCapabilities() {
        getClassifier().getCapabilities().setOwner(this);
        return super.getCapabilities();
    }

    protected Instances determineOutputFormat(Instances instances) throws Exception {
        return new Instances(instances, 0);
    }

    protected Evaluation crossValidate(Instances instances, int i) throws Exception {
        if (Performance.determineNumThreads(this.m_NumThreads) == 1) {
            Evaluation evaluation = new Evaluation(instances);
            evaluation.setDiscardPredictions(false);
            evaluation.crossValidateModel(this.m_Classifier, instances, i, new Random(this.m_Seed), new Object[0]);
            return evaluation;
        }
        DefaultCrossValidationFoldGenerator defaultCrossValidationFoldGenerator = new DefaultCrossValidationFoldGenerator(instances, i, this.m_Seed, true);
        LocalJobRunner localJobRunner = new LocalJobRunner();
        localJobRunner.setNumThreads(this.m_NumThreads);
        JobList jobList = new JobList();
        while (defaultCrossValidationFoldGenerator.hasNext()) {
            WekaTrainTestSetContainer next = defaultCrossValidationFoldGenerator.next();
            jobList.add(new WekaCrossValidationJob(this.m_Classifier, (Instances) next.getValue("Train"), (Instances) next.getValue(WekaTrainTestSetContainer.VALUE_TEST), ((Integer) next.getValue(WekaTrainTestSetContainer.VALUE_FOLD_NUMBER)).intValue(), false));
        }
        localJobRunner.add(jobList);
        localJobRunner.start();
        localJobRunner.stop();
        String str = null;
        AggregateEvaluations aggregateEvaluations = new AggregateEvaluations();
        int i2 = 0;
        while (true) {
            if (i2 >= localJobRunner.getJobs().size()) {
                break;
            }
            WekaCrossValidationJob wekaCrossValidationJob = (WekaCrossValidationJob) localJobRunner.getJobs().get(i2);
            if (wekaCrossValidationJob.getEvaluation() == null) {
                String str2 = "Fold #" + (i2 + 1) + " failed to evaluate";
                str = !wekaCrossValidationJob.hasExecutionError() ? str2 + "?" : str2 + ":\n" + wekaCrossValidationJob.getExecutionError();
            } else {
                aggregateEvaluations.add(wekaCrossValidationJob.getEvaluation());
                wekaCrossValidationJob.cleanUp();
                i2++;
            }
        }
        if (str != null) {
            throw new Exception(str);
        }
        jobList.cleanUp();
        localJobRunner.cleanUp();
        Evaluation aggregated = aggregateEvaluations.aggregated();
        if (aggregated != null) {
            throw new IllegalStateException(aggregateEvaluations.hasLastError() ? aggregateEvaluations.getLastError() : "Failed to aggregate evaluations!");
        }
        return aggregated;
    }

    protected SpreadSheet evaluationToSpreadSheet(Evaluation evaluation) throws Exception {
        WekaPredictionsToSpreadSheet wekaPredictionsToSpreadSheet = new WekaPredictionsToSpreadSheet();
        String up = wekaPredictionsToSpreadSheet.setUp();
        if (up != null) {
            throw new Exception("Failed to convert predictions to spreadsheet (setUp): " + up);
        }
        wekaPredictionsToSpreadSheet.input(new Token(evaluation));
        String execute = wekaPredictionsToSpreadSheet.execute();
        if (execute != null) {
            throw new Exception("Failed to convert predictions to spreadsheet (execute): " + execute);
        }
        if (wekaPredictionsToSpreadSheet.hasPendingOutput()) {
            return (SpreadSheet) wekaPredictionsToSpreadSheet.output().getPayload();
        }
        throw new Exception("No output data generated from predictions!");
    }

    protected Instances process(Instances instances) throws Exception {
        int i = this.m_NumFolds;
        if (i == -1) {
            i = instances.numInstances();
        }
        try {
            SpreadSheet evaluationToSpreadSheet = evaluationToSpreadSheet(crossValidate(instances, i));
            if (evaluationToSpreadSheet == null) {
                return null;
            }
            Set detect = this.m_Detector.detect(evaluationToSpreadSheet, new SpreadSheetColumnIndex("Actual"), new SpreadSheetColumnIndex("Predicted"));
            if (detect == null) {
                throw new Exception("Failed to detect outliers!");
            }
            if (getDebug()) {
                ArrayList arrayList = new ArrayList(detect);
                Collections.sort(arrayList);
                System.err.println(getClass().getName() + ": Outliers (0-based index): " + arrayList);
            }
            int[] crossValidationIndices = CrossValidationHelper.crossValidationIndices(instances, i, new Random(this.m_Seed));
            Instances instances2 = new Instances(instances, instances.numInstances() - detect.size());
            for (int i2 = 0; i2 < crossValidationIndices.length; i2++) {
                if (!detect.contains(Integer.valueOf(i2))) {
                    instances2.add((Instance) instances.instance(crossValidationIndices[i2]).copy());
                }
            }
            return instances2;
        } catch (Exception e) {
            throw new Exception("Failed to cross-validate!", e);
        }
    }
}
