/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers;

import adams.flow.container.WekaTrainTestSetContainer;
import gnu.trove.TIntCollection;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import java.util.NoSuchElementException;
import java.util.Random;
import weka.classifiers.AbstractSplitGenerator;
import weka.classifiers.CrossValidationHelper;
import weka.core.Instances;
import weka.core.InstancesView;

public class CrossValidationFoldGenerator
extends AbstractSplitGenerator {
    private static final long serialVersionUID = -8387205583429213079L;
    public static final String PLACEHOLDER_ORIGINAL = "@";
    public static final String PLACEHOLDER_TYPE = "$T";
    public static final String PLACEHOLDER_CURRENTFOLD = "$N";
    protected int m_NumFolds;
    protected boolean m_Stratify;
    protected int m_CurrentFold;
    protected String m_RelationName;
    protected boolean m_Randomize;
    protected Random m_RandomIndices;

    public CrossValidationFoldGenerator(Instances data, int numFolds, long seed, boolean stratify) {
        this(data, numFolds, seed, true, stratify, null);
    }

    public CrossValidationFoldGenerator(Instances data, int numFolds, long seed, boolean randomize, boolean stratify, String relName) {
        super(data, seed);
        if (data.classIndex() == -1) {
            throw new IllegalArgumentException("No class attribute set!");
        }
        this.m_NumFolds = numFolds < 2 ? data.numInstances() : numFolds;
        if (data.numInstances() < this.m_NumFolds) {
            throw new IllegalArgumentException("Cannot have less data than folds: required=" + this.m_NumFolds + ", provided=" + data.numInstances());
        }
        if (relName == null || relName.length() == 0) {
            relName = PLACEHOLDER_ORIGINAL;
        }
        this.m_Randomize = randomize;
        this.m_RelationName = relName;
        this.m_CurrentFold = 1;
        this.m_Stratify = stratify;
    }

    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public boolean getStratify() {
        return this.m_Stratify;
    }

    public String getRelationName() {
        return this.m_RelationName;
    }

    @Override
    protected boolean canRandomize() {
        return this.m_Randomize;
    }

    @Override
    protected boolean checkNext() {
        return this.m_CurrentFold <= this.m_NumFolds;
    }

    protected String createRelationName(boolean train) {
        StringBuilder result = new StringBuilder();
        String name = this.m_RelationName;
        while (name.length() > 0) {
            int len;
            if (name.startsWith(PLACEHOLDER_ORIGINAL)) {
                len = 1;
                result.append(this.m_Data.relationName());
            } else if (name.startsWith(PLACEHOLDER_TYPE)) {
                len = 2;
                if (train) {
                    result.append("train");
                } else {
                    result.append("test");
                }
            } else if (name.startsWith(PLACEHOLDER_CURRENTFOLD)) {
                len = 2;
                result.append(Integer.toString(this.m_CurrentFold));
            } else {
                len = 1;
                result.append(name.charAt(0));
            }
            name = name.substring(len);
        }
        return result.toString();
    }

    @Override
    protected TIntList originalIndices() {
        TIntArrayList result = new TIntArrayList();
        result.add(CrossValidationHelper.crossValidationIndices(this.m_Data, this.m_NumFolds, new Random(this.m_Seed), this.m_Stratify && this.m_NumFolds < this.m_Data.numInstances()));
        TIntArrayList dummy = new TIntArrayList((TIntCollection)result);
        this.randomize((TIntList)dummy, this.m_RandomIndices);
        return result;
    }

    @Override
    protected void initialize() {
        this.m_RandomIndices = new Random(this.m_Seed);
        super.initialize();
        if (this.m_Random == null) {
            this.m_Random = new Random(this.m_Seed);
        }
        if (!this.m_UseViews && this.m_Stratify && this.m_Data.classAttribute().isNominal() && this.m_NumFolds < this.m_Data.numInstances()) {
            this.m_Data.stratify(this.m_NumFolds);
        }
    }

    protected TIntList trainCV(int numFolds, int numFold) {
        int offset;
        if (numFolds < 2) {
            throw new IllegalArgumentException("Number of folds must be at least 2!");
        }
        if (numFolds > this.m_Data.numInstances()) {
            throw new IllegalArgumentException("Can't have more folds than instances!");
        }
        int numInstForFold = this.m_Data.numInstances() / numFolds;
        if (numFold < this.m_Data.numInstances() % numFolds) {
            ++numInstForFold;
            offset = numFold;
        } else {
            offset = this.m_Data.numInstances() % numFolds;
        }
        int first = numFold * (this.m_Data.numInstances() / numFolds) + offset;
        TIntList train = this.m_OriginalIndices.subList(0, first);
        train.add(this.m_OriginalIndices.subList(first + numInstForFold, first + numInstForFold + this.m_Data.numInstances() - first - numInstForFold).toArray());
        return train;
    }

    protected TIntList trainCV(int numFolds, int numFold, Random random) {
        TIntList train = this.trainCV(numFolds, numFold);
        this.randomize(train, random);
        return train;
    }

    protected TIntList testCV(int numFolds, int numFold) {
        int offset;
        if (numFolds < 2) {
            throw new IllegalArgumentException("Number of folds must be at least 2!");
        }
        if (numFolds > this.m_Data.numInstances()) {
            throw new IllegalArgumentException("Can't have more folds than instances!");
        }
        int numInstForFold = this.m_Data.numInstances() / numFolds;
        if (numFold < this.m_Data.numInstances() % numFolds) {
            ++numInstForFold;
            offset = numFold;
        } else {
            offset = this.m_Data.numInstances() % numFolds;
        }
        int first = numFold * (this.m_Data.numInstances() / numFolds) + offset;
        TIntList test = this.m_OriginalIndices.subList(first, first + numInstForFold);
        return test;
    }

    @Override
    protected WekaTrainTestSetContainer createNext() {
        InstancesView test;
        InstancesView train;
        if (this.m_CurrentFold > this.m_NumFolds) {
            throw new NoSuchElementException("No more folds available!");
        }
        int[] trainRows = this.trainCV(this.m_NumFolds, this.m_CurrentFold - 1, this.m_RandomIndices).toArray();
        int[] testRows = this.testCV(this.m_NumFolds, this.m_CurrentFold - 1).toArray();
        if (this.m_UseViews) {
            train = new InstancesView(this.m_Data, trainRows);
            test = new InstancesView(this.m_Data, testRows);
        } else {
            train = this.m_Data.trainCV(this.m_NumFolds, this.m_CurrentFold - 1, this.m_Random);
            test = this.m_Data.testCV(this.m_NumFolds, this.m_CurrentFold - 1);
        }
        train.setRelationName(this.createRelationName(true));
        test.setRelationName(this.createRelationName(false));
        WekaTrainTestSetContainer result = new WekaTrainTestSetContainer(train, test, this.m_Seed, this.m_CurrentFold, this.m_NumFolds, trainRows, testRows);
        ++this.m_CurrentFold;
        return result;
    }

    @Override
    public String toString() {
        return super.toString() + ", numFolds=" + this.m_NumFolds + ", relName=" + this.m_RelationName;
    }
}

