/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers;

import adams.data.binning.Binnable;
import adams.data.binning.BinnableGroup;
import adams.data.binning.BinnableInstances;
import adams.data.binning.operation.Grouping;
import adams.data.binning.operation.Wrapping;
import adams.data.splitgenerator.generic.core.Subset;
import adams.data.splitgenerator.generic.crossvalidation.CrossValidationGenerator;
import adams.data.splitgenerator.generic.crossvalidation.FoldPair;
import adams.data.splitgenerator.generic.randomization.DefaultRandomization;
import adams.data.splitgenerator.generic.randomization.PassThrough;
import adams.data.splitgenerator.generic.randomization.Randomization;
import adams.data.splitgenerator.generic.stratification.Stratification;
import adams.flow.container.WekaTrainTestSetContainer;
import com.github.fracpete.javautils.struct.Struct2;
import gnu.trove.TIntCollection;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import java.util.Collection;
import java.util.List;
import java.util.NoSuchElementException;
import weka.classifiers.AbstractSplitGenerator;
import weka.classifiers.CrossValidationFoldGenerator;
import weka.classifiers.CrossValidationHelper;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.InstancesView;

public class GroupedCrossValidationFoldGeneratorUsingNumericClassValues
extends AbstractSplitGenerator
implements CrossValidationFoldGenerator {
    private static final long serialVersionUID = -6949071991599401776L;
    protected int m_NumFolds;
    protected int m_ActualNumFolds;
    protected int m_CurrentFold;
    protected String m_RelationName;
    protected boolean m_Randomize;
    protected transient CrossValidationGenerator m_Generator;
    protected transient List<Binnable<BinnableGroup<Instance>>> m_BinnableGroups;
    protected transient List<FoldPair<Binnable<BinnableGroup<Instance>>>> m_FoldPairs;

    public GroupedCrossValidationFoldGeneratorUsingNumericClassValues() {
    }

    public GroupedCrossValidationFoldGeneratorUsingNumericClassValues(Instances data, int numFolds, long seed, boolean randomize) {
        this.setData(data);
        this.setSeed(seed);
        this.setNumFolds(numFolds);
        this.setRandomize(randomize);
    }

    public String globalInfo() {
        return "Generates cross-validation fold pairs. Leave-one-out is performed when specified folds <2.\nUses the string representation of the numeric class values as grouping.";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("num-folds", "numFolds", (Object)10);
        this.m_OptionManager.add("relation-name", "relationName", (Object)"@");
        this.m_OptionManager.add("randomize", "randomize", (Object)true);
    }

    @Override
    protected void reset() {
        super.reset();
        this.m_CurrentFold = 1;
        this.m_ActualNumFolds = -1;
        this.m_FoldPairs = null;
        this.m_BinnableGroups = null;
    }

    @Override
    public void setData(Instances value) {
        super.setData(value);
        if (this.m_Data != null) {
            if (this.m_Data.classIndex() == -1) {
                throw new IllegalArgumentException("No class attribute set!");
            }
            if (!this.m_Data.classAttribute().isNumeric()) {
                throw new IllegalArgumentException("No numeric class attribute!");
            }
        }
    }

    @Override
    public void setNumFolds(int value) {
        this.m_NumFolds = value;
        this.reset();
    }

    @Override
    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public String numFoldsTipText() {
        return "The number of folds; use <2 for leave one out (LOO).";
    }

    @Override
    public int getActualNumFolds() {
        return this.m_ActualNumFolds;
    }

    @Override
    public void setRandomize(boolean value) {
        this.m_Randomize = value;
        this.reset();
    }

    @Override
    public boolean getRandomize() {
        return this.m_Randomize;
    }

    public String randomizeTipText() {
        return "If enabled, the data is randomized first.";
    }

    @Override
    public void setRelationName(String value) {
        this.m_RelationName = value;
        this.reset();
    }

    @Override
    public String getRelationName() {
        return this.m_RelationName;
    }

    public String relationNameTipText() {
        return CrossValidationHelper.relationNameTemplateTipText();
    }

    @Override
    public void setStratify(boolean value) {
    }

    @Override
    public boolean getStratify() {
        return false;
    }

    @Override
    protected boolean canRandomize() {
        return this.m_Randomize;
    }

    @Override
    protected boolean checkNext() {
        return this.m_CurrentFold <= this.m_ActualNumFolds;
    }

    @Override
    protected void doInitializeIterator() {
        DefaultRandomization rand;
        if (this.m_Data == null) {
            throw new IllegalStateException("No data provided!");
        }
        try {
            List binnableInst = BinnableInstances.toBinnableUsingIndex(this.m_Data);
            binnableInst = Wrapping.addTmpIndex(binnableInst);
            List groupedInst = Grouping.groupAsList((List)binnableInst, (Grouping.GroupExtractor)new BinnableInstances.NumericClassGroupExtractor());
            this.m_BinnableGroups = Wrapping.wrap((Collection)groupedInst, (Wrapping.BinValueExtractor)new BinnableInstances.GroupedClassValueBinValueExtractor());
        }
        catch (Exception e) {
            throw new IllegalStateException("Failed to create binnable Instances!", e);
        }
        this.m_ActualNumFolds = this.m_NumFolds < 2 ? this.m_BinnableGroups.size() : this.m_NumFolds;
        if (this.m_BinnableGroups.size() < this.m_ActualNumFolds) {
            throw new IllegalArgumentException("Cannot have less data than (grouped) folds: required=" + this.m_ActualNumFolds + ", provided=" + this.m_BinnableGroups.size());
        }
        this.m_Generator = new CrossValidationGenerator();
        this.m_Generator.setNumFolds(this.m_NumFolds);
        if (this.canRandomize()) {
            rand = new DefaultRandomization();
            rand.setSeed(this.m_Seed);
            rand.setLoggingLevel(this.m_LoggingLevel);
            this.m_Generator.setRandomization((Randomization)rand);
        } else {
            rand = new PassThrough();
            rand.setLoggingLevel(this.m_LoggingLevel);
            this.m_Generator.setRandomization((Randomization)rand);
        }
        adams.data.splitgenerator.generic.stratification.PassThrough strat = new adams.data.splitgenerator.generic.stratification.PassThrough();
        strat.setLoggingLevel(this.m_LoggingLevel);
        this.m_Generator.setStratification((Stratification)strat);
        if (this.m_RelationName == null || this.m_RelationName.isEmpty()) {
            this.m_RelationName = "@";
        }
    }

    @Override
    protected WekaTrainTestSetContainer createNext() {
        Instances testSet;
        Instances trainSet;
        if (this.m_CurrentFold > this.m_ActualNumFolds) {
            throw new NoSuchElementException("No more folds available!");
        }
        if (this.m_FoldPairs == null) {
            this.m_FoldPairs = this.m_Generator.generate(this.m_BinnableGroups);
            this.m_OriginalIndices = new TIntArrayList();
            for (FoldPair<Binnable<BinnableGroup<Instance>>> pair : this.m_FoldPairs) {
                this.m_OriginalIndices.addAll((TIntCollection)Subset.extractIndicesAndBinnable((Subset)pair.getTest()).value1);
            }
        }
        FoldPair<Binnable<BinnableGroup<Instance>>> foldPair = this.m_FoldPairs.get(this.m_CurrentFold - 1);
        Struct2 subsetTrain = Subset.extractIndicesAndBinnable((Subset)foldPair.getTrain());
        Struct2 subsetTest = Subset.extractIndicesAndBinnable((Subset)foldPair.getTest());
        int[] trainRows = ((TIntList)subsetTrain.value1).toArray();
        int[] testRows = ((TIntList)subsetTest.value1).toArray();
        if (this.m_UseViews) {
            trainSet = new InstancesView(this.m_Data, trainRows);
            testSet = new InstancesView(this.m_Data, testRows);
        } else {
            trainSet = BinnableInstances.toInstances((List)subsetTrain.value2);
            testSet = BinnableInstances.toInstances((List)subsetTest.value2);
        }
        trainSet.setRelationName(CrossValidationHelper.createRelationName(this.m_Data.relationName(), this.m_RelationName, this.m_CurrentFold, true));
        testSet.setRelationName(CrossValidationHelper.createRelationName(this.m_Data.relationName(), this.m_RelationName, this.m_CurrentFold, false));
        WekaTrainTestSetContainer result = new WekaTrainTestSetContainer(trainSet, testSet, this.m_Seed, this.m_CurrentFold, this.m_NumFolds, trainRows, testRows);
        ++this.m_CurrentFold;
        if (this.m_CurrentFold > this.m_ActualNumFolds) {
            this.m_FoldPairs = null;
        }
        return result;
    }

    @Override
    public int[] crossValidationIndices() {
        return this.m_OriginalIndices.toArray();
    }

    @Override
    public String toString() {
        return super.toString() + ", numFolds=" + this.m_NumFolds + ", randomize=" + this.m_Randomize + ", relName=" + this.m_RelationName;
    }
}

