package adams.ml.evaluation;

import adams.data.binning.Binnable;
import adams.data.binning.BinnableDataset;
import adams.data.splitgenerator.CrossValidationFoldGenerator;
import adams.data.splitgenerator.generic.crossvalidation.CrossValidationGenerator;
import adams.data.splitgenerator.generic.crossvalidation.FoldPair;
import adams.data.splitgenerator.generic.randomization.DefaultRandomization;
import adams.data.splitgenerator.generic.randomization.PassThrough;
import adams.data.splitgenerator.generic.stratification.DefaultStratification;
import adams.data.spreadsheet.DataRow;
import adams.flow.container.TrainTestSetContainer;
import adams.ml.data.Dataset;
import adams.ml.data.DatasetView;
import gnu.trove.list.array.TIntArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;

/* loaded from: input_file:adams/ml/evaluation/DefaultCrossValidationFoldGenerator.class */
public class DefaultCrossValidationFoldGenerator extends AbstractSplitGenerator implements CrossValidationFoldGenerator<Dataset, TrainTestSetContainer> {
    private static final long serialVersionUID = -8387205583429213079L;
    protected int m_NumFolds;
    protected int m_ActualNumFolds;
    protected boolean m_Stratify;
    protected transient int m_CurrentFold;
    protected boolean m_Randomize;
    protected transient CrossValidationGenerator m_Generator;
    protected transient List<FoldPair<Binnable<DataRow>>> m_FoldPairs;

    public DefaultCrossValidationFoldGenerator() {
    }

    public DefaultCrossValidationFoldGenerator(Dataset dataset, int i, long j, boolean z) {
        this(dataset, i, j, true, z);
    }

    public DefaultCrossValidationFoldGenerator(Dataset dataset, int i, long j, boolean z, boolean z2) {
        setData(dataset);
        setSeed(j);
        setNumFolds(i);
        setStratify(z2);
        setRandomize(z);
    }

    public String globalInfo() {
        return "Generates cross-validation fold pairs. Leave-one-out is performed when specified folds <2.";
    }

