package weka.classifiers;

import adams.core.ObjectCopyHelper;
import adams.core.Utils;
import adams.core.option.OptionUtils;
import adams.data.binning.Binnable;
import adams.data.binning.BinnableInstances;
import adams.data.binning.algorithm.BinningAlgorithm;
import adams.data.binning.algorithm.ManualBinning;
import adams.data.binning.operation.Bins;
import adams.data.statistics.StatUtils;
import adams.flow.container.WekaTrainTestSetContainer;
import adams.gui.tools.wekainvestigator.tab.PartialLeastSquaresTab;
import com.github.fracpete.javautils.Enumerate;
import com.github.fracpete.javautils.enumerate.Enumerated;
import java.util.ArrayList;
import java.util.List;
import weka.core.AttributeStats;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:weka/classifiers/BestBinnedNumericClassRandomSplitGenerator.class */
public class BestBinnedNumericClassRandomSplitGenerator extends AbstractSplitGenerator implements RandomSplitGenerator {
    private static final long serialVersionUID = -3836027382933579890L;
    protected double m_Percentage;
    protected boolean m_PreserveOrder;
    protected BinningAlgorithm[] m_Algorithms;
    protected int m_NumEvaluationBins;
    protected boolean m_Generated;
    protected ManualBinning m_Manual;

    public String globalInfo() {
        return "Picks the best binning algorithm from the provided ones.\nIn order to do this, the class distributions from generated train and test splits are compared against the overall class distribution. For comparison, the class values are binned using the specified number of bins, using a fixed max/min. How well the class distributions align is determined by computing the correlation coefficient (CC). The binning algorithm with the highest sum of CCs for train and test is then picked.";
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("percentage", "percentage", Double.valueOf(0.66d), Double.valueOf(0.0d), Double.valueOf(1.0d));
        this.m_OptionManager.add("preserve-order", "preserveOrder", false);
        this.m_OptionManager.add(PartialLeastSquaresTab.KEY_ALGORITHM, "algorithms", new BinningAlgorithm[0]);
        this.m_OptionManager.add("num-evaluation-bins", "numEvaluationBins", 20, 1, (Number) null);
    }

    @Override // weka.classifiers.RandomSplitGenerator
    public void setPercentage(double d) {
        if (getOptionManager().isValid("percentage", Double.valueOf(d))) {
            this.m_Percentage = d;
            reset();
        }
    }

    @Override // weka.classifiers.RandomSplitGenerator
    public double getPercentage() {
        return this.m_Percentage;
    }

    public String percentageTipText() {
        return "The percentage to use for training (0-1).";
    }

    @Override // weka.classifiers.RandomSplitGenerator
    public void setPreserveOrder(boolean z) {
        this.m_PreserveOrder = z;
        reset();
    }

    @Override // weka.classifiers.RandomSplitGenerator
    public boolean getPreserveOrder() {
        return this.m_PreserveOrder;
    }

    public String preserveOrderTipText() {
        return "If enabled, the order in the data is preserved in the split.";
    }

    public void setAlgorithms(BinningAlgorithm[] binningAlgorithmArr) {
        this.m_Algorithms = binningAlgorithmArr;
        reset();
    }

    public BinningAlgorithm[] getAlgorithms() {
        return this.m_Algorithms;
    }

    public String algorithmsTipText() {
        return "The binning algorithms to pick the best one from.";
    }

    public void setNumEvaluationBins(int i) {
        if (getOptionManager().isValid("numEvaluationBins", Integer.valueOf(i))) {
            this.m_NumEvaluationBins = i;
            reset();
        }
    }

    public int getNumEvaluationBins() {
        return this.m_NumEvaluationBins;
    }

