/*
 * Decompiled with CFR 0.152.
 */
package weka.attributeSelection;

import java.util.BitSet;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.attributeSelection.ASEvaluation;
import weka.attributeSelection.SubsetEvaluator;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.rules.ZeroR;
import weka.core.Capabilities;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

public class WrapperSubsetEval
extends ASEvaluation
implements SubsetEvaluator,
OptionHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = -4573057658746728675L;
    private Instances m_trainInstances;
    private int m_classIndex;
    private int m_numAttribs;
    private int m_numInstances;
    private Evaluation m_Evaluation;
    private Classifier m_BaseClassifier;
    private int m_folds;
    private int m_seed;
    private double m_threshold;
    public static final int EVAL_DEFAULT = 1;
    public static final int EVAL_ACCURACY = 2;
    public static final int EVAL_RMSE = 3;
    public static final int EVAL_MAE = 4;
    public static final int EVAL_FMEASURE = 5;
    public static final int EVAL_AUC = 6;
    public static final Tag[] TAGS_EVALUATION = new Tag[]{new Tag(1, "Default: accuracy (discrete class); RMSE (numeric class)"), new Tag(2, "Accuracy (discrete class only)"), new Tag(3, "RMSE (of the class probabilities for discrete class)"), new Tag(4, "MAE (of the class probabilities for discrete class)"), new Tag(5, "F-measure (discrete class only)"), new Tag(6, "AUC (area under the ROC curve - discrete class only)")};
    protected int m_evaluationMeasure = 1;

    public String globalInfo() {
        return "WrapperSubsetEval:\n\nEvaluates attribute sets by using a learning scheme. Cross validation is used to estimate the accuracy of the learning scheme for a set of attributes.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Ron Kohavi and George H. John");
        result.setValue(TechnicalInformation.Field.YEAR, "1997");
        result.setValue(TechnicalInformation.Field.TITLE, "Wrappers for feature subset selection");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Artificial Intelligence");
        result.setValue(TechnicalInformation.Field.VOLUME, "97");
        result.setValue(TechnicalInformation.Field.NUMBER, "1-2");
        result.setValue(TechnicalInformation.Field.PAGES, "273-324");
        result.setValue(TechnicalInformation.Field.NOTE, "Special issue on relevance");
        result.setValue(TechnicalInformation.Field.ISSN, "0004-3702");
        return result;
    }

    public WrapperSubsetEval() {
        this.resetOptions();
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(4);
        newVector.addElement(new Option("\tclass name of base learner to use for \taccuracy estimation.\n\tPlace any classifier options LAST on the command line\n\tfollowing a \"--\". eg.:\n\t\t-B weka.classifiers.bayes.NaiveBayes ... -- -K\n\t(default: weka.classifiers.rules.ZeroR)", "B", 1, "-B <base learner>"));
        newVector.addElement(new Option("\tnumber of cross validation folds to use for estimating accuracy.\n\t(default=5)", "F", 1, "-F <num>"));
        newVector.addElement(new Option("\tSeed for cross validation accuracy testimation.\n\t(default = 1)", "R", 1, "-R <seed>"));
        newVector.addElement(new Option("\tthreshold by which to execute another cross validation\n\t(standard deviation---expressed as a percentage of the mean).\n\t(default: 0.01 (1%))", "T", 1, "-T <num>"));
        newVector.addElement(new Option("\tPerformance evaluation measure to use for selecting attributes.\n\t(Default = accuracy for discrete class and rmse for numeric class)", "E", 1, "-E <acc | rmse | mae | f-meas | auc>"));
        if (this.m_BaseClassifier != null && this.m_BaseClassifier instanceof OptionHandler) {
            newVector.addElement(new Option("", "", 0, "\nOptions specific to scheme " + this.m_BaseClassifier.getClass().getName() + ":"));
            Enumeration enu = ((OptionHandler)((Object)this.m_BaseClassifier)).listOptions();
            while (enu.hasMoreElements()) {
                newVector.addElement((Option)enu.nextElement());
            }
        }
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        this.resetOptions();
        String optionString = Utils.getOption('B', options);
        if (optionString.length() == 0) {
            optionString = ZeroR.class.getName();
        }
        this.setClassifier(AbstractClassifier.forName(optionString, Utils.partitionOptions(options)));
        optionString = Utils.getOption('F', options);
        if (optionString.length() != 0) {
            this.setFolds(Integer.parseInt(optionString));
        }
        if ((optionString = Utils.getOption('R', options)).length() != 0) {
            this.setSeed(Integer.parseInt(optionString));
        }
        if ((optionString = Utils.getOption('T', options)).length() != 0) {
            Double temp = Double.valueOf(optionString);
            this.setThreshold(temp);
        }
        if ((optionString = Utils.getOption('E', options)).length() != 0) {
            if (optionString.equals("acc")) {
                this.setEvaluationMeasure(new SelectedTag(2, TAGS_EVALUATION));
            } else if (optionString.equals("rmse")) {
                this.setEvaluationMeasure(new SelectedTag(3, TAGS_EVALUATION));
            } else if (optionString.equals("mae")) {
                this.setEvaluationMeasure(new SelectedTag(4, TAGS_EVALUATION));
            } else if (optionString.equals("f-meas")) {
                this.setEvaluationMeasure(new SelectedTag(5, TAGS_EVALUATION));
            } else if (optionString.equals("auc")) {
                this.setEvaluationMeasure(new SelectedTag(6, TAGS_EVALUATION));
            } else {
                throw new IllegalArgumentException("Invalid evaluation measure");
            }
        }
    }

    public String evaluationMeasureTipText() {
        return "The measure used to evaluate the performance of attribute combinations.";
    }

    public SelectedTag getEvaluationMeasure() {
        return new SelectedTag(this.m_evaluationMeasure, TAGS_EVALUATION);
    }

    public void setEvaluationMeasure(SelectedTag newMethod) {
        if (newMethod.getTags() == TAGS_EVALUATION) {
            this.m_evaluationMeasure = newMethod.getSelectedTag().getID();
        }
    }

    public String thresholdTipText() {
        return "Repeat xval if stdev of mean exceeds this value.";
    }

    public void setThreshold(double t) {
        this.m_threshold = t;
    }

    public double getThreshold() {
        return this.m_threshold;
    }

    public String foldsTipText() {
        return "Number of xval folds to use when estimating subset accuracy.";
    }

    public void setFolds(int f) {
        this.m_folds = f;
    }

    public int getFolds() {
        return this.m_folds;
    }

    public String seedTipText() {
        return "Seed to use for randomly generating xval splits.";
    }

    public void setSeed(int s) {
        this.m_seed = s;
    }

    public int getSeed() {
        return this.m_seed;
    }

    public String classifierTipText() {
        return "Classifier to use for estimating the accuracy of subsets";
    }

    public void setClassifier(Classifier newClassifier) {
        this.m_BaseClassifier = newClassifier;
    }

    public Classifier getClassifier() {
        return this.m_BaseClassifier;
    }

    @Override
    public String[] getOptions() {
        String[] classifierOptions = new String[]{};
        if (this.m_BaseClassifier != null && this.m_BaseClassifier instanceof OptionHandler) {
            classifierOptions = ((OptionHandler)((Object)this.m_BaseClassifier)).getOptions();
        }
        String[] options = new String[9 + classifierOptions.length];
        int current = 0;
        if (this.getClassifier() != null) {
            options[current++] = "-B";
            options[current++] = this.getClassifier().getClass().getName();
        }
        options[current++] = "-F";
        options[current++] = "" + this.getFolds();
        options[current++] = "-T";
        options[current++] = "" + this.getThreshold();
        options[current++] = "-R";
        options[current++] = "" + this.getSeed();
        options[current++] = "--";
        System.arraycopy(classifierOptions, 0, options, current, classifierOptions.length);
        current += classifierOptions.length;
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    protected void resetOptions() {
        this.m_trainInstances = null;
        this.m_Evaluation = null;
        this.m_BaseClassifier = new ZeroR();
        this.m_folds = 5;
        this.m_seed = 1;
        this.m_threshold = 0.01;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result;
        if (this.getClassifier() == null) {
            result = super.getCapabilities();
            result.disableAll();
        } else {
            result = this.getClassifier().getCapabilities();
        }
        Capabilities.Capability[] capabilityArray = Capabilities.Capability.values();
        int n = capabilityArray.length;
        int n2 = 0;
        while (n2 < n) {
            Capabilities.Capability cap = capabilityArray[n2];
            result.enableDependency(cap);
            ++n2;
        }
        result.disable(Capabilities.Capability.NUMERIC_CLASS);
        result.disable(Capabilities.Capability.DATE_CLASS);
        if (this.m_evaluationMeasure != 2 && this.m_evaluationMeasure != 5 && this.m_evaluationMeasure != 6) {
            result.enable(Capabilities.Capability.NUMERIC_CLASS);
            result.enable(Capabilities.Capability.DATE_CLASS);
        }
        result.setMinimumNumberInstances(this.getFolds());
        return result;
    }

    @Override
    public void buildEvaluator(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        this.m_trainInstances = data;
        this.m_classIndex = this.m_trainInstances.classIndex();
        this.m_numAttribs = this.m_trainInstances.numAttributes();
        this.m_numInstances = this.m_trainInstances.numInstances();
    }

    @Override
    public double evaluateSubset(BitSet subset) throws Exception {
        double evalMetric = 0.0;
        double[] repError = new double[5];
        int numAttributes = 0;
        Random Rnd = new Random(this.m_seed);
        Remove delTransform = new Remove();
        delTransform.setInvertSelection(true);
        Instances trainCopy = new Instances(this.m_trainInstances);
        int i = 0;
        while (i < this.m_numAttribs) {
            if (subset.get(i)) {
                ++numAttributes;
            }
            ++i;
        }
        int[] featArray = new int[numAttributes + 1];
        i = 0;
        int j = 0;
        while (i < this.m_numAttribs) {
            if (subset.get(i)) {
                featArray[j++] = i;
            }
            ++i;
        }
        featArray[j] = this.m_classIndex;
        delTransform.setAttributeIndicesArray(featArray);
        delTransform.setInputFormat(trainCopy);
        trainCopy = Filter.useFilter(trainCopy, delTransform);
        i = 0;
        while (i < 5) {
            this.m_Evaluation = new Evaluation(trainCopy);
            this.m_Evaluation.crossValidateModel(this.m_BaseClassifier, trainCopy, this.m_folds, Rnd, new Object[0]);
            switch (this.m_evaluationMeasure) {
                case 1: {
                    repError[i] = this.m_Evaluation.errorRate();
                    break;
                }
                case 2: {
                    repError[i] = this.m_Evaluation.errorRate();
                    break;
                }
                case 3: {
                    repError[i] = this.m_Evaluation.rootMeanSquaredError();
                    break;
                }
                case 4: {
                    repError[i] = this.m_Evaluation.meanAbsoluteError();
                    break;
                }
                case 5: {
                    repError[i] = this.m_Evaluation.weightedFMeasure();
                    break;
                }
                case 6: {
                    repError[i] = this.m_Evaluation.weightedAreaUnderROC();
                }
            }
            if (!this.repeat(repError, i + 1)) {
                ++i;
                break;
            }
            ++i;
        }
        j = 0;
        while (j < i) {
            evalMetric += repError[j];
            ++j;
        }
        evalMetric /= (double)i;
        this.m_Evaluation = null;
        switch (this.m_evaluationMeasure) {
            case 1: 
            case 2: 
            case 3: 
            case 4: {
                evalMetric = -evalMetric;
            }
        }
        return evalMetric;
    }

    public String toString() {
        StringBuffer text = new StringBuffer();
        if (this.m_trainInstances == null) {
            text.append("\tWrapper subset evaluator has not been built yet\n");
        } else {
            text.append("\tWrapper Subset Evaluator\n");
            text.append("\tLearning scheme: " + this.getClassifier().getClass().getName() + "\n");
            text.append("\tScheme options: ");
            String[] classifierOptions = new String[]{};
            if (this.m_BaseClassifier instanceof OptionHandler) {
                classifierOptions = ((OptionHandler)((Object)this.m_BaseClassifier)).getOptions();
                int i = 0;
                while (i < classifierOptions.length) {
                    text.append(String.valueOf(classifierOptions[i]) + " ");
                    ++i;
                }
            }
            text.append("\n");
            switch (this.m_evaluationMeasure) {
                case 1: 
                case 2: {
                    if (this.m_trainInstances.attribute(this.m_classIndex).isNumeric()) {
                        text.append("\tSubset evaluation: RMSE\n");
                        break;
                    }
                    text.append("\tSubset evaluation: classification error\n");
                    break;
                }
                case 3: {
                    if (this.m_trainInstances.attribute(this.m_classIndex).isNumeric()) {
                        text.append("\tSubset evaluation: RMSE\n");
                        break;
                    }
                    text.append("\tSubset evaluation: RMSE (probability estimates)\n");
                    break;
                }
                case 4: {
                    if (this.m_trainInstances.attribute(this.m_classIndex).isNumeric()) {
                        text.append("\tSubset evaluation: MAE\n");
                        break;
                    }
                    text.append("\tSubset evaluation: MAE (probability estimates)\n");
                    break;
                }
                case 5: {
                    text.append("\tSubset evaluation: F-measure\n");
                    break;
                }
                case 6: {
                    text.append("\tSubset evaluation: area under the ROC curve\n");
                }
            }
            text.append("\tNumber of folds for accuracy estimation: " + this.m_folds + "\n");
        }
        return text.toString();
    }

    private boolean repeat(double[] repError, int entries) {
        double mean = 0.0;
        double variance = 0.0;
        if (entries == 1) {
            return true;
        }
        int i = 0;
        while (i < entries) {
            mean += repError[i];
            ++i;
        }
        mean /= (double)entries;
        i = 0;
        while (i < entries) {
            variance += (repError[i] - mean) * (repError[i] - mean);
            ++i;
        }
        if ((variance /= (double)entries) > 0.0) {
            variance = Math.sqrt(variance);
        }
        return variance / mean > this.m_threshold;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5928 $");
    }

    public static void main(String[] args) {
        WrapperSubsetEval.runEvaluator(new WrapperSubsetEval(), args);
    }
}

