package adams.flow.transformer;

import adams.core.EnumWithCustomDisplay;
import adams.core.Pausable;
import adams.core.QuickInfoHelper;
import adams.core.Randomizable;
import adams.core.ThreadLimiter;
import adams.core.classmanager.ClassManager;
import adams.core.logging.LoggingHelper;
import adams.core.option.AbstractOption;
import adams.core.option.OptionUtils;
import adams.data.weka.WekaLabelIndex;
import adams.event.FlowPauseStateEvent;
import adams.event.FlowPauseStateListener;
import adams.event.JobCompleteEvent;
import adams.event.JobCompleteListener;
import adams.flow.core.ActorUtils;
import adams.flow.core.CallableActorHelper;
import adams.flow.core.CallableActorReference;
import adams.flow.core.Compatibility;
import adams.flow.core.OutputProducer;
import adams.flow.core.PauseStateHandler;
import adams.flow.core.Token;
import adams.flow.standalone.JobRunnerSetup;
import adams.multiprocess.AbstractJob;
import adams.multiprocess.JobList;
import adams.multiprocess.JobRunner;
import adams.multiprocess.LocalJobRunner;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import java.util.logging.Level;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.meta.GridSearch;
import weka.classifiers.meta.MultiSearch;
import weka.classifiers.meta.multisearch.DefaultEvaluationMetrics;
import weka.classifiers.meta.multisearch.DefaultEvaluationWrapper;
import weka.classifiers.meta.multisearch.Performance;
import weka.classifiers.meta.multisearch.PerformanceComparator;
import weka.core.Instances;
import weka.core.setupgenerator.Point;
import weka.filters.supervised.attribute.SIMPLSMatrixFilterFromGeneticString;
import weka.filters.supervised.instance.RemoveOutliers;

/* loaded from: input_file:adams/flow/transformer/WekaClassifierRanker.class */
public class WekaClassifierRanker extends AbstractTransformer implements Randomizable, Pausable, FlowPauseStateListener, ThreadLimiter {
    private static final long serialVersionUID = -3019442578354930841L;
    protected CallableActorReference m_Train;
    protected CallableActorReference m_Test;
    protected int m_Max;
    protected long m_Seed;
    protected int m_Folds;
    protected Measure m_Measure;
    protected WekaLabelIndex m_ClassLabel;
    protected boolean m_OutputBestSetup;
    protected int m_NumThreads;
    protected CallableActorHelper m_Helper;
    protected transient JobRunnerSetup m_JobRunnerSetup;
    protected JobRunner m_JobRunner;

    /* loaded from: input_file:adams/flow/transformer/WekaClassifierRanker$Measure.class */
    public enum Measure implements EnumWithCustomDisplay<Measure> {
        CC("Correlation coefficient", 0),
        RMSE("Root mean squared error", 1),
        RRSE("Root relative squared error", 2),
        MAE("Mean absolute error", 3),
        RAE("Root absolute error", 4),
        COMBINED("Combined: (1-abs(CC)) + RRSE + RAE", 5),
        ACC("Accuracy", 6),
        KAPPA("Kapp", 7);

        private String m_Display;
        private String m_Raw = super.toString();
        private int m_Measure;

        Measure(String str, int i) {
            this.m_Display = str;
            this.m_Measure = i;
        }

        public int getMeasure() {
            return this.m_Measure;
        }

        public String toDisplay() {
            return this.m_Display;
        }

        public String toRaw() {
            return this.m_Raw;
        }

        @Override // java.lang.Enum
        public String toString() {
            return toDisplay();
        }

        /* renamed from: parse, reason: merged with bridge method [inline-methods] */
        public Measure m70parse(String str) {
            return valueOf((AbstractOption) null, str);
        }

        public static String toString(AbstractOption abstractOption, Object obj) {
            return ((Measure) obj).toRaw();
        }

        public static Measure valueOf(AbstractOption abstractOption, String str) {
            Measure measure = null;
            try {
                measure = valueOf(str);
            } catch (Exception e) {
            }
            if (measure == null) {
                Measure[] values = values();
                int length = values.length;
                int i = 0;
                while (true) {
                    if (i >= length) {
                        break;
                    }
                    Measure measure2 = values[i];
                    if (measure2.toDisplay().equals(str)) {
                        measure = measure2;
                        break;
                    }
                    i++;
                }
            }
            return measure;
        }
    }

