package adams.data.splitgenerator.generic.crossvalidation;

import adams.core.logging.CustomLoggingLevelObject;
import adams.data.binning.Binnable;
import adams.data.binning.operation.Wrapping;
import adams.data.splitgenerator.generic.core.Subset;
import adams.data.splitgenerator.generic.randomization.DefaultRandomization;
import adams.data.splitgenerator.generic.randomization.Randomization;
import adams.data.splitgenerator.generic.stratification.DefaultStratification;
import adams.data.splitgenerator.generic.stratification.Stratification;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:adams/data/splitgenerator/generic/crossvalidation/CrossValidationGenerator.class */
public class CrossValidationGenerator extends CustomLoggingLevelObject {
    private static final long serialVersionUID = 6906260013695977045L;
    protected int m_NumFolds = 10;
    protected Randomization m_Randomization = new DefaultRandomization();
    protected Stratification m_Stratification = new DefaultStratification();

    public CrossValidationGenerator() {
        reset();
    }

    public void reset() {
        this.m_Randomization.reset();
        this.m_Stratification.reset();
    }

    public void setNumFolds(int i) {
        this.m_NumFolds = i;
        reset();
    }

    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public void setRandomization(Randomization randomization) {
        this.m_Randomization = randomization;
        reset();
    }

    public Randomization getRandomization() {
        return this.m_Randomization;
    }

    public void setStratification(Stratification stratification) {
        this.m_Stratification = stratification;
        reset();
    }

    public Stratification getStratification() {
        return this.m_Stratification;
    }

    public <T> List<FoldPair<Binnable<T>>> generate(List<Binnable<T>> list) {
        ArrayList arrayList = new ArrayList();
        int size = this.m_NumFolds < 2 ? list.size() : this.m_NumFolds;
        List<Binnable<T>> stratify = this.m_Stratification.stratify(this.m_Randomization.randomize(Wrapping.addTmpIndex(list)), size);
        for (int i = 0; i < size; i++) {
            List trainCV = trainCV(stratify, size, i, this.m_Randomization);
            Subset subset = new Subset(trainCV, Wrapping.getTmpIndices(trainCV));
            List testCV = testCV(stratify, size, i);
            arrayList.add(new FoldPair(i, subset, new Subset(testCV, Wrapping.getTmpIndices(testCV))));
        }
        Wrapping.removeTmpIndex(stratify);
        return arrayList;
    }

    public static <T> List<Binnable<T>> trainCV(List<Binnable<T>> list, int i, int i2) {
        int size;
        if (i < 2) {
            throw new IllegalArgumentException("Number of folds must be at least 2!");
        }
        if (i > list.size()) {
            throw new IllegalArgumentException("Can't have more folds than instances!");
        }
        int size2 = list.size() / i;
        if (i2 < list.size() % i) {
            size2++;
            size = i2;
        } else {
            size = list.size() % i;
        }
        ArrayList arrayList = new ArrayList();
        int size3 = (i2 * (list.size() / i)) + size;
        for (int i3 = 0; i3 < size3; i3++) {
            arrayList.add(list.get(i3));
        }
        for (int i4 = 0; i4 < (list.size() - size3) - size2; i4++) {
            arrayList.add(list.get(size3 + size2 + i4));
        }
        return arrayList;
    }

    public static <T> List<Binnable<T>> trainCV(List<Binnable<T>> list, int i, int i2, Randomization randomization) {
        List<Binnable<T>> trainCV = trainCV(list, i, i2);
        randomization.randomize(trainCV);
        return trainCV;
    }

    public static <T> List<Binnable<T>> testCV(List<Binnable<T>> list, int i, int i2) {
        int size;
        if (i < 2) {
            throw new IllegalArgumentException("Number of folds must be at least 2!");
        }
        if (i > list.size()) {
            throw new IllegalArgumentException("Can't have more folds than instances!");
        }
        int size2 = list.size() / i;
        if (i2 < list.size() % i) {
            size2++;
            size = i2;
        } else {
            size = list.size() % i;
        }
        ArrayList arrayList = new ArrayList();
        int size3 = (i2 * (list.size() / i)) + size;
        for (int i3 = 0; i3 < size2; i3++) {
            arrayList.add(list.get(size3 + i3));
        }
        return arrayList;
    }
}
