/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import adams.core.io.FileUtils;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.SingleClassifierEnhancer;
import weka.classifiers.functions.LinearRegression;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

public class Corr
extends SingleClassifierEnhancer
implements WeightedInstancesHandler {
    private static final long serialVersionUID = -8615616615151098897L;
    protected Remove m_remove;
    protected int m_classIndex = -1;
    protected int m_k = 0;
    protected int[] m_subset;
    protected double[] m_coeffs;

    protected String defaultClassifierString() {
        return LinearRegression.class.getName();
    }

    public String globalInfo() {
        return "Assume NO MISSING VALUES, all attributes must be  NUMERIC (or 0/1 maybe ...). Simple attribute selection for regression: select k most correlated attrs ...";
    }

    public Enumeration listOptions() {
        Vector newVector = new Vector();
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement(enu.nextElement());
        }
        newVector.addElement(new Option("\tThe number of attrs. (default: 0 = numEx/2)", "K", 1, "-K <int>"));
        return newVector.elements();
    }

    public void setOptions(String[] options) throws Exception {
        String tmpStr = Utils.getOption((char)'K', (String[])options);
        if (tmpStr.length() > 0) {
            this.setNumattrs(Integer.parseInt(tmpStr));
        } else {
            this.setNumattrs(0);
        }
        super.setOptions(options);
    }

    public String[] getOptions() {
        String[] options;
        ArrayList<String> result = new ArrayList<String>();
        result.add("-K");
        result.add("" + this.getNumattrs());
        for (String option : options = super.getOptions()) {
            result.add(option);
        }
        return result.toArray(new String[result.size()]);
    }

    public int getNumattrs() {
        return this.m_k;
    }

    public void setNumattrs(int k) {
        this.m_k = k;
    }

    public String numattrsTipText() {
        return "The number of attributes.";
    }

    public double[] sampleCorrs(Instances data) {
        double[] mean = this.means(data);
        double[] s = this.sampleDevs(data, mean);
        double[] sumXY = this.sumXY(data);
        int n = data.numInstances();
        int k = mean.length;
        double[] corrs = new double[k];
        int yIndex = data.classIndex();
        for (int i = 0; i < k; ++i) {
            double r = (sumXY[i] - (double)n * mean[i] * mean[yIndex]) / (((double)n - 1.0) * s[i] * s[yIndex]);
            corrs[i] = r * r;
            if (!this.getDebug()) continue;
            System.out.println("x " + i + " sXY " + sumXY[i] + " m " + mean[i] + " mY " + mean[yIndex] + " s " + s[i] + " sY " + s[yIndex]);
        }
        return corrs;
    }

    public double[] sumXY(Instances data) {
        int classIndex = data.classIndex();
        int k = data.numAttributes();
        double[] sum = new double[k];
        for (int i = 0; i < data.numInstances(); ++i) {
            Instance instance = data.instance(i);
            double classValue = instance.value(classIndex);
            for (int j = 0; j < k; ++j) {
                int n = j;
                sum[n] = sum[n] + classValue * instance.value(j);
            }
        }
        return sum;
    }

    public double[] sampleDevs(Instances data, double[] mean) {
        int j;
        int k = mean.length;
        double[] sum = new double[k];
        for (int i = 0; i < data.numInstances(); ++i) {
            Instance instance = data.instance(i);
            j = 0;
            while (j < k) {
                double delta = instance.value(j) - mean[j];
                int n = j++;
                sum[n] = sum[n] + delta * delta;
            }
        }
        double factor = 1.0 / ((double)data.numInstances() - 1.0);
        for (j = 0; j < k; ++j) {
            sum[j] = Math.sqrt(sum[j] * factor);
        }
        return sum;
    }

    public double[] msq(Instances data) {
        int k = data.numAttributes();
        int classIndex = data.classIndex();
        double[] mean = this.means(data);
        double[] slope = new double[k];
        double[] sumWeightedDiffSquared = new double[k];
        for (int j = 0; j < data.numInstances(); ++j) {
            Instance inst = data.instance(j);
            double yDiff = inst.value(classIndex) - mean[classIndex];
            int i = 0;
            while (i < k) {
                double xDiff = inst.value(i) - mean[i];
                double weightedXDiff = inst.weight() * xDiff;
                int n = i;
                slope[n] = slope[n] + weightedXDiff * yDiff;
                int n2 = i++;
                sumWeightedDiffSquared[n2] = sumWeightedDiffSquared[n2] + weightedXDiff * xDiff;
            }
        }
        double[] msq = new double[k];
        Arrays.fill(msq, Double.MAX_VALUE);
        for (int i = 0; i < msq.length; ++i) {
            double sse;
            if (i == classIndex || sumWeightedDiffSquared[i] == 0.0 || Double.isInfinite(sse = sumWeightedDiffSquared[classIndex] - slope[i] * slope[i] / sumWeightedDiffSquared[i]) || Double.isNaN(sse)) continue;
            msq[i] = sse;
        }
        return msq;
    }

    public double[] means(Instances data) {
        int j;
        int k = data.numAttributes();
        double[] sum = new double[k];
        for (int i = 0; i < data.numInstances(); ++i) {
            Instance instance = data.instance(i);
            for (j = 0; j < k; ++j) {
                int n = j;
                sum[n] = sum[n] + instance.value(j);
            }
        }
        double factor = 1.0 / (double)data.numInstances();
        j = 0;
        while (j < k) {
            int n = j++;
            sum[n] = sum[n] * factor;
        }
        return sum;
    }

    public String keepIndices(Instances data) {
        int numKeep;
        double[] msq = this.msq(data);
        int[] ascending = Utils.sort((double[])msq);
        if (this.getDebug()) {
            System.out.println("msq " + Arrays.toString(msq));
            System.out.println("order " + Arrays.toString(ascending));
        }
        if ((numKeep = this.getNumattrs()) == 0) {
            numKeep = (data.numInstances() + 1) / 2;
        }
        int classIndex = data.classIndex();
        if (this.getDebug()) {
            System.out.println("keep " + data.numInstances() + " => " + numKeep + " ci " + classIndex);
        }
        StringBuilder sb = new StringBuilder("" + (1 + classIndex));
        int[] indices = new int[numKeep];
        for (int i = 0; i < numKeep; ++i) {
            int index = ascending[i];
            assert (index != classIndex);
            sb.append("," + (1 + index));
            indices[i] = index;
        }
        if (this.getDebug()) {
            System.out.println("keep " + data.numInstances() + " " + numKeep + " " + sb.toString());
            System.out.println("indices " + indices.length + " " + Arrays.toString(indices));
        }
        this.m_subset = indices;
        return sb.toString();
    }

    public String keepIndicesBasedOnCorrelation(Instances data) {
        int numKeep;
        double[] corrs = this.sampleCorrs(data);
        int[] ascending = Utils.sort((double[])corrs);
        if (this.getDebug()) {
            System.out.println("corrs " + Arrays.toString(corrs));
            System.out.println("order " + Arrays.toString(ascending));
        }
        if ((numKeep = this.getNumattrs()) == 0) {
            numKeep = (data.numInstances() + 1) / 2;
        }
        int classIndex = data.classIndex();
        StringBuilder sb = new StringBuilder("" + (1 + classIndex));
        int[] indices = new int[numKeep];
        int offset = 0;
        for (int i = ascending.length - numKeep; i < ascending.length; ++i) {
            int index = ascending[i];
            if (index == classIndex) continue;
            sb.append("," + (1 + index));
            indices[offset++] = index;
        }
        if (this.getDebug()) {
            System.out.println("keep " + data.numInstances() + " " + numKeep + " " + sb.toString());
            System.out.println("indices " + indices.length + " " + Arrays.toString(indices));
        }
        return sb.toString();
    }

    public Instances getSubset(Instances data) throws Exception {
        this.m_remove = null;
        if (data.numAttributes() - 1 < this.getNumattrs()) {
            return data;
        }
        if (this.getNumattrs() == 0 && data.numInstances() >= 2 * (data.numAttributes() - 1)) {
            return data;
        }
        this.m_remove = new Remove();
        String toKeep = this.keepIndices(data);
        this.m_remove.setOptions(new String[]{"-R", toKeep, "-V"});
        this.m_remove.setInputFormat(data);
        return Filter.useFilter((Instances)data, (Filter)this.m_remove);
    }

    public void buildClassifier(Instances data) throws Exception {
        Instances train = this.getSubset(data);
        this.m_Classifier.buildClassifier(train);
        if (this.m_Classifier instanceof LinearRegression) {
            double[] coeffs = ((LinearRegression)this.m_Classifier).coefficients();
            if (this.getDebug()) {
                System.out.println("coeffs " + coeffs.length + " " + Arrays.toString(coeffs));
            }
            if (this.getDebug()) {
                System.out.println(this.m_Classifier);
            }
            this.m_coeffs = new double[coeffs.length - 1];
            int offset = 0;
            for (int i = 0; i < this.m_coeffs.length; ++i) {
                if (i == train.classIndex()) continue;
                this.m_coeffs[offset++] = coeffs[i];
            }
            this.m_coeffs[offset] = coeffs[this.m_coeffs.length];
        }
    }

    public void saveObject(Object o) throws Exception {
        FileOutputStream fos = new FileOutputStream("xxx.ser");
        ObjectOutputStream oos = new ObjectOutputStream(fos);
        oos.writeObject(o);
        FileUtils.closeQuietly((OutputStream)oos);
        FileUtils.closeQuietly((OutputStream)fos);
    }

    public double classifyInstance(Instance instance) throws Exception {
        if (this.m_coeffs != null) {
            if (this.m_subset != null) {
                double sum = this.m_coeffs[this.m_coeffs.length - 1];
                for (int i = 0; i < this.m_subset.length; ++i) {
                    sum += this.m_coeffs[i] * instance.value(this.m_subset[i]);
                }
                return sum;
            }
            int offset = 0;
            double sum = 0.0;
            for (int i = 0; i < instance.numAttributes(); ++i) {
                if (i == instance.classIndex()) continue;
                sum += this.m_coeffs[offset++] * instance.value(i);
            }
            return sum += this.m_coeffs[offset];
        }
        if (this.m_remove != null) {
            this.m_remove.input(instance);
            this.m_remove.batchFinished();
            instance = this.m_remove.output();
        }
        return this.m_Classifier.classifyInstance(instance);
    }

    public double[] getCoeffs() {
        return (double[])this.m_coeffs.clone();
    }

    public int[] getSubset() {
        if (this.m_subset == null) {
            return null;
        }
        return (int[])this.m_subset.clone();
    }

    public String toString() {
        return this.m_Classifier.toString();
    }

    public String getRevision() {
        return "$Revision: 10824 $";
    }

    public static void main(String[] argv) throws Exception {
        Corr.runClassifier((Classifier)new Corr(), (String[])argv);
    }
}

