package weka.classifiers.meta;

import adams.core.io.FileUtils;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.SingleClassifierEnhancer;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.functions.LinearRegressionJ;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

/* loaded from: input_file:weka/classifiers/meta/Corr.class */
public class Corr extends SingleClassifierEnhancer implements WeightedInstancesHandler {
    private static final long serialVersionUID = -8615616615151098897L;
    protected Remove m_remove;
    protected int m_classIndex = -1;
    protected int m_k = 0;
    protected int[] m_subset;
    protected double[] m_coeffs;
    static final /* synthetic */ boolean $assertionsDisabled;

    protected String defaultClassifierString() {
        return LinearRegression.class.getName();
    }

    public String globalInfo() {
        return "Assume NO MISSING VALUES, all attributes must be  NUMERIC (or 0/1 maybe ...). Simple attribute selection for regression: select k most correlated attrs ...";
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        Enumeration listOptions = super.listOptions();
        while (listOptions.hasMoreElements()) {
            vector.addElement(listOptions.nextElement());
        }
        vector.addElement(new Option("\tThe number of attrs. (default: 0 = numEx/2)", "K", 1, "-K <int>"));
        return vector.elements();
    }

    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('K', strArr);
        if (option.length() > 0) {
            setNumattrs(Integer.parseInt(option));
        } else {
            setNumattrs(0);
        }
        super.setOptions(strArr);
    }

    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        arrayList.add("-K");
        arrayList.add("" + getNumattrs());
        for (String str : super.getOptions()) {
            arrayList.add(str);
        }
        return (String[]) arrayList.toArray(new String[arrayList.size()]);
    }

    public int getNumattrs() {
        return this.m_k;
    }

    public void setNumattrs(int i) {
        this.m_k = i;
    }

    public String numattrsTipText() {
        return "The number of attributes.";
    }

    public double[] sampleCorrs(Instances instances) {
        double[] means = means(instances);
        double[] sampleDevs = sampleDevs(instances, means);
        double[] sumXY = sumXY(instances);
        int numInstances = instances.numInstances();
        int length = means.length;
        double[] dArr = new double[length];
        int classIndex = instances.classIndex();
        for (int i = 0; i < length; i++) {
            double d = (sumXY[i] - ((numInstances * means[i]) * means[classIndex])) / (((numInstances - 1.0d) * sampleDevs[i]) * sampleDevs[classIndex]);
            dArr[i] = d * d;
            if (getDebug()) {
                System.out.println("x " + i + " sXY " + sumXY[i] + " m " + means[i] + " mY " + means[classIndex] + " s " + sampleDevs[i] + " sY " + sampleDevs[classIndex]);
            }
        }
        return dArr;
    }

    public double[] sumXY(Instances instances) {
        int classIndex = instances.classIndex();
        int numAttributes = instances.numAttributes();
        double[] dArr = new double[numAttributes];
        for (int i = 0; i < instances.numInstances(); i++) {
            Instance instance = instances.instance(i);
            double value = instance.value(classIndex);
            for (int i2 = 0; i2 < numAttributes; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + (value * instance.value(i2));
            }
        }
        return dArr;
    }

    public double[] sampleDevs(Instances instances, double[] dArr) {
        int length = dArr.length;
        double[] dArr2 = new double[length];
        for (int i = 0; i < instances.numInstances(); i++) {
            Instance instance = instances.instance(i);
            for (int i2 = 0; i2 < length; i2++) {
                double value = instance.value(i2) - dArr[i2];
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + (value * value);
            }
        }
        double numInstances = 1.0d / (instances.numInstances() - 1.0d);
        for (int i4 = 0; i4 < length; i4++) {
            dArr2[i4] = Math.sqrt(dArr2[i4] * numInstances);
        }
        return dArr2;
    }

    public double[] msq(Instances instances) {
        int numAttributes = instances.numAttributes();
        int classIndex = instances.classIndex();
        double[] means = means(instances);
        double[] dArr = new double[numAttributes];
        double[] dArr2 = new double[numAttributes];
        for (int i = 0; i < instances.numInstances(); i++) {
            Instance instance = instances.instance(i);
            double value = instance.value(classIndex) - means[classIndex];
            for (int i2 = 0; i2 < numAttributes; i2++) {
                double value2 = instance.value(i2) - means[i2];
                double weight = instance.weight() * value2;
                int i3 = i2;
                dArr[i3] = dArr[i3] + (weight * value);
                int i4 = i2;
                dArr2[i4] = dArr2[i4] + (weight * value2);
            }
        }
        double[] dArr3 = new double[numAttributes];
        Arrays.fill(dArr3, Double.MAX_VALUE);
        for (int i5 = 0; i5 < dArr3.length; i5++) {
            if (i5 != classIndex && dArr2[i5] != 0.0d) {
                double d = dArr2[classIndex] - ((dArr[i5] * dArr[i5]) / dArr2[i5]);
                if (!Double.isInfinite(d) && !Double.isNaN(d)) {
                    dArr3[i5] = d;
                }
            }
        }
        return dArr3;
    }

    public double[] means(Instances instances) {
        int numAttributes = instances.numAttributes();
        double[] dArr = new double[numAttributes];
        for (int i = 0; i < instances.numInstances(); i++) {
            Instance instance = instances.instance(i);
            for (int i2 = 0; i2 < numAttributes; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + instance.value(i2);
            }
        }
        double numInstances = 1.0d / instances.numInstances();
        for (int i4 = 0; i4 < numAttributes; i4++) {
            int i5 = i4;
            dArr[i5] = dArr[i5] * numInstances;
        }
        return dArr;
    }

    public String keepIndices(Instances instances) {
        double[] msq = msq(instances);
        int[] sort = Utils.sort(msq);
        if (getDebug()) {
            System.out.println("msq " + Arrays.toString(msq));
            System.out.println("order " + Arrays.toString(sort));
        }
        int numattrs = getNumattrs();
        if (numattrs == 0) {
            numattrs = (instances.numInstances() + 1) / 2;
        }
        int classIndex = instances.classIndex();
        if (getDebug()) {
            System.out.println("keep " + instances.numInstances() + " => " + numattrs + " ci " + classIndex);
        }
        StringBuilder sb = new StringBuilder("" + (1 + classIndex));
        int[] iArr = new int[numattrs];
        for (int i = 0; i < numattrs; i++) {
            int i2 = sort[i];
            if (!$assertionsDisabled && i2 == classIndex) {
                throw new AssertionError();
            }
            sb.append("," + (1 + i2));
            iArr[i] = i2;
        }
        if (getDebug()) {
            System.out.println("keep " + instances.numInstances() + " " + numattrs + " " + sb.toString());
            System.out.println("indices " + iArr.length + " " + Arrays.toString(iArr));
        }
        this.m_subset = iArr;
        return sb.toString();
    }

    public String keepIndicesBasedOnCorrelation(Instances instances) {
        double[] sampleCorrs = sampleCorrs(instances);
        int[] sort = Utils.sort(sampleCorrs);
        if (getDebug()) {
            System.out.println("corrs " + Arrays.toString(sampleCorrs));
            System.out.println("order " + Arrays.toString(sort));
        }
        int numattrs = getNumattrs();
        if (numattrs == 0) {
            numattrs = (instances.numInstances() + 1) / 2;
        }
        int classIndex = instances.classIndex();
        StringBuilder sb = new StringBuilder("" + (1 + classIndex));
        int[] iArr = new int[numattrs];
        int i = 0;
        for (int length = sort.length - numattrs; length < sort.length; length++) {
            int i2 = sort[length];
            if (i2 != classIndex) {
                sb.append("," + (1 + i2));
                int i3 = i;
                i++;
                iArr[i3] = i2;
            }
        }
        if (getDebug()) {
            System.out.println("keep " + instances.numInstances() + " " + numattrs + " " + sb.toString());
            System.out.println("indices " + iArr.length + " " + Arrays.toString(iArr));
        }
        return sb.toString();
    }

    public Instances getSubset(Instances instances) throws Exception {
        this.m_remove = null;
        if (instances.numAttributes() - 1 < getNumattrs()) {
            return instances;
        }
        if (getNumattrs() == 0 && instances.numInstances() >= 2 * (instances.numAttributes() - 1)) {
            return instances;
        }
        this.m_remove = new Remove();
        this.m_remove.setOptions(new String[]{"-R", keepIndices(instances), "-V"});
        this.m_remove.setInputFormat(instances);
        return Filter.useFilter(instances, this.m_remove);
    }

    public void buildClassifier(Instances instances) throws Exception {
        Instances subset = getSubset(instances);
        this.m_Classifier.buildClassifier(subset);
        if (this.m_Classifier instanceof LinearRegressionJ) {
            double[] coefficients = this.m_Classifier.coefficients();
            if (getDebug()) {
                System.out.println("coeffs " + coefficients.length + " " + Arrays.toString(coefficients));
            }
            if (getDebug()) {
                System.out.println(this.m_Classifier);
            }
            this.m_coeffs = new double[coefficients.length - 1];
            int i = 0;
            for (int i2 = 0; i2 < this.m_coeffs.length; i2++) {
                if (i2 != subset.classIndex()) {
                    int i3 = i;
                    i++;
                    this.m_coeffs[i3] = coefficients[i2];
                }
            }
            this.m_coeffs[i] = coefficients[this.m_coeffs.length];
        }
    }

    public void saveObject(Object obj) throws Exception {
        FileOutputStream fileOutputStream = new FileOutputStream("xxx.ser");
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream);
        objectOutputStream.writeObject(obj);
        FileUtils.closeQuietly(objectOutputStream);
        FileUtils.closeQuietly(fileOutputStream);
    }

    public double classifyInstance(Instance instance) throws Exception {
        if (this.m_coeffs == null) {
            if (this.m_remove != null) {
                this.m_remove.input(instance);
                this.m_remove.batchFinished();
                instance = this.m_remove.output();
            }
            return this.m_Classifier.classifyInstance(instance);
        }
        if (this.m_subset != null) {
            double d = this.m_coeffs[this.m_coeffs.length - 1];
            for (int i = 0; i < this.m_subset.length; i++) {
                d += this.m_coeffs[i] * instance.value(this.m_subset[i]);
            }
            return d;
        }
        int i2 = 0;
        double d2 = 0.0d;
        for (int i3 = 0; i3 < instance.numAttributes(); i3++) {
            if (i3 != instance.classIndex()) {
                int i4 = i2;
                i2++;
                d2 += this.m_coeffs[i4] * instance.value(i3);
            }
        }
        return d2 + this.m_coeffs[i2];
    }

    public double[] getCoeffs() {
        return (double[]) this.m_coeffs.clone();
    }

    public int[] getSubset() {
        if (this.m_subset == null) {
            return null;
        }
        return (int[]) this.m_subset.clone();
    }

    public String toString() {
        return this.m_Classifier.toString();
    }

    public String getRevision() {
        return "$Revision$";
    }

    public static void main(String[] strArr) throws Exception {
        runClassifier(new Corr(), strArr);
    }

    static {
        $assertionsDisabled = !Corr.class.desiredAssertionStatus();
    }
}
