/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer.indexedsplitsrunsgenerator;

import adams.core.MessageCollection;
import adams.core.OptionalRandomizable;
import adams.core.QuickInfoHelper;
import adams.core.option.OptionHandler;
import adams.data.binning.Binnable;
import adams.data.binning.BinnableInstances;
import adams.data.indexedsplits.IndexedSplit;
import adams.data.indexedsplits.IndexedSplits;
import adams.data.indexedsplits.IndexedSplitsRun;
import adams.data.indexedsplits.IndexedSplitsRuns;
import adams.data.indexedsplits.SplitIndices;
import adams.data.splitgenerator.generic.crossvalidation.CrossValidationGenerator;
import adams.data.splitgenerator.generic.crossvalidation.FoldPair;
import adams.data.splitgenerator.generic.randomization.DefaultRandomization;
import adams.data.splitgenerator.generic.randomization.PassThrough;
import adams.data.splitgenerator.generic.randomization.Randomization;
import adams.data.splitgenerator.generic.stratification.DefaultStratification;
import adams.data.splitgenerator.generic.stratification.Stratification;
import adams.flow.transformer.indexedsplitsrunsgenerator.AbstractInstancesIndexedSplitsRunsGenerator;
import java.util.List;
import weka.core.Instance;
import weka.core.Instances;

public class InstancesCrossValidationFoldGenerator
extends AbstractInstancesIndexedSplitsRunsGenerator
implements OptionalRandomizable {
    private static final long serialVersionUID = -845552507613381226L;
    protected int m_NumFolds;
    protected boolean m_Randomize;
    protected long m_Seed;
    protected boolean m_Stratify;

    public String globalInfo() {
        return "Split generator that generates folds for cross-validation for Instances objects.";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("num-folds", "numFolds", (Object)10);
        this.m_OptionManager.add("seed", "seed", (Object)1L);
        this.m_OptionManager.add("randomize", "randomize", (Object)true);
        this.m_OptionManager.add("stratify", "stratify", (Object)true);
    }

    public void setNumFolds(int value) {
        this.m_NumFolds = value;
        this.reset();
    }

    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public String numFoldsTipText() {
        return "The number of folds; use <2 for leave one out (LOO).";
    }

    public void setRandomize(boolean value) {
        this.m_Randomize = value;
        this.reset();
    }

    public boolean getRandomize() {
        return this.m_Randomize;
    }

    public String randomizeTipText() {
        return "If enabled, the data is randomized first.";
    }

    public void setSeed(long value) {
        this.m_Seed = value;
        this.reset();
    }

    public long getSeed() {
        return this.m_Seed;
    }

    public String seedTipText() {
        return "The seed value for the random number generator.";
    }

    public void setStratify(boolean value) {
        this.m_Stratify = value;
        this.reset();
    }

    public boolean getStratify() {
        return this.m_Stratify;
    }

    public String stratifyTipText() {
        return "If enabled, the folds get stratified in case of a nominal class attribute.";
    }

    @Override
    public String getQuickInfo() {
        String result = super.getQuickInfo();
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"numFolds", (Object)this.m_NumFolds, (String)", folds: ");
        if (this.m_Randomize) {
            result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"seed", (Object)this.m_Seed, (String)", seed: ");
        }
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"stratify", (Object)(this.m_Stratify ? "stratified" : "not stratified"), (String)", ");
        return result;
    }

    protected String check(Object data) {
        Instances inst;
        String result = super.check(data);
        if (result == null && (inst = (Instances)data).classIndex() == -1) {
            result = "No class attribute defined!";
        }
        return result;
    }

    protected IndexedSplitsRuns doGenerate(Object data, MessageCollection errors) {
        List<Binnable<Instance>> binnableInst;
        adams.data.splitgenerator.generic.stratification.PassThrough strat;
        DefaultRandomization rand;
        Instances instances = (Instances)data;
        int actualNumFolds = this.m_NumFolds < 2 ? instances.numInstances() : this.m_NumFolds;
        if (instances.numInstances() < actualNumFolds) {
            errors.add("Cannot have less data than folds: required=" + actualNumFolds + ", provided=" + instances.numInstances());
            return null;
        }
        CrossValidationGenerator generator = new CrossValidationGenerator();
        generator.setNumFolds(this.m_NumFolds);
        if (this.m_Randomize) {
            rand = new DefaultRandomization();
            rand.setSeed(this.m_Seed);
            rand.setLoggingLevel(this.m_LoggingLevel);
            generator.setRandomization((Randomization)rand);
        } else {
            rand = new PassThrough();
            rand.setLoggingLevel(this.m_LoggingLevel);
            generator.setRandomization((Randomization)rand);
        }
        if (this.m_Stratify && instances.classAttribute().isNominal() && actualNumFolds < instances.numInstances()) {
            strat = new DefaultStratification();
            strat.setLoggingLevel(this.m_LoggingLevel);
            generator.setStratification((Stratification)strat);
        } else {
            strat = new adams.data.splitgenerator.generic.stratification.PassThrough();
            strat.setLoggingLevel(this.m_LoggingLevel);
            generator.setStratification((Stratification)strat);
        }
        try {
            binnableInst = BinnableInstances.toBinnableUsingClass(instances);
        }
        catch (Exception e) {
            throw new IllegalStateException("Failed to create binnable Instances!", e);
        }
        List foldPairs = generator.generate(binnableInst);
        IndexedSplitsRuns result = new IndexedSplitsRuns();
        IndexedSplits indexedSplits = new IndexedSplits();
        IndexedSplitsRun indexedSplitsRun = new IndexedSplitsRun(0, indexedSplits);
        result.add((Object)indexedSplitsRun);
        for (int fold = 0; fold < actualNumFolds; ++fold) {
            FoldPair foldPair = (FoldPair)foldPairs.get(fold);
            int[] trainRows = foldPair.getTrain().getOriginalIndices().toArray();
            int[] testRows = foldPair.getTest().getOriginalIndices().toArray();
            IndexedSplit indexedSplit = new IndexedSplit(fold);
            indexedSplit.add(new SplitIndices("train", trainRows));
            indexedSplit.add(new SplitIndices("test", testRows));
            indexedSplits.add((Object)indexedSplit);
        }
        return result;
    }
}

