package weka.classifiers.meta;

import JSci.maths.wavelet.IllegalScalingException;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.rules.ZeroR;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.MultiFilter;
import weka.filters.unsupervised.attribute.Remove;
import weka.filters.unsupervised.instance.RemoveInstancesWithMissingValue;

/* loaded from: input_file:weka/classifiers/meta/SubsetEnsemble.class */
public class SubsetEnsemble extends RandomizableSingleClassifierEnhancer {
    private static final long serialVersionUID = -7637300579884789439L;
    protected Classifier[] m_Classifiers;
    protected int m_NumExecutionSlots = 1;
    protected int m_CombinationRule = 1;
    protected int m_NumRandomFeatures = 0;
    protected transient ThreadPoolExecutor m_ExecutorPool;
    protected int m_Completed;
    protected int m_Failed;
    protected Instances m_Data;
    protected Instances m_Header;
    protected ZeroR m_BackupModel;

    public String globalInfo() {
        return "Generates an ensemble using the following approach:\n- for each attribute apart from class attribute do:\n  * create new dataset with only this feature and the class attribute\n  * remove all instances that contain a missing value\n  * if no instances left in subset, don't build a classifier for this feature\n  * if at least 1 instance is left in subset, build base classifier with it\nIf no classifier gets built at all, use ZeroR as backup model, built on the full dataset.\nIn addition to the default feature for a subset, a number of random features can be added to the subset before the classifier is trained.\nAt prediction time, the Vote meta-classifier (using the pre-built classifiers) is used to determing the class probabilities or regression value.";
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tNumber of execution slots.\n\t(default: 1 - i.e. no parallelism)", "num-slots", 1, "-num-slots <num>"));
        vector.addElement(new Option("\tThe combination rule to use\n\t(default: AVG)", "combination-rule", 1, "-combination-rule " + Tag.toOptionList(Vote.TAGS_RULES)));
        vector.addElement(new Option("\tNumber of random features to use in addition.\n\t(default: 0)", "num-random", 1, "-num-random <num>"));
        Enumeration listOptions = super.listOptions();
        while (listOptions.hasMoreElements()) {
            vector.addElement(listOptions.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption("num-slots", strArr);
        if (option.length() != 0) {
            setNumExecutionSlots(Integer.parseInt(option));
        } else {
            setNumExecutionSlots(1);
        }
        String option2 = Utils.getOption("combination-rule", strArr);
        if (option2.length() != 0) {
            setCombinationRule(new SelectedTag(option2, Vote.TAGS_RULES));
        } else {
            setCombinationRule(new SelectedTag(1, Vote.TAGS_RULES));
        }
        String option3 = Utils.getOption("num-random", strArr);
        if (option3.length() != 0) {
            setNumRandomFeatures(Integer.parseInt(option3));
        } else {
            setNumRandomFeatures(0);
        }
        super.setOptions(strArr);
    }

    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-num-slots");
        vector.add("" + getNumExecutionSlots());
        vector.add("-combination-rule");
        vector.add("" + getCombinationRule());
        vector.add("-num-random");
        vector.add("" + getNumRandomFeatures());
        vector.addAll(Arrays.asList(super.getOptions()));
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    public void setNumExecutionSlots(int i) {
        if (i >= 1) {
            this.m_NumExecutionSlots = i;
        } else {
            System.err.println("Number of execution slots must be >= 1");
        }
    }

    public int getNumExecutionSlots() {
        return this.m_NumExecutionSlots;
    }

    public String numExecutionSlotsTipText() {
        return "The number of execution slots (threads) to use for constructing the ensemble.";
    }

    public void setCombinationRule(SelectedTag selectedTag) {
        if (selectedTag.getTags() == Vote.TAGS_RULES) {
            this.m_CombinationRule = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getCombinationRule() {
        return new SelectedTag(this.m_CombinationRule, Vote.TAGS_RULES);
    }

    public String combinationRuleTipText() {
        return "The combination rule used.";
    }

    public void setNumRandomFeatures(int i) {
        if (i >= 0) {
            this.m_NumRandomFeatures = i;
        } else {
            System.err.println("Number of additional random features must be >= 0");
        }
    }

    public int getNumRandomFeatures() {
        return this.m_NumRandomFeatures;
    }

    public String numRandomFeaturesTipText() {
        return "The number of additional random features to use.";
    }

    protected void startExecutorPool() {
        if (this.m_ExecutorPool != null) {
            this.m_ExecutorPool.shutdownNow();
        }
        this.m_ExecutorPool = new ThreadPoolExecutor(this.m_NumExecutionSlots, this.m_NumExecutionSlots, 120L, TimeUnit.SECONDS, new LinkedBlockingQueue());
    }

    private synchronized void block(boolean z) {
        if (!z) {
            notifyAll();
        } else {
            try {
                wait();
            } catch (InterruptedException e) {
            }
        }
    }

    protected synchronized void buildClassifiers() throws Exception {
        Random random = new Random(this.m_Seed);
        for (int i = 0; i < this.m_Classifiers.length; i++) {
            final int i2 = i;
            final int nextInt = random.nextInt();
            if (this.m_Debug) {
                System.out.print("Training classifier (" + (i + 1) + ")");
            }
            this.m_ExecutorPool.execute(new Runnable() { // from class: weka.classifiers.meta.SubsetEnsemble.1
                @Override // java.lang.Runnable
                public void run() {
                    try {
                        if (SubsetEnsemble.this.getTrainingSet(i2, nextInt).numInstances() > 0) {
                            Classifier filteredClassifier = new FilteredClassifier();
                            filteredClassifier.setFilter(SubsetEnsemble.this.getFilter(i2, nextInt, false));
                            filteredClassifier.setClassifier(SubsetEnsemble.this.m_Classifiers[i2]);
                            filteredClassifier.buildClassifier(SubsetEnsemble.this.m_Data);
                            SubsetEnsemble.this.m_Classifiers[i2] = filteredClassifier;
                        } else {
                            SubsetEnsemble.this.m_Classifiers[i2] = null;
                        }
                        SubsetEnsemble.this.completedClassifier(i2, true);
                    } catch (Exception e) {
                        e.printStackTrace();
                        SubsetEnsemble.this.completedClassifier(i2, false);
                    }
                }
            });
        }
        if (this.m_Completed + this.m_Failed < this.m_Classifiers.length) {
            block(true);
        }
    }

    protected synchronized void completedClassifier(int i, boolean z) {
        if (z) {
            this.m_Completed++;
        } else {
            this.m_Failed++;
            if (this.m_Debug) {
                System.err.println("Building of classifier " + i + " failed!");
            }
        }
        if (this.m_Completed + this.m_Failed == this.m_Classifiers.length) {
            if (this.m_Failed > 0 && this.m_Debug) {
                System.err.println("Problem building classifiers - some iterations failed.");
            }
            this.m_ExecutorPool.shutdown();
            block(false);
            this.m_Data = null;
        }
    }

    protected int getActualIndex(int i) throws Exception {
        int i2 = -1;
        int i3 = 0;
        int i4 = 0;
        while (true) {
            if (i4 >= this.m_Header.numAttributes()) {
                break;
            }
            if (i4 != this.m_Header.classIndex()) {
                if (i3 == i) {
                    i2 = i4;
                    break;
                }
                i3++;
            }
            i4++;
        }
        if (i2 == -1) {
            throw new IllegalScalingException("Actual attribute index for index " + i + " could not be determined!");
        }
        return i2;
    }

    protected Filter getFilter(int i, int i2, boolean z) throws Exception {
        Filter filter;
        int actualIndex = getActualIndex(i);
        HashSet hashSet = new HashSet();
        hashSet.add(Integer.valueOf(actualIndex));
        hashSet.add(Integer.valueOf(this.m_Data.classIndex()));
        if (this.m_NumRandomFeatures > 0) {
            int min = Math.min(this.m_NumRandomFeatures, this.m_Data.numAttributes() - 2);
            Random random = new Random(i2);
            while (hashSet.size() < min) {
                hashSet.add(Integer.valueOf(random.nextInt(this.m_Data.numAttributes())));
            }
        }
        int[] iArr = new int[hashSet.size()];
        int i3 = 0;
        Iterator it = hashSet.iterator();
        while (it.hasNext()) {
            iArr[i3] = ((Integer) it.next()).intValue();
            i3++;
        }
        Arrays.sort(iArr);
        Filter remove = new Remove();
        remove.setAttributeIndicesArray(iArr);
        remove.setInvertSelection(true);
        if (z) {
            Filter removeInstancesWithMissingValue = new RemoveInstancesWithMissingValue();
            filter = new MultiFilter();
            ((MultiFilter) filter).setFilters(new Filter[]{remove, removeInstancesWithMissingValue});
        } else {
            filter = remove;
        }
        return filter;
    }

    protected Instances getTrainingSet(int i, int i2) throws Exception {
        Filter filter = getFilter(i, i2, true);
        filter.setInputFormat(this.m_Data);
        return Filter.useFilter(this.m_Data, filter);
    }

    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        this.m_Data = new Instances(instances);
        this.m_Data.deleteWithMissingClass();
        this.m_Header = new Instances(this.m_Data, 0);
        if (this.m_Classifier == null) {
            throw new Exception("A base classifier has not been specified!");
        }
        this.m_Classifiers = AbstractClassifier.makeCopies(this.m_Classifier, this.m_Data.numAttributes() - 1);
        startExecutorPool();
        this.m_Completed = 0;
        this.m_Failed = 0;
        this.m_BackupModel = new ZeroR();
        this.m_BackupModel.buildClassifier(this.m_Data);
        buildClassifiers();
    }

    protected Classifier constructEnsemble(Instance instance) {
        Vote vote;
        Vector vector = new Vector();
        int i = 0;
        for (int i2 = 0; i2 < instance.numAttributes(); i2++) {
            if (i2 != instance.classIndex() && !instance.isMissing(i2) && this.m_Classifiers[i] != null) {
                vector.add(this.m_Classifiers[i]);
                i++;
            }
        }
        if (vector.size() > 1) {
            vote = new Vote();
            vote.setCombinationRule(getCombinationRule());
            vote.setClassifiers((Classifier[]) vector.toArray(new Classifier[vector.size()]));
        } else {
            vote = vector.size() == 1 ? (Classifier) vector.get(0) : this.m_BackupModel;
        }
        return vote;
    }

    public double classifyInstance(Instance instance) throws Exception {
        return constructEnsemble(instance).classifyInstance(instance);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        return constructEnsemble(instance).distributionForInstance(instance);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (this.m_BackupModel == null) {
            sb.append("No model built yet!");
        } else {
            sb.append("--> Backup model\n");
            sb.append(this.m_BackupModel.toString());
            sb.append("\n");
            for (int i = 0; i < this.m_Classifiers.length; i++) {
                try {
                    sb.append("\n");
                    sb.append("--> Classifier #" + (i + 1) + " (for attribute #" + (getActualIndex(i) + 1) + "):\n");
                    if (this.m_Classifiers[i] == null) {
                        sb.append("No model built - no useful data available");
                    } else {
                        sb.append(this.m_Classifiers[i].toString());
                    }
                    sb.append("\n");
                } catch (Exception e) {
                    sb.append("Classifier #" + (i + 1) + ": skipped due to error\n");
                }
            }
        }
        return sb.toString();
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 2991 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new SubsetEnsemble(), strArr);
    }
}
