/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer;

import adams.core.EnumWithCustomDisplay;
import adams.core.Pausable;
import adams.core.QuickInfoHelper;
import adams.core.Randomizable;
import adams.core.Utils;
import adams.core.logging.LoggingHelper;
import adams.core.option.AbstractOption;
import adams.core.option.OptionHandler;
import adams.core.option.OptionUtils;
import adams.event.FlowPauseStateEvent;
import adams.event.FlowPauseStateListener;
import adams.event.JobCompleteEvent;
import adams.event.JobCompleteListener;
import adams.flow.core.AbstractActor;
import adams.flow.core.Actor;
import adams.flow.core.ActorUtils;
import adams.flow.core.CallableActorHelper;
import adams.flow.core.CallableActorReference;
import adams.flow.core.Compatibility;
import adams.flow.core.OutputProducer;
import adams.flow.core.PauseStateHandler;
import adams.flow.core.Token;
import adams.flow.transformer.AbstractTransformer;
import adams.multiprocess.Job;
import adams.multiprocess.JobList;
import adams.multiprocess.JobRunner;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.meta.FilteredClassifier;
import weka.classifiers.meta.GridSearch;
import weka.classifiers.meta.MultiSearch;
import weka.classifiers.meta.multisearch.Performance;
import weka.classifiers.meta.multisearch.PerformanceComparator;
import weka.core.Instances;
import weka.core.setupgenerator.Point;