    /* loaded from: input_file:adams/flow/transformer/WekaClassifierRanker$RankingJob.class */
    public static class RankingJob extends AbstractJob {
        private static final long serialVersionUID = 6105881068149718863L;
        protected Classifier m_Classifier;
        protected int m_Index;
        protected Instances m_Train;
        protected Instances m_Test;
        protected long m_Seed;
        protected int m_Folds;
        protected Measure m_Measure;
        protected WekaLabelIndex m_ClassLabel;
        protected Performance m_Performance = null;
        protected String m_EvaluationError = "";
        protected boolean m_OutputBestSetup;
        protected Classifier m_BestClassifier;

        public RankingJob(Classifier classifier, int i, Instances instances, Instances instances2, long j, int i2, Measure measure, WekaLabelIndex wekaLabelIndex, boolean z) {
            this.m_Classifier = classifier;
            this.m_Index = i;
            this.m_Train = instances;
            this.m_Test = instances2;
            this.m_Seed = j;
            this.m_Folds = i2;
            this.m_Measure = measure;
            this.m_ClassLabel = wekaLabelIndex;
            this.m_BestClassifier = (Classifier) ClassManager.getSingleton().deepCopy(classifier);
            this.m_OutputBestSetup = z;
        }

        public Classifier getClassifier() {
            return this.m_Classifier;
        }

        public int getIndex() {
            return this.m_Index;
        }

        public Instances getTrain() {
            return this.m_Train;
        }

        public Instances getTest() {
            return this.m_Test;
        }

        public long getSeed() {
            return this.m_Seed;
        }

        public int getFolds() {
            return this.m_Folds;
        }

        public Measure getMeasure() {
            return this.m_Measure;
        }

        public Performance getPerformance() {
            return this.m_Performance;
        }

        public Classifier getBestClassifier() {
            return this.m_BestClassifier;
        }

        public boolean getOutputBestSetup() {
            return this.m_OutputBestSetup;
        }

        protected String preProcessCheck() {
            if (this.m_Classifier == null) {
                return "No classifier set!";
            }
            if (this.m_Train == null) {
                return "No training data set!";
            }
            if (this.m_Folds >= 2 || this.m_Test != null) {
                return null;
            }
            return "No test data set!";
        }

        protected Classifier getBestClassifier(Classifier classifier, Classifier classifier2) {
            Classifier classifier3 = classifier;
            if (this.m_OutputBestSetup && this.m_Folds < 2) {
                try {
                    if (classifier2 instanceof GridSearch) {
                        classifier3 = (Classifier) OptionUtils.shallowCopy(((GridSearch) classifier2).getBestClassifier());
                    } else if (classifier2 instanceof MultiSearch) {
                        classifier3 = (Classifier) OptionUtils.shallowCopy(((MultiSearch) classifier2).getBestClassifier());
                    }
                } catch (Exception e) {
                    getLogger().log(Level.SEVERE, "Failed to copy best '" + classifier2.getClass().getName() + "' classifier:", e);
                    classifier3 = classifier;
                }
            }
            return classifier3;
        }

        protected void process() throws Exception {
            Evaluation evaluation = new Evaluation(this.m_Train);
            evaluation.setDiscardPredictions(true);
            if (this.m_Folds >= 2) {
                evaluation.crossValidateModel(this.m_Classifier, this.m_Train, this.m_Folds, new Random(this.m_Seed), new Object[0]);
            } else {
                Classifier classifier = (Classifier) OptionUtils.shallowCopy(this.m_Classifier);
                classifier.buildClassifier(this.m_Train);
                evaluation.evaluateModel(classifier, this.m_Test, new Object[0]);
                this.m_BestClassifier = getBestClassifier(this.m_Classifier, classifier);
            }
            this.m_ClassLabel.setData(this.m_Train.classAttribute());
            this.m_Performance = new Performance(new Point(new Integer[]{Integer.valueOf(this.m_Index)}), new DefaultEvaluationWrapper(evaluation, new DefaultEvaluationMetrics()), this.m_Measure.getMeasure(), this.m_ClassLabel.getIntIndex(), this.m_Classifier);
        }

        protected String postProcessCheck() {
            if (this.m_Performance == null) {
                return "No performance established!";
            }
            return null;
        }

        public void cleanUp() {
            super.cleanUp();
            this.m_Classifier = null;
            this.m_Performance = null;
            this.m_Train = null;
            this.m_Test = null;
        }

        protected String getAdditionalErrorInformation() {
            return this.m_EvaluationError;
        }