    @Override // adams.ml.evaluation.AbstractSplitGenerator
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("num-folds", "numFolds", 10);
        this.m_OptionManager.add("randomize", "randomize", true);
        this.m_OptionManager.add("stratify", "stratify", true);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // adams.ml.evaluation.AbstractSplitGenerator
    public void reset() {
        super.reset();
        this.m_CurrentFold = 1;
        this.m_ActualNumFolds = -1;
        this.m_FoldPairs = null;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // adams.ml.evaluation.AbstractSplitGenerator, adams.data.splitgenerator.SplitGenerator
    public void setData(Dataset dataset) {
        super.setData(dataset);
        if (this.m_Data != null && getStratify() && this.m_Data.getClassAttributeIndices().length == 0) {
            throw new IllegalArgumentException("No class attribute set!");
        }
    }

    @Override // adams.data.splitgenerator.CrossValidationFoldGenerator
    public void setNumFolds(int i) {
        this.m_NumFolds = i;
        reset();
    }

    @Override // adams.data.splitgenerator.CrossValidationFoldGenerator
    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public String numFoldsTipText() {
        return "The number of folds; use <2 for leave one out (LOO).";
    }

    @Override // adams.data.splitgenerator.CrossValidationFoldGenerator
    public int getActualNumFolds() {
        return this.m_ActualNumFolds;
    }

    @Override // adams.data.splitgenerator.CrossValidationFoldGenerator
    public void setRandomize(boolean z) {
        this.m_Randomize = z;
        reset();
    }

    @Override // adams.data.splitgenerator.CrossValidationFoldGenerator
    public boolean getRandomize() {
        return this.m_Randomize;
    }

    public String randomizeTipText() {
        return "If enabled, the data is randomized first.";
    }

    public void setStratify(boolean z) {
        this.m_Stratify = z;
        reset();
    }

    public boolean getStratify() {
        return this.m_Stratify;
    }

    public String stratifyTipText() {
        return "If enabled, the folds get stratified in case of a nominal class attribute.";
    }

    @Override // adams.ml.evaluation.AbstractSplitGenerator
    protected boolean canRandomize() {
        return this.m_Randomize;
    }

    @Override // adams.ml.evaluation.AbstractSplitGenerator
    protected boolean checkNext() {
        return this.m_CurrentFold <= this.m_ActualNumFolds;
    }

    @Override // adams.ml.evaluation.AbstractSplitGenerator
    protected void doInitializeIterator() {
        if (this.m_Data == null) {
            throw new IllegalStateException("No data provided!");
        }
        if (this.m_NumFolds < 2) {
            this.m_ActualNumFolds = this.m_Data.getRowCount();
        } else {
            this.m_ActualNumFolds = this.m_NumFolds;
        }
        if (this.m_Data.getRowCount() < this.m_ActualNumFolds) {
            throw new IllegalArgumentException("Cannot have less data than folds: required=" + this.m_ActualNumFolds + ", provided=" + this.m_Data.getRowCount());
        }
        this.m_Generator = new CrossValidationGenerator();
        this.m_Generator.setNumFolds(this.m_NumFolds);
        if (canRandomize()) {
            DefaultRandomization defaultRandomization = new DefaultRandomization();
            defaultRandomization.setSeed(this.m_Seed);
            defaultRandomization.setLoggingLevel(this.m_LoggingLevel);
            this.m_Generator.setRandomization(defaultRandomization);
        } else {
            PassThrough passThrough = new PassThrough();
            passThrough.setLoggingLevel(this.m_LoggingLevel);
            this.m_Generator.setRandomization(passThrough);
        }
        boolean z = !this.m_Data.isNumeric(this.m_Data.getClassAttributeIndices()[0]);
        if (this.m_Stratify && z && this.m_ActualNumFolds < this.m_Data.getRowCount()) {
            DefaultStratification defaultStratification = new DefaultStratification();
            defaultStratification.setLoggingLevel(this.m_LoggingLevel);
            this.m_Generator.setStratification(defaultStratification);
        } else {
            adams.data.splitgenerator.generic.stratification.PassThrough passThrough2 = new adams.data.splitgenerator.generic.stratification.PassThrough();
            passThrough2.setLoggingLevel(this.m_LoggingLevel);
            this.m_Generator.setStratification(passThrough2);
        }
    }

    @Override // adams.ml.evaluation.AbstractSplitGenerator
    protected TrainTestSetContainer createNext() {
        Dataset dataset;
        Dataset dataset2;
        if (this.m_CurrentFold > this.m_ActualNumFolds) {
            throw new NoSuchElementException("No more folds available!");
        }
        if (this.m_FoldPairs == null) {
            try {
                this.m_FoldPairs = this.m_Generator.generate(BinnableDataset.toBinnableUsingClass(this.m_Data, this.m_Data.getClassAttributeIndices()[0]));
                this.m_OriginalIndices = new TIntArrayList();
                Iterator<FoldPair<Binnable<DataRow>>> it = this.m_FoldPairs.iterator();
                while (it.hasNext()) {
                    this.m_OriginalIndices.addAll(it.next().getTest().getOriginalIndices());
                }
            } catch (Exception e) {
                throw new IllegalStateException("Failed to create binnable Dataset!", e);
            }
        }
        FoldPair<Binnable<DataRow>> foldPair = this.m_FoldPairs.get(this.m_CurrentFold - 1);
        int[] array = foldPair.getTrain().getOriginalIndices().toArray();
        int[] array2 = foldPair.getTest().getOriginalIndices().toArray();
        if (this.m_UseViews) {
            dataset = new DatasetView(this.m_Data, array, null);
            dataset2 = new DatasetView(this.m_Data, array2, null);
        } else {
            dataset = BinnableDataset.toDataset(foldPair.getTrain().getData());
            dataset2 = BinnableDataset.toDataset(foldPair.getTest().getData());
        }
        TrainTestSetContainer trainTestSetContainer = new TrainTestSetContainer(dataset, dataset2, Long.valueOf(this.m_Seed), Integer.valueOf(this.m_CurrentFold), Integer.valueOf(this.m_ActualNumFolds), array, array2);
        this.m_CurrentFold++;
        if (this.m_CurrentFold > this.m_ActualNumFolds) {
            this.m_FoldPairs = null;
        }
        return trainTestSetContainer;
    }

    @Override // adams.data.splitgenerator.CrossValidationFoldGenerator
    public int[] crossValidationIndices() {
        return this.m_OriginalIndices.toArray();
    }

    @Override // adams.ml.evaluation.AbstractSplitGenerator, adams.data.splitgenerator.SplitGenerator
    public String toString() {
        return super.toString() + ", numFolds=" + this.m_NumFolds + ", stratify=" + this.m_Stratify;
    }
}