public class WekaClassifierRanker
extends AbstractTransformer
implements Randomizable,
Pausable,
FlowPauseStateListener {
    private static final long serialVersionUID = -3019442578354930841L;
    protected CallableActorReference m_Train;
    protected CallableActorReference m_Test;
    protected int m_Max;
    protected long m_Seed;
    protected int m_Folds;
    protected Measure m_Measure;
    protected boolean m_OutputBestSetup;
    protected int m_NumThreads;
    protected CallableActorHelper m_Helper;
    protected JobRunner<RankingJob> m_JobRunner;

    public String globalInfo() {
        return "Performs a quick evaluation using cross-validation on a single dataset (or evaluation on a separate test set if the number of folds is less than 2) to rank the classifiers received on the input and forwarding the x best ones. Further evaluation can be performed using the Experimenter.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("max", "max", (Object)3, (Number)-1, null);
        this.m_OptionManager.add("seed", "seed", (Object)1L);
        this.m_OptionManager.add("folds", "folds", (Object)10, (Number)1, null);
        this.m_OptionManager.add("measure", "measure", (Object)Measure.CC);
        this.m_OptionManager.add("train", "train", (Object)new CallableActorReference("train"));
        this.m_OptionManager.add("test", "test", (Object)new CallableActorReference("test"));
        this.m_OptionManager.add("output-best", "outputBestSetup", (Object)false);
        this.m_OptionManager.add("num-threads", "numThreads", (Object)-1, (Number)-1, null);
    }

    protected void initialize() {
        super.initialize();
        this.m_Helper = new CallableActorHelper();
    }

    public String getQuickInfo() {
        String variable = QuickInfoHelper.getVariable((OptionHandler)this, (String)"max");
        String result = variable != null ? variable + " best" : (this.m_Max < 1 ? "all" : this.m_Max + " best");
        String value = QuickInfoHelper.toString((OptionHandler)this, (String)"folds", (Object)(this.m_Folds >= 2 ? Integer.valueOf(this.m_Folds) : null), (String)", ");
        if (value != null) {
            result = result + value + " folds";
        }
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"train", (Object)this.m_Train, (String)", training data: ");
        if (QuickInfoHelper.hasVariable((OptionHandler)this, (String)"folds") || this.m_Folds < 2) {
            result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"test", (Object)this.m_Test, (String)", test data: ");
        }
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"numThreads", (Object)(this.m_NumThreads < 1 ? "#cores" : Integer.valueOf(this.m_NumThreads)), (String)", threads: ");
        return result;
    }

    public void setMax(int value) {
        if (value > 0 || value == -1) {
            this.m_Max = value;
            this.reset();
        } else {
            this.getLogger().severe("Maximum number must be >0 or -1 for 'all', provided: " + value);
        }
    }

    public int getMax() {
        return this.m_Max;
    }

    public String maxTipText() {
        return "The maximum number of top-ranked classifiers to forward; use -1 to forward all of them (ranked array).";
    }

    public void setSeed(long value) {
        this.m_Seed = value;
        this.reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value to use in the cross-validation.";
    }

    public void setFolds(int value) {
        if (value >= 1) {
            this.m_Folds = value;
            this.reset();
        } else {
            this.getLogger().severe("Number of folds must be >=1, provided: " + value);
        }
    }

    public int getFolds() {
        return this.m_Folds;
    }

    public String foldsTipText() {
        return "The number of folds to use in cross-validation.";
    }

    public void setMeasure(Measure value) {
        this.m_Measure = value;
        this.reset();
    }

    public Measure getMeasure() {
        return this.m_Measure;
    }

    public String measureTipText() {
        return "The measure used for ranking the classifiers.";
    }

    public void setTrain(CallableActorReference value) {
        this.m_Train = value;
        this.reset();
    }

    public CallableActorReference getTrain() {
        return this.m_Train;
    }

    public String trainTipText() {
        return "The name of the callable actor that is used for obtaining the training set.";
    }

    public void setTest(CallableActorReference value) {
        this.m_Test = value;
        this.reset();
    }

    public CallableActorReference getTest() {
        return this.m_Test;
    }

    public String testTipText() {
        return "The name of the callable actor that is used for obtaining the test set (only if folds <2).";
    }

    public void setOutputBestSetup(boolean value) {
        this.m_OutputBestSetup = value;
        this.reset();
    }

    public boolean getOutputBestSetup() {
        return this.m_OutputBestSetup;
    }

    public String outputBestSetupTipText() {
        return "If true, then for optimizers like GridSearch and MultiSearch the best setup that was found will be output instead of the optimizer setup.";
    }

    public void setNumThreads(int value) {
        if (value >= -1) {
            this.m_NumThreads = value;
            this.reset();
        }
    }

    public int getNumThreads() {
        return this.m_NumThreads;
    }

    public String numThreadsTipText() {
        return "The number of threads to use for evaluating the classifiers in parallel (-1 means one for each core/cpu).";
    }

    public String setUp() {
        Compatibility comp;
        AbstractActor actor;
        String result = super.setUp();
        if (result == null) {
            actor = this.m_Helper.findCallableActorRecursive((AbstractActor)this, this.m_Train);
            comp = new Compatibility();
            if (actor == null) {
                result = "Callable actor '" + this.m_Train + "' providing the training set not found!";
            } else if (!ActorUtils.isSource((Actor)actor)) {
                result = "Callable actor '" + this.m_Train + "' (training set) is not a source!";
            } else if (!comp.isCompatible(((OutputProducer)actor).generates(), new Class[]{Instances.class})) {
                result = "Callable actor '" + this.m_Train + "' (training set) does not generated " + Instances.class.getName() + "!";
            }
        }
        if (result == null && this.m_Folds < 2) {
            actor = this.m_Helper.findCallableActorRecursive((AbstractActor)this, this.m_Test);
            comp = new Compatibility();
            if (actor == null) {
                result = "Callable actor '" + this.m_Test + "' providing the test set not found!";
            } else if (!ActorUtils.isSource((Actor)actor)) {
                result = "Callable actor '" + this.m_Test + "' (test set) is not a source!";
            } else if (!comp.isCompatible(((OutputProducer)actor).generates(), new Class[]{Instances.class})) {
                result = "Callable actor '" + this.m_Test + "' (test set) does not generated " + Instances.class.getName() + "!";
            }
        }
        if (result == null && this.getRoot() instanceof PauseStateHandler) {
            ((PauseStateHandler)this.getRoot()).getPauseStateManager().addListener((FlowPauseStateListener)this);
        }
        return result;
    }

    public Class[] accepts() {
        return new Class[]{Classifier[].class};
    }

    public Class[] generates() {
        return new Class[]{Classifier[].class};
    }

    protected String doExecute() {
        String result = null;
        try {
            RankingJob job;
            int i;
            Classifier[] cls = (Classifier[])this.m_InputToken.getPayload();
            Instances train = null;
            AbstractActor source = this.m_Helper.findCallableActorRecursive((AbstractActor)this, this.m_Train);
            result = source.execute();
            if (result == null && (train = (Instances)((OutputProducer)source).output().getPayload()) == null) {
                result = "Failed to obtain training data from '" + this.m_Train + "'!";
            }
            Instances test = null;
            if (this.m_Folds < 2 && (result = (source = this.m_Helper.findCallableActorRecursive((AbstractActor)this, this.m_Test)).execute()) == null && (test = (Instances)((OutputProducer)source).output().getPayload()) == null) {
                result = "Failed to obtain test data from '" + this.m_Test + "'!";
            }
            JobList jobs = new JobList();
            for (i = 0; i < cls.length; ++i) {
                job = new RankingJob(cls[i], i, train, test, this.m_Seed, this.m_Folds, this.m_Measure, this.m_OutputBestSetup);
                jobs.add((Job)job);
            }
            this.m_JobRunner = new JobRunner(this.m_NumThreads);
            this.m_JobRunner.addJobCompleteListener(new JobCompleteListener(){
                private static final long serialVersionUID = 4773790554588513879L;

                public void jobCompleted(JobCompleteEvent e) {
                    if (WekaClassifierRanker.this.isLoggingEnabled()) {
                        System.out.print(".");
                    }
                }
            });
            this.m_JobRunner.add(jobs);
            this.m_JobRunner.start();
            this.m_JobRunner.stop();
            if (!this.isStopped()) {
                if (LoggingHelper.isAtLeast((Logger)this.getLogger(), (Level)Level.FINE)) {
                    this.getLogger().fine("\nEvaluations:");
                }
                ArrayList<Performance> ranking = new ArrayList<Performance>();
                for (i = 0; i < jobs.size(); ++i) {
                    job = (RankingJob)jobs.get(i);
                    if (job.getPerformance() != null) {
                        ranking.add(job.getPerformance());
                        if (!LoggingHelper.isAtLeast((Logger)this.getLogger(), (Level)Level.FINE)) continue;
                        this.getLogger().fine(i + 1 + ". " + this.m_Measure.toRaw() + "=" + job.getPerformance().getPerformance() + ": " + OptionUtils.getCommandLine((Object)cls[i]));
                        continue;
                    }
                    String msg = i + 1 + ". no evaluation: " + OptionUtils.getCommandLine((Object)cls[i]);
                    this.getLogger().severe(msg);
                    result = result == null ? "" : result + "\n\n";
                    result = result + msg;
                    if (!job.hasExecutionError()) continue;
                    this.getLogger().severe(job.getExecutionError());
                    result = result + job.getExecutionError();
                }
                Collections.sort(ranking, new PerformanceComparator(this.m_Measure.getMeasure()));
                if (LoggingHelper.isAtLeast((Logger)this.getLogger(), (Level)Level.FINE)) {
                    this.getLogger().fine("\nChosen classifiers (ranked):");
                }
                ArrayList<Classifier> ranked = new ArrayList<Classifier>();
                for (i = ranking.size() - 1; i >= 0 && ranked.size() < this.m_Max; --i) {
                    int index = (Integer)((Performance)ranking.get(i)).getValues().getValue(0);
                    ranked.add(((RankingJob)jobs.get(index)).getBestClassifier());
                    if (!LoggingHelper.isAtLeast((Logger)this.getLogger(), (Level)Level.FINE)) continue;
                    this.getLogger().fine(i + 1 + ". " + OptionUtils.getCommandLine(ranked.get(ranked.size() - 1)) + "/" + this.m_Measure.toRaw() + ": " + ((Performance)ranking.get(i)).getPerformance());
                }
                this.m_OutputToken = new Token((Object)ranked.toArray(new Classifier[ranked.size()]));
            }
            for (i = 0; i < jobs.size(); ++i) {
                job = (RankingJob)jobs.get(i);
                job.cleanUp();
            }
        }
        catch (Exception e) {
            this.m_OutputToken = null;
            result = this.handleException("Failed to rank: ", e);
        }
        return result;
    }

    public void flowPauseStateChanged(FlowPauseStateEvent e) {
        if (e.getType() == FlowPauseStateEvent.Type.PAUSED) {
            this.pauseExecution();
        } else {
            this.resumeExecution();
        }
    }

    public void pauseExecution() {
        if (this.m_JobRunner != null) {
            this.m_JobRunner.pauseExecution();
        }
    }

    public boolean isPaused() {
        if (this.m_JobRunner != null) {
            return this.m_JobRunner.isPaused();
        }
        return false;
    }

    public void resumeExecution() {
        if (this.m_JobRunner != null) {
            this.m_JobRunner.resumeExecution();
        }
    }

    public void stopExecution() {
        super.stopExecution();
        if (this.m_JobRunner != null) {
            this.m_JobRunner.terminate();
        }
    }

    public void cleanUp() {
        super.cleanUp();
        this.m_JobRunner = null;
        this.m_Helper = null;
    }

    public static enum Measure implements EnumWithCustomDisplay<Measure>
    {
        CC("Correlation coefficient", 0),
        RMSE("Root mean squared error", 1),
        RRSE("Root relative squared error", 2),
        MAE("Mean absolute error", 3),
        RAE("Root absolute error", 4),
        COMBINED("Combined: (1-abs(CC)) + RRSE + RAE", 5),
        ACC("Accuracy", 6),
        KAPPA("Kapp", 7);

        private String m_Display;
        private String m_Raw;
        private int m_Measure;

        private Measure(String display, int measure) {
            this.m_Display = display;
            this.m_Raw = super.toString();
            this.m_Measure = measure;
        }

        public int getMeasure() {
            return this.m_Measure;
        }

        public String toDisplay() {
            return this.m_Display;
        }

        public String toRaw() {
            return this.m_Raw;
        }

        public String toString() {
            return this.toDisplay();
        }

        public Measure parse(String s) {
            return Measure.valueOf((AbstractOption)null, s);
        }

        public static String toString(AbstractOption option, Object object) {
            return ((Measure)((Object)object)).toRaw();
        }

        public static Measure valueOf(AbstractOption option, String str) {
            Measure result = null;
            try {
                result = Measure.valueOf(str);
            }
            catch (Exception e) {
                // empty catch block
            }
            if (result == null) {
                for (Measure dt : Measure.values()) {
                    if (!dt.toDisplay().equals(str)) continue;
                    result = dt;
                    break;
                }
            }
            return result;
        }
    }

    public static class RankingJob
    extends Job {
        private static final long serialVersionUID = 6105881068149718863L;
        protected Classifier m_Classifier;
        protected int m_Index;
        protected Instances m_Train;
        protected Instances m_Test;
        protected long m_Seed;
        protected int m_Folds;
        protected Measure m_Measure;
        protected Performance m_Performance;
        protected String m_EvaluationError;
        protected boolean m_OutputBestSetup;
        protected Classifier m_BestClassifier;

        public RankingJob(Classifier cls, int index, Instances train, Instances test, long seed, int folds, Measure measure, boolean best) {
            this.m_Classifier = cls;
            this.m_Index = index;
            this.m_Train = train;
            this.m_Test = test;
            this.m_Seed = seed;
            this.m_Folds = folds;
            this.m_Measure = measure;
            this.m_Performance = null;
            this.m_EvaluationError = "";
            this.m_BestClassifier = (Classifier)Utils.deepCopy((Object)cls);
            this.m_OutputBestSetup = best;
        }

        public Classifier getClassifier() {
            return this.m_Classifier;
        }

        public int getIndex() {
            return this.m_Index;
        }

        public Instances getTrain() {
            return this.m_Train;
        }

        public Instances getTest() {
            return this.m_Test;
        }

        public long getSeed() {
            return this.m_Seed;
        }

        public int getFolds() {
            return this.m_Folds;
        }

        public Measure getMeasure() {
            return this.m_Measure;
        }

        public Performance getPerformance() {
            return this.m_Performance;
        }

        public Classifier getBestClassifier() {
            return this.m_BestClassifier;
        }

        public boolean getOutputBestSetup() {
            return this.m_OutputBestSetup;
        }

        protected String preProcessCheck() {
            if (this.m_Classifier == null) {
                return "No classifier set!";
            }
            if (this.m_Train == null) {
                return "No training data set!";
            }
            if (this.m_Folds < 2 && this.m_Test == null) {
                return "No test data set!";
            }
            return null;
        }

        protected Classifier getBestClassifier(Classifier template, Classifier trained) {
            Classifier result = template;
            if (this.m_OutputBestSetup && this.m_Folds < 2) {
                if (trained instanceof GridSearch) {
                    result = new FilteredClassifier();
                    ((FilteredClassifier)result).setClassifier(((GridSearch)trained).getBestClassifier());
                    ((FilteredClassifier)result).setFilter(((GridSearch)trained).getBestFilter());
                } else if (trained instanceof MultiSearch) {
                    try {
                        result = AbstractClassifier.makeCopy((Classifier)((MultiSearch)trained).getBestClassifier());
                    }
                    catch (Exception e) {
                        this.getLogger().log(Level.SEVERE, "Failed to copy best MultiSearch classifier:", e);
                        result = template;
                    }
                }
            }
            return result;
        }

        protected void process() throws Exception {
            Evaluation eval = new Evaluation(this.m_Train);
            eval.setDiscardPredictions(true);
            if (this.m_Folds >= 2) {
                eval.crossValidateModel(this.m_Classifier, this.m_Train, this.m_Folds, new Random(this.m_Seed), new Object[0]);
            } else {
                Classifier cls = AbstractClassifier.makeCopy((Classifier)this.m_Classifier);
                cls.buildClassifier(this.m_Train);
                eval.evaluateModel(cls, this.m_Test, new Object[0]);
                this.m_BestClassifier = this.getBestClassifier(this.m_Classifier, cls);
                Object var2_2 = null;
            }
            this.m_Performance = new Performance(new Point((Object[])new Integer[]{this.m_Index}), eval, this.m_Measure.getMeasure());
            eval = null;
        }

        protected String postProcessCheck() {
            if (this.m_Performance == null) {
                return "No performance established!";
            }
            return null;
        }

        public void cleanUp() {
            super.cleanUp();
            this.m_Classifier = null;
            this.m_Performance = null;
            this.m_Train = null;
            this.m_Test = null;
        }

        protected String getAdditionalErrorInformation() {
            return this.m_EvaluationError;
        }

        public String toString() {
            String result = "data:" + this.m_Train.relationName() + ", ";
            result = result + "classifier: " + OptionUtils.getCommandLine((Object)this.m_Classifier) + ", ";
            result = result + "seed: " + this.m_Seed + ", ";
            result = result + "folds: " + this.m_Folds + ", ";
            result = result + "measure: " + (Object)((Object)this.m_Measure);
            return result;
        }
    }
}