        public String toString() {
            return (((("data:" + this.m_Train.relationName() + ", ") + "classifier: " + OptionUtils.getCommandLine(this.m_Classifier) + ", ") + "seed: " + this.m_Seed + ", ") + "folds: " + this.m_Folds + ", ") + "measure: " + this.m_Measure;
        }
    }

    public String globalInfo() {
        return "Performs a quick evaluation using cross-validation on a single dataset (or evaluation on a separate test set if the number of folds is less than 2) to rank the classifiers received on the input and forwarding the x best ones. Further evaluation can be performed using the Experimenter.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add(SIMPLSMatrixFilterFromGeneticString.MAX, SIMPLSMatrixFilterFromGeneticString.MAX, 3, -1, (Number) null);
        this.m_OptionManager.add("seed", "seed", 1L);
        this.m_OptionManager.add("folds", "folds", 10, 1, (Number) null);
        this.m_OptionManager.add("measure", "measure", Measure.CC);
        this.m_OptionManager.add("class-label", "classLabel", new WekaLabelIndex("first"));
        this.m_OptionManager.add("train", "train", new CallableActorReference("train"));
        this.m_OptionManager.add("test", "test", new CallableActorReference("test"));
        this.m_OptionManager.add("output-best", "outputBestSetup", false);
        this.m_OptionManager.add(RemoveOutliers.NUM_THREADS, "numThreads", 0);
    }

    protected void initialize() {
        super.initialize();
        this.m_Helper = new CallableActorHelper();
    }

    public String getQuickInfo() {
        String variable = QuickInfoHelper.getVariable(this, SIMPLSMatrixFilterFromGeneticString.MAX);
        String str = variable != null ? variable + " best" : this.m_Max < 1 ? "all" : this.m_Max + " best";
        String quickInfoHelper = QuickInfoHelper.toString(this, "folds", this.m_Folds >= 2 ? Integer.valueOf(this.m_Folds) : null, ", ");
        if (quickInfoHelper != null) {
            str = str + quickInfoHelper + " folds";
        }
        String str2 = str + QuickInfoHelper.toString(this, "train", this.m_Train, ", training data: ");
        if (QuickInfoHelper.hasVariable(this, "folds") || this.m_Folds < 2) {
            str2 = str2 + QuickInfoHelper.toString(this, "test", this.m_Test, ", test data: ");
        }
        return str2 + QuickInfoHelper.toString(this, "numThreads", adams.core.Performance.getNumThreadsQuickInfo(this.m_NumThreads), ", ");
    }

    public void setMax(int i) {
        if (i <= 0 && i != -1) {
            getLogger().severe("Maximum number must be >0 or -1 for 'all', provided: " + i);
        } else {
            this.m_Max = i;
            reset();
        }
    }

    public int getMax() {
        return this.m_Max;
    }

    public String maxTipText() {
        return "The maximum number of top-ranked classifiers to forward; use -1 to forward all of them (ranked array).";
    }

    public void setSeed(long j) {
        this.m_Seed = j;
        reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value to use in the cross-validation.";
    }

    public void setFolds(int i) {
        if (i < 1) {
            getLogger().severe("Number of folds must be >=1, provided: " + i);
        } else {
            this.m_Folds = i;
            reset();
        }
    }

    public int getFolds() {
        return this.m_Folds;
    }

    public String foldsTipText() {
        return "The number of folds to use in cross-validation.";
    }

    public void setMeasure(Measure measure) {
        this.m_Measure = measure;
        reset();
    }

    public Measure getMeasure() {
        return this.m_Measure;
    }

    public String measureTipText() {
        return "The measure used for ranking the classifiers.";
    }

    public void setClassLabel(WekaLabelIndex wekaLabelIndex) {
        this.m_ClassLabel = wekaLabelIndex;
        reset();
    }

    public WekaLabelIndex getClassLabel() {
        return this.m_ClassLabel;
    }

    public String classLabelTipText() {
        return "The class label index to use in case of class-specific measures.";
    }

    public void setTrain(CallableActorReference callableActorReference) {
        this.m_Train = callableActorReference;
        reset();
    }

    public CallableActorReference getTrain() {
        return this.m_Train;
    }

    public String trainTipText() {
        return "The name of the callable actor that is used for obtaining the training set.";
    }

    public void setTest(CallableActorReference callableActorReference) {
        this.m_Test = callableActorReference;
        reset();
    }

