package adams.gui.tools.wekamultiexperimenter.experiment;

import adams.core.ThreadLimiter;
import adams.core.option.OptionUtils;
import adams.data.spreadsheet.DefaultSpreadSheet;
import adams.data.spreadsheet.SpreadSheet;
import adams.gui.tools.wekamultiexperimenter.analysis.DefaultAnalysisPanel;
import adams.gui.tools.wekamultiexperimenter.experiment.AbstractExperiment;
import adams.multiprocess.WekaCrossValidationExecution;
import weka.classifiers.Classifier;
import weka.classifiers.CrossValidationFoldGenerator;
import weka.classifiers.DefaultCrossValidationFoldGenerator;
import weka.core.Instances;

/* loaded from: input_file:adams/gui/tools/wekamultiexperimenter/experiment/CrossValidationExperiment.class */
public class CrossValidationExperiment extends AbstractExperiment {
    private static final long serialVersionUID = -4147644361063132314L;
    protected int m_Folds;
    protected CrossValidationFoldGenerator m_Generator;

    /* loaded from: input_file:adams/gui/tools/wekamultiexperimenter/experiment/CrossValidationExperiment$CrossValidationExperimentJob.class */
    public static class CrossValidationExperimentJob extends AbstractExperiment.AbstractExperimentJob<CrossValidationExperiment> {
        private static final long serialVersionUID = -4979824848864995696L;
        protected WekaCrossValidationExecution m_CrossValidation;

        public CrossValidationExperimentJob(CrossValidationExperiment crossValidationExperiment, int i, Classifier classifier, Instances instances) {
            super(crossValidationExperiment, i, classifier, instances);
        }

        @Override // adams.gui.tools.wekamultiexperimenter.experiment.AbstractExperiment.AbstractExperimentJob
        protected void evaluate() {
            ((CrossValidationExperiment) this.m_Owner).log("Run " + this.m_Run + " [start]: " + this.m_Data.relationName() + " on " + AbstractExperiment.shortenCommandLine(this.m_Classifier));
            boolean z = ((CrossValidationExperiment) this.m_Owner).getDatasets().length == 1 && ((CrossValidationExperiment) this.m_Owner).getClassifiers().length == 1;
            this.m_CrossValidation = new WekaCrossValidationExecution();
            this.m_CrossValidation.setClassifier(this.m_Classifier);
            this.m_CrossValidation.setData(this.m_Data);
            this.m_CrossValidation.setFolds(((CrossValidationExperiment) this.m_Owner).getFolds());
            this.m_CrossValidation.setGenerator((CrossValidationFoldGenerator) OptionUtils.shallowCopy(((CrossValidationExperiment) this.m_Owner).getGenerator()));
            this.m_CrossValidation.setSeed(this.m_Run);
            this.m_CrossValidation.setDiscardPredictions(false);
            this.m_CrossValidation.setNumThreads(1);
            if (z && (((CrossValidationExperiment) this.m_Owner).getJobRunner() instanceof ThreadLimiter)) {
                this.m_CrossValidation.setNumThreads(((CrossValidationExperiment) this.m_Owner).getJobRunner().getNumThreads());
            }
            this.m_CrossValidation.setSeparateFolds(true);
            this.m_CrossValidation.setStatusMessageHandler(((CrossValidationExperiment) this.m_Owner).getStatusMessageHandler());
            this.m_CrossValidation.setWaitForJobs(false);
            if (this.m_CrossValidation.execute() == null) {
                SpreadSheet defaultSpreadSheet = new DefaultSpreadSheet();
                for (int i = 0; i < ((CrossValidationExperiment) this.m_Owner).getFolds(); i++) {
                    addMetrics(defaultSpreadSheet, this.m_Run, this.m_Classifier, this.m_Data, this.m_CrossValidation.getEvaluations()[i]);
                    addMetric(defaultSpreadSheet, DefaultAnalysisPanel.KEY_FOLD, Integer.valueOf(i));
                }
                ((CrossValidationExperiment) this.m_Owner).appendResults(defaultSpreadSheet);
            }
            ((CrossValidationExperiment) this.m_Owner).log("Run " + this.m_Run + " [end]: " + this.m_Data.relationName() + " on " + AbstractExperiment.shortenCommandLine(this.m_Classifier));
        }

        public void stopExecution() {
            super.stopExecution();
            if (this.m_CrossValidation != null) {
                this.m_CrossValidation.stopExecution();
            }
        }
    }

    public String globalInfo() {
        return "Performs cross-validation on each classifier/dataset combination.";
    }

    @Override // adams.gui.tools.wekamultiexperimenter.experiment.AbstractExperiment
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("folds", "folds", 10, -1, (Number) null);
        this.m_OptionManager.add("generator", "generator", new DefaultCrossValidationFoldGenerator());
    }

    public void setFolds(int i) {
        this.m_Folds = i;
        reset();
    }

    public int getFolds() {
        return this.m_Folds;
    }

    public String foldsTipText() {
        return "The number of folds to perform.";
    }

    public void setGenerator(CrossValidationFoldGenerator crossValidationFoldGenerator) {
        this.m_Generator = crossValidationFoldGenerator;
        reset();
    }

    public CrossValidationFoldGenerator getGenerator() {
        return this.m_Generator;
    }

    public String generatorTipText() {
        return "The scheme to use for generating the folds; the actor options take precedence over the scheme's ones.";
    }

    @Override // adams.gui.tools.wekamultiexperimenter.experiment.AbstractExperiment
    protected boolean isComplete(int[] iArr) {
        return iArr.length == this.m_Folds;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // adams.gui.tools.wekamultiexperimenter.experiment.AbstractExperiment
    public CrossValidationExperimentJob evaluate(int i, Classifier classifier, Instances instances) {
        return new CrossValidationExperimentJob(this, i, classifier, instances);
    }
}
