/*
 * Decompiled with CFR 0.152.
 */
package adams.multiprocess;

import adams.core.MessageCollection;
import adams.core.ObjectCopyHelper;
import adams.core.Performance;
import adams.core.StatusMessageHandler;
import adams.core.Stoppable;
import adams.core.ThreadLimiter;
import adams.core.logging.CustomLoggingLevelObject;
import adams.core.logging.LoggingHelper;
import adams.core.logging.LoggingSupporter;
import adams.core.option.OptionUtils;
import adams.data.weka.InstancesViewSupporter;
import adams.flow.container.WekaTrainTestSetContainer;
import adams.flow.core.Actor;
import adams.flow.core.FlowContextHandler;
import adams.flow.standalone.JobRunnerSetup;
import adams.multiprocess.Job;
import adams.multiprocess.JobList;
import adams.multiprocess.JobRunner;
import adams.multiprocess.LocalJobRunner;
import adams.multiprocess.WekaCrossValidationJob;
import weka.classifiers.AggregateEvaluations;
import weka.classifiers.Classifier;
import weka.classifiers.CrossValidationFoldGenerator;
import weka.classifiers.DefaultCrossValidationFoldGenerator;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.output.prediction.AbstractOutput;
import weka.core.Instances;

public class WekaCrossValidationExecution
extends CustomLoggingLevelObject
implements Stoppable,
InstancesViewSupporter,
ThreadLimiter,
FlowContextHandler {
    private static final long serialVersionUID = 2021758441076652982L;
    protected Classifier m_Classifier = null;
    protected Instances m_Data = null;
    protected AbstractOutput m_Output = null;
    protected StringBuffer m_OutputBuffer;
    protected int m_Folds;
    protected boolean m_SeparateFolds;
    protected long m_Seed;
    protected boolean m_UseViews;
    protected CrossValidationFoldGenerator m_Generator = new DefaultCrossValidationFoldGenerator();
    protected boolean m_DiscardPredictions;
    protected int m_NumThreads;
    protected int m_ActualNumThreads;
    protected transient JobRunnerSetup m_JobRunnerSetup = null;
    protected transient JobRunner m_JobRunner = null;
    protected transient JobRunner m_ActualJobRunner = null;
    protected Evaluation m_Evaluation;
    protected Evaluation[] m_Evaluations;
    protected Classifier[] m_Classifiers;
    protected int[] m_OriginalIndices;
    protected boolean m_Stopped;
    protected StatusMessageHandler m_StatusMessageHandler = null;
    protected boolean m_WaitForJobs = true;
    protected transient Actor m_FlowContext = null;

    public void setFlowContext(Actor value) {
        this.m_FlowContext = value;
    }

    public Actor getFlowContext() {
        return this.m_FlowContext;
    }

    public void setJobRunnerSetup(JobRunnerSetup value) {
        this.m_JobRunnerSetup = value;
    }

    public JobRunnerSetup getJobRunnerSetup() {
        return this.m_JobRunnerSetup;
    }

    public void setJobRunner(JobRunner value) {
        this.m_JobRunner = value;
    }

    public JobRunner getJobRunner() {
        return this.m_JobRunner;
    }

    public void setWaitForJobs(boolean value) {
        this.m_WaitForJobs = value;
    }

    public boolean getWaitForJobs() {
        return this.m_WaitForJobs;
    }

    public void setClassifier(Classifier value) {
        this.m_Classifier = value;
    }

    public Classifier getClassifier() {
        return this.m_Classifier;
    }

    public Classifier[] getClassifiers() {
        return this.m_Classifiers;
    }

    public void setData(Instances value) {
        this.m_Data = value;
    }

    public Instances getData() {
        return this.m_Data;
    }

    public void setOutput(AbstractOutput value) {
        this.m_Output = value;
    }

    public AbstractOutput getOutput() {
        return this.m_Output;
    }

    public void setFolds(int value) {
        if (value < 2) {
            value = -1;
        }
        this.m_Folds = value;
    }

    public int getFolds() {
        return this.m_Folds;
    }

    public void setSeparateFolds(boolean value) {
        this.m_SeparateFolds = value;
    }

    public boolean getSeparateFolds() {
        return this.m_SeparateFolds;
    }

    public void setSeed(long value) {
        this.m_Seed = value;
    }

    public long getSeed() {
        return this.m_Seed;
    }

    @Override
    public void setUseViews(boolean value) {
        this.m_UseViews = value;
    }

    @Override
    public boolean getUseViews() {
        return this.m_UseViews;
    }

    public void setGenerator(CrossValidationFoldGenerator value) {
        this.m_Generator = value;
    }

    public CrossValidationFoldGenerator getGenerator() {
        return this.m_Generator;
    }

    public void setDiscardPredictions(boolean value) {
        this.m_DiscardPredictions = value;
    }

    public boolean getDiscardPredictions() {
        return this.m_DiscardPredictions;
    }

    public void setNumThreads(int value) {
        this.m_NumThreads = value;
    }

    public int getNumThreads() {
        return this.m_NumThreads;
    }

    public void setStatusMessageHandler(StatusMessageHandler value) {
        this.m_StatusMessageHandler = value;
    }

    public StatusMessageHandler getStatusMessageHandler() {
        return this.m_StatusMessageHandler;
    }

    protected void initOutputBuffer() {
        this.m_OutputBuffer = new StringBuffer();
        if (this.m_Output != null) {
            try {
                this.m_Output = (AbstractOutput)OptionUtils.forAnyCommandLine(AbstractOutput.class, (String)OptionUtils.getCommandLine((Object)this.m_Output));
                this.m_Output.setBuffer(this.m_OutputBuffer);
            }
            catch (Exception e) {
                throw new IllegalStateException("Failed to create copy of output!", e);
            }
        }
    }

    public StringBuffer getOutputBuffer() {
        return this.m_OutputBuffer;
    }

    public Evaluation getEvaluation() {
        return this.m_Evaluation;
    }

    public Evaluation[] getEvaluations() {
        return this.m_Evaluations;
    }

    public int[] getOriginalIndices() {
        return this.m_OriginalIndices;
    }

    public boolean isSingleThreaded() {
        return this.m_ActualNumThreads == 0;
    }

    public String execute() {
        MessageCollection result = new MessageCollection();
        int[] indices = null;
        this.m_Evaluation = null;
        this.m_Evaluations = null;
        this.m_Classifiers = null;
        try {
            if (this.m_Classifier == null) {
                throw new IllegalStateException("Classifier '" + this.getClassifier() + "' not found!");
            }
            if (this.isLoggingEnabled()) {
                this.getLogger().info(OptionUtils.getCommandLine((Object)this.m_Classifier));
            }
            this.m_ActualNumThreads = Performance.determineNumThreads((int)this.m_NumThreads);
            CrossValidationFoldGenerator generator = (CrossValidationFoldGenerator)OptionUtils.shallowCopy((Object)this.m_Generator);
            generator.setData(this.m_Data);
            generator.setNumFolds(this.m_Folds);
            generator.setSeed(this.m_Seed);
            generator.setStratify(true);
            generator.setUseViews(this.m_UseViews);
            generator.initializeIterator();
            int folds = generator.getActualNumFolds();
            if (this.m_ActualNumThreads == 1 && !this.m_SeparateFolds) {
                this.initOutputBuffer();
                if (this.m_Output != null) {
                    this.m_Output.setHeader(this.m_Data);
                    this.m_Output.printHeader();
                }
                Evaluation eval = new Evaluation(this.m_Data);
                eval.setDiscardPredictions(this.m_DiscardPredictions);
                int current = 0;
                while (generator.hasNext() && !this.isStopped()) {
                    if (this.m_StatusMessageHandler != null) {
                        this.m_StatusMessageHandler.showStatus("Fold " + current + "/" + folds + ": '" + this.m_Data.relationName() + "' using " + OptionUtils.getCommandLine((Object)this.m_Classifier));
                    }
                    WekaTrainTestSetContainer cont = generator.next();
                    Instances train = (Instances)cont.getValue("Train");
                    Instances test = (Instances)cont.getValue("Test");
                    Classifier cls = (Classifier)OptionUtils.shallowCopy((Object)this.m_Classifier);
                    if (cls instanceof FlowContextHandler) {
                        ((FlowContextHandler)cls).setFlowContext(this.m_FlowContext);
                    }
                    cls.buildClassifier(train);
                    eval.setPriors(train);
                    eval.evaluateModel(cls, test, new Object[]{this.m_Output});
                    ++current;
                }
                if (this.m_Output != null) {
                    this.m_Output.printFooter();
                }
                if (!this.isStopped()) {
                    this.m_Evaluation = eval;
                }
            } else {
                WekaCrossValidationJob job;
                if (this.m_DiscardPredictions) {
                    throw new IllegalStateException("Cannot discard predictions in parallel mode, as they are used for aggregating the statistics!");
                }
                boolean setNumThreads = true;
                if (this.m_JobRunnerSetup != null) {
                    this.m_ActualJobRunner = this.m_JobRunnerSetup.newInstance();
                    setNumThreads = false;
                } else if (this.m_JobRunner != null) {
                    this.m_ActualJobRunner = (JobRunner)ObjectCopyHelper.copyObject((Object)this.m_JobRunner);
                    setNumThreads = false;
                } else {
                    this.m_ActualJobRunner = new LocalJobRunner();
                }
                if (setNumThreads && this.m_ActualJobRunner instanceof ThreadLimiter) {
                    ((ThreadLimiter)this.m_ActualJobRunner).setNumThreads(this.m_NumThreads);
                }
                JobList list = new JobList();
                while (generator.hasNext()) {
                    WekaTrainTestSetContainer cont = generator.next();
                    job = new WekaCrossValidationJob((Classifier)OptionUtils.shallowCopy((Object)this.m_Classifier), (Instances)cont.getValue("Train"), (Instances)cont.getValue("Test"), (Integer)cont.getValue("FoldNumber"), this.m_DiscardPredictions, this.m_StatusMessageHandler);
                    job.setFlowContext(this.m_FlowContext);
                    list.add((Job)job);
                }
                this.m_ActualJobRunner.add(list);
                this.m_ActualJobRunner.start();
                this.m_ActualJobRunner.stop();
                if (!this.isStopped()) {
                    AggregateEvaluations evalAgg = new AggregateEvaluations();
                    this.m_Evaluations = new Evaluation[this.m_ActualJobRunner.getJobs().size()];
                    this.m_Classifiers = new Classifier[this.m_ActualJobRunner.getJobs().size()];
                    for (int i = 0; i < this.m_ActualJobRunner.getJobs().size(); ++i) {
                        job = (WekaCrossValidationJob)((Object)this.m_ActualJobRunner.getJobs().get(i));
                        if (job.getEvaluation() == null) {
                            result.add("Fold #" + (i + 1) + " failed to evaluate" + (job.hasExecutionError() ? job.getExecutionError() : "?"));
                            break;
                        }
                        evalAgg.add(job.getEvaluation());
                        this.m_Evaluations[i] = job.getEvaluation();
                        this.m_Classifiers[i] = job.getClassifier();
                        job.cleanUp();
                    }
                    this.m_Evaluation = evalAgg.aggregated();
                    if (this.m_Evaluation == null) {
                        if (evalAgg.hasLastError()) {
                            result.add(evalAgg.getLastError());
                        } else {
                            result.add("Failed to aggregate evaluations!");
                        }
                    }
                }
                list.cleanUp();
                this.m_ActualJobRunner.cleanUp();
                this.m_ActualJobRunner = null;
            }
            if (!this.m_DiscardPredictions) {
                indices = generator.crossValidationIndices();
            }
        }
        catch (Exception e) {
            result.add(LoggingHelper.handleException((LoggingSupporter)this, (String)"Failed to cross-validate classifier: ", (Throwable)e));
        }
        this.m_OriginalIndices = indices;
        if (result.isEmpty()) {
            return null;
        }
        return result.toString();
    }

    public boolean isStopped() {
        return this.m_Stopped;
    }

    public void stopExecution() {
        this.getLogger().severe("Execution stopped");
        if (this.m_ActualJobRunner != null) {
            this.m_ActualJobRunner.terminate(this.m_WaitForJobs);
        }
    }
}

