/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers;

import adams.core.ObjectCopyHelper;
import adams.core.Utils;
import adams.core.option.OptionUtils;
import adams.data.binning.Binnable;
import adams.data.binning.BinnableInstances;
import adams.data.binning.algorithm.BinningAlgorithm;
import adams.data.binning.algorithm.ManualBinning;
import adams.data.binning.operation.Bins;
import adams.data.statistics.StatUtils;
import adams.flow.container.WekaTrainTestSetContainer;
import com.github.fracpete.javautils.Enumerate;
import com.github.fracpete.javautils.enumerate.Enumerated;
import java.util.ArrayList;
import java.util.List;
import weka.classifiers.AbstractSplitGenerator;
import weka.classifiers.BinnedNumericClassRandomSplitGenerator;
import weka.classifiers.RandomSplitGenerator;
import weka.core.AttributeStats;
import weka.core.Instance;
import weka.core.Instances;

public class BestBinnedNumericClassRandomSplitGenerator
extends AbstractSplitGenerator
implements RandomSplitGenerator {
    private static final long serialVersionUID = -3836027382933579890L;
    protected double m_Percentage;
    protected boolean m_PreserveOrder;
    protected BinningAlgorithm[] m_Algorithms;
    protected int m_NumEvaluationBins;
    protected boolean m_Generated;
    protected ManualBinning m_Manual;

    public String globalInfo() {
        return "Picks the best binning algorithm from the provided ones.\nIn order to do this, the class distributions from generated train and test splits are compared against the overall class distribution. For comparison, the class values are binned using the specified number of bins, using a fixed max/min. How well the class distributions align is determined by computing the correlation coefficient (CC). The binning algorithm with the highest sum of CCs for train and test is then picked.";
    }

    @Override
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("percentage", "percentage", (Object)0.66, (Number)0.0, (Number)1.0);
        this.m_OptionManager.add("preserve-order", "preserveOrder", (Object)false);
        this.m_OptionManager.add("algorithm", "algorithms", (Object)new BinningAlgorithm[0]);
        this.m_OptionManager.add("num-evaluation-bins", "numEvaluationBins", (Object)20, (Number)1, null);
    }

    @Override
    public void setPercentage(double value) {
        if (this.getOptionManager().isValid("percentage", (Number)value)) {
            this.m_Percentage = value;
            this.reset();
        }
    }

    @Override
    public double getPercentage() {
        return this.m_Percentage;
    }

    public String percentageTipText() {
        return "The percentage to use for training (0-1).";
    }

    @Override
    public void setPreserveOrder(boolean value) {
        this.m_PreserveOrder = value;
        this.reset();
    }

    @Override
    public boolean getPreserveOrder() {
        return this.m_PreserveOrder;
    }

    public String preserveOrderTipText() {
        return "If enabled, the order in the data is preserved in the split.";
    }

    public void setAlgorithms(BinningAlgorithm[] value) {
        this.m_Algorithms = value;
        this.reset();
    }

    public BinningAlgorithm[] getAlgorithms() {
        return this.m_Algorithms;
    }

    public String algorithmsTipText() {
        return "The binning algorithms to pick the best one from.";
    }

    public void setNumEvaluationBins(int value) {
        if (this.getOptionManager().isValid("numEvaluationBins", (Number)value)) {
            this.m_NumEvaluationBins = value;
            this.reset();
        }
    }

    public int getNumEvaluationBins() {
        return this.m_NumEvaluationBins;
    }

    public String numEvaluationBinsTipText() {
        return "The number of bins for determining the class distribution during evaluation.";
    }

    @Override
    protected boolean canRandomize() {
        return !this.m_PreserveOrder;
    }

    @Override
    protected void doInitializeIterator() {
        if (this.m_Data == null) {
            throw new IllegalStateException("No data available!");
        }
        if (this.m_Algorithms.length == 0) {
            throw new IllegalStateException("No binning algorithms specified!");
        }
        this.m_Generated = false;
    }

    @Override
    protected boolean checkNext() {
        return !this.m_Generated;
    }

    protected double[] calcDistribution(List<Binnable<Instance>> binnable, double emptyBinValue) {
        List bins = this.m_Manual.generateBins(binnable);
        double[] result = Bins.summarizeBinnableValues((List)bins, (Bins.SummaryType)Bins.SummaryType.MEAN, (double)emptyBinValue);
        result = StatUtils.normalizeRange((double[])result, (double)0.0, (double)1.0);
        return result;
    }

    @Override
    protected WekaTrainTestSetContainer createNext() {
        BinnedNumericClassRandomSplitGenerator generator;
        double[] distOverall;
        List<Binnable<Instance>> binnableInst;
        this.m_Generated = true;
        AttributeStats stats = this.m_Data.attributeStats(this.m_Data.classIndex());
        double min = stats.numericStats.min;
        double max = stats.numericStats.max;
        this.m_Manual = new ManualBinning();
        this.m_Manual.setNumBins(this.m_NumEvaluationBins);
        this.m_Manual.setUseFixedMinMax(true);
        this.m_Manual.setManualMin(min);
        this.m_Manual.setManualMax(max);
        try {
            binnableInst = BinnableInstances.toBinnableUsingClass(this.m_Data);
            distOverall = this.calcDistribution(binnableInst, min);
            if (this.isLoggingEnabled()) {
                this.getLogger().info("Total distribution: " + Utils.arrayToString((Object)distOverall));
            }
        }
        catch (Exception e) {
            throw new IllegalStateException("Failed to create binnable Instances!", e);
        }
        ArrayList<double[]> distTrain = new ArrayList<double[]>();
        ArrayList<double[]> distTest = new ArrayList<double[]>();
        for (Enumerated algorithm : Enumerate.enumerate((Object[])this.m_Algorithms)) {
            double[] dist;
            generator = new BinnedNumericClassRandomSplitGenerator();
            generator.setAlgorithm((BinningAlgorithm)ObjectCopyHelper.copyObject((Object)((BinningAlgorithm)algorithm.value)));
            generator.setPercentage(this.m_Percentage);
            generator.setPreserveOrder(this.m_PreserveOrder);
            generator.setData(this.m_Data);
            WekaTrainTestSetContainer cont = generator.next();
            try {
                binnableInst = BinnableInstances.toBinnableUsingClass((Instances)cont.getValue("Train", Instances.class));
                dist = this.calcDistribution(binnableInst, min);
                distTrain.add(dist);
                if (this.isLoggingEnabled()) {
                    this.getLogger().info("train distribution #" + algorithm.index + ": " + Utils.arrayToString((Object)dist));
                }
            }
            catch (Exception e) {
                throw new IllegalStateException("Failed to create binnable Instances (train #" + algorithm.index + ")!", e);
            }
            try {
                binnableInst = BinnableInstances.toBinnableUsingClass((Instances)cont.getValue("Test", Instances.class));
                dist = this.calcDistribution(binnableInst, min);
                distTest.add(dist);
                if (!this.isLoggingEnabled()) continue;
                this.getLogger().info("test distribution #" + algorithm.index + ": " + Utils.arrayToString((Object)dist));
            }
            catch (Exception e) {
                throw new IllegalStateException("Failed to create binnable Instances (test #" + algorithm.index + ")!", e);
            }
        }
        double[] ccTrain = new double[distTrain.size()];
        double[] ccTest = new double[distTest.size()];
        double[] ccSum = new double[distTrain.size()];
        for (int i = 0; i < ccTrain.length; ++i) {
            ccTrain[i] = StatUtils.correlationCoefficient((double[])distOverall, (double[])((double[])distTrain.get(i)));
            ccTest[i] = StatUtils.correlationCoefficient((double[])distOverall, (double[])((double[])distTest.get(i)));
            ccSum[i] = ccTrain[i] + ccTest[i];
        }
        if (this.isLoggingEnabled()) {
            this.getLogger().info("CC train: " + Utils.arrayToString((Object)ccTrain));
            this.getLogger().info("CC test: " + Utils.arrayToString((Object)ccTest));
            this.getLogger().info("CC sum: " + Utils.arrayToString((Object)ccSum));
        }
        int best = StatUtils.maxIndex((double[])ccSum);
        if (this.isLoggingEnabled()) {
            this.getLogger().info("Best: #" + best + ", " + OptionUtils.getCommandLine((Object)this.m_Algorithms[best]));
        }
        generator = new BinnedNumericClassRandomSplitGenerator();
        generator.setAlgorithm((BinningAlgorithm)ObjectCopyHelper.copyObject((Object)this.m_Algorithms[best]));
        generator.setPercentage(this.m_Percentage);
        generator.setPreserveOrder(this.m_PreserveOrder);
        generator.setData(this.m_Data);
        return generator.next();
    }
}

