package adams.multiprocess;

import adams.core.Performance;
import adams.core.StatusMessageHandler;
import adams.core.Stoppable;
import adams.core.ThreadLimiter;
import adams.core.Utils;
import adams.core.logging.CustomLoggingLevelObject;
import adams.core.option.OptionUtils;
import adams.data.weka.InstancesViewSupporter;
import adams.flow.container.WekaTrainTestSetContainer;
import adams.flow.standalone.JobRunnerSetup;
import java.util.Random;
import weka.classifiers.AggregateableEvaluation;
import weka.classifiers.Classifier;
import weka.classifiers.CrossValidationFoldGenerator;
import weka.classifiers.CrossValidationHelper;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.output.prediction.AbstractOutput;
import weka.core.Instances;

/* loaded from: input_file:adams/multiprocess/WekaCrossValidationExecution.class */
public class WekaCrossValidationExecution extends CustomLoggingLevelObject implements Stoppable, InstancesViewSupporter {
    private static final long serialVersionUID = 2021758441076652982L;
    protected StringBuffer m_OutputBuffer;
    protected int m_Folds;
    protected boolean m_SeparateFolds;
    protected long m_Seed;
    protected boolean m_UseViews;
    protected boolean m_DiscardPredictions;
    protected int m_NumThreads;
    protected int m_ActualNumThreads;
    protected Evaluation m_Evaluation;
    protected Evaluation[] m_Evaluations;
    protected int[] m_OriginalIndices;
    protected boolean m_Stopped;
    protected Classifier m_Classifier = null;
    protected Instances m_Data = null;
    protected AbstractOutput m_Output = null;
    protected transient JobRunner m_JobRunner = null;
    protected transient JobRunnerSetup m_JobRunnerSetup = null;
    protected StatusMessageHandler m_StatusMessageHandler = null;
    protected boolean m_WaitForJobs = true;

    public void setJobRunnerSetup(JobRunnerSetup jobRunnerSetup) {
        this.m_JobRunnerSetup = jobRunnerSetup;
    }

    public JobRunnerSetup getJobRunnerSetup() {
        return this.m_JobRunnerSetup;
    }

    public void setWaitForJobs(boolean z) {
        this.m_WaitForJobs = z;
    }

    public boolean getWaitForJobs() {
        return this.m_WaitForJobs;
    }

    public void setClassifier(Classifier classifier) {
        this.m_Classifier = classifier;
    }

    public Classifier getClassifier() {
        return this.m_Classifier;
    }

    public void setData(Instances instances) {
        this.m_Data = instances;
    }

    public Instances getData() {
        return this.m_Data;
    }

    public void setOutput(AbstractOutput abstractOutput) {
        this.m_Output = abstractOutput;
    }

    public AbstractOutput getOutput() {
        return this.m_Output;
    }

    public void setFolds(int i) {
        if (i == -1 || i >= 2) {
            this.m_Folds = i;
        } else {
            getLogger().severe("Number of folds must be >=2 or -1 for LOOCV, provided: " + i);
        }
    }

    public int getFolds() {
        return this.m_Folds;
    }

    public void setSeparateFolds(boolean z) {
        this.m_SeparateFolds = z;
    }

    public boolean getSeparateFolds() {
        return this.m_SeparateFolds;
    }

    public void setSeed(long j) {
        this.m_Seed = j;
    }

    public long getSeed() {
        return this.m_Seed;
    }

    @Override // adams.data.weka.InstancesViewSupporter
    public void setUseViews(boolean z) {
        this.m_UseViews = z;
    }

    @Override // adams.data.weka.InstancesViewSupporter
    public boolean getUseViews() {
        return this.m_UseViews;
    }

    public void setDiscardPredictions(boolean z) {
        this.m_DiscardPredictions = z;
    }

    public boolean getDiscardPredictions() {
        return this.m_DiscardPredictions;
    }

    public void setNumThreads(int i) {
        this.m_NumThreads = i;
    }

    public int getNumThreads() {
        return this.m_NumThreads;
    }

    public void setStatusMessageHandler(StatusMessageHandler statusMessageHandler) {
        this.m_StatusMessageHandler = statusMessageHandler;
    }

    public StatusMessageHandler getStatusMessageHandler() {
        return this.m_StatusMessageHandler;
    }

    protected void initOutputBuffer() {
        this.m_OutputBuffer = new StringBuffer();
        if (this.m_Output != null) {
            try {
                this.m_Output = (AbstractOutput) OptionUtils.forAnyCommandLine(AbstractOutput.class, OptionUtils.getCommandLine(this.m_Output));
                this.m_Output.setBuffer(this.m_OutputBuffer);
            } catch (Exception e) {
                throw new IllegalStateException("Failed to create copy of output!", e);
            }
        }
    }

    public StringBuffer getOutputBuffer() {
        return this.m_OutputBuffer;
    }

    public Evaluation getEvaluation() {
        return this.m_Evaluation;
    }

    public Evaluation[] getEvaluations() {
        return this.m_Evaluations;
    }

    public int[] getOriginalIndices() {
        return this.m_OriginalIndices;
    }

    public boolean isSingleThreaded() {
        return this.m_ActualNumThreads == 0;
    }

