/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.RegressionAnalysis;
import weka.classifiers.functions.SimpleLinearRegression;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class SimpleLinearRegressionWithAccess
extends AbstractClassifier
implements WeightedInstancesHandler {
    static final long serialVersionUID = 1679336022895414137L;
    protected Attribute m_attribute;
    protected int m_attributeIndex;
    protected double m_slope;
    protected double m_intercept;
    protected double m_classMeanForMissing;
    protected boolean m_outputAdditionalStats;
    protected int m_df;
    protected double m_seSlope = Double.NaN;
    protected double m_seIntercept = Double.NaN;
    protected double m_tstatSlope = Double.NaN;
    protected double m_tstatIntercept = Double.NaN;
    protected double m_rsquared = Double.NaN;
    protected double m_rsquaredAdj = Double.NaN;
    protected double m_fstat = Double.NaN;
    protected boolean m_suppressErrorMessage = false;

    public String globalInfo() {
        return "Learns a simple linear regression model. Picks the attribute that results in the lowest squared error. Can only deal with numeric attributes.\nMakes standard errors available.";
    }

    public Enumeration<Option> listOptions() {
        Vector<Object> newVector = new Vector<Object>();
        newVector.addElement(new Option("\tOutput additional statistics.", "additional-stats", 0, "-additional-stats"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    public void setOptions(String[] options) throws Exception {
        this.setOutputAdditionalStats(Utils.getFlag((String)"additional-stats", (String[])options));
        super.setOptions(options);
        Utils.checkForRemainingOptions((String[])options);
    }

    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        if (this.getOutputAdditionalStats()) {
            result.add("-additional-stats");
        }
        Collections.addAll(result, super.getOptions());
        return result.toArray(new String[result.size()]);
    }

    public String outputAdditionalStatsTipText() {
        return "Output additional statistics (such as std deviation of coefficients and t-statistics)";
    }

    public void setOutputAdditionalStats(boolean additional) {
        this.m_outputAdditionalStats = additional;
    }

    public boolean getOutputAdditionalStats() {
        return this.m_outputAdditionalStats;
    }

    public double classifyInstance(Instance inst) throws Exception {
        if (this.m_attribute == null) {
            return this.m_intercept;
        }
        if (inst.isMissing(this.m_attributeIndex)) {
            return this.m_classMeanForMissing;
        }
        return this.m_intercept + this.m_slope * inst.value(this.m_attributeIndex);
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.DATE_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    public void buildClassifier(Instances insts) throws Exception {
        this.getCapabilities().testWithFail(insts);
        if (this.m_outputAdditionalStats) {
            boolean ok = true;
            for (int i = 0; i < insts.numInstances(); ++i) {
                if (insts.instance(i).weight() == 1.0) continue;
                ok = false;
                break;
            }
            if (!ok) {
                throw new Exception("Can only compute additional statistics on unweighted data");
            }
        }
        double[] sum = new double[insts.numAttributes()];
        double[] count = new double[insts.numAttributes()];
        double[] classSumForMissing = new double[insts.numAttributes()];
        double[] classSumSquaredForMissing = new double[insts.numAttributes()];
        double classCount = 0.0;
        double classSum = 0.0;
        for (int j = 0; j < insts.numInstances(); ++j) {
            Instance inst = insts.instance(j);
            if (inst.classIsMissing()) continue;
            for (int i = 0; i < insts.numAttributes(); ++i) {
                if (!inst.isMissing(i)) {
                    int n = i;
                    sum[n] = sum[n] + inst.weight() * inst.value(i);
                    int n2 = i;
                    count[n2] = count[n2] + inst.weight();
                    continue;
                }
                int n = i;
                classSumForMissing[n] = classSumForMissing[n] + inst.classValue() * inst.weight();
                int n3 = i;
                classSumSquaredForMissing[n3] = classSumSquaredForMissing[n3] + inst.classValue() * inst.classValue() * inst.weight();
            }
            classCount += inst.weight();
            classSum += inst.weight() * inst.classValue();
        }
        double[] mean = new double[insts.numAttributes()];
        double[] classMeanForMissing = new double[insts.numAttributes()];
        double[] classMeanForKnown = new double[insts.numAttributes()];
        for (int i = 0; i < insts.numAttributes(); ++i) {
            if (i == insts.classIndex()) continue;
            if (count[i] > 0.0) {
                mean[i] = sum[i] / count[i];
            }
            if (classCount - count[i] > 0.0) {
                classMeanForMissing[i] = classSumForMissing[i] / (classCount - count[i]);
            }
            if (!(count[i] > 0.0)) continue;
            classMeanForKnown[i] = (classSum - classSumForMissing[i]) / count[i];
        }
        sum = null;
        count = null;
        double[] slopes = new double[insts.numAttributes()];
        double[] sumWeightedDiffsSquared = new double[insts.numAttributes()];
        double[] sumWeightedClassDiffsSquared = new double[insts.numAttributes()];
        for (int j = 0; j < insts.numInstances(); ++j) {
            Instance inst = insts.instance(j);
            if (inst.classIsMissing()) continue;
            for (int i = 0; i < insts.numAttributes(); ++i) {
                if (inst.isMissing(i) || i == insts.classIndex()) continue;
                double yDiff = inst.classValue() - classMeanForKnown[i];
                double weightedYDiff = inst.weight() * yDiff;
                double diff = inst.value(i) - mean[i];
                double weightedDiff = inst.weight() * diff;
                int n = i;
                slopes[n] = slopes[n] + weightedYDiff * diff;
                int n4 = i;
                sumWeightedDiffsSquared[n4] = sumWeightedDiffsSquared[n4] + weightedDiff * diff;
                int n5 = i;
                sumWeightedClassDiffsSquared[n5] = sumWeightedClassDiffsSquared[n5] + weightedYDiff * yDiff;
            }
        }
        double minSSE = Double.MAX_VALUE;
        this.m_attribute = null;
        int chosen = -1;
        double chosenSlope = Double.NaN;
        double chosenIntercept = Double.NaN;
        double chosenMeanForMissing = Double.NaN;
        for (int i = 0; i < insts.numAttributes(); ++i) {
            double sseForMissing = classSumSquaredForMissing[i] - classSumForMissing[i] * classMeanForMissing[i];
            if (i == insts.classIndex() || sumWeightedDiffsSquared[i] == 0.0) continue;
            double numerator = slopes[i];
            int n = i;
            slopes[n] = slopes[n] / sumWeightedDiffsSquared[i];
            double intercept = classMeanForKnown[i] - slopes[i] * mean[i];
            double sse = sumWeightedClassDiffsSquared[i] - slopes[i] * numerator;
            if (!((sse += sseForMissing) < minSSE)) continue;
            minSSE = sse;
            chosen = i;
            chosenSlope = slopes[i];
            chosenIntercept = intercept;
            chosenMeanForMissing = classMeanForMissing[i];
        }
        if (chosen == -1) {
            if (!this.m_suppressErrorMessage) {
                System.err.println("----- no useful attribute found");
            }
            this.m_attribute = null;
            this.m_attributeIndex = 0;
            this.m_slope = 0.0;
            this.m_intercept = classSum / classCount;
            this.m_classMeanForMissing = 0.0;
        } else {
            this.m_attribute = insts.attribute(chosen);
            this.m_attributeIndex = chosen;
            this.m_slope = chosenSlope;
            this.m_intercept = chosenIntercept;
            this.m_classMeanForMissing = chosenMeanForMissing;
            if (this.m_outputAdditionalStats) {
                Instances newInsts = new Instances(insts, insts.numInstances());
                for (int i = 0; i < insts.numInstances(); ++i) {
                    Instance inst = insts.instance(i);
                    if (inst.classIsMissing() || inst.isMissing(this.m_attributeIndex)) continue;
                    newInsts.add(inst);
                }
                insts = newInsts;
                this.m_df = insts.numInstances() - 2;
                double[] stdErrors = RegressionAnalysis.calculateStdErrorOfCoef((Instances)insts, (Attribute)this.m_attribute, (double)this.m_slope, (double)this.m_intercept, (int)this.m_df);
                this.m_seSlope = stdErrors[0];
                this.m_seIntercept = stdErrors[1];
                double[] coef = new double[]{this.m_slope, this.m_intercept};
                double[] tStats = RegressionAnalysis.calculateTStats((double[])coef, (double[])stdErrors, (int)2);
                this.m_tstatSlope = tStats[0];
                this.m_tstatIntercept = tStats[1];
                double ssr = RegressionAnalysis.calculateSSR((Instances)insts, (Attribute)this.m_attribute, (double)this.m_slope, (double)this.m_intercept);
                this.m_rsquared = RegressionAnalysis.calculateRSquared((Instances)insts, (double)ssr);
                this.m_rsquaredAdj = RegressionAnalysis.calculateAdjRSquared((double)this.m_rsquared, (int)insts.numInstances(), (int)2);
                this.m_fstat = RegressionAnalysis.calculateFStat((double)this.m_rsquared, (int)insts.numInstances(), (int)2);
            }
        }
    }

    public boolean foundUsefulAttribute() {
        return this.m_attribute != null;
    }

    public int getAttributeIndex() {
        return this.m_attributeIndex;
    }

    public double getSlope() {
        return this.m_slope;
    }

    public double getSlopeSE() {
        return this.m_seSlope;
    }

    public double getIntercept() {
        return this.m_intercept;
    }

    public double getInterceptSE() {
        return this.m_seIntercept;
    }

    public void setSuppressErrorMessage(boolean s) {
        this.m_suppressErrorMessage = s;
    }

    public String toString() {
        StringBuffer text = new StringBuffer();
        if (this.m_attribute == null) {
            text.append("Predicting constant " + this.m_intercept);
        } else {
            text.append("Linear regression on " + this.m_attribute.name() + "\n\n");
            text.append(Utils.doubleToString((double)this.m_slope, (int)2) + " * " + this.m_attribute.name());
            if (this.m_intercept > 0.0) {
                text.append(" + " + Utils.doubleToString((double)this.m_intercept, (int)2));
            } else {
                text.append(" - " + Utils.doubleToString((double)(-this.m_intercept), (int)2));
            }
            text.append("\n\nPredicting " + Utils.doubleToString((double)this.m_classMeanForMissing, (int)2) + " if attribute value is missing.");
            if (this.m_outputAdditionalStats) {
                int attNameLength = this.m_attribute.name().length() + 3;
                if (attNameLength < "Variable".length() + 3) {
                    attNameLength = "Variable".length() + 3;
                }
                text.append("\n\nRegression Analysis:\n\n" + Utils.padRight((String)"Variable", (int)attNameLength) + "  Coefficient     SE of Coef        t-Stat");
                text.append("\n" + Utils.padRight((String)this.m_attribute.name(), (int)attNameLength));
                text.append(Utils.doubleToString((double)this.m_slope, (int)12, (int)4));
                text.append("   " + Utils.doubleToString((double)this.m_seSlope, (int)12, (int)5));
                text.append("   " + Utils.doubleToString((double)this.m_tstatSlope, (int)12, (int)5));
                text.append(Utils.padRight((String)"\nconst", (int)(attNameLength + 1)) + Utils.doubleToString((double)this.m_intercept, (int)12, (int)4));
                text.append("   " + Utils.doubleToString((double)this.m_seIntercept, (int)12, (int)5));
                text.append("   " + Utils.doubleToString((double)this.m_tstatIntercept, (int)12, (int)5));
                text.append("\n\nDegrees of freedom = " + Integer.toString(this.m_df));
                text.append("\nR^2 value = " + Utils.doubleToString((double)this.m_rsquared, (int)5));
                text.append("\nAdjusted R^2 = " + Utils.doubleToString((double)this.m_rsquaredAdj, (int)5));
                text.append("\nF-statistic = " + Utils.doubleToString((double)this.m_fstat, (int)5));
            }
        }
        text.append("\n");
        return text.toString();
    }

    public String getRevision() {
        return RevisionUtils.extract((String)"$Revision: 11130 $");
    }

    public static void main(String[] argv) {
        SimpleLinearRegressionWithAccess.runClassifier((Classifier)new SimpleLinearRegression(), (String[])argv);
    }
}