    public CallableActorReference getTest() {
        return this.m_Test;
    }

    public String testTipText() {
        return "The name of the callable actor that is used for obtaining the test set (only if folds <2).";
    }

    public void setOutputBestSetup(boolean z) {
        this.m_OutputBestSetup = z;
        reset();
    }

    public boolean getOutputBestSetup() {
        return this.m_OutputBestSetup;
    }

    public String outputBestSetupTipText() {
        return "If true, then for optimizers like GridSearch and MultiSearch the best setup that was found will be output instead of the optimizer setup.";
    }

    public void setNumThreads(int i) {
        if (i >= -1) {
            this.m_NumThreads = i;
            reset();
        }
    }

    public int getNumThreads() {
        return this.m_NumThreads;
    }

    public String numThreadsTipText() {
        return "The number of threads to use for evaluating the classifiers in parallel (-1 means one for each core/cpu).";
    }

    public String setUp() {
        String up = super.setUp();
        if (up == null) {
            OutputProducer findCallableActorRecursive = this.m_Helper.findCallableActorRecursive(this, this.m_Train);
            Compatibility compatibility = new Compatibility();
            if (findCallableActorRecursive == null) {
                up = "Callable actor '" + this.m_Train + "' providing the training set not found!";
            } else if (!ActorUtils.isSource(findCallableActorRecursive)) {
                up = "Callable actor '" + this.m_Train + "' (training set) is not a source!";
            } else if (!compatibility.isCompatible(findCallableActorRecursive.generates(), new Class[]{Instances.class})) {
                up = "Callable actor '" + this.m_Train + "' (training set) does not generated " + Instances.class.getName() + "!";
            }
        }
        if (up == null && this.m_Folds < 2) {
            OutputProducer findCallableActorRecursive2 = this.m_Helper.findCallableActorRecursive(this, this.m_Test);
            Compatibility compatibility2 = new Compatibility();
            if (findCallableActorRecursive2 == null) {
                up = "Callable actor '" + this.m_Test + "' providing the test set not found!";
            } else if (!ActorUtils.isSource(findCallableActorRecursive2)) {
                up = "Callable actor '" + this.m_Test + "' (test set) is not a source!";
            } else if (!compatibility2.isCompatible(findCallableActorRecursive2.generates(), new Class[]{Instances.class})) {
                up = "Callable actor '" + this.m_Test + "' (test set) does not generated " + Instances.class.getName() + "!";
            }
        }
        if (up == null && (getRoot() instanceof PauseStateHandler) && getRoot().getPauseStateManager() != null) {
            getRoot().getPauseStateManager().addListener(this);
        }
        if (up == null) {
            this.m_JobRunnerSetup = ActorUtils.findClosestType(this, JobRunnerSetup.class);
        }
        return up;
    }

    public Class[] accepts() {
        return new Class[]{Classifier[].class};
    }

    public Class[] generates() {
        return new Class[]{Classifier[].class};
    }

