package adams.flow.transformer.indexedsplitsrunsgenerator;

import adams.core.MessageCollection;
import adams.core.OptionalRandomizable;
import adams.core.QuickInfoHelper;
import adams.data.binning.BinnableInstances;
import adams.data.indexedsplits.IndexedSplit;
import adams.data.indexedsplits.IndexedSplits;
import adams.data.indexedsplits.IndexedSplitsRun;
import adams.data.indexedsplits.IndexedSplitsRuns;
import adams.data.indexedsplits.SplitIndices;
import adams.data.splitgenerator.generic.crossvalidation.CrossValidationGenerator;
import adams.data.splitgenerator.generic.crossvalidation.FoldPair;
import adams.data.splitgenerator.generic.randomization.DefaultRandomization;
import adams.data.splitgenerator.generic.randomization.PassThrough;
import adams.data.splitgenerator.generic.stratification.DefaultStratification;
import java.util.List;
import weka.core.Instances;
import weka.filters.supervised.instance.RemoveOutliers;

/* loaded from: input_file:adams/flow/transformer/indexedsplitsrunsgenerator/InstancesCrossValidationFoldGenerator.class */
public class InstancesCrossValidationFoldGenerator extends AbstractInstancesIndexedSplitsRunsGenerator implements OptionalRandomizable {
    private static final long serialVersionUID = -845552507613381226L;
    protected int m_NumFolds;
    protected boolean m_Randomize;
    protected long m_Seed;
    protected boolean m_Stratify;

    public String globalInfo() {
        return "Split generator that generates folds for cross-validation for Instances objects.";
    }

    @Override // adams.flow.transformer.indexedsplitsrunsgenerator.AbstractInstancesIndexedSplitsRunsGenerator
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add(RemoveOutliers.NUM_FOLDS, "numFolds", 10);
        this.m_OptionManager.add("seed", "seed", 1L);
        this.m_OptionManager.add("randomize", "randomize", true);
        this.m_OptionManager.add("stratify", "stratify", true);
    }

    public void setNumFolds(int i) {
        this.m_NumFolds = i;
        reset();
    }

    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public String numFoldsTipText() {
        return "The number of folds; use <2 for leave one out (LOO).";
    }

    public void setRandomize(boolean z) {
        this.m_Randomize = z;
        reset();
    }

    public boolean getRandomize() {
        return this.m_Randomize;
    }

    public String randomizeTipText() {
        return "If enabled, the data is randomized first.";
    }

    public void setSeed(long j) {
        this.m_Seed = j;
        reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value for the random number generator.";
    }

    public void setStratify(boolean z) {
        this.m_Stratify = z;
        reset();
    }

    public boolean getStratify() {
        return this.m_Stratify;
    }

    public String stratifyTipText() {
        return "If enabled, the folds get stratified in case of a nominal class attribute.";
    }

    @Override // adams.flow.transformer.indexedsplitsrunsgenerator.AbstractInstancesIndexedSplitsRunsGenerator
    public String getQuickInfo() {
        String str = super.getQuickInfo() + QuickInfoHelper.toString(this, "numFolds", Integer.valueOf(this.m_NumFolds), ", folds: ");
        if (this.m_Randomize) {
            str = str + QuickInfoHelper.toString(this, "seed", Long.valueOf(this.m_Seed), ", seed: ");
        }
        return str + QuickInfoHelper.toString(this, "stratify", this.m_Stratify ? "stratified" : "not stratified", ", ");
    }

    protected String check(Object obj) {
        String check = super.check(obj);
        if (check == null && ((Instances) obj).classIndex() == -1) {
            check = "No class attribute defined!";
        }
        return check;
    }

    protected IndexedSplitsRuns doGenerate(Object obj, MessageCollection messageCollection) {
        Instances instances = (Instances) obj;
        int numInstances = this.m_NumFolds < 2 ? instances.numInstances() : this.m_NumFolds;
        if (instances.numInstances() < numInstances) {
            messageCollection.add("Cannot have less data than folds: required=" + numInstances + ", provided=" + instances.numInstances());
            return null;
        }
        CrossValidationGenerator crossValidationGenerator = new CrossValidationGenerator();
        crossValidationGenerator.setNumFolds(this.m_NumFolds);
        if (this.m_Randomize) {
            DefaultRandomization defaultRandomization = new DefaultRandomization();
            defaultRandomization.setSeed(this.m_Seed);
            defaultRandomization.setLoggingLevel(this.m_LoggingLevel);
            crossValidationGenerator.setRandomization(defaultRandomization);
        } else {
            PassThrough passThrough = new PassThrough();
            passThrough.setLoggingLevel(this.m_LoggingLevel);
            crossValidationGenerator.setRandomization(passThrough);
        }
        if (this.m_Stratify && instances.classAttribute().isNominal() && numInstances < instances.numInstances()) {
            DefaultStratification defaultStratification = new DefaultStratification();
            defaultStratification.setLoggingLevel(this.m_LoggingLevel);
            crossValidationGenerator.setStratification(defaultStratification);
        } else {
            adams.data.splitgenerator.generic.stratification.PassThrough passThrough2 = new adams.data.splitgenerator.generic.stratification.PassThrough();
            passThrough2.setLoggingLevel(this.m_LoggingLevel);
            crossValidationGenerator.setStratification(passThrough2);
        }
        try {
            List generate = crossValidationGenerator.generate(BinnableInstances.toBinnableUsingClass(instances));
            IndexedSplitsRuns indexedSplitsRuns = new IndexedSplitsRuns();
            IndexedSplits indexedSplits = new IndexedSplits();
            indexedSplitsRuns.add(new IndexedSplitsRun(0, indexedSplits));
            for (int i = 0; i < numInstances; i++) {
                FoldPair foldPair = (FoldPair) generate.get(i);
                int[] array = foldPair.getTrain().getOriginalIndices().toArray();
                int[] array2 = foldPair.getTest().getOriginalIndices().toArray();
                IndexedSplit indexedSplit = new IndexedSplit(i);
                indexedSplit.add(new SplitIndices("train", array));
                indexedSplit.add(new SplitIndices("test", array2));
                indexedSplits.add(indexedSplit);
            }
            return indexedSplitsRuns;
        } catch (Exception e) {
            throw new IllegalStateException("Failed to create binnable Instances!", e);
        }
    }
}
