package weka.classifiers;

import adams.core.Utils;
import adams.core.base.BaseRegExp;
import adams.data.binning.Binnable;
import adams.data.binning.BinnableGroup;
import adams.data.binning.BinnableInstances;
import adams.data.binning.algorithm.BinningAlgorithm;
import adams.data.binning.algorithm.BinningAlgorithmUser;
import adams.data.binning.algorithm.ManualBinning;
import adams.data.binning.operation.Bins;
import adams.data.binning.operation.Grouping;
import adams.data.binning.operation.Stratify;
import adams.data.binning.operation.Wrapping;
import adams.data.binning.postprocessing.MinBinSize;
import adams.data.splitgenerator.generic.core.Subset;
import adams.data.splitgenerator.generic.crossvalidation.CrossValidationGenerator;
import adams.data.splitgenerator.generic.crossvalidation.FoldPair;
import adams.data.splitgenerator.generic.randomization.DefaultRandomization;
import adams.data.weka.WekaAttributeIndex;
import adams.flow.container.WekaTrainTestSetContainer;
import adams.gui.tools.wekainvestigator.tab.PartialLeastSquaresTab;
import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.InstancesView;
import weka.filters.supervised.instance.RemoveOutliers;
import weka.filters.unsupervised.attribute.EquiDistance;
import weka.filters.unsupervised.attribute.NominalToNumeric;

/* loaded from: input_file:weka/classifiers/GroupedBinnedNumericClassCrossValidationFoldGenerator.class */
public class GroupedBinnedNumericClassCrossValidationFoldGenerator extends AbstractSplitGenerator implements CrossValidationFoldGenerator, BinningAlgorithmUser {
    private static final long serialVersionUID = -8387205583429213079L;
    protected int m_NumFolds;
    protected boolean m_Stratify;
    protected transient int m_CurrentFold;
    protected String m_RelationName;
    protected boolean m_Randomize;
    protected WekaAttributeIndex m_Index;
    protected BaseRegExp m_RegExp;
    protected String m_Group;
    protected BinningAlgorithm m_Algorithm;
    protected transient List<FoldPair<Binnable<Instance>>> m_FoldPairs;

    public GroupedBinnedNumericClassCrossValidationFoldGenerator() {
    }

    public GroupedBinnedNumericClassCrossValidationFoldGenerator(Instances instances, int i, long j, boolean z) {
        this(instances, i, j, true, z, null);
    }

    public GroupedBinnedNumericClassCrossValidationFoldGenerator(Instances instances, int i, long j, boolean z, boolean z2, String str) {
        setData(instances);
        setSeed(j);
        setNumFolds(i);
        setRelationName(str);
        setStratify(z2);
        setRandomize(z);
    }