    protected String doExecute() {
        String handleException;
        try {
            Classifier[] classifierArr = (Classifier[]) this.m_InputToken.getPayload();
            Instances instances = null;
            OutputProducer findCallableActorRecursive = this.m_Helper.findCallableActorRecursive(this, this.m_Train);
            handleException = findCallableActorRecursive.execute();
            if (handleException == null) {
                instances = (Instances) findCallableActorRecursive.output().getPayload();
                if (instances == null) {
                    handleException = "Failed to obtain training data from '" + this.m_Train + "'!";
                }
            }
            Instances instances2 = null;
            if (this.m_Folds < 2) {
                OutputProducer findCallableActorRecursive2 = this.m_Helper.findCallableActorRecursive(this, this.m_Test);
                handleException = findCallableActorRecursive2.execute();
                if (handleException == null) {
                    instances2 = (Instances) findCallableActorRecursive2.output().getPayload();
                    if (instances2 == null) {
                        handleException = "Failed to obtain test data from '" + this.m_Test + "'!";
                    }
                }
            }
            JobList jobList = new JobList();
            for (int i = 0; i < classifierArr.length; i++) {
                jobList.add(new RankingJob(classifierArr[i], i, instances, instances2, this.m_Seed, this.m_Folds, this.m_Measure, this.m_ClassLabel, this.m_OutputBestSetup));
            }
            if (this.m_JobRunnerSetup == null) {
                this.m_JobRunner = new LocalJobRunner();
            } else {
                this.m_JobRunner = this.m_JobRunnerSetup.newInstance();
            }
            this.m_JobRunner.setFlowContext(this);
            if (this.m_JobRunner instanceof ThreadLimiter) {
                this.m_JobRunner.setNumThreads(this.m_NumThreads);
            }
            this.m_JobRunner.addJobCompleteListener(new JobCompleteListener() { // from class: adams.flow.transformer.WekaClassifierRanker.1
                private static final long serialVersionUID = 4773790554588513879L;

                public void jobCompleted(JobCompleteEvent jobCompleteEvent) {
                    if (WekaClassifierRanker.this.isLoggingEnabled()) {
                        System.out.print(".");
                    }
                }
            });
            this.m_JobRunner.add(jobList);
            this.m_JobRunner.start();
            this.m_JobRunner.stop();
            if (!isStopped()) {
                if (LoggingHelper.isAtLeast(getLogger(), Level.FINE)) {
                    getLogger().fine("\nEvaluations:");
                }
                ArrayList arrayList = new ArrayList();
                for (int i2 = 0; i2 < this.m_JobRunner.getJobs().size(); i2++) {
                    RankingJob rankingJob = (RankingJob) this.m_JobRunner.getJobs().get(i2);
                    if (rankingJob.getPerformance() != null) {
                        arrayList.add(rankingJob.getPerformance());
                        if (LoggingHelper.isAtLeast(getLogger(), Level.FINE)) {
                            getLogger().fine((i2 + 1) + ". " + this.m_Measure.toRaw() + "=" + rankingJob.getPerformance().getPerformance() + ": " + OptionUtils.getCommandLine(classifierArr[i2]));
                        }
                    } else {
                        String str = (i2 + 1) + ". no evaluation: " + OptionUtils.getCommandLine(classifierArr[i2]);
                        getLogger().severe(str);
                        handleException = (handleException == null ? "" : handleException + "\n\n") + str;
                        if (rankingJob.hasExecutionError()) {
                            getLogger().severe(rankingJob.getExecutionError());
                            handleException = handleException + rankingJob.getExecutionError();
                        }
                    }
                }
                Collections.sort(arrayList, new PerformanceComparator(this.m_Measure.getMeasure(), new DefaultEvaluationMetrics()));
                if (LoggingHelper.isAtLeast(getLogger(), Level.FINE)) {
                    getLogger().fine("\nChosen classifiers (ranked):");
                }
                ArrayList arrayList2 = new ArrayList();
                for (int size = arrayList.size() - 1; size >= 0 && arrayList2.size() < this.m_Max; size--) {
                    arrayList2.add(jobList.get(((Integer) ((Performance) arrayList.get(size)).getValues().getValue(0)).intValue()).getBestClassifier());
                    if (LoggingHelper.isAtLeast(getLogger(), Level.FINE)) {
                        getLogger().fine((size + 1) + ". " + OptionUtils.getCommandLine(arrayList2.get(arrayList2.size() - 1)) + "/" + this.m_Measure.toRaw() + ": " + ((Performance) arrayList.get(size)).getPerformance());
                    }
                }
                this.m_OutputToken = new Token(arrayList2.toArray(new Classifier[arrayList2.size()]));
            }
            for (int i3 = 0; i3 < jobList.size(); i3++) {
                jobList.get(i3).cleanUp();
            }
            this.m_JobRunner.cleanUp();
        } catch (Exception e) {
            this.m_OutputToken = null;
            handleException = handleException("Failed to rank: ", e);
        }
        return handleException;
    }

    public void flowPauseStateChanged(FlowPauseStateEvent flowPauseStateEvent) {
        if (flowPauseStateEvent.getType() == FlowPauseStateEvent.Type.PAUSED) {
            pauseExecution();
        } else {
            resumeExecution();
        }
    }

    public void pauseExecution() {
        if (this.m_JobRunner != null) {
            this.m_JobRunner.pauseExecution();
        }
    }

    public boolean isPaused() {
        if (this.m_JobRunner != null) {
            return this.m_JobRunner.isPaused();
        }
        return false;
    }

    public void resumeExecution() {
        if (this.m_JobRunner != null) {
            this.m_JobRunner.resumeExecution();
        }
    }

    public void stopExecution() {
        super.stopExecution();
        if (this.m_JobRunner != null) {
            this.m_JobRunner.terminate();
        }
    }

    public void cleanUp() {
        super.cleanUp();
        this.m_JobRunner = null;
        this.m_Helper = null;
    }
}
