/*
 * Decompiled with CFR 0.152.
 */
package adams.data.splitgenerator.generic.crossvalidation;

import adams.core.logging.CustomLoggingLevelObject;
import adams.data.binning.Binnable;
import adams.data.binning.operation.Wrapping;
import adams.data.splitgenerator.generic.core.Subset;
import adams.data.splitgenerator.generic.crossvalidation.FoldPair;
import adams.data.splitgenerator.generic.randomization.DefaultRandomization;
import adams.data.splitgenerator.generic.randomization.Randomization;
import adams.data.splitgenerator.generic.stratification.DefaultStratification;
import adams.data.splitgenerator.generic.stratification.Stratification;
import gnu.trove.list.TIntList;
import java.util.ArrayList;
import java.util.List;

public class CrossValidationGenerator
extends CustomLoggingLevelObject {
    private static final long serialVersionUID = 6906260013695977045L;
    protected int m_NumFolds = 10;
    protected Randomization m_Randomization = new DefaultRandomization();
    protected Stratification m_Stratification = new DefaultStratification();

    public CrossValidationGenerator() {
        this.reset();
    }

    public void reset() {
        this.m_Randomization.reset();
        this.m_Stratification.reset();
    }

    public void setNumFolds(int value) {
        this.m_NumFolds = value;
        this.reset();
    }

    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public void setRandomization(Randomization value) {
        this.m_Randomization = value;
        this.reset();
    }

    public Randomization getRandomization() {
        return this.m_Randomization;
    }

    public void setStratification(Stratification value) {
        this.m_Stratification = value;
        this.reset();
    }

    public Stratification getStratification() {
        return this.m_Stratification;
    }

    public <T> List<FoldPair<Binnable<T>>> generate(List<Binnable<T>> data) {
        ArrayList<FoldPair<Binnable<T>>> result = new ArrayList<FoldPair<Binnable<T>>>();
        int folds = this.m_NumFolds < 2 ? data.size() : this.m_NumFolds;
        data = Wrapping.addTmpIndex(data);
        data = this.m_Randomization.randomize(data);
        data = this.m_Stratification.stratify(data, folds);
        for (int i = 0; i < folds; ++i) {
            List trainData = CrossValidationGenerator.trainCV(data, folds, i, this.m_Randomization);
            TIntList trainIndices = Wrapping.getTmpIndices(trainData);
            Subset train = new Subset(trainData, trainIndices);
            List testData = CrossValidationGenerator.testCV(data, folds, i);
            TIntList testIndices = Wrapping.getTmpIndices(testData);
            Subset test = new Subset(testData, testIndices);
            result.add(new FoldPair(i, train, test));
        }
        Wrapping.removeTmpIndex(data);
        return result;
    }

    public static <T> List<Binnable<T>> trainCV(List<Binnable<T>> data, int numFolds, int numFold) {
        int i;
        int offset;
        if (numFolds < 2) {
            throw new IllegalArgumentException("Number of folds must be at least 2!");
        }
        if (numFolds > data.size()) {
            throw new IllegalArgumentException("Can't have more folds than instances!");
        }
        int numInstForFold = data.size() / numFolds;
        if (numFold < data.size() % numFolds) {
            ++numInstForFold;
            offset = numFold;
        } else {
            offset = data.size() % numFolds;
        }
        ArrayList<Binnable<T>> result = new ArrayList<Binnable<T>>();
        int first = numFold * (data.size() / numFolds) + offset;
        for (i = 0; i < first; ++i) {
            result.add(data.get(i));
        }
        for (i = 0; i < data.size() - first - numInstForFold; ++i) {
            result.add(data.get(first + numInstForFold + i));
        }
        return result;
    }

    public static <T> List<Binnable<T>> trainCV(List<Binnable<T>> data, int numFolds, int numFold, Randomization random) {
        List<Binnable<T>> result = CrossValidationGenerator.trainCV(data, numFolds, numFold);
        random.randomize(result);
        return result;
    }

    public static <T> List<Binnable<T>> testCV(List<Binnable<T>> data, int numFolds, int numFold) {
        int offset;
        if (numFolds < 2) {
            throw new IllegalArgumentException("Number of folds must be at least 2!");
        }
        if (numFolds > data.size()) {
            throw new IllegalArgumentException("Can't have more folds than instances!");
        }
        int numInstForFold = data.size() / numFolds;
        if (numFold < data.size() % numFolds) {
            ++numInstForFold;
            offset = numFold;
        } else {
            offset = data.size() % numFolds;
        }
        ArrayList<Binnable<T>> result = new ArrayList<Binnable<T>>();
        int first = numFold * (data.size() / numFolds) + offset;
        for (int i = 0; i < numInstForFold; ++i) {
            result.add(data.get(first + i));
        }
        return result;
    }
}

