/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer;

import adams.core.QuickInfoHelper;
import adams.core.Randomizable;
import adams.core.management.ProcessUtils;
import adams.core.option.OptionHandler;
import adams.core.option.OptionUtils;
import adams.flow.container.WekaEvaluationContainer;
import adams.flow.container.WekaTrainTestSetContainer;
import adams.flow.core.AbstractActor;
import adams.flow.core.Token;
import adams.flow.provenance.ActorType;
import adams.flow.provenance.Provenance;
import adams.flow.provenance.ProvenanceContainer;
import adams.flow.provenance.ProvenanceInformation;
import adams.flow.provenance.ProvenanceSupporter;
import adams.flow.transformer.AbstractCallableWekaClassifierEvaluator;
import adams.multiprocess.Job;
import adams.multiprocess.JobList;
import adams.multiprocess.JobRunner;
import java.util.Random;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.AggregateableEvaluation;
import weka.classifiers.Classifier;
import weka.classifiers.CrossValidationFoldGenerator;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.output.prediction.Null;
import weka.core.Instances;

public class WekaCrossValidationEvaluator
extends AbstractCallableWekaClassifierEvaluator
implements Randomizable,
ProvenanceSupporter {
    private static final long serialVersionUID = -3019442578354930841L;
    protected int m_Folds;
    protected long m_Seed;
    protected int m_NumThreads;
    protected int m_ActualNumThreads;

    public String globalInfo() {
        return "Cross-validates a classifier on an incoming dataset. The classifier setup being used in the evaluation is a callable 'Classifier' actor.";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("seed", "seed", (Object)1L);
        this.m_OptionManager.add("folds", "folds", (Object)10, (Number)-1, null);
        this.m_OptionManager.add("num-threads", "numThreads", (Object)1, (Number)-1, null);
    }

    @Override
    public String getQuickInfo() {
        String result = super.getQuickInfo();
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"folds", (Object)this.m_Folds, (String)", folds: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"seed", (Object)this.m_Seed, (String)", seed: ");
        String variable = QuickInfoHelper.getVariable((OptionHandler)this, (String)"numThreads");
        if (variable != null) {
            result = result + ", threads: " + variable;
        } else if (this.m_NumThreads == 0 || this.m_NumThreads == 1) {
            result = result + ", sequential";
        } else {
            result = result + ", parallel/threads: ";
            result = this.m_NumThreads == -1 ? result + "#cores" : result + this.m_NumThreads;
        }
        return result;
    }

    @Override
    public String classifierTipText() {
        return "The callable classifier actor to cross-validate on the input data.";
    }

    @Override
    public String outputTipText() {
        return "The class for generating prediction output; if 'Null' is used, then an Evaluation object is forwarded instead of a String; not used when using parallel execution.";
    }

    public void setFolds(int value) {
        if (value == -1 || value >= 2) {
            this.m_Folds = value;
            this.reset();
        } else {
            this.getLogger().severe("Number of folds must be >=2 or -1 for LOOCV, provided: " + value);
        }
    }

    public int getFolds() {
        return this.m_Folds;
    }

    public String foldsTipText() {
        return "The number of folds to use in the cross-validation; use -1 for leave-one-out cross-validation (LOOCV).";
    }

    public void setSeed(long value) {
        this.m_Seed = value;
        this.reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value for the cross-validation (used for randomization).";
    }

    public void setNumThreads(int value) {
        this.m_NumThreads = value;
        this.reset();
    }

    public int getNumThreads() {
        return this.m_NumThreads;
    }

    public String numThreadsTipText() {
        return "The number of threads to use for executing the branches; -1 = number of CPUs/cores; 0 or 1 = sequential execution.";
    }

    public Class[] accepts() {
        return new Class[]{Instances.class};
    }

    @Override
    public String setUp() {
        String result = super.setUp();
        if (result == null) {
            // empty if block
        }
        return result;
    }

    protected String doExecute() {
        String result = null;
        try {
            Classifier cls = this.getClassifierInstance();
            if (cls == null) {
                throw new IllegalStateException("Classifier '" + this.getClassifier() + "' not found!");
            }
            Instances data = (Instances)this.m_InputToken.getPayload();
            int folds = this.m_Folds;
            if (folds == -1) {
                folds = data.numInstances();
            }
            this.m_ActualNumThreads = this.m_NumThreads == -1 ? ProcessUtils.getAvailableProcessors() : (this.m_NumThreads > 1 ? Math.min(this.m_NumThreads, folds) : 0);
            if (this.m_ActualNumThreads == 0) {
                this.initOutputBuffer();
                this.m_Output.setHeader(data);
                Evaluation eval = new Evaluation(data);
                eval.setDiscardPredictions(this.m_DiscardPredictions);
                eval.crossValidateModel(cls, data, folds, new Random(this.m_Seed), new Object[]{this.m_Output});
                this.m_OutputToken = this.m_Output instanceof Null ? new Token((Object)new WekaEvaluationContainer(eval)) : new Token((Object)this.m_Output.getBuffer().toString());
            } else {
                CrossValidationFoldGenerator generator = new CrossValidationFoldGenerator(data, folds, this.m_Seed, true);
                JobRunner runner = new JobRunner(this.m_ActualNumThreads);
                JobList list = new JobList();
                while (generator.hasNext()) {
                    WekaTrainTestSetContainer cont = generator.next();
                    CrossValidationJob job = new CrossValidationJob(this.getClassifierInstance(), (Instances)cont.getValue("Train"), (Instances)cont.getValue("Test"), (Integer)cont.getValue("FoldNumber"), this.m_DiscardPredictions);
                    list.add((Job)job);
                }
                runner.add(list);
                runner.start();
                runner.stop();
                AggregateableEvaluation evalAgg = new AggregateableEvaluation(data);
                for (int i = 0; i < list.size(); ++i) {
                    if (((CrossValidationJob)list.get(i)).getEvaluation() == null) {
                        result = "Fold #" + (i + 1) + " failed to evaluate";
                        if (!((CrossValidationJob)list.get(i)).hasExecutionError()) {
                            result = result + "?";
                            break;
                        }
                        result = result + ":\n" + ((CrossValidationJob)list.get(i)).getExecutionError();
                        break;
                    }
                    evalAgg.aggregate(((CrossValidationJob)list.get(i)).getEvaluation());
                    ((CrossValidationJob)list.get(i)).cleanUp();
                }
                list.cleanUp();
                this.m_OutputToken = new Token((Object)new WekaEvaluationContainer((Evaluation)evalAgg));
            }
        }
        catch (Exception e) {
            this.m_OutputToken = null;
            result = this.handleException("Failed to cross-validate classifier: ", e);
        }
        if (this.m_OutputToken != null) {
            this.updateProvenance((ProvenanceContainer)this.m_OutputToken);
        }
        return result;
    }

    public void updateProvenance(ProvenanceContainer cont) {
        if (Provenance.getSingleton().isEnabled()) {
            if (this.m_InputToken.hasProvenance()) {
                cont.setProvenance(this.m_InputToken.getProvenance().getClone());
            }
            cont.addProvenance(new ProvenanceInformation(ActorType.EVALUATOR, this.m_InputToken.getPayload().getClass(), (AbstractActor)this, this.m_OutputToken.getPayload().getClass()));
        }
    }

    public static class CrossValidationJob
    extends Job {
        private static final long serialVersionUID = -9085803857529039559L;
        protected Classifier m_Classifier;
        protected int m_Fold;
        protected Instances m_Train;
        protected Instances m_Test;
        protected boolean m_DiscardPredictions;
        protected Evaluation m_Evaluation;

        public CrossValidationJob(Classifier classifier, Instances train, Instances test, int fold, boolean discardPred) {
            try {
                this.m_Classifier = AbstractClassifier.makeCopy((Classifier)classifier);
            }
            catch (Exception e) {
                this.m_Classifier = null;
            }
            this.m_Train = train;
            this.m_Test = test;
            this.m_Fold = fold;
            this.m_DiscardPredictions = discardPred;
        }

        public Instances getTrain() {
            return this.m_Train;
        }

        public Instances getTest() {
            return this.m_Test;
        }

        public int getFold() {
            return this.m_Fold;
        }

        public boolean getDiscardPredictions() {
            return this.m_DiscardPredictions;
        }

        public Evaluation getEvaluation() {
            return this.m_Evaluation;
        }

        protected String preProcessCheck() {
            if (this.m_Classifier == null) {
                return "No classifier set/failed to copy!";
            }
            if (this.m_Train == null) {
                return "No training set!";
            }
            if (this.m_Test == null) {
                return "No test set!";
            }
            return null;
        }

        protected void process() throws Exception {
            this.m_Classifier.buildClassifier(this.m_Train);
            this.m_Evaluation = new Evaluation(this.m_Train);
            this.m_Evaluation.setDiscardPredictions(this.m_DiscardPredictions);
            this.m_Evaluation.evaluateModel(this.m_Classifier, this.m_Test, new Object[0]);
        }

        protected String postProcessCheck() {
            if (this.m_Evaluation == null) {
                return "Failed to evaluate?";
            }
            return null;
        }

        public void cleanUp() {
            super.cleanUp();
            this.m_Train = null;
            this.m_Test = null;
            this.m_Evaluation = null;
        }

        public String toString() {
            return "classifier=" + OptionUtils.getCommandLine((Object)this.m_Classifier) + ", fold=" + this.m_Fold;
        }
    }
}

