/*
 * Decompiled with CFR 0.152.
 */
package adams.multiprocess;

import adams.core.Performance;
import adams.core.StatusMessageHandler;
import adams.core.Stoppable;
import adams.core.ThreadLimiter;
import adams.core.Utils;
import adams.core.logging.CustomLoggingLevelObject;
import adams.core.logging.LoggingSupporter;
import adams.core.option.OptionUtils;
import adams.data.weka.InstancesViewSupporter;
import adams.flow.container.WekaTrainTestSetContainer;
import adams.flow.standalone.JobRunnerSetup;
import adams.multiprocess.Job;
import adams.multiprocess.JobList;
import adams.multiprocess.JobRunner;
import adams.multiprocess.LocalJobRunner;
import adams.multiprocess.WekaCrossValidationJob;
import java.util.Random;
import weka.classifiers.AggregateableEvaluation;
import weka.classifiers.Classifier;
import weka.classifiers.CrossValidationFoldGenerator;
import weka.classifiers.CrossValidationHelper;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.output.prediction.AbstractOutput;
import weka.core.Instances;

public class WekaCrossValidationExecution
extends CustomLoggingLevelObject
implements Stoppable,
InstancesViewSupporter {
    private static final long serialVersionUID = 2021758441076652982L;
    protected Classifier m_Classifier = null;
    protected Instances m_Data = null;
    protected AbstractOutput m_Output = null;
    protected StringBuffer m_OutputBuffer;
    protected int m_Folds;
    protected boolean m_SeparateFolds;
    protected long m_Seed;
    protected boolean m_UseViews;
    protected boolean m_DiscardPredictions;
    protected int m_NumThreads;
    protected int m_ActualNumThreads;
    protected transient JobRunnerSetup m_JobRunnerSetup = null;
    protected transient JobRunner m_JobRunner = null;
    protected Evaluation m_Evaluation;
    protected Evaluation[] m_Evaluations;
    protected int[] m_OriginalIndices;
    protected boolean m_Stopped;
    protected StatusMessageHandler m_StatusMessageHandler = null;
    protected boolean m_WaitForJobs = true;

    public void setJobRunnerSetup(JobRunnerSetup value) {
        this.m_JobRunnerSetup = value;
    }

    public JobRunnerSetup getJobRunnerSetup() {
        return this.m_JobRunnerSetup;
    }

    public void setWaitForJobs(boolean value) {
        this.m_WaitForJobs = value;
    }

    public boolean getWaitForJobs() {
        return this.m_WaitForJobs;
    }

    public void setClassifier(Classifier value) {
        this.m_Classifier = value;
    }

    public Classifier getClassifier() {
        return this.m_Classifier;
    }

    public void setData(Instances value) {
        this.m_Data = value;
    }

    public Instances getData() {
        return this.m_Data;
    }

    public void setOutput(AbstractOutput value) {
        this.m_Output = value;
    }

    public AbstractOutput getOutput() {
        return this.m_Output;
    }

    public void setFolds(int value) {
        if (value == -1 || value >= 2) {
            this.m_Folds = value;
        } else {
            this.getLogger().severe("Number of folds must be >=2 or -1 for LOOCV, provided: " + value);
        }
    }

    public int getFolds() {
        return this.m_Folds;
    }

    public void setSeparateFolds(boolean value) {
        this.m_SeparateFolds = value;
    }

    public boolean getSeparateFolds() {
        return this.m_SeparateFolds;
    }

    public void setSeed(long value) {
        this.m_Seed = value;
    }

    public long getSeed() {
        return this.m_Seed;
    }

    @Override
    public void setUseViews(boolean value) {
        this.m_UseViews = value;
    }

    @Override
    public boolean getUseViews() {
        return this.m_UseViews;
    }

    public void setDiscardPredictions(boolean value) {
        this.m_DiscardPredictions = value;
    }

    public boolean getDiscardPredictions() {
        return this.m_DiscardPredictions;
    }

    public void setNumThreads(int value) {
        this.m_NumThreads = value;
    }

    public int getNumThreads() {
        return this.m_NumThreads;
    }

    public void setStatusMessageHandler(StatusMessageHandler value) {
        this.m_StatusMessageHandler = value;
    }

    public StatusMessageHandler getStatusMessageHandler() {
        return this.m_StatusMessageHandler;
    }

    protected void initOutputBuffer() {
        this.m_OutputBuffer = new StringBuffer();
        if (this.m_Output != null) {
            try {
                this.m_Output = (AbstractOutput)OptionUtils.forAnyCommandLine(AbstractOutput.class, (String)OptionUtils.getCommandLine((Object)this.m_Output));
                this.m_Output.setBuffer(this.m_OutputBuffer);
            }
            catch (Exception e) {
                throw new IllegalStateException("Failed to create copy of output!", e);
            }
        }
    }

    public StringBuffer getOutputBuffer() {
        return this.m_OutputBuffer;
    }

    public Evaluation getEvaluation() {
        return this.m_Evaluation;
    }

    public Evaluation[] getEvaluations() {
        return this.m_Evaluations;
    }

    public int[] getOriginalIndices() {
        return this.m_OriginalIndices;
    }

    public boolean isSingleThreaded() {
        return this.m_ActualNumThreads == 0;
    }

    public String execute() {
        String result = null;
        int[] indices = null;
        this.m_Evaluation = null;
        this.m_Evaluations = null;
        try {
            int folds;
            if (this.m_Classifier == null) {
                throw new IllegalStateException("Classifier '" + this.getClassifier() + "' not found!");
            }
            if (this.isLoggingEnabled()) {
                this.getLogger().info(OptionUtils.getCommandLine((Object)this.m_Classifier));
            }
            if ((folds = this.m_Folds) == -1) {
                folds = this.m_Data.numInstances();
            }
            this.m_ActualNumThreads = Performance.determineNumThreads((int)this.m_NumThreads);
            if (!this.m_DiscardPredictions) {
                indices = CrossValidationHelper.crossValidationIndices(this.m_Data, folds, new Random(this.m_Seed));
            }
            CrossValidationFoldGenerator generator = new CrossValidationFoldGenerator(this.m_Data, folds, this.m_Seed, true);
            generator.setUseViews(this.m_UseViews);
            if (this.m_ActualNumThreads == 1 && !this.m_SeparateFolds) {
                this.initOutputBuffer();
                if (this.m_Output != null) {
                    this.m_Output.setHeader(this.m_Data);
                    this.m_Output.printHeader();
                }
                Evaluation eval = new Evaluation(this.m_Data);
                eval.setDiscardPredictions(this.m_DiscardPredictions);
                int current = 0;
                while (generator.hasNext() && !this.isStopped()) {
                    if (this.m_StatusMessageHandler != null) {
                        this.m_StatusMessageHandler.showStatus("Fold " + current + "/" + folds + ": '" + this.m_Data.relationName() + "' using " + OptionUtils.getCommandLine((Object)this.m_Classifier));
                    }
                    WekaTrainTestSetContainer cont = generator.next();
                    Instances train = (Instances)cont.getValue("Train");
                    Instances test = (Instances)cont.getValue("Test");
                    Classifier cls = (Classifier)OptionUtils.shallowCopy((Object)this.m_Classifier);
                    cls.buildClassifier(train);
                    eval.setPriors(train);
                    eval.evaluateModel(cls, test, new Object[]{this.m_Output});
                    ++current;
                }
                if (this.m_Output != null) {
                    this.m_Output.printFooter();
                }
                if (!this.isStopped()) {
                    this.m_Evaluation = eval;
                }
            } else {
                WekaCrossValidationJob job;
                this.m_JobRunner = this.m_JobRunnerSetup == null ? new LocalJobRunner() : this.m_JobRunnerSetup.newInstance();
                if (this.m_JobRunner instanceof ThreadLimiter) {
                    ((ThreadLimiter)this.m_JobRunner).setNumThreads(this.m_NumThreads);
                }
                JobList list = new JobList();
                while (generator.hasNext()) {
                    WekaTrainTestSetContainer cont = generator.next();
                    job = new WekaCrossValidationJob((Classifier)OptionUtils.shallowCopy((Object)this.m_Classifier), (Instances)cont.getValue("Train"), (Instances)cont.getValue("Test"), (Integer)cont.getValue("FoldNumber"), this.m_DiscardPredictions, this.m_StatusMessageHandler);
                    list.add((Job)job);
                }
                this.m_JobRunner.add(list);
                this.m_JobRunner.start();
                this.m_JobRunner.stop();
                if (!this.isStopped()) {
                    AggregateableEvaluation evalAgg = new AggregateableEvaluation(this.m_Data);
                    evalAgg.setDiscardPredictions(this.m_DiscardPredictions);
                    this.m_Evaluations = new Evaluation[this.m_JobRunner.getJobs().size()];
                    for (int i = 0; i < this.m_JobRunner.getJobs().size(); ++i) {
                        job = (WekaCrossValidationJob)((Object)this.m_JobRunner.getJobs().get(i));
                        if (job.getEvaluation() == null) {
                            result = "Fold #" + (i + 1) + " failed to evaluate";
                            if (!job.hasExecutionError()) {
                                result = result + "?";
                                break;
                            }
                            result = result + ":\n" + job.getExecutionError();
                            break;
                        }
                        evalAgg.aggregate(job.getEvaluation());
                        this.m_Evaluations[i] = job.getEvaluation();
                        job.cleanUp();
                    }
                    this.m_Evaluation = evalAgg;
                }
                list.cleanUp();
                this.m_JobRunner.cleanUp();
                this.m_JobRunner = null;
            }
        }
        catch (Exception e) {
            result = Utils.handleException((LoggingSupporter)this, (String)"Failed to cross-validate classifier: ", (Throwable)e);
        }
        this.m_OriginalIndices = indices;
        return result;
    }

    public boolean isStopped() {
        return this.m_Stopped;
    }

    public void stopExecution() {
        this.getLogger().severe("Execution stopped");
        if (this.m_JobRunner != null) {
            this.m_JobRunner.terminate(this.m_WaitForJobs);
        }
    }
}

