/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers;

import adams.core.StoppableWithFeedback;
import adams.core.Utils;
import adams.core.base.BaseRegExp;
import adams.core.base.BaseString;
import adams.data.binning.BinnableGroup;
import adams.data.binning.BinnableInstances;
import adams.data.binning.operation.Grouping;
import adams.data.binning.operation.Wrapping;
import adams.data.weka.WekaAttributeIndex;
import adams.flow.container.WekaTrainTestSetContainer;
import com.github.fracpete.javautils.struct.Struct2;
import java.util.ArrayList;
import java.util.List;
import weka.classifiers.AbstractSplitGenerator;
import weka.classifiers.SplitGenerator;
import weka.core.Instance;
import weka.core.Instances;

public class MultiLevelSplitGenerator
extends AbstractSplitGenerator
implements SplitGenerator,
StoppableWithFeedback {
    private static final long serialVersionUID = -4813006743965500489L;
    protected WekaAttributeIndex[] m_Indices;
    protected BaseRegExp[] m_RegExps;
    protected BaseString[] m_Groups;
    protected boolean m_Silent;
    protected List<WekaTrainTestSetContainer> m_Containers;
    protected boolean m_Stopped;

    public String globalInfo() {
        return "Generates splits based on groups extracted via regular expressions.\nEach attribute index/regexp/group represents a level.\nAt each level, the data gets split into groups according to the level's regexp/group, making up train and test sets.";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.removeByProperty("seed");
        this.m_OptionManager.removeByProperty("useViews");
        this.m_OptionManager.add("index", "indices", (Object)new WekaAttributeIndex[]{new WekaAttributeIndex("first")});
        this.m_OptionManager.add("regexp", "regExps", (Object)new BaseRegExp[]{new BaseRegExp(".*")});
        this.m_OptionManager.add("group", "groups", (Object)new BaseString[]{new BaseString("$0")});
        this.m_OptionManager.add("silent", "silent", (Object)false);
    }

    @Override
    protected void initialize() {
        super.initialize();
        this.m_Containers = new ArrayList<WekaTrainTestSetContainer>();
    }

    @Override
    protected void reset() {
        super.reset();
        this.m_Containers.clear();
    }

    public void setIndices(WekaAttributeIndex[] value) {
        this.m_Indices = value;
        this.m_RegExps = (BaseRegExp[])Utils.adjustArray((Object)this.m_RegExps, (int)this.m_Indices.length, (Object)new BaseRegExp(".*"));
        this.m_Groups = (BaseString[])Utils.adjustArray((Object)this.m_Groups, (int)this.m_Indices.length, (Object)new BaseString("$0"));
        this.reset();
    }

    public WekaAttributeIndex[] getIndices() {
        return this.m_Indices;
    }

    public String indicesTipText() {
        return "The attribute indices to work on.";
    }

    public void setRegExps(BaseRegExp[] value) {
        this.m_RegExps = value;
        this.m_Indices = (WekaAttributeIndex[])Utils.adjustArray((Object)this.m_Indices, (int)this.m_RegExps.length, (Object)((Object)new WekaAttributeIndex("first")));
        this.m_Groups = (BaseString[])Utils.adjustArray((Object)this.m_Groups, (int)this.m_RegExps.length, (Object)new BaseString("$0"));
        this.reset();
    }

    public BaseRegExp[] getRegExps() {
        return this.m_RegExps;
    }

    public String regExpsTipText() {
        return "The regular expressions to use for extracting the groups.";
    }

    public void setGroups(BaseString[] value) {
        this.m_Groups = value;
        this.m_Indices = (WekaAttributeIndex[])Utils.adjustArray((Object)this.m_Indices, (int)this.m_Groups.length, (Object)((Object)new WekaAttributeIndex("first")));
        this.m_RegExps = (BaseRegExp[])Utils.adjustArray((Object)this.m_RegExps, (int)this.m_Groups.length, (Object)new BaseRegExp(".*"));
        this.reset();
    }

    public BaseString[] getGroups() {
        return this.m_Groups;
    }

    public String groupsTipText() {
        return "The groups to generate.";
    }

    public void setSilent(boolean value) {
        this.m_Silent = value;
        this.reset();
    }

    public boolean getSilent() {
        return this.m_Silent;
    }

    public String silentTipText() {
        return "If enabled, error messages are suppressed.";
    }

    @Override
    protected boolean canRandomize() {
        return false;
    }

    protected List<Instances> generateGroups(Instances data, int index, String regexp, String group) {
        List groupedInst;
        List binnableInst;
        ArrayList<Instances> result = new ArrayList<Instances>();
        try {
            binnableInst = BinnableInstances.toBinnableUsingIndex(data);
            binnableInst = Wrapping.addTmpIndex(binnableInst);
            groupedInst = Grouping.groupAsList((List)binnableInst, (Grouping.GroupExtractor)new BinnableInstances.StringAttributeGroupExtractor(index, regexp, group));
        }
        catch (Exception e) {
            throw new IllegalStateException("Failed to create binnable Instances!", e);
        }
        for (BinnableGroup item : groupedInst) {
            binnableInst = Grouping.ungroup((BinnableGroup)item);
            Instances groupData = BinnableInstances.toInstances(binnableInst);
            result.add(groupData);
        }
        return result;
    }

    protected Instances subset(List<Instances> groups, int index, boolean invert) {
        Instances result;
        if (!invert) {
            result = new Instances(groups.get(index));
        } else {
            int i;
            int capacity = 0;
            for (i = 0; i < groups.size(); ++i) {
                if (i == index) continue;
                capacity += groups.get(i).numInstances();
            }
            result = new Instances(groups.get(index), capacity);
            for (i = 0; i < groups.size(); ++i) {
                if (i == index) continue;
                for (int n = 0; n < groups.get(i).numInstances(); ++n) {
                    result.add((Instance)groups.get(i).instance(n).copy());
                }
            }
        }
        return result;
    }

    protected List<Struct2<Instances, Instances>> generateSplits(Instances data, int index, String regexp, String group) {
        ArrayList<Struct2<Instances, Instances>> result = new ArrayList<Struct2<Instances, Instances>>();
        List<Instances> grouped = this.generateGroups(data, index, regexp, group);
        for (int i = 0; i < grouped.size(); ++i) {
            if (this.isStopped()) {
                return new ArrayList<Struct2<Instances, Instances>>();
            }
            result.add((Struct2<Instances, Instances>)new Struct2((Object)this.subset(grouped, i, true), (Object)this.subset(grouped, i, false)));
        }
        return result;
    }

    protected List<Struct2<Instances, Instances>> match(List<Struct2<Instances, Instances>> trainSplits, List<Struct2<Instances, Instances>> testSplits, int index) {
        int i;
        ArrayList<Struct2<Instances, Instances>> result = new ArrayList<Struct2<Instances, Instances>>();
        ArrayList<String> trainIDs = new ArrayList<String>();
        for (i = 0; i < trainSplits.size(); ++i) {
            trainIDs.add(((Instances)trainSplits.get((int)i).value2).instance(0).stringValue(index));
        }
        ArrayList<String> testIDs = new ArrayList<String>();
        for (i = 0; i < testSplits.size(); ++i) {
            testIDs.add(((Instances)testSplits.get((int)i).value2).instance(0).stringValue(index));
        }
        for (int trainIndex = 0; trainIndex < trainIDs.size(); ++trainIndex) {
            int testIndex = testIDs.indexOf(trainIDs.get(trainIndex));
            if (testIndex > -1) {
                result.add((Struct2<Instances, Instances>)new Struct2((Object)((Instances)testSplits.get((int)testIndex).value1), (Object)((Instances)trainSplits.get((int)trainIndex).value2)));
                continue;
            }
            if (this.m_Silent) continue;
            this.getLogger().warning("No matching test data found for '" + (String)trainIDs.get(trainIndex) + "' (att index #" + (index + 1) + ")!");
        }
        return result;
    }

    protected void generateContainers() {
        this.m_Containers.clear();
        ArrayList<Struct2<Instances, Instances>> collected = new ArrayList<Struct2<Instances, Instances>>();
        List<Struct2<Instances, Instances>> splits = this.generateSplits(this.m_Data, this.m_Indices[0].getIntIndex(), this.m_RegExps[0].getValue(), this.m_Groups[0].getValue());
        for (int i = 1; i < this.m_Indices.length && !this.isStopped(); ++i) {
            collected.clear();
            for (Struct2<Instances, Instances> split : splits) {
                List<Struct2<Instances, Instances>> trainSplits = this.generateSplits((Instances)split.value1, this.m_Indices[i].getIntIndex(), this.m_RegExps[i].getValue(), this.m_Groups[i].getValue());
                List<Struct2<Instances, Instances>> testSplits = this.generateSplits((Instances)split.value2, this.m_Indices[i].getIntIndex(), this.m_RegExps[i].getValue(), this.m_Groups[i].getValue());
                collected.addAll(this.match(trainSplits, testSplits, this.m_Indices[i].getIntIndex()));
            }
            splits = collected;
        }
        if (!this.isStopped()) {
            for (Struct2<Instances, Instances> split : splits) {
                this.m_Containers.add(new WekaTrainTestSetContainer((Instances)split.value1, (Instances)split.value2));
            }
        }
    }

    @Override
    protected void doInitializeIterator() {
        this.m_Stopped = false;
        if (this.m_Data == null) {
            throw new IllegalStateException("No data available!");
        }
        if (this.m_Indices.length == 0) {
            throw new IllegalStateException("At least one level of index/regexp/group required!");
        }
        for (WekaAttributeIndex index : this.m_Indices) {
            index.setData(this.m_Data);
        }
        this.generateContainers();
    }

    @Override
    protected boolean checkNext() {
        return !this.isStopped() && this.m_Containers.size() > 0;
    }

    @Override
    protected WekaTrainTestSetContainer createNext() {
        return this.m_Containers.remove(0);
    }

    public void stopExecution() {
        this.m_Stopped = true;
    }

    public boolean isStopped() {
        return this.m_Stopped;
    }

    @Override
    public String toString() {
        return super.toString() + ", indices=" + Utils.arrayToString((Object)this.m_Indices) + ", regexps=" + Utils.arrayToString((Object)this.m_RegExps) + ", groups=" + Utils.arrayToString((Object)this.m_Groups);
    }
}