    public String globalInfo() {
        return "Generates cross-validation fold pairs. Uses binning algorithm to obtain similar class distributions. Groups instances according to the grouping expression.";
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add(RemoveOutliers.NUM_FOLDS, "numFolds", 10, 2, (Number) null);
        this.m_OptionManager.add("relation-name", "relationName", CrossValidationHelper.PLACEHOLDER_ORIGINAL);
        this.m_OptionManager.add("randomize", "randomize", true);
        this.m_OptionManager.add("stratify", "stratify", true);
        this.m_OptionManager.add(NominalToNumeric.INDEX, NominalToNumeric.INDEX, new WekaAttributeIndex("first"));
        this.m_OptionManager.add(EquiDistance.REGEXP, "regExp", new BaseRegExp(".*"));
        this.m_OptionManager.add("group", "group", "$0");
        this.m_OptionManager.add(PartialLeastSquaresTab.KEY_ALGORITHM, PartialLeastSquaresTab.KEY_ALGORITHM, new ManualBinning());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // weka.classifiers.AbstractSplitGenerator
    public void reset() {
        super.reset();
        this.m_CurrentFold = 1;
        this.m_FoldPairs = null;
    }

    @Override // weka.classifiers.AbstractSplitGenerator, weka.classifiers.SplitGenerator
    public void setData(Instances instances) {
        super.setData(instances);
        if (this.m_Data != null) {
            if (this.m_Data.classIndex() == -1) {
                throw new IllegalArgumentException("No class attribute set!");
            }
            if (!this.m_Data.classAttribute().isNumeric()) {
                throw new IllegalArgumentException("Class attribute is not numeric!");
            }
        }
    }

    @Override // weka.classifiers.CrossValidationFoldGenerator
    public void setNumFolds(int i) {
        this.m_NumFolds = i;
        reset();
    }

    @Override // weka.classifiers.CrossValidationFoldGenerator
    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public String numFoldsTipText() {
        return "The number of folds; use <2 for leave one out (LOO).";
    }

    @Override // weka.classifiers.CrossValidationFoldGenerator
    public int getActualNumFolds() {
        return this.m_NumFolds;
    }

    @Override // weka.classifiers.CrossValidationFoldGenerator
    public void setRandomize(boolean z) {
        this.m_Randomize = z;
        reset();
    }

    @Override // weka.classifiers.CrossValidationFoldGenerator
    public boolean getRandomize() {
        return this.m_Randomize;
    }

    public String randomizeTipText() {
        return "If enabled, the data is randomized first.";
    }

    @Override // weka.classifiers.CrossValidationFoldGenerator
    public void setStratify(boolean z) {
        this.m_Stratify = z;
        reset();
    }

    @Override // weka.classifiers.CrossValidationFoldGenerator
    public boolean getStratify() {
        return this.m_Stratify;
    }

    public String stratifyTipText() {
        return "If enabled, the folds get stratified in case of a nominal class attribute.";
    }

    @Override // weka.classifiers.CrossValidationFoldGenerator
    public void setRelationName(String str) {
        this.m_RelationName = str;
        reset();
    }

    @Override // weka.classifiers.CrossValidationFoldGenerator
    public String getRelationName() {
        return this.m_RelationName;
    }

    public String relationNameTipText() {
        return CrossValidationHelper.relationNameTemplateTipText();
    }

    public void setIndex(WekaAttributeIndex wekaAttributeIndex) {
        this.m_Index = wekaAttributeIndex;
        reset();
    }

    public WekaAttributeIndex getIndex() {
        return this.m_Index;
    }

    public String indexTipText() {
        return "The percentage to use for training (0-1).";
    }

    public void setRegExp(BaseRegExp baseRegExp) {
        this.m_RegExp = baseRegExp;
        reset();
    }

    public BaseRegExp getRegExp() {
        return this.m_RegExp;
    }

    public String regExpTipText() {
        return "The regular expression for identifying the group (eg '^(.*)-([0-9]+)-(.*)$').";
    }

    public void setGroup(String str) {
        this.m_Group = str;
        reset();
    }

    public String getGroup() {
        return this.m_Group;
    }

    public String groupTipText() {
        return "The replacement string to use as group (eg '$2').";
    }

    public void setAlgorithm(BinningAlgorithm binningAlgorithm) {
        this.m_Algorithm = binningAlgorithm;
        reset();
    }

    public BinningAlgorithm getAlgorithm() {
        return this.m_Algorithm;
    }

    public String algorithmTipText() {
        return "The binning algorithm to apply to the data.";
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    protected boolean canRandomize() {
        return this.m_Randomize;
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    protected boolean checkNext() {
        return this.m_CurrentFold <= this.m_NumFolds;
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    protected void doInitializeIterator() {
        if (this.m_Data == null) {
            throw new IllegalStateException("No data provided!");
        }
        if (this.m_Data.numInstances() < this.m_NumFolds) {
            throw new IllegalArgumentException("Cannot have less data than folds: required=" + this.m_NumFolds + ", provided=" + this.m_Data.numInstances());
        }
        if (this.m_RelationName == null || this.m_RelationName.isEmpty()) {
            this.m_RelationName = CrossValidationHelper.PLACEHOLDER_ORIGINAL;
        }
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    protected WekaTrainTestSetContainer createNext() {
        Instances instances;
        Instances instances2;
        if (this.m_CurrentFold > this.m_NumFolds) {
            throw new NoSuchElementException("No more folds available!");
        }
        if (this.m_FoldPairs == null) {
            this.m_Index.setData(this.m_Data);
            try {
                List wrap = Wrapping.wrap(Grouping.groupAsList(Wrapping.addTmpIndex(BinnableInstances.toBinnableUsingIndex(this.m_Data)), new BinnableInstances.StringAttributeGroupExtractor(this.m_Index.getIntIndex(), this.m_RegExp.getValue(), this.m_Group)), new Wrapping.IndexedBinValueExtractor());
                DefaultRandomization defaultRandomization = new DefaultRandomization();
                defaultRandomization.setSeed(this.m_Seed);
                if (canRandomize()) {
                    wrap = defaultRandomization.randomize(wrap);
                }
                List generateBins = this.m_Algorithm.generateBins(wrap);
                if (isLoggingEnabled()) {
                    getLogger().info("Bins: " + Utils.arrayToString(Bins.binSizes(generateBins)));
                }
                MinBinSize minBinSize = new MinBinSize();
                minBinSize.setMinSize(2);
                List postProcessBins = minBinSize.postProcessBins(generateBins);
                if (isLoggingEnabled()) {
                    getLogger().info("Bins after post-processing: " + Utils.arrayToString(Bins.binSizes(postProcessBins)));
                }
                List flatten = Bins.flatten(Bins.useBinIndex(postProcessBins));
                if (getStratify()) {
                    flatten = Stratify.stratify(flatten, this.m_NumFolds);
                }
                this.m_FoldPairs = new ArrayList();
                for (int i = 0; i < this.m_NumFolds; i++) {
                    List trainCV = CrossValidationGenerator.trainCV(flatten, this.m_NumFolds, i, defaultRandomization);
                    ArrayList arrayList = new ArrayList();
                    Iterator it = trainCV.iterator();
                    while (it.hasNext()) {
                        arrayList.addAll(Grouping.ungroup((BinnableGroup) ((Binnable) it.next()).getPayload()));
                    }
                    Subset subset = new Subset(arrayList, Wrapping.getTmpIndices(arrayList));
                    List testCV = CrossValidationGenerator.testCV(flatten, this.m_NumFolds, i);
                    ArrayList arrayList2 = new ArrayList();
                    Iterator it2 = testCV.iterator();
                    while (it2.hasNext()) {
                        arrayList2.addAll(Grouping.ungroup((BinnableGroup) ((Binnable) it2.next()).getPayload()));
                    }
                    this.m_FoldPairs.add(new FoldPair<>(i, subset, new Subset(arrayList2, Wrapping.getTmpIndices(arrayList2))));
                }
                this.m_OriginalIndices = new TIntArrayList();
                Iterator<FoldPair<Binnable<Instance>>> it3 = this.m_FoldPairs.iterator();
                while (it3.hasNext()) {
                    this.m_OriginalIndices.addAll(it3.next().getTest().getOriginalIndices());
                }
            } catch (Exception e) {
                throw new IllegalStateException("Failed to create binnable groups!", e);
            }
        }
        FoldPair<Binnable<Instance>> foldPair = this.m_FoldPairs.get(this.m_CurrentFold - 1);
        int[] array = foldPair.getTrain().getOriginalIndices().toArray();
        int[] array2 = foldPair.getTest().getOriginalIndices().toArray();
        if (this.m_UseViews) {
            instances = new InstancesView(this.m_Data, array);
            instances2 = new InstancesView(this.m_Data, array2);
        } else {
            instances = BinnableInstances.toInstances(foldPair.getTrain().getData());
            instances2 = BinnableInstances.toInstances(foldPair.getTest().getData());
        }
        instances.setRelationName(CrossValidationHelper.createRelationName(this.m_Data.relationName(), this.m_RelationName, this.m_CurrentFold, true));
        instances2.setRelationName(CrossValidationHelper.createRelationName(this.m_Data.relationName(), this.m_RelationName, this.m_CurrentFold, false));
        WekaTrainTestSetContainer wekaTrainTestSetContainer = new WekaTrainTestSetContainer(instances, instances2, Long.valueOf(this.m_Seed), Integer.valueOf(this.m_CurrentFold), Integer.valueOf(this.m_NumFolds), array, array2);
        this.m_CurrentFold++;
        if (this.m_CurrentFold > this.m_NumFolds) {
            this.m_FoldPairs = null;
        }
        return wekaTrainTestSetContainer;
    }

    @Override // weka.classifiers.CrossValidationFoldGenerator
    public int[] crossValidationIndices() {
        return this.m_OriginalIndices.toArray();
    }

    @Override // weka.classifiers.AbstractSplitGenerator, weka.classifiers.SplitGenerator
    public String toString() {
        return super.toString() + ", numFolds=" + this.m_NumFolds + ", stratify=" + this.m_Stratify + ", relName=" + this.m_RelationName;
    }
}
