/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import adams.core.Utils;
import adams.core.base.BaseKeyValuePair;
import adams.data.statistics.StatUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.meta.Vote;
import weka.classifiers.rules.ZeroR;
import weka.core.AttributeStats;
import weka.core.Capabilities;
import weka.core.CapabilitiesHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.ModelOutputHandler;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.filters.Filter;
import weka.filters.supervised.instance.Resample;

public class VotedImbalance
extends RandomizableSingleClassifierEnhancer
implements ModelOutputHandler {
    private static final long serialVersionUID = -7637300579884789439L;
    protected Classifier[] m_Classifiers;
    protected int m_NumExecutionSlots = 1;
    protected int m_CombinationRule = 1;
    protected int m_NumBalanced = 1;
    protected BaseKeyValuePair[] m_Thresholds = new BaseKeyValuePair[0];
    protected int m_ActualNumBalanced;
    protected double m_Bias = 0.0;
    protected boolean m_NoReplacement = false;
    protected transient ThreadPoolExecutor m_ExecutorPool;
    protected int m_Completed;
    protected int m_Failed;
    protected Instances m_Data;
    protected Instances m_Header;
    protected ZeroR m_BackupModel;
    protected Classifier m_Ensemble;
    protected double m_SamplePercentage;
    protected boolean m_SuppressModelOutput = false;

    public String globalInfo() {
        return "Generates an ensemble using the following approach:\n- do x times:\n  * create new dataset, resampled with specified bias\n  * build base classifier with it\nIf no classifier gets built at all, use ZeroR as backup model, built on the full dataset.\nAt prediction time, the Vote meta-classifier (using the pre-built classifiers) is used to determining the class probabilities or regression value.\nInstead of just using a fixed number of resampled models, you can also specify thresholds (= probability that the minority class does not meet) with associated number of resampled models to use.";
    }

    public Enumeration listOptions() {
        Vector<Object> result = new Vector<Object>();
        result.addElement(new Option("\tNumber of execution slots.\n\t(default: 1 - i.e. no parallelism)", "num-slots", 1, "-num-slots <num>"));
        result.addElement(new Option("\tThe combination rule to use\n\t(default: AVG)", "combination-rule", 1, "-combination-rule " + Tag.toOptionList((Tag[])Vote.TAGS_RULES)));
        result.addElement(new Option("\tNumber of balanced datasets (= number of classifiers) to create.\n\t(default: 1)", "num-balanced", 1, "-num-balanced <num>"));
        result.addElement(new Option("\tThresholds for number of resampled models (probability=#models); blank-separated list.\n\t(default: none)", "thresholds", 1, "-thresholds <prob=# [prob=# [...]]>"));
        result.addElement(new Option("\tNumber of balanced datasets (= number of classifiers) to create.\n\t(default: 1)", "num-balanced", 1, "-num-balanced <num>"));
        result.addElement(new Option("\tBias factor towards uniform class distribution.\n\t0 = distribution in input data -- 1 = uniform distribution.\n\t(default 0)", "B", 1, "-B <num>"));
        result.addElement(new Option("\tDisables replacement of instances\n\t(default: with replacement)", "no-replacement", 0, "-no-replacement"));
        result.addElement(new Option("\tSuppress model output\n\t(default: no)", "suppress-model-output", 0, "-suppress-model-output"));
        Enumeration enm = super.listOptions();
        while (enm.hasMoreElements()) {
            result.addElement(enm.nextElement());
        }
        return result.elements();
    }

    public void setOptions(String[] options) throws Exception {
        String tmpStr = weka.core.Utils.getOption((String)"num-slots", (String[])options);
        if (!tmpStr.isEmpty()) {
            this.setNumExecutionSlots(Integer.parseInt(tmpStr));
        } else {
            this.setNumExecutionSlots(1);
        }
        tmpStr = weka.core.Utils.getOption((String)"combination-rule", (String[])options);
        if (!tmpStr.isEmpty()) {
            this.setCombinationRule(new SelectedTag(tmpStr, Vote.TAGS_RULES));
        } else {
            this.setCombinationRule(new SelectedTag(1, Vote.TAGS_RULES));
        }
        tmpStr = weka.core.Utils.getOption((String)"num-balanced", (String[])options);
        if (tmpStr.length() != 0) {
            this.setNumBalanced(Integer.parseInt(tmpStr));
        } else {
            this.setNumBalanced(1);
        }
        tmpStr = weka.core.Utils.getOption((String)"thresholds", (String[])options);
        if (!tmpStr.isEmpty()) {
            this.setThresholds(tmpStr);
        } else {
            this.setThresholds("");
        }
        tmpStr = weka.core.Utils.getOption((char)'B', (String[])options);
        if (!tmpStr.isEmpty()) {
            this.setBias(Double.parseDouble(tmpStr));
        } else {
            this.setBias(0.0);
        }
        this.setNoReplacement(weka.core.Utils.getFlag((String)"no-replacement", (String[])options));
        this.setSuppressModelOutput(weka.core.Utils.getFlag((String)"suppress-model-output", (String[])options));
        super.setOptions(options);
    }

    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        result.add("-num-slots");
        result.add("" + this.getNumExecutionSlots());
        result.add("-combination-rule");
        result.add("" + this.getCombinationRule());
        result.add("-num-balanced");
        result.add("" + this.getNumBalanced());
        if (!this.getThresholds().isEmpty()) {
            result.add("-thresholds");
            result.add(this.getThresholds());
        }
        result.add("-B");
        result.add("" + this.getBias());
        if (this.getNoReplacement()) {
            result.add("-no-replacement");
        }
        if (this.getSuppressModelOutput()) {
            result.add("-suppress-model-output");
        }
        result.addAll(Arrays.asList(super.getOptions()));
        return result.toArray(new String[result.size()]);
    }

    public void setClassifier(Classifier newClassifier) {
        Capabilities cap = newClassifier.getCapabilities();
        if (cap.handles(Capabilities.Capability.BINARY_CLASS)) {
            super.setClassifier(newClassifier);
        } else {
            System.err.println("Classifier must at least handle binary class!");
        }
    }

    public void setNumExecutionSlots(int value) {
        if (value >= 1) {
            this.m_NumExecutionSlots = value;
        } else {
            System.err.println("Number of execution slots must be >= 1");
        }
    }

    public int getNumExecutionSlots() {
        return this.m_NumExecutionSlots;
    }

    public String numExecutionSlotsTipText() {
        return "The number of execution slots (threads) to use for constructing the ensemble.";
    }

    public void setCombinationRule(SelectedTag value) {
        if (value.getTags() == Vote.TAGS_RULES) {
            this.m_CombinationRule = value.getSelectedTag().getID();
        }
    }

    public SelectedTag getCombinationRule() {
        return new SelectedTag(this.m_CombinationRule, Vote.TAGS_RULES);
    }

    public String combinationRuleTipText() {
        return "The combination rule used.";
    }

    public void setNumBalanced(int value) {
        if (value >= 1) {
            this.m_NumBalanced = value;
        } else {
            System.err.println("Number of datasets must be >= 1, provided: " + value);
        }
    }

    public int getNumBalanced() {
        return this.m_NumBalanced;
    }

    public String numBalancedTipText() {
        return "The number of balanced datasets to generate (= #classifiers).";
    }

    public void setThresholds(String value) {
        if (value.trim().isEmpty()) {
            this.m_Thresholds = new BaseKeyValuePair[0];
            return;
        }
        try {
            ArrayList<BaseKeyValuePair> pairs = new ArrayList<BaseKeyValuePair>();
            String[] parts = weka.core.Utils.splitOptions((String)value);
            for (int i = 0; i < parts.length; ++i) {
                BaseKeyValuePair pair = new BaseKeyValuePair();
                if (!pair.isValid(parts[i])) continue;
                pair.setValue(parts[i]);
                if (!Utils.isDouble((String)pair.getPairKey()) || Double.parseDouble(pair.getPairKey()) < 0.0 || Double.parseDouble(pair.getPairKey()) > 1.0) {
                    System.err.println("Key #" + (i + 1) + " is not a valid probability (0-1): " + pair.getPairKey());
                    continue;
                }
                if (!Utils.isInteger((String)pair.getPairValue()) || Integer.parseInt(pair.getPairValue()) < 1) {
                    System.err.println("Value #" + (i + 1) + " is not a valid model amount: " + pair.getPairValue());
                    continue;
                }
                pairs.add(pair);
            }
            this.m_Thresholds = pairs.toArray(new BaseKeyValuePair[pairs.size()]);
        }
        catch (Exception e) {
            System.err.println("Invalid threshold specs: " + value);
        }
    }

    public String getThresholds() {
        String[] result = new String[this.m_Thresholds.length];
        for (int i = 0; i < this.m_Thresholds.length; ++i) {
            result[i] = this.m_Thresholds[i].getValue();
        }
        return weka.core.Utils.joinOptions((String[])result);
    }

    public String thresholdsTipText() {
        return "The blank-separated list of probability thresholds for the minority class with their associated number of resampled models; e.g.: '0.5=1 0.3=3 0.1=5 0.05=10 0.01=25'.";
    }

    public void setBias(double value) {
        if (value >= 0.0 && value <= 1.0) {
            this.m_Bias = value;
        } else {
            System.err.println("Bias must be 0 <= x <= 1, provided: " + value);
        }
    }

    public double getBias() {
        return this.m_Bias;
    }

    public String biasTipText() {
        return "Whether to use bias towards a uniform class. A value of 0 leaves the class distribution as-is, a value of 1 ensures the class distribution is uniform in the output data.";
    }

    public void setNoReplacement(boolean value) {
        this.m_NoReplacement = value;
    }

    public boolean getNoReplacement() {
        return this.m_NoReplacement;
    }

    public String noReplacementTipText() {
        return "Disables the replacement of instances.";
    }

    @Override
    public void setSuppressModelOutput(boolean value) {
        this.m_SuppressModelOutput = value;
    }

    @Override
    public boolean getSuppressModelOutput() {
        return this.m_SuppressModelOutput;
    }

    @Override
    public String suppressModelOutputTipText() {
        return "If enabled, suppresses any large model output.";
    }

    protected void startExecutorPool() {
        if (this.m_ExecutorPool != null) {
            this.m_ExecutorPool.shutdownNow();
        }
        this.m_ExecutorPool = new ThreadPoolExecutor(this.m_NumExecutionSlots, this.m_NumExecutionSlots, 120L, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>());
    }

    private synchronized void block(boolean wait) {
        if (wait) {
            try {
                this.wait();
            }
            catch (InterruptedException interruptedException) {}
        } else {
            this.notifyAll();
        }
    }

    protected Filter getFilter(int index, int seed) throws Exception {
        Resample resample = new Resample();
        resample.setBiasToUniformClass(this.m_Bias);
        resample.setNoReplacement(this.m_NoReplacement);
        resample.setRandomSeed(seed);
        resample.setSampleSizePercent(this.m_SamplePercentage);
        return resample;
    }

    protected Instances getTrainingSet(int index, int seed) throws Exception {
        Filter filter = this.getFilter(index, seed);
        filter.setInputFormat(this.m_Data);
        Instances result = Filter.useFilter((Instances)this.m_Data, (Filter)filter);
        return result;
    }

    protected synchronized void completedClassifier(int index, boolean success) {
        if (!success) {
            ++this.m_Failed;
            if (this.m_Debug) {
                System.err.println("Building of classifier " + index + " failed!");
            }
        } else {
            ++this.m_Completed;
        }
        if (this.m_Completed + this.m_Failed == this.m_Classifiers.length) {
            if (this.m_Failed > 0 && this.m_Debug) {
                System.err.println("Problem building classifiers - some iterations failed.");
            }
            this.m_ExecutorPool.shutdown();
            this.block(false);
            this.m_Data = null;
        }
    }

    protected synchronized void buildClassifiers() throws Exception {
        Random rand = new Random(this.m_Seed);
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            int index = i;
            int seed = rand.nextInt();
            if (this.getDebug()) {
                System.out.print("Training classifier (" + (i + 1) + ")");
            }
            this.m_ExecutorPool.execute(() -> {
                try {
                    Instances train = this.getTrainingSet(index, seed);
                    this.m_Classifiers[index].buildClassifier(train);
                    this.completedClassifier(index, true);
                }
                catch (Exception ex) {
                    System.err.println("Classifier #" + (index + 1) + " failed with:");
                    ex.printStackTrace();
                    this.completedClassifier(index, false);
                }
            });
        }
        if (this.m_Completed + this.m_Failed < this.m_Classifiers.length) {
            this.block(true);
        }
    }

    protected Classifier constructEnsemble() {
        Object result;
        ArrayList<Classifier> classifiers = new ArrayList<Classifier>();
        for (Classifier cls : this.m_Classifiers) {
            if (cls == null) continue;
            classifiers.add(cls);
        }
        if (classifiers.size() > 1) {
            result = new Vote();
            result.setCombinationRule(this.getCombinationRule());
            result.setClassifiers(classifiers.toArray(new Classifier[classifiers.size()]));
        } else {
            result = classifiers.size() == 1 ? (Classifier)classifiers.get(0) : this.m_BackupModel;
        }
        return result;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAllClasses();
        for (Capabilities.Capability cap : Capabilities.Capability.values()) {
            result.enableDependency(cap);
        }
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.setOwner((CapabilitiesHandler)this);
        return result;
    }

    public void buildClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        this.m_Data = new Instances(data);
        this.m_Data.deleteWithMissingClass();
        this.m_Header = new Instances(this.m_Data, 0);
        this.m_ActualNumBalanced = this.m_NumBalanced;
        if (this.m_Thresholds.length > 0) {
            int i;
            double minorityClass = 1.0;
            AttributeStats stats = data.attributeStats(data.classIndex());
            double total = StatUtils.sum((int[])stats.nominalCounts);
            for (i = 0; i < stats.nominalCounts.length; ++i) {
                if (stats.nominalCounts[i] == 0) continue;
                minorityClass = Math.min((double)stats.nominalCounts[i] / total, minorityClass);
            }
            if (this.getDebug()) {
                System.out.println("Minority class probability: " + minorityClass);
            }
            double lastThreshold = 1.0;
            for (i = 0; i < this.m_Thresholds.length; ++i) {
                double threshold = Double.parseDouble(this.m_Thresholds[i].getPairKey());
                if (!(threshold > minorityClass) || !(threshold < lastThreshold)) continue;
                lastThreshold = threshold;
                this.m_ActualNumBalanced = Integer.parseInt(this.m_Thresholds[i].getPairValue());
            }
            if (this.getDebug()) {
                System.out.println("Actual # of resampled models: " + this.m_ActualNumBalanced);
            }
        }
        if (data.numInstances() < this.m_ActualNumBalanced) {
            System.err.println("WARNING: generating more balanced datasets than rows in input dataset (" + this.m_ActualNumBalanced + " > " + data.numInstances() + ")");
        }
        if (this.m_Classifier == null) {
            throw new Exception("A base classifier has not been specified!");
        }
        this.m_Classifiers = AbstractClassifier.makeCopies((Classifier)this.m_Classifier, (int)this.m_ActualNumBalanced);
        this.startExecutorPool();
        this.m_Completed = 0;
        this.m_Failed = 0;
        int smallest = StatUtils.min((int[])data.attributeStats((int)data.classIndex()).nominalCounts);
        this.m_SamplePercentage = 100.0 / (double)data.numInstances() * (double)(smallest * data.classAttribute().numValues());
        if (this.getDebug()) {
            System.out.println("Sample percentage: " + this.m_SamplePercentage);
        }
        this.m_BackupModel = new ZeroR();
        this.m_BackupModel.buildClassifier(this.m_Data);
        this.buildClassifiers();
        this.m_Ensemble = this.constructEnsemble();
        this.m_Classifiers = null;
        this.m_BackupModel = null;
    }

    public double classifyInstance(Instance instance) throws Exception {
        return this.m_Ensemble.classifyInstance(instance);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        return this.m_Ensemble.distributionForInstance(instance);
    }

    public String toString() {
        StringBuilder result = new StringBuilder();
        if (this.m_Ensemble == null) {
            result.append("No model built yet!");
        } else if (!this.m_SuppressModelOutput) {
            result.append("--> Model\n");
            result.append(this.m_Ensemble.toString());
            result.append("\n");
        } else {
            result.append("Model suppressed\n");
        }
        return result.toString();
    }

    public String getRevision() {
        return RevisionUtils.extract((String)"$Revision: 15194 $");
    }

    public static void main(String[] args) {
        VotedImbalance.runClassifier((Classifier)new VotedImbalance(), (String[])args);
    }
}

