/*
 * Decompiled with CFR 0.152.
 */
package weka.attributeSelection;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Vector;
import weka.attributeSelection.ASEvaluation;
import weka.attributeSelection.AttributeEvaluator;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

public class CorrelationAttributeEval
extends ASEvaluation
implements AttributeEvaluator,
OptionHandler {
    private static final long serialVersionUID = -4931946995055872438L;
    protected double[] m_correlations;
    protected boolean m_detailedOutput = false;
    protected StringBuffer m_detailedOutputBuff;

    public String globalInfo() {
        return "CorrelationAttributeEval :\n\nEvaluates the worth of an attribute by measuring the correlation (Pearson's) between it and the class.\n\nNominal attributes are considered on a value by value basis by treating each value as an indicator. An overall correlation for a nominal attribute is arrived at via a weighted average.\n";
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>();
        newVector.addElement(new Option("\tOutput detailed info for nominal attributes", "D", 0, "-D"));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        this.setOutputDetailedInfo(Utils.getFlag('D', options));
    }

    @Override
    public String[] getOptions() {
        String[] options = new String[]{this.getOutputDetailedInfo() ? "-D" : ""};
        return options;
    }

    public String outputDetailedInfoTipText() {
        return "Output per value correlation for nominal attributes";
    }

    public void setOutputDetailedInfo(boolean d) {
        this.m_detailedOutput = d;
    }

    public boolean getOutputDetailedInfo() {
        return this.m_detailedOutput;
    }

    @Override
    public double evaluateAttribute(int attribute) throws Exception {
        return this.m_correlations[attribute];
    }

    public String toString() {
        StringBuffer buff = new StringBuffer();
        if (this.m_correlations == null) {
            buff.append("Correlation attribute evaluator has not been built yet.");
        } else {
            buff.append("\tCorrelation Ranking Filter");
            if (this.m_detailedOutput && this.m_detailedOutputBuff.length() > 0) {
                buff.append("\n\tDetailed output for nominal attributes");
                buff.append(this.m_detailedOutputBuff);
            }
        }
        return buff.toString();
    }

    @Override
    public void buildEvaluator(Instances data) throws Exception {
        int i;
        data = new Instances(data);
        data.deleteWithMissingClass();
        ReplaceMissingValues rmv = new ReplaceMissingValues();
        rmv.setInputFormat(data);
        data = Filter.useFilter(data, rmv);
        int numClasses = data.classAttribute().numValues();
        int classIndex = data.classIndex();
        int numInstances = data.numInstances();
        this.m_correlations = new double[data.numAttributes()];
        ArrayList<Integer> numericIndexes = new ArrayList<Integer>();
        ArrayList<Integer> nominalIndexes = new ArrayList<Integer>();
        if (this.m_detailedOutput) {
            this.m_detailedOutputBuff = new StringBuffer();
        }
        double[][][] nomAtts = new double[data.numAttributes()][][];
        for (i = 0; i < data.numAttributes(); ++i) {
            if (data.attribute(i).isNominal() && i != classIndex) {
                nomAtts[i] = new double[data.attribute(i).numValues()][data.numInstances()];
                Arrays.fill(nomAtts[i][0], 1.0);
                nominalIndexes.add(i);
                continue;
            }
            if (!data.attribute(i).isNumeric() || i == classIndex) continue;
            numericIndexes.add(i);
        }
        if (nominalIndexes.size() > 0) {
            for (i = 0; i < data.numInstances(); ++i) {
                Instance current = data.instance(i);
                for (int j = 0; j < current.numValues(); ++j) {
                    if (!current.attribute(current.index(j)).isNominal() || current.index(j) == classIndex) continue;
                    double[] dArray = nomAtts[current.index(j)][(int)current.valueSparse(j)];
                    int n = i;
                    dArray[n] = dArray[n] + 1.0;
                    double[] dArray2 = nomAtts[current.index(j)][0];
                    int n2 = i;
                    dArray2[n2] = dArray2[n2] - 1.0;
                }
            }
        }
        if (data.classAttribute().isNumeric()) {
            double[] classVals = data.attributeToDoubleArray(classIndex);
            for (Integer i2 : numericIndexes) {
                double[] numAttVals = data.attributeToDoubleArray(i2);
                this.m_correlations[i2.intValue()] = Utils.correlation(numAttVals, classVals, numAttVals.length);
                if (this.m_correlations[i2] != 1.0 || Utils.variance(numAttVals) != 0.0) continue;
                this.m_correlations[i2.intValue()] = 0.0;
            }
            if (nominalIndexes.size() > 0) {
                for (Integer i3 : nominalIndexes) {
                    double sum = 0.0;
                    double corr = 0.0;
                    double sumCorr = 0.0;
                    double sumForValue = 0.0;
                    if (this.m_detailedOutput) {
                        this.m_detailedOutputBuff.append("\n\n").append(data.attribute(i3).name());
                    }
                    for (int j = 0; j < data.attribute(i3).numValues(); ++j) {
                        sumForValue = Utils.sum(nomAtts[i3][j]);
                        corr = Utils.correlation(nomAtts[i3][j], classVals, classVals.length);
                        if (sumForValue == (double)numInstances || sumForValue == 0.0) {
                            corr = 0.0;
                        }
                        if (corr < 0.0) {
                            corr = -corr;
                        }
                        sumCorr += sumForValue * corr;
                        sum += sumForValue;
                        if (!this.m_detailedOutput) continue;
                        this.m_detailedOutputBuff.append("\n\t").append(data.attribute(i3).value(j)).append(": ");
                        this.m_detailedOutputBuff.append(Utils.doubleToString(corr, 6));
                    }
                    this.m_correlations[i3.intValue()] = sum > 0.0 ? sumCorr / sum : 0.0;
                }
            }
        } else {
            int i4;
            double[][] binarizedClasses = new double[data.classAttribute().numValues()][data.numInstances()];
            double[] classValCounts = new double[data.classAttribute().numValues()];
            for (i4 = 0; i4 < data.numInstances(); ++i4) {
                Instance current = data.instance(i4);
                binarizedClasses[(int)current.classValue()][i4] = 1.0;
            }
            for (i4 = 0; i4 < data.classAttribute().numValues(); ++i4) {
                classValCounts[i4] = Utils.sum(binarizedClasses[i4]);
            }
            double sumClass = Utils.sum(classValCounts);
            if (numericIndexes.size() > 0) {
                for (Integer i5 : numericIndexes) {
                    double[] numAttVals = data.attributeToDoubleArray(i5);
                    double corr = 0.0;
                    double sumCorr = 0.0;
                    for (int j = 0; j < data.classAttribute().numValues(); ++j) {
                        corr = Utils.correlation(numAttVals, binarizedClasses[j], numAttVals.length);
                        if (corr < 0.0) {
                            corr = -corr;
                        }
                        if (corr == 1.0 && Utils.variance(numAttVals) == 0.0) {
                            corr = 0.0;
                        }
                        sumCorr += classValCounts[j] * corr;
                    }
                    this.m_correlations[i5.intValue()] = sumCorr / sumClass;
                }
            }
            if (nominalIndexes.size() > 0) {
                for (Integer i5 : nominalIndexes) {
                    if (this.m_detailedOutput) {
                        this.m_detailedOutputBuff.append("\n\n").append(data.attribute(i5).name());
                    }
                    double sumForAtt = 0.0;
                    double corrForAtt = 0.0;
                    for (int j = 0; j < data.attribute(i5).numValues(); ++j) {
                        double sumForValue = Utils.sum(nomAtts[i5][j]);
                        double corr = 0.0;
                        double sumCorr = 0.0;
                        double avgCorrForValue = 0.0;
                        sumForAtt += sumForValue;
                        for (int k = 0; k < numClasses; ++k) {
                            corr = Utils.correlation(nomAtts[i5][j], binarizedClasses[k], binarizedClasses[k].length);
                            if (sumForValue == (double)numInstances || sumForValue == 0.0) {
                                corr = 0.0;
                            }
                            if (corr < 0.0) {
                                corr = -corr;
                            }
                            sumCorr += classValCounts[k] * corr;
                        }
                        avgCorrForValue = sumCorr / sumClass;
                        corrForAtt += sumForValue * avgCorrForValue;
                        if (!this.m_detailedOutput) continue;
                        this.m_detailedOutputBuff.append("\n\t").append(data.attribute(i5).value(j)).append(": ");
                        this.m_detailedOutputBuff.append(Utils.doubleToString(avgCorrForValue, 6));
                    }
                    this.m_correlations[i5.intValue()] = sumForAtt > 0.0 ? corrForAtt / sumForAtt : 0.0;
                }
            }
        }
        if (this.m_detailedOutputBuff != null && this.m_detailedOutputBuff.length() > 0) {
            this.m_detailedOutputBuff.append("\n");
        }
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 9049 $");
    }

    public static void main(String[] args) {
        CorrelationAttributeEval.runEvaluator(new CorrelationAttributeEval(), args);
    }
}