    public String execute() {
        String str = null;
        int[] iArr = null;
        this.m_Evaluation = null;
        this.m_Evaluations = null;
        try {
        } catch (Exception e) {
            str = Utils.handleException(this, "Failed to cross-validate classifier: ", e);
        }
        if (this.m_Classifier == null) {
            throw new IllegalStateException("Classifier '" + getClassifier() + "' not found!");
        }
        if (isLoggingEnabled()) {
            getLogger().info(OptionUtils.getCommandLine(this.m_Classifier));
        }
        int i = this.m_Folds;
        if (i == -1) {
            i = this.m_Data.numInstances();
        }
        this.m_ActualNumThreads = Performance.determineNumThreads(this.m_NumThreads);
        if (!this.m_DiscardPredictions) {
            iArr = CrossValidationHelper.crossValidationIndices(this.m_Data, i, new Random(this.m_Seed));
        }
        CrossValidationFoldGenerator crossValidationFoldGenerator = new CrossValidationFoldGenerator(this.m_Data, i, this.m_Seed, true);
        crossValidationFoldGenerator.setUseViews(this.m_UseViews);
        if (this.m_ActualNumThreads != 1 || this.m_SeparateFolds) {
            if (this.m_JobRunnerSetup == null) {
                this.m_JobRunner = new LocalJobRunner();
            } else {
                this.m_JobRunner = this.m_JobRunnerSetup.newInstance();
            }
            if (this.m_JobRunner instanceof ThreadLimiter) {
                this.m_JobRunner.setNumThreads(this.m_NumThreads);
            }
            JobList jobList = new JobList();
            while (crossValidationFoldGenerator.hasNext()) {
                WekaTrainTestSetContainer next = crossValidationFoldGenerator.next();
                jobList.add(new WekaCrossValidationJob((Classifier) OptionUtils.shallowCopy(this.m_Classifier), (Instances) next.getValue("Train"), (Instances) next.getValue(WekaTrainTestSetContainer.VALUE_TEST), ((Integer) next.getValue(WekaTrainTestSetContainer.VALUE_FOLD_NUMBER)).intValue(), this.m_DiscardPredictions, this.m_StatusMessageHandler));
            }
            this.m_JobRunner.add(jobList);
            this.m_JobRunner.start();
            this.m_JobRunner.stop();
            if (!isStopped()) {
                AggregateableEvaluation aggregateableEvaluation = new AggregateableEvaluation(this.m_Data);
                aggregateableEvaluation.setDiscardPredictions(this.m_DiscardPredictions);
                this.m_Evaluations = new Evaluation[this.m_JobRunner.getJobs().size()];
                int i2 = 0;
                while (true) {
                    if (i2 >= this.m_JobRunner.getJobs().size()) {
                        break;
                    }
                    WekaCrossValidationJob wekaCrossValidationJob = (WekaCrossValidationJob) this.m_JobRunner.getJobs().get(i2);
                    if (wekaCrossValidationJob.getEvaluation() == null) {
                        String str2 = "Fold #" + (i2 + 1) + " failed to evaluate";
                        str = !wekaCrossValidationJob.hasExecutionError() ? str2 + "?" : str2 + ":\n" + wekaCrossValidationJob.getExecutionError();
                    } else {
                        aggregateableEvaluation.aggregate(wekaCrossValidationJob.getEvaluation());
                        this.m_Evaluations[i2] = wekaCrossValidationJob.getEvaluation();
                        wekaCrossValidationJob.cleanUp();
                        i2++;
                    }
                }
                this.m_Evaluation = aggregateableEvaluation;
            }
            jobList.cleanUp();
            this.m_JobRunner.cleanUp();
            this.m_JobRunner = null;
        } else {
            initOutputBuffer();
            if (this.m_Output != null) {
                this.m_Output.setHeader(this.m_Data);
                this.m_Output.printHeader();
            }
            Evaluation evaluation = new Evaluation(this.m_Data);
            evaluation.setDiscardPredictions(this.m_DiscardPredictions);
            int i3 = 0;
            while (crossValidationFoldGenerator.hasNext() && !isStopped()) {
                if (this.m_StatusMessageHandler != null) {
                    this.m_StatusMessageHandler.showStatus("Fold " + i3 + "/" + i + ": '" + this.m_Data.relationName() + "' using " + OptionUtils.getCommandLine(this.m_Classifier));
                }
                WekaTrainTestSetContainer next2 = crossValidationFoldGenerator.next();
                Instances instances = (Instances) next2.getValue("Train");
                Instances instances2 = (Instances) next2.getValue(WekaTrainTestSetContainer.VALUE_TEST);
                Classifier classifier = (Classifier) OptionUtils.shallowCopy(this.m_Classifier);
                classifier.buildClassifier(instances);
                evaluation.setPriors(instances);
                evaluation.evaluateModel(classifier, instances2, new Object[]{this.m_Output});
                i3++;
            }
            if (this.m_Output != null) {
                this.m_Output.printFooter();
            }
            if (!isStopped()) {
                this.m_Evaluation = evaluation;
            }
        }
        this.m_OriginalIndices = iArr;
        return str;
    }

    public boolean isStopped() {
        return this.m_Stopped;
    }

    public void stopExecution() {
        getLogger().severe("Execution stopped");
        if (this.m_JobRunner != null) {
            this.m_JobRunner.terminate(this.m_WaitForJobs);
        }
    }
}
