package weka.classifiers;

import adams.core.Utils;
import adams.data.binning.Bin;
import adams.data.binning.Binnable;
import adams.data.binning.BinnableInstances;
import adams.data.binning.algorithm.BinningAlgorithm;
import adams.data.binning.algorithm.BinningAlgorithmUser;
import adams.data.binning.algorithm.ManualBinning;
import adams.data.binning.operation.Bins;
import adams.data.binning.operation.Randomize;
import adams.data.binning.operation.Wrapping;
import adams.data.binning.postprocessing.MinBinSize;
import adams.flow.container.WekaTrainTestSetContainer;
import adams.gui.tools.wekainvestigator.tab.PartialLeastSquaresTab;
import java.util.ArrayList;
import java.util.List;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.InstancesView;

/* loaded from: input_file:weka/classifiers/BinnedNumericClassRandomSplitGenerator.class */
public class BinnedNumericClassRandomSplitGenerator extends AbstractSplitGenerator implements RandomSplitGenerator, BinningAlgorithmUser {
    private static final long serialVersionUID = -4813006743965500489L;
    protected double m_Percentage;
    protected boolean m_PreserveOrder;
    protected BinningAlgorithm m_Algorithm;
    protected boolean m_Generated;

    public BinnedNumericClassRandomSplitGenerator() {
    }

    public BinnedNumericClassRandomSplitGenerator(Instances instances, long j, double d) {
        setData(instances);
        setSeed(j);
        setPercentage(d);
        setPreserveOrder(false);
    }

    public BinnedNumericClassRandomSplitGenerator(Instances instances, double d) {
        setData(instances);
        setSeed(-1L);
        setPercentage(d);
        setPreserveOrder(true);
    }

    public BinnedNumericClassRandomSplitGenerator(Instances instances, long j, double d, boolean z) {
        setData(instances);
        setSeed(j);
        setPercentage(d);
        setPreserveOrder(z);
    }

    public String globalInfo() {
        return "Generates random splits of datasets with numeric classes. Uses a binning algorithm to obtain similar distribution in splits. Order can be preserved.";
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("percentage", "percentage", Double.valueOf(0.66d), Double.valueOf(0.0d), Double.valueOf(1.0d));
        this.m_OptionManager.add("preserve-order", "preserveOrder", false);
        this.m_OptionManager.add(PartialLeastSquaresTab.KEY_ALGORITHM, PartialLeastSquaresTab.KEY_ALGORITHM, new ManualBinning());
    }

    @Override // weka.classifiers.RandomSplitGenerator
    public void setPercentage(double d) {
        if (getOptionManager().isValid("percentage", Double.valueOf(d))) {
            this.m_Percentage = d;
            reset();
        }
    }

    @Override // weka.classifiers.RandomSplitGenerator
    public double getPercentage() {
        return this.m_Percentage;
    }

    public String percentageTipText() {
        return "The percentage to use for training (0-1).";
    }

    @Override // weka.classifiers.RandomSplitGenerator
    public void setPreserveOrder(boolean z) {
        this.m_PreserveOrder = z;
        reset();
    }

    @Override // weka.classifiers.RandomSplitGenerator
    public boolean getPreserveOrder() {
        return this.m_PreserveOrder;
    }

    public String preserveOrderTipText() {
        return "If enabled, the order in the data is preserved in the split.";
    }

    public void setAlgorithm(BinningAlgorithm binningAlgorithm) {
        this.m_Algorithm = binningAlgorithm;
        reset();
    }

    public BinningAlgorithm getAlgorithm() {
        return this.m_Algorithm;
    }

    public String algorithmTipText() {
        return "The binning algorithm to apply to the data.";
    }

