/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers;

import adams.core.Utils;
import adams.core.base.BaseRegExp;
import adams.data.binning.Bin;
import adams.data.binning.Binnable;
import adams.data.binning.BinnableGroup;
import adams.data.binning.BinnableInstances;
import adams.data.binning.algorithm.BinningAlgorithm;
import adams.data.binning.algorithm.BinningAlgorithmUser;
import adams.data.binning.algorithm.ManualBinning;
import adams.data.binning.operation.Bins;
import adams.data.binning.operation.Grouping;
import adams.data.binning.operation.Randomize;
import adams.data.binning.operation.Wrapping;
import adams.data.binning.postprocessing.MinBinSize;
import adams.data.weka.WekaAttributeIndex;
import adams.flow.container.WekaTrainTestSetContainer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import weka.classifiers.AbstractSplitGenerator;
import weka.classifiers.RandomSplitGenerator;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.InstancesView;

public class GroupedBinnedNumericClassRandomSplitGenerator
extends AbstractSplitGenerator
implements RandomSplitGenerator,
BinningAlgorithmUser {
    private static final long serialVersionUID = -4813006743965500489L;
    protected double m_Percentage;
    protected boolean m_PreserveOrder;
    protected WekaAttributeIndex m_Index;
    protected BaseRegExp m_RegExp;
    protected String m_Group;
    protected BinningAlgorithm m_Algorithm;
    protected boolean m_Generated;

    public GroupedBinnedNumericClassRandomSplitGenerator() {
    }

    public GroupedBinnedNumericClassRandomSplitGenerator(Instances data, long seed, double percentage) {
        this.setData(data);
        this.setSeed(seed);
        this.setPercentage(percentage);
        this.setPreserveOrder(false);
    }

    public GroupedBinnedNumericClassRandomSplitGenerator(Instances data, double percentage) {
        this.setData(data);
        this.setSeed(-1L);
        this.setPercentage(percentage);
        this.setPreserveOrder(true);
    }

    public GroupedBinnedNumericClassRandomSplitGenerator(Instances data, long seed, double percentage, boolean preserveOrder) {
        this.setData(data);
        this.setSeed(seed);
        this.setPercentage(percentage);
        this.setPreserveOrder(preserveOrder);
    }

    public String globalInfo() {
        return "Generates random splits of datasets with numeric classes. Uses a binning algorithm to obtain similar distribution in splits. Order can be preserved. Groups instances according to the grouping expression.";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("percentage", "percentage", (Object)0.66, (Number)0.0, (Number)1.0);
        this.m_OptionManager.add("preserve-order", "preserveOrder", (Object)false);
        this.m_OptionManager.add("index", "index", (Object)new WekaAttributeIndex("first"));
        this.m_OptionManager.add("regexp", "regExp", (Object)new BaseRegExp(".*"));
        this.m_OptionManager.add("group", "group", (Object)"$0");
        this.m_OptionManager.add("algorithm", "algorithm", (Object)new ManualBinning());
    }

    @Override
    public void setPercentage(double value) {
        if (this.getOptionManager().isValid("percentage", (Number)value)) {
            this.m_Percentage = value;
            this.reset();
        }
    }

    @Override
    public double getPercentage() {
        return this.m_Percentage;
    }

    public String percentageTipText() {
        return "The percentage to use for training (0-1).";
    }

    @Override
    public void setPreserveOrder(boolean value) {
        this.m_PreserveOrder = value;
        this.reset();
    }

    @Override
    public boolean getPreserveOrder() {
        return this.m_PreserveOrder;
    }

    public String preserveOrderTipText() {
        return "If enabled, the order in the data is preserved in the split.";
    }

    public void setIndex(WekaAttributeIndex value) {
        this.m_Index = value;
        this.reset();
    }

    public WekaAttributeIndex getIndex() {
        return this.m_Index;
    }

    public String indexTipText() {
        return "The percentage to use for training (0-1).";
    }

    public void setRegExp(BaseRegExp value) {
        this.m_RegExp = value;
        this.reset();
    }

    public BaseRegExp getRegExp() {
        return this.m_RegExp;
    }

    public String regExpTipText() {
        return "The regular expression for identifying the group (eg '^(.*)-([0-9]+)-(.*)$').";
    }

    public void setGroup(String value) {
        this.m_Group = value;
        this.reset();
    }

    public String getGroup() {
        return this.m_Group;
    }

    public String groupTipText() {
        return "The replacement string to use as group (eg '$2').";
    }

    public void setAlgorithm(BinningAlgorithm value) {
        this.m_Algorithm = value;
        this.reset();
    }

    public BinningAlgorithm getAlgorithm() {
        return this.m_Algorithm;
    }

    public String algorithmTipText() {
        return "The binning algorithm to apply to the data.";
    }

    @Override
    public void setData(Instances value) {
        super.setData(value);
        if (this.m_Data != null) {
            if (this.m_Data.classIndex() == -1) {
                throw new IllegalArgumentException("No class attribute set!");
            }
            if (!this.m_Data.classAttribute().isNumeric()) {
                throw new IllegalArgumentException("Class attribute is not numeric!");
            }
        }
    }

    @Override
    protected boolean canRandomize() {
        return !this.m_PreserveOrder;
    }

    @Override
    protected void doInitializeIterator() {
        if (this.m_Data == null) {
            throw new IllegalStateException("No data available!");
        }
        this.m_Generated = false;
    }

    @Override
    protected boolean checkNext() {
        return !this.m_Generated;
    }

    @Override
    protected WekaTrainTestSetContainer createNext() {
        Instances testSet;
        Instances trainSet;
        int i;
        List binnableGroups;
        this.m_Generated = true;
        this.m_Index.setData(this.m_Data);
        try {
            List binnableInst = BinnableInstances.toBinnableUsingIndex(this.m_Data);
            binnableInst = Wrapping.addTmpIndex(binnableInst);
            List groupedInst = Grouping.groupAsList((List)binnableInst, (Grouping.GroupExtractor)new BinnableInstances.StringAttributeGroupExtractor(this.m_Index.getIntIndex(), this.m_RegExp.getValue(), this.m_Group));
            binnableGroups = Wrapping.wrap((Collection)groupedInst, (Wrapping.BinValueExtractor)new Wrapping.IndexedBinValueExtractor());
        }
        catch (Exception e) {
            throw new IllegalStateException("Failed to create binnable groups!", e);
        }
        if (this.canRandomize()) {
            Randomize.randomizeData((List)binnableGroups, (long)this.m_Seed);
        }
        List binGroups = this.m_Algorithm.generateBins(binnableGroups);
        if (this.isLoggingEnabled()) {
            this.getLogger().info("Bins: " + Utils.arrayToString((Object)Bins.binSizes((List)binGroups)));
        }
        MinBinSize minBinSize = new MinBinSize();
        minBinSize.setMinSize(2);
        binGroups = minBinSize.postProcessBins(binGroups);
        if (this.isLoggingEnabled()) {
            this.getLogger().info("Bins after postprocessing: " + Utils.arrayToString((Object)Bins.binSizes((List)binGroups)));
        }
        int maxTotal = (int)Math.round((double)Bins.totalSize((List)binGroups) * this.m_Percentage);
        int[] maxPerBin = new int[binGroups.size()];
        for (i = 0; i < binGroups.size(); ++i) {
            maxPerBin[i] = Math.min(((Bin)binGroups.get(i)).size() - 1, (int)Math.round((double)((Bin)binGroups.get(i)).size() * this.m_Percentage));
        }
        if (this.isLoggingEnabled()) {
            this.getLogger().info("max total: " + maxTotal);
            this.getLogger().info("max per bin: " + Utils.arrayToString((Object)maxPerBin));
        }
        ArrayList<Binnable> groupedTrain = new ArrayList<Binnable>();
        int[] trainPerBin = new int[maxPerBin.length];
        int trainTotal = 0;
        while (trainTotal < maxTotal) {
            boolean added = false;
            for (i = 0; i < trainPerBin.length; ++i) {
                if (trainPerBin[i] >= maxPerBin[i] || trainTotal >= maxTotal) continue;
                groupedTrain.add((Binnable)((Bin)binGroups.get(i)).get().get(trainPerBin[i]));
                ++trainTotal;
                int n = i;
                trainPerBin[n] = trainPerBin[n] + 1;
                added = true;
            }
            if (added) continue;
        }
        if (this.isLoggingEnabled()) {
            this.getLogger().info("train total: " + trainTotal);
            this.getLogger().info("train per bin: " + Utils.arrayToString((Object)trainPerBin));
        }
        ArrayList<Binnable<Instance>> binnableTrain = new ArrayList<Binnable<Instance>>();
        for (Binnable grouped : groupedTrain) {
            binnableTrain.addAll(Grouping.ungroup((BinnableGroup)((BinnableGroup)grouped.getPayload())));
        }
        ArrayList<Binnable> groupedTest = new ArrayList<Binnable>();
        for (i = 0; i < trainPerBin.length; ++i) {
            for (int n = trainPerBin[i]; n < ((Bin)binGroups.get(i)).size(); ++n) {
                groupedTest.add((Binnable)((Bin)binGroups.get(i)).get().get(n));
            }
        }
        ArrayList<Binnable<Instance>> binnableTest = new ArrayList<Binnable<Instance>>();
        for (Binnable grouped : groupedTest) {
            binnableTest.addAll(Grouping.ungroup((BinnableGroup)((BinnableGroup)grouped.getPayload())));
        }
        int[] trainRows = Wrapping.getTmpIndices(binnableTrain).toArray();
        int[] testRows = Wrapping.getTmpIndices(binnableTest).toArray();
        if (this.m_UseViews) {
            trainSet = new InstancesView(this.m_Data, trainRows);
            testSet = new InstancesView(this.m_Data, testRows);
        } else {
            trainSet = BinnableInstances.toInstances(binnableTrain);
            testSet = BinnableInstances.toInstances(binnableTest);
        }
        WekaTrainTestSetContainer result = new WekaTrainTestSetContainer(trainSet, testSet, this.m_Seed, null, null, trainRows, testRows);
        return result;
    }

    @Override
    public String toString() {
        return super.toString() + ", percentage=" + this.m_Percentage;
    }
}