    public String numEvaluationBinsTipText() {
        return "The number of bins for determining the class distribution during evaluation.";
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    protected boolean canRandomize() {
        return !this.m_PreserveOrder;
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    protected void doInitializeIterator() {
        if (this.m_Data == null) {
            throw new IllegalStateException("No data available!");
        }
        if (this.m_Algorithms.length == 0) {
            throw new IllegalStateException("No binning algorithms specified!");
        }
        this.m_Generated = false;
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    protected boolean checkNext() {
        return !this.m_Generated;
    }

    protected double[] calcDistribution(List<Binnable<Instance>> list, double d) {
        return StatUtils.normalizeRange(Bins.summarizeBinnableValues(this.m_Manual.generateBins(list), Bins.SummaryType.MEAN, d), 0.0d, 1.0d);
    }

    @Override // weka.classifiers.AbstractSplitGenerator
    protected WekaTrainTestSetContainer createNext() {
        this.m_Generated = true;
        AttributeStats attributeStats = this.m_Data.attributeStats(this.m_Data.classIndex());
        double d = attributeStats.numericStats.min;
        double d2 = attributeStats.numericStats.max;
        this.m_Manual = new ManualBinning();
        this.m_Manual.setNumBins(this.m_NumEvaluationBins);
        this.m_Manual.setUseFixedMinMax(true);
        this.m_Manual.setManualMin(d);
        this.m_Manual.setManualMax(d2);
        try {
            double[] calcDistribution = calcDistribution(BinnableInstances.toBinnableUsingClass(this.m_Data), d);
            if (isLoggingEnabled()) {
                getLogger().info("Total distribution: " + Utils.arrayToString(calcDistribution));
            }
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (Enumerated enumerated : Enumerate.enumerate(this.m_Algorithms)) {
                BinnedNumericClassRandomSplitGenerator binnedNumericClassRandomSplitGenerator = new BinnedNumericClassRandomSplitGenerator();
                binnedNumericClassRandomSplitGenerator.setAlgorithm((BinningAlgorithm) ObjectCopyHelper.copyObject(enumerated.value));
                binnedNumericClassRandomSplitGenerator.setPercentage(this.m_Percentage);
                binnedNumericClassRandomSplitGenerator.setPreserveOrder(this.m_PreserveOrder);
                binnedNumericClassRandomSplitGenerator.setData(this.m_Data);
                WekaTrainTestSetContainer next = binnedNumericClassRandomSplitGenerator.mo155next();
                try {
                    double[] calcDistribution2 = calcDistribution(BinnableInstances.toBinnableUsingClass((Instances) next.getValue("Train", Instances.class)), d);
                    arrayList.add(calcDistribution2);
                    if (isLoggingEnabled()) {
                        getLogger().info("train distribution #" + enumerated.index + ": " + Utils.arrayToString(calcDistribution2));
                    }
                    try {
                        double[] calcDistribution3 = calcDistribution(BinnableInstances.toBinnableUsingClass((Instances) next.getValue(WekaTrainTestSetContainer.VALUE_TEST, Instances.class)), d);
                        arrayList2.add(calcDistribution3);
                        if (isLoggingEnabled()) {
                            getLogger().info("test distribution #" + enumerated.index + ": " + Utils.arrayToString(calcDistribution3));
                        }
                    } catch (Exception e) {
                        throw new IllegalStateException("Failed to create binnable Instances (test #" + enumerated.index + ")!", e);
                    }
                } catch (Exception e2) {
                    throw new IllegalStateException("Failed to create binnable Instances (train #" + enumerated.index + ")!", e2);
                }
            }
            double[] dArr = new double[arrayList.size()];
            double[] dArr2 = new double[arrayList2.size()];
            double[] dArr3 = new double[arrayList.size()];
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = StatUtils.correlationCoefficient(calcDistribution, (double[]) arrayList.get(i));
                dArr2[i] = StatUtils.correlationCoefficient(calcDistribution, (double[]) arrayList2.get(i));
                dArr3[i] = dArr[i] + dArr2[i];
            }
            if (isLoggingEnabled()) {
                getLogger().info("CC train: " + Utils.arrayToString(dArr));
                getLogger().info("CC test: " + Utils.arrayToString(dArr2));
                getLogger().info("CC sum: " + Utils.arrayToString(dArr3));
            }
            int maxIndex = StatUtils.maxIndex(dArr3);
            if (isLoggingEnabled()) {
                getLogger().info("Best: #" + maxIndex + ", " + OptionUtils.getCommandLine(this.m_Algorithms[maxIndex]));
            }
            BinnedNumericClassRandomSplitGenerator binnedNumericClassRandomSplitGenerator2 = new BinnedNumericClassRandomSplitGenerator();
            binnedNumericClassRandomSplitGenerator2.setAlgorithm((BinningAlgorithm) ObjectCopyHelper.copyObject(this.m_Algorithms[maxIndex]));
            binnedNumericClassRandomSplitGenerator2.setPercentage(this.m_Percentage);
            binnedNumericClassRandomSplitGenerator2.setPreserveOrder(this.m_PreserveOrder);
            binnedNumericClassRandomSplitGenerator2.setData(this.m_Data);
            return binnedNumericClassRandomSplitGenerator2.mo155next();
        } catch (Exception e3) {
            throw new IllegalStateException("Failed to create binnable Instances!", e3);
        }
    }
}
