/*
 * Decompiled with CFR 0.152.
 */
package weka.estimators;

import java.util.Random;
import java.util.Vector;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.matrix.Matrix;
import weka.estimators.ConditionalEstimator;
import weka.estimators.Estimator;
import weka.estimators.MahalanobisEstimator;

public class NNConditionalEstimator
implements ConditionalEstimator {
    private Vector m_Values = new Vector();
    private Vector m_CondValues = new Vector();
    private Vector m_Weights = new Vector();
    private double m_SumOfWeights;
    private double m_CondMean;
    private double m_ValueMean;
    private Matrix m_Covariance;
    private boolean m_AllWeightsOne = true;
    private static double TWO_PI = Math.PI * 2;

    private int findNearestPair(double key, double secondaryKey) {
        int low = 0;
        int high = this.m_CondValues.size();
        int middle = 0;
        while (low < high) {
            middle = (low + high) / 2;
            double current = (Double)this.m_CondValues.elementAt(middle);
            if (current == key) {
                double secondary = (Double)this.m_Values.elementAt(middle);
                if (secondary == secondaryKey) {
                    return middle;
                }
                if (secondary > secondaryKey) {
                    high = middle;
                } else if (secondary < secondaryKey) {
                    low = middle + 1;
                }
            }
            if (current > key) {
                high = middle;
                continue;
            }
            if (!(current < key)) continue;
            low = middle + 1;
        }
        return low;
    }

    private void calculateCovariance() {
        double sumValues = 0.0;
        double sumConds = 0.0;
        for (int i = 0; i < this.m_Values.size(); ++i) {
            sumValues += (Double)this.m_Values.elementAt(i) * (Double)this.m_Weights.elementAt(i);
            sumConds += (Double)this.m_CondValues.elementAt(i) * (Double)this.m_Weights.elementAt(i);
        }
        this.m_ValueMean = sumValues / this.m_SumOfWeights;
        this.m_CondMean = sumConds / this.m_SumOfWeights;
        double c00 = 0.0;
        double c01 = 0.0;
        double c10 = 0.0;
        double c11 = 0.0;
        for (int i = 0; i < this.m_Values.size(); ++i) {
            double x = (Double)this.m_Values.elementAt(i);
            double y = (Double)this.m_CondValues.elementAt(i);
            double weight = (Double)this.m_Weights.elementAt(i);
            c00 += (x - this.m_ValueMean) * (x - this.m_ValueMean) * weight;
            c01 += (x - this.m_ValueMean) * (y - this.m_CondMean) * weight;
            c11 += (y - this.m_CondMean) * (y - this.m_CondMean) * weight;
        }
        c10 = c01 /= this.m_SumOfWeights - 1.0;
        this.m_Covariance = new Matrix(2, 2);
        this.m_Covariance.set(0, 0, c00 /= this.m_SumOfWeights - 1.0);
        this.m_Covariance.set(0, 1, c01);
        this.m_Covariance.set(1, 0, c10);
        this.m_Covariance.set(1, 1, c11 /= this.m_SumOfWeights - 1.0);
    }

    private double normalKernel(double x, double variance) {
        return Math.exp(-x * x / (2.0 * variance)) / Math.sqrt(variance * TWO_PI);
    }

    @Override
    public void addValue(double data, double given, double weight) {
        int insertIndex = this.findNearestPair(given, data);
        if (this.m_Values.size() <= insertIndex || (Double)this.m_CondValues.elementAt(insertIndex) != given || (Double)this.m_Values.elementAt(insertIndex) != data) {
            this.m_CondValues.insertElementAt(new Double(given), insertIndex);
            this.m_Values.insertElementAt(new Double(data), insertIndex);
            this.m_Weights.insertElementAt(new Double(weight), insertIndex);
            if (weight != 1.0) {
                this.m_AllWeightsOne = false;
            }
        } else {
            double newWeight = (Double)this.m_Weights.elementAt(insertIndex);
            this.m_Weights.setElementAt(new Double(newWeight += weight), insertIndex);
            this.m_AllWeightsOne = false;
        }
        this.m_SumOfWeights += weight;
        this.m_Covariance = null;
    }

    @Override
    public Estimator getEstimator(double given) {
        if (this.m_Covariance == null) {
            this.calculateCovariance();
        }
        MahalanobisEstimator result = new MahalanobisEstimator(this.m_Covariance, given - this.m_CondMean, this.m_ValueMean);
        return result;
    }

    @Override
    public double getProbability(double data, double given) {
        return this.getEstimator(given).getProbability(data);
    }

    public String toString() {
        if (this.m_Covariance == null) {
            this.calculateCovariance();
        }
        String result = "NN Conditional Estimator. " + this.m_CondValues.size() + " data points.  Mean = " + Utils.doubleToString(this.m_ValueMean, 4, 2) + "  Conditional mean = " + Utils.doubleToString(this.m_CondMean, 4, 2);
        result = result + "  Covariance Matrix: \n" + this.m_Covariance;
        return result;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8034 $");
    }

    public static void main(String[] argv) {
        try {
            int seed = 42;
            if (argv.length > 0) {
                seed = Integer.parseInt(argv[0]);
            }
            NNConditionalEstimator newEst = new NNConditionalEstimator();
            Random r = new Random(seed);
            int numPoints = 50;
            if (argv.length > 2) {
                numPoints = Integer.parseInt(argv[2]);
            }
            for (int i = 0; i < numPoints; ++i) {
                int x = Math.abs(r.nextInt() % 100);
                int y = Math.abs(r.nextInt() % 100);
                System.out.println("# " + x + "  " + y);
                newEst.addValue(x, y, 1.0);
            }
            int cond = argv.length > 1 ? Integer.parseInt(argv[1]) : Math.abs(r.nextInt() % 100);
            System.out.println("## Conditional = " + cond);
            Estimator result = newEst.getEstimator(cond);
            for (int i = 0; i <= 100; i += 5) {
                System.out.println(" " + i + "  " + result.getProbability(i));
            }
        }
        catch (Exception e) {
            System.out.println(e.getMessage());
        }
    }
}

