/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers;

import adams.core.base.BaseRegExp;
import adams.data.binning.BinnableInstances;
import adams.data.binning.operation.Grouping;
import adams.data.binning.operation.Wrapping;
import adams.data.splitgenerator.generic.core.Subset;
import adams.data.splitgenerator.generic.randomization.DefaultRandomization;
import adams.data.splitgenerator.generic.randomization.PassThrough;
import adams.data.splitgenerator.generic.randomization.Randomization;
import adams.data.splitgenerator.generic.randomsplit.SplitPair;
import adams.data.splitgenerator.generic.splitter.DefaultSplitter;
import adams.data.splitgenerator.generic.splitter.Splitter;
import adams.data.weka.WekaAttributeIndex;
import adams.flow.container.WekaTrainTestSetContainer;
import com.github.fracpete.javautils.struct.Struct2;
import gnu.trove.list.TIntList;
import java.util.Collection;
import java.util.List;
import weka.classifiers.AbstractSplitGenerator;
import weka.classifiers.RandomSplitGenerator;
import weka.core.Instances;
import weka.core.InstancesView;

public class GroupedRandomSplitGenerator
extends AbstractSplitGenerator
implements RandomSplitGenerator {
    private static final long serialVersionUID = -4813006743965500489L;
    protected double m_Percentage;
    protected boolean m_PreserveOrder;
    protected boolean m_Generated;
    protected WekaAttributeIndex m_Index;
    protected BaseRegExp m_RegExp;
    protected String m_Group;
    protected adams.data.splitgenerator.generic.randomsplit.RandomSplitGenerator m_Generator;

    public GroupedRandomSplitGenerator() {
    }

    public GroupedRandomSplitGenerator(Instances data, long seed, double percentage, boolean preserveOrder, WekaAttributeIndex index, BaseRegExp regExp, String group) {
        this.setData(data);
        this.setSeed(seed);
        this.setPercentage(percentage);
        this.setPreserveOrder(preserveOrder);
        this.setIndex(index);
        this.setRegExp(regExp);
        this.setGroup(group);
    }

    public String globalInfo() {
        return "Performs a percentage split, either randomized or with the order preserved.\nEnsures that groups of instances stay together, determined via a regular expression (eg '^(.*)-([0-9]+)-(.*)$') and a group replacement string (eg '$2').";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("percentage", "percentage", (Object)0.66, (Number)0.0, (Number)1.0);
        this.m_OptionManager.add("preserve-order", "preserveOrder", (Object)false);
        this.m_OptionManager.add("index", "index", (Object)new WekaAttributeIndex("first"));
        this.m_OptionManager.add("regexp", "regExp", (Object)new BaseRegExp(".*"));
        this.m_OptionManager.add("group", "group", (Object)"$0");
    }

    @Override
    public void setPercentage(double value) {
        if (this.getOptionManager().isValid("percentage", (Number)value)) {
            this.m_Percentage = value;
            this.reset();
        }
    }

    @Override
    public double getPercentage() {
        return this.m_Percentage;
    }

    public String percentageTipText() {
        return "The percentage to use for training (0-1).";
    }

    @Override
    public void setPreserveOrder(boolean value) {
        this.m_PreserveOrder = value;
        this.reset();
    }

    @Override
    public boolean getPreserveOrder() {
        return this.m_PreserveOrder;
    }

    public String preserveOrderTipText() {
        return "If enabled, the order in the data is preserved in the split.";
    }

    public void setIndex(WekaAttributeIndex value) {
        this.m_Index = value;
        this.reset();
    }

    public WekaAttributeIndex getIndex() {
        return this.m_Index;
    }

    public String indexTipText() {
        return "The percentage to use for training (0-1).";
    }

    public void setRegExp(BaseRegExp value) {
        this.m_RegExp = value;
        this.reset();
    }

    public BaseRegExp getRegExp() {
        return this.m_RegExp;
    }

    public String regExpTipText() {
        return "The regular expression for identifying the group (eg '^(.*)-([0-9]+)-(.*)$').";
    }

    public void setGroup(String value) {
        this.m_Group = value;
        this.reset();
    }

    public String getGroup() {
        return this.m_Group;
    }

    public String groupTipText() {
        return "The replacement string to use as group (eg '$2').";
    }

    @Override
    protected boolean canRandomize() {
        return !this.m_PreserveOrder;
    }

    @Override
    protected void doInitializeIterator() {
        DefaultRandomization rand;
        if (this.m_Data == null) {
            throw new IllegalStateException("No data available!");
        }
        this.m_Generator = new adams.data.splitgenerator.generic.randomsplit.RandomSplitGenerator();
        if (this.canRandomize()) {
            rand = new DefaultRandomization();
            rand.setSeed(this.m_Seed);
            this.m_Generator.setRandomization((Randomization)rand);
        } else {
            rand = new PassThrough();
            this.m_Generator.setRandomization((Randomization)rand);
        }
        DefaultSplitter splitter = new DefaultSplitter();
        splitter.setPercentage(this.m_Percentage);
        this.m_Generator.setSplitter((Splitter)splitter);
        this.m_Generated = false;
    }

    @Override
    protected boolean checkNext() {
        return !this.m_Generated;
    }

    @Override
    protected WekaTrainTestSetContainer createNext() {
        Instances testSet;
        Instances trainSet;
        List binnableGroups;
        this.m_Generated = true;
        try {
            this.m_Index.setData(this.m_Data);
            List binnableInst = BinnableInstances.toBinnableUsingIndex(this.m_Data);
            binnableInst = Wrapping.addTmpIndex(binnableInst);
            List groupedInst = Grouping.groupAsList((List)binnableInst, (Grouping.GroupExtractor)new BinnableInstances.StringAttributeGroupExtractor(this.m_Index.getIntIndex(), this.m_RegExp.getValue(), this.m_Group));
            binnableGroups = Wrapping.wrap((Collection)groupedInst, (Wrapping.BinValueExtractor)new Wrapping.IndexedBinValueExtractor());
        }
        catch (Exception e) {
            throw new IllegalStateException("Failed to create binnable Instances!", e);
        }
        SplitPair splitGroups = this.m_Generator.generate(binnableGroups);
        Struct2 subsetTrain = Subset.extractIndicesAndBinnable((Subset)splitGroups.getTrain());
        Struct2 subsetTest = Subset.extractIndicesAndBinnable((Subset)splitGroups.getTest());
        int[] trainRows = ((TIntList)subsetTrain.value1).toArray();
        int[] testRows = ((TIntList)subsetTest.value1).toArray();
        if (this.m_UseViews) {
            trainSet = new InstancesView(this.m_Data, trainRows);
            testSet = new InstancesView(this.m_Data, testRows);
        } else {
            trainSet = BinnableInstances.toInstances((List)subsetTrain.value2);
            testSet = BinnableInstances.toInstances((List)subsetTest.value2);
        }
        WekaTrainTestSetContainer result = new WekaTrainTestSetContainer(trainSet, testSet, this.m_Seed, null, null, trainRows, testRows);
        return result;
    }

    @Override
    public String toString() {
        return super.toString() + ", index=" + (Object)((Object)this.m_Index) + ", regexp=" + this.m_RegExp + ", group=" + this.m_Group;
    }
}