    @Override // weka.classifiers.AbstractSplitGenerator, weka.classifiers.SplitGenerator
    public void setData(Instances instances) {
        super.setData(instances);
        if (this.m_Data != null) {
            if (this.m_Data.classIndex() == -1) {
                throw new IllegalArgumentException("No class attribute set!");
            }
            if (!this.m_Data.classAttribute().isNumeric()) {
                throw new IllegalArgumentException("Class attribute is not numeric!");
            }
        }
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    protected boolean canRandomize() {
        return !this.m_PreserveOrder;
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    protected void doInitializeIterator() {
        if (this.m_Data == null) {
            throw new IllegalStateException("No data available!");
        }
        this.m_Generated = false;
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    protected boolean checkNext() {
        return !this.m_Generated;
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    protected WekaTrainTestSetContainer createNext() {
        Instances instances;
        Instances instances2;
        this.m_Generated = true;
        try {
            List<Binnable<Instance>> binnableUsingClass = BinnableInstances.toBinnableUsingClass(this.m_Data);
            Wrapping.addTmpIndex(binnableUsingClass);
            if (canRandomize()) {
                Randomize.randomizeData(binnableUsingClass, this.m_Seed);
            }
            List generateBins = this.m_Algorithm.generateBins(binnableUsingClass);
            if (isLoggingEnabled()) {
                getLogger().info("Bins: " + Utils.arrayToString(Bins.binSizes(generateBins)));
            }
            MinBinSize minBinSize = new MinBinSize();
            minBinSize.setMinSize(2);
            List postProcessBins = minBinSize.postProcessBins(generateBins);
            if (isLoggingEnabled()) {
                getLogger().info("Bins after postprocessing: " + Utils.arrayToString(Bins.binSizes(postProcessBins)));
            }
            int round = (int) Math.round(this.m_Data.size() * this.m_Percentage);
            int[] iArr = new int[postProcessBins.size()];
            for (int i = 0; i < postProcessBins.size(); i++) {
                iArr[i] = Math.min(((Bin) postProcessBins.get(i)).size() - 1, (int) Math.round(((Bin) postProcessBins.get(i)).size() * this.m_Percentage));
            }
            if (isLoggingEnabled()) {
                getLogger().info("max total: " + round);
                getLogger().info("max per bin: " + Utils.arrayToString(iArr));
            }
            ArrayList arrayList = new ArrayList();
            int[] iArr2 = new int[iArr.length];
            int i2 = 0;
            while (i2 < round) {
                boolean z = false;
                for (int i3 = 0; i3 < iArr2.length; i3++) {
                    if (iArr2[i3] < iArr[i3] && i2 < round) {
                        arrayList.add(((Bin) postProcessBins.get(i3)).get().get(iArr2[i3]));
                        i2++;
                        int i4 = i3;
                        iArr2[i4] = iArr2[i4] + 1;
                        z = true;
                    }
                }
                if (!z) {
                    break;
                }
            }
            if (isLoggingEnabled()) {
                getLogger().info("train total: " + i2);
                getLogger().info("train per bin: " + Utils.arrayToString(iArr2));
            }
            ArrayList arrayList2 = new ArrayList();
            for (int i5 = 0; i5 < iArr2.length; i5++) {
                for (int i6 = iArr2[i5]; i6 < ((Bin) postProcessBins.get(i5)).size(); i6++) {
                    arrayList2.add(((Bin) postProcessBins.get(i5)).get().get(i6));
                }
            }
            int[] array = Wrapping.getTmpIndices(arrayList).toArray();
            int[] array2 = Wrapping.getTmpIndices(arrayList2).toArray();
            if (this.m_UseViews) {
                instances = new InstancesView(this.m_Data, array);
                instances2 = new InstancesView(this.m_Data, array2);
            } else {
                instances = BinnableInstances.toInstances(arrayList);
                instances2 = BinnableInstances.toInstances(arrayList2);
            }
            return new WekaTrainTestSetContainer(instances, instances2, Long.valueOf(this.m_Seed), null, null, array, array2);
        } catch (Exception e) {
            throw new IllegalStateException("Failed to create binnable Instances!", e);
        }
    }

    @Override // weka.classifiers.AbstractSplitGenerator, weka.classifiers.SplitGenerator
    public String toString() {
        return super.toString() + ", percentage=" + this.m_Percentage;
    }
}
