package weka.classifiers.meta;

import adams.core.base.BaseKeyValuePair;
import adams.data.statistics.StatUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.rules.ZeroR;
import weka.core.AttributeStats;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.ModelOutputHandler;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.supervised.instance.Resample;

/* loaded from: input_file:weka/classifiers/meta/VotedImbalance.class */
public class VotedImbalance extends RandomizableSingleClassifierEnhancer implements ModelOutputHandler {
    private static final long serialVersionUID = -7637300579884789439L;
    protected Classifier[] m_Classifiers;
    protected int m_ActualNumBalanced;
    protected transient ThreadPoolExecutor m_ExecutorPool;
    protected int m_Completed;
    protected int m_Failed;
    protected Instances m_Data;
    protected Instances m_Header;
    protected ZeroR m_BackupModel;
    protected Classifier m_Ensemble;
    protected double m_SamplePercentage;
    protected int m_NumExecutionSlots = 1;
    protected int m_CombinationRule = 1;
    protected int m_NumBalanced = 1;
    protected BaseKeyValuePair[] m_Thresholds = new BaseKeyValuePair[0];
    protected double m_Bias = 0.0d;
    protected boolean m_NoReplacement = false;
    protected boolean m_SuppressModelOutput = false;

    public String globalInfo() {
        return "Generates an ensemble using the following approach:\n- do x times:\n  * create new dataset, resampled with specified bias\n  * build base classifier with it\nIf no classifier gets built at all, use ZeroR as backup model, built on the full dataset.\nAt prediction time, the Vote meta-classifier (using the pre-built classifiers) is used to determining the class probabilities or regression value.\nInstead of just using a fixed number of resampled models, you can also specify thresholds (= probability that the minority class does not meet) with associated number of resampled models to use.";
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tNumber of execution slots.\n\t(default: 1 - i.e. no parallelism)", "num-slots", 1, "-num-slots <num>"));
        vector.addElement(new Option("\tThe combination rule to use\n\t(default: AVG)", "combination-rule", 1, "-combination-rule " + Tag.toOptionList(Vote.TAGS_RULES)));
        vector.addElement(new Option("\tNumber of balanced datasets (= number of classifiers) to create.\n\t(default: 1)", "num-balanced", 1, "-num-balanced <num>"));
        vector.addElement(new Option("\tThresholds for number of resampled models (probability=#models); blank-separated list.\n\t(default: none)", "thresholds", 1, "-thresholds <prob=# [prob=# [...]]>"));
        vector.addElement(new Option("\tNumber of balanced datasets (= number of classifiers) to create.\n\t(default: 1)", "num-balanced", 1, "-num-balanced <num>"));
        vector.addElement(new Option("\tBias factor towards uniform class distribution.\n\t0 = distribution in input data -- 1 = uniform distribution.\n\t(default 0)", "B", 1, "-B <num>"));
        vector.addElement(new Option("\tDisables replacement of instances\n\t(default: with replacement)", "no-replacement", 0, "-no-replacement"));
        vector.addElement(new Option("\tSuppress model output\n\t(default: no)", "suppress-model-output", 0, "-suppress-model-output"));
        Enumeration listOptions = super.listOptions();
        while (listOptions.hasMoreElements()) {
            vector.addElement(listOptions.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption("num-slots", strArr);
        if (option.isEmpty()) {
            setNumExecutionSlots(1);
        } else {
            setNumExecutionSlots(Integer.parseInt(option));
        }
        String option2 = Utils.getOption("combination-rule", strArr);
        if (option2.isEmpty()) {
            setCombinationRule(new SelectedTag(1, Vote.TAGS_RULES));
        } else {
            setCombinationRule(new SelectedTag(option2, Vote.TAGS_RULES));
        }
        String option3 = Utils.getOption("num-balanced", strArr);
        if (option3.length() != 0) {
            setNumBalanced(Integer.parseInt(option3));
        } else {
            setNumBalanced(1);
        }
        String option4 = Utils.getOption("thresholds", strArr);
        if (option4.isEmpty()) {
            setThresholds("");
        } else {
            setThresholds(option4);
        }
        String option5 = Utils.getOption('B', strArr);
        if (option5.isEmpty()) {
            setBias(0.0d);
        } else {
            setBias(Double.parseDouble(option5));
        }
        setNoReplacement(Utils.getFlag("no-replacement", strArr));
        setSuppressModelOutput(Utils.getFlag("suppress-model-output", strArr));
        super.setOptions(strArr);
    }

    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-num-slots");
        vector.add("" + getNumExecutionSlots());
        vector.add("-combination-rule");
        vector.add("" + getCombinationRule());
        vector.add("-num-balanced");
        vector.add("" + getNumBalanced());
        if (!getThresholds().isEmpty()) {
            vector.add("-thresholds");
            vector.add(getThresholds());
        }
        vector.add("-B");
        vector.add("" + getBias());
        if (getNoReplacement()) {
            vector.add("-no-replacement");
        }
        if (getSuppressModelOutput()) {
            vector.add("-suppress-model-output");
        }
        vector.addAll(Arrays.asList(super.getOptions()));
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    public void setClassifier(Classifier classifier) {
        if (classifier.getCapabilities().handles(Capabilities.Capability.BINARY_CLASS)) {
            super.setClassifier(classifier);
        } else {
            System.err.println("Classifier must at least handle binary class!");
        }
    }

    public void setNumExecutionSlots(int i) {
        if (i >= 1) {
            this.m_NumExecutionSlots = i;
        } else {
            System.err.println("Number of execution slots must be >= 1");
        }
    }

    public int getNumExecutionSlots() {
        return this.m_NumExecutionSlots;
    }

    public String numExecutionSlotsTipText() {
        return "The number of execution slots (threads) to use for constructing the ensemble.";
    }

    public void setCombinationRule(SelectedTag selectedTag) {
        if (selectedTag.getTags() == Vote.TAGS_RULES) {
            this.m_CombinationRule = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getCombinationRule() {
        return new SelectedTag(this.m_CombinationRule, Vote.TAGS_RULES);
    }

    public String combinationRuleTipText() {
        return "The combination rule used.";
    }

    public void setNumBalanced(int i) {
        if (i >= 1) {
            this.m_NumBalanced = i;
        } else {
            System.err.println("Number of datasets must be >= 1, provided: " + i);
        }
    }

    public int getNumBalanced() {
        return this.m_NumBalanced;
    }

    public String numBalancedTipText() {
        return "The number of balanced datasets to generate (= #classifiers).";
    }

    public void setThresholds(String str) {
        if (str.trim().isEmpty()) {
            this.m_Thresholds = new BaseKeyValuePair[0];
            return;
        }
        try {
            ArrayList arrayList = new ArrayList();
            String[] splitOptions = Utils.splitOptions(str);
            for (int i = 0; i < splitOptions.length; i++) {
                BaseKeyValuePair baseKeyValuePair = new BaseKeyValuePair();
                if (baseKeyValuePair.isValid(splitOptions[i])) {
                    baseKeyValuePair.setValue(splitOptions[i]);
                    if (!adams.core.Utils.isDouble(baseKeyValuePair.getPairKey()) || Double.parseDouble(baseKeyValuePair.getPairKey()) < 0.0d || Double.parseDouble(baseKeyValuePair.getPairKey()) > 1.0d) {
                        System.err.println("Key #" + (i + 1) + " is not a valid probability (0-1): " + baseKeyValuePair.getPairKey());
                    } else if (!adams.core.Utils.isInteger(baseKeyValuePair.getPairValue()) || Integer.parseInt(baseKeyValuePair.getPairValue()) < 1) {
                        System.err.println("Value #" + (i + 1) + " is not a valid model amount: " + baseKeyValuePair.getPairValue());
                    } else {
                        arrayList.add(baseKeyValuePair);
                    }
                }
            }
            this.m_Thresholds = (BaseKeyValuePair[]) arrayList.toArray(new BaseKeyValuePair[arrayList.size()]);
        } catch (Exception e) {
            System.err.println("Invalid threshold specs: " + str);
        }
    }

    public String getThresholds() {
        String[] strArr = new String[this.m_Thresholds.length];
        for (int i = 0; i < this.m_Thresholds.length; i++) {
            strArr[i] = this.m_Thresholds[i].getValue();
        }
        return Utils.joinOptions(strArr);
    }

    public String thresholdsTipText() {
        return "The blank-separated list of probability thresholds for the minority class with their associated number of resampled models; e.g.: '0.5=1 0.3=3 0.1=5 0.05=10 0.01=25'.";
    }

    public void setBias(double d) {
        if (d < 0.0d || d > 1.0d) {
            System.err.println("Bias must be 0 <= x <= 1, provided: " + d);
        } else {
            this.m_Bias = d;
        }
    }

    public double getBias() {
        return this.m_Bias;
    }

    public String biasTipText() {
        return "Whether to use bias towards a uniform class. A value of 0 leaves the class distribution as-is, a value of 1 ensures the class distribution is uniform in the output data.";
    }

    public void setNoReplacement(boolean z) {
        this.m_NoReplacement = z;
    }

    public boolean getNoReplacement() {
        return this.m_NoReplacement;
    }

    public String noReplacementTipText() {
        return "Disables the replacement of instances.";
    }

    @Override // weka.core.ModelOutputHandler
    public void setSuppressModelOutput(boolean z) {
        this.m_SuppressModelOutput = z;
    }

    @Override // weka.core.ModelOutputHandler
    public boolean getSuppressModelOutput() {
        return this.m_SuppressModelOutput;
    }

    @Override // weka.core.ModelOutputHandler
    public String suppressModelOutputTipText() {
        return "If enabled, suppresses any large model output.";
    }

    protected void startExecutorPool() {
        if (this.m_ExecutorPool != null) {
            this.m_ExecutorPool.shutdownNow();
        }
        this.m_ExecutorPool = new ThreadPoolExecutor(this.m_NumExecutionSlots, this.m_NumExecutionSlots, 120L, TimeUnit.SECONDS, new LinkedBlockingQueue());
    }

    private synchronized void block(boolean z) {
        if (!z) {
            notifyAll();
        } else {
            try {
                wait();
            } catch (InterruptedException e) {
            }
        }
    }

    protected Filter getFilter(int i, int i2) throws Exception {
        Resample resample = new Resample();
        resample.setBiasToUniformClass(this.m_Bias);
        resample.setNoReplacement(this.m_NoReplacement);
        resample.setRandomSeed(i2);
        resample.setSampleSizePercent(this.m_SamplePercentage);
        return resample;
    }

    protected Instances getTrainingSet(int i, int i2) throws Exception {
        Filter filter = getFilter(i, i2);
        filter.setInputFormat(this.m_Data);
        return Filter.useFilter(this.m_Data, filter);
    }

    protected synchronized void completedClassifier(int i, boolean z) {
        if (z) {
            this.m_Completed++;
        } else {
            this.m_Failed++;
            if (this.m_Debug) {
                System.err.println("Building of classifier " + i + " failed!");
            }
        }
        if (this.m_Completed + this.m_Failed == this.m_Classifiers.length) {
            if (this.m_Failed > 0 && this.m_Debug) {
                System.err.println("Problem building classifiers - some iterations failed.");
            }
            this.m_ExecutorPool.shutdown();
            block(false);
            this.m_Data = null;
        }
    }

    protected synchronized void buildClassifiers() throws Exception {
        Random random = new Random(this.m_Seed);
        for (int i = 0; i < this.m_Classifiers.length; i++) {
            int i2 = i;
            int nextInt = random.nextInt();
            if (getDebug()) {
                System.out.print("Training classifier (" + (i + 1) + ")");
            }
            this.m_ExecutorPool.execute(() -> {
                try {
                    this.m_Classifiers[i2].buildClassifier(getTrainingSet(i2, nextInt));
                    completedClassifier(i2, true);
                } catch (Exception e) {
                    System.err.println("Classifier #" + (i2 + 1) + " failed with:");
                    e.printStackTrace();
                    completedClassifier(i2, false);
                }
            });
        }
        if (this.m_Completed + this.m_Failed < this.m_Classifiers.length) {
            block(true);
        }
    }

    protected Classifier constructEnsemble() {
        Vote vote;
        ArrayList arrayList = new ArrayList();
        for (Classifier classifier : this.m_Classifiers) {
            if (classifier != null) {
                arrayList.add(classifier);
            }
        }
        if (arrayList.size() > 1) {
            vote = new Vote();
            vote.setCombinationRule(getCombinationRule());
            vote.setClassifiers((Classifier[]) arrayList.toArray(new Classifier[arrayList.size()]));
        } else {
            vote = arrayList.size() == 1 ? (Classifier) arrayList.get(0) : this.m_BackupModel;
        }
        return vote;
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAllClasses();
        for (Capabilities.Capability capability : Capabilities.Capability.values()) {
            capabilities.enableDependency(capability);
        }
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.setOwner(this);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        this.m_Data = new Instances(instances);
        this.m_Data.deleteWithMissingClass();
        this.m_Header = new Instances(this.m_Data, 0);
        this.m_ActualNumBalanced = this.m_NumBalanced;
        if (this.m_Thresholds.length > 0) {
            double d = 1.0d;
            AttributeStats attributeStats = instances.attributeStats(instances.classIndex());
            double sum = StatUtils.sum(attributeStats.nominalCounts);
            for (int i = 0; i < attributeStats.nominalCounts.length; i++) {
                if (attributeStats.nominalCounts[i] != 0) {
                    d = Math.min(attributeStats.nominalCounts[i] / sum, d);
                }
            }
            if (getDebug()) {
                System.out.println("Minority class probability: " + d);
            }
            double d2 = 1.0d;
            for (int i2 = 0; i2 < this.m_Thresholds.length; i2++) {
                double parseDouble = Double.parseDouble(this.m_Thresholds[i2].getPairKey());
                if (parseDouble > d && parseDouble < d2) {
                    d2 = parseDouble;
                    this.m_ActualNumBalanced = Integer.parseInt(this.m_Thresholds[i2].getPairValue());
                }
            }
            if (getDebug()) {
                System.out.println("Actual # of resampled models: " + this.m_ActualNumBalanced);
            }
        }
        if (instances.numInstances() < this.m_ActualNumBalanced) {
            System.err.println("WARNING: generating more balanced datasets than rows in input dataset (" + this.m_ActualNumBalanced + " > " + instances.numInstances() + ")");
        }
        if (this.m_Classifier == null) {
            throw new Exception("A base classifier has not been specified!");
        }
        this.m_Classifiers = AbstractClassifier.makeCopies(this.m_Classifier, this.m_ActualNumBalanced);
        startExecutorPool();
        this.m_Completed = 0;
        this.m_Failed = 0;
        this.m_SamplePercentage = (100.0d / instances.numInstances()) * StatUtils.min(instances.attributeStats(instances.classIndex()).nominalCounts) * instances.classAttribute().numValues();
        if (getDebug()) {
            System.out.println("Sample percentage: " + this.m_SamplePercentage);
        }
        this.m_BackupModel = new ZeroR();
        this.m_BackupModel.buildClassifier(this.m_Data);
        buildClassifiers();
        this.m_Ensemble = constructEnsemble();
        this.m_Classifiers = null;
        this.m_BackupModel = null;
    }

    public double classifyInstance(Instance instance) throws Exception {
        return this.m_Ensemble.classifyInstance(instance);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        return this.m_Ensemble.distributionForInstance(instance);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (this.m_Ensemble == null) {
            sb.append("No model built yet!");
        } else if (this.m_SuppressModelOutput) {
            sb.append("Model suppressed\n");
        } else {
            sb.append("--> Model\n");
            sb.append(this.m_Ensemble.toString());
            sb.append("\n");
        }
        return sb.toString();
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 15194 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new VotedImbalance(), strArr);
    }
}
