/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.stats;

import com.aliasi.util.Math;
import java.util.Arrays;
import java.util.Random;

public class Statistics {
    private Statistics() {
    }

    public static double klDivergenceDirichlet(double[] xs, double[] ys) {
        Statistics.verifyDivergenceDirichletArgs(xs, ys);
        double sumXs = Statistics.sum(xs);
        double sumYs = Statistics.sum(ys);
        double divergence = Statistics.logGamma(sumXs) - Statistics.logGamma(sumYs);
        double digammaSumXs = Math.digamma(sumXs);
        int i = 0;
        while (i < xs.length) {
            divergence += Statistics.logGamma(ys[i]) - Statistics.logGamma(xs[i]) + (xs[i] - ys[i]) * (Math.digamma(xs[i]) - digammaSumXs);
            ++i;
        }
        return divergence;
    }

    static void verifyDivergenceDirichletArgs(double[] xs, double[] ys) {
        if (xs.length != ys.length) {
            String msg = "Parameter arrays must be the same length. Found xs.length=" + xs.length + " ys.length=" + ys.length;
            throw new IllegalArgumentException(msg);
        }
        int i = 0;
        while (i < xs.length) {
            if (xs[i] <= 0.0 || Double.isInfinite(xs[i]) || Double.isNaN(xs[i])) {
                String msg = "All parameters must be positive and finite. Found xs[" + i + "]=" + xs[i];
                throw new IllegalArgumentException(msg);
            }
            ++i;
        }
        i = 0;
        while (i < ys.length) {
            if (ys[i] <= 0.0 || Double.isInfinite(ys[i]) || Double.isNaN(ys[i])) {
                String msg = "All parameters must be positive and finite. Found ys[" + i + "]=" + ys[i];
                throw new IllegalArgumentException(msg);
            }
            ++i;
        }
    }

    public static double symmetrizedKlDivergenceDirichlet(double[] xs, double[] ys) {
        return (Statistics.klDivergenceDirichlet(xs, ys) + Statistics.klDivergenceDirichlet(ys, xs)) / 2.0;
    }

    static double logGamma(double x) {
        return Math.log2Gamma(x) / Math.log2(java.lang.Math.E);
    }

    static double sum(double[] xs) {
        double sum = 0.0;
        int i = 0;
        while (i < xs.length) {
            sum += xs[i];
            ++i;
        }
        return sum;
    }

    public static double klDivergence(double[] p, double[] q) {
        Statistics.verifyDivergenceArgs(p, q);
        double divergence = 0.0;
        int len = p.length;
        int i = 0;
        while (i < len) {
            if (p[i] > 0.0 && p[i] != q[i]) {
                divergence += p[i] * Math.log2(p[i] / q[i]);
            }
            ++i;
        }
        return divergence;
    }

    static void verifyDivergenceArgs(double[] p, double[] q) {
        if (p.length != q.length) {
            String msg = "Input distributions must have same length. Found p.length=" + p.length + " q.length=" + q.length;
            throw new IllegalArgumentException(msg);
        }
        int len = p.length;
        int i = 0;
        while (i < len) {
            if (p[i] < 0.0 || p[i] > 1.0 || Double.isNaN(p[i]) || Double.isInfinite(p[i])) {
                String msg = "p[i] must be between 0.0 and 1.0 inclusive. found p[" + i + "]=" + p[i];
                throw new IllegalArgumentException(msg);
            }
            if (q[i] < 0.0 || q[i] > 1.0 || Double.isNaN(q[i]) || Double.isInfinite(q[i])) {
                String msg = "q[i] must be between 0.0 and 1.0 inclusive. found q[" + i + "] =" + q[i];
                throw new IllegalArgumentException(msg);
            }
            ++i;
        }
    }

    public static double symmetrizedKlDivergence(double[] p, double[] q) {
        Statistics.verifyDivergenceArgs(p, q);
        return (Statistics.klDivergence(p, q) + Statistics.klDivergence(q, p)) / 2.0;
    }

    public static double jsDivergence(double[] p, double[] q) {
        Statistics.verifyDivergenceArgs(p, q);
        double[] m = new double[p.length];
        int i = 0;
        while (i < p.length) {
            m[i] = (p[i] + q[i]) / 2.0;
            ++i;
        }
        return (Statistics.klDivergence(p, m) + Statistics.klDivergence(q, m)) / 2.0;
    }

    public static int[] permutation(int length) {
        return Statistics.permutation(length, new Random());
    }

    public static int[] permutation(int length, Random random) {
        int[] xs = new int[length];
        int i = 0;
        while (i < xs.length) {
            xs[i] = i;
            ++i;
        }
        i = xs.length;
        while (--i > 0) {
            int pos = random.nextInt(i);
            int temp = xs[pos];
            xs[pos] = xs[i];
            xs[i] = temp;
        }
        return xs;
    }

    public static double chiSquaredIndependence(double both, double oneOnly, double twoOnly, double neither) {
        Statistics.assertNonNegative("both", both);
        Statistics.assertNonNegative("oneOnly", oneOnly);
        Statistics.assertNonNegative("twoOnly", twoOnly);
        Statistics.assertNonNegative("neither", neither);
        double n = both + oneOnly + twoOnly + neither;
        double p1 = (both + oneOnly) / n;
        double p2 = (both + twoOnly) / n;
        double eBoth = n * p1 * p2;
        double eOneOnly = n * p1 * (1.0 - p2);
        double eTwoOnly = n * (1.0 - p1) * p2;
        double eNeither = n * (1.0 - p1) * (1.0 - p2);
        return Statistics.csTerm(both, eBoth) + Statistics.csTerm(oneOnly, eOneOnly) + Statistics.csTerm(twoOnly, eTwoOnly) + Statistics.csTerm(neither, eNeither);
    }

    public static double[] linearRegression(double[] xs, double[] ys) {
        if (xs.length != ys.length) {
            String msg = "Require parallel arrays of x and y values. Found xs.length=" + xs.length + " ys.length=" + ys.length;
            throw new IllegalArgumentException(msg);
        }
        if (xs.length < 2) {
            String msg = "Require arrays of length >= 2. Found xs.length=" + xs.length;
            throw new IllegalArgumentException(msg);
        }
        double n = xs.length;
        double xSum = 0.0;
        double ySum = 0.0;
        double xySum = 0.0;
        double xxSum = 0.0;
        int i = 0;
        while (i < xs.length) {
            double x = xs[i];
            double y = ys[i];
            xSum += x;
            ySum += y;
            xxSum += x * x;
            xySum += x * y;
            ++i;
        }
        double denominator = n * xxSum - xSum * xSum;
        if (denominator == 0.0) {
            String msg = "Ill formed input. Denominator for beta1 is zero. Most likely cause is fewer than 2 distinct inputs.";
            throw new IllegalArgumentException(msg);
        }
        double beta1 = (n * xySum - xSum * ySum) / denominator;
        double beta0 = (ySum - beta1 * xSum) / n;
        return new double[]{beta0, beta1};
    }

    public static double[] logisticRegression(double[] xs, double[] ys, double maxValue) {
        if (maxValue <= 0.0 || Double.isInfinite(maxValue) || Double.isNaN(maxValue)) {
            String msg = "Require finite max value > 0. Found maxValue=" + maxValue;
            throw new IllegalArgumentException(msg);
        }
        double[] logisticYs = new double[ys.length];
        int i = 0;
        while (i < ys.length) {
            logisticYs[i] = java.lang.Math.log((maxValue - ys[i]) / ys[i]);
            ++i;
        }
        return Statistics.linearRegression(xs, logisticYs);
    }

    public static double chiSquaredIndependence(double[][] contingencyMatrix) {
        int numRows = contingencyMatrix.length;
        if (numRows < 2) {
            String msg = "Require at least two rows. Found numRows=" + numRows;
            throw new IllegalArgumentException(msg);
        }
        int numCols = contingencyMatrix[0].length;
        if (numCols < 2) {
            String msg = "Require at least two cols. Found numCols=" + numCols;
            throw new IllegalArgumentException(msg);
        }
        double[] rowSums = new double[numRows];
        Arrays.fill(rowSums, 0.0);
        double[] colSums = new double[numCols];
        Arrays.fill(colSums, 0.0);
        double totalCount = 0.0;
        int i = 0;
        while (i < numRows) {
            if (contingencyMatrix[i].length != numCols) {
                String msg = "Matrix must be rectangular.Row 0 length=" + numCols + "Row " + i + " length=" + contingencyMatrix[i].length;
                throw new IllegalArgumentException(msg);
            }
            int j = 0;
            while (j < numCols) {
                double val = contingencyMatrix[i][j];
                if (Double.isInfinite(val) || val < 0.0 || Double.isNaN(val)) {
                    String msg = "Values must be finite non-negative. Found matrix[" + i + "][" + j + "]=" + val;
                    throw new IllegalArgumentException(msg);
                }
                int n = i;
                rowSums[n] = rowSums[n] + val;
                int n2 = j++;
                colSums[n2] = colSums[n2] + val;
                totalCount += val;
            }
            ++i;
        }
        double result = 0.0;
        int i2 = 0;
        while (i2 < numRows) {
            int j = 0;
            while (j < numCols) {
                result += Statistics.csTerm(contingencyMatrix[i2][j], rowSums[i2] * colSums[j] / totalCount);
                ++j;
            }
            ++i2;
        }
        return result;
    }

    public static double[] normalize(double[] probabilityRatios) {
        int i = 0;
        while (i < probabilityRatios.length) {
            if (probabilityRatios[i] < 0.0 || Double.isInfinite(probabilityRatios[i]) || Double.isNaN(probabilityRatios[i])) {
                String msg = "Probabilities must be finite non-negative. Found probabilityRatios[" + i + "]=" + probabilityRatios[i];
                throw new IllegalArgumentException(msg);
            }
            ++i;
        }
        double sum = Math.sum(probabilityRatios);
        if (sum <= 0.0) {
            String msg = "Ratios must sum to number greater than zero. Found sum=" + sum;
            throw new IllegalArgumentException(msg);
        }
        double[] result = new double[probabilityRatios.length];
        int i2 = 0;
        while (i2 < probabilityRatios.length) {
            result[i2] = probabilityRatios[i2] / sum;
            ++i2;
        }
        return result;
    }

    public static double kappa(double observedProb, double expectedProb) {
        return (observedProb - expectedProb) / (1.0 - expectedProb);
    }

    public static double mean(double[] xs) {
        return Math.sum(xs) / (double)xs.length;
    }

    public static double variance(double[] xs) {
        return Statistics.variance(xs, Statistics.mean(xs));
    }

    public static double standardDeviation(double[] xs) {
        return java.lang.Math.sqrt(Statistics.variance(xs));
    }

    public static double correlation(double[] xs, double[] ys) {
        if (xs.length != ys.length) {
            String msg = "xs and ys must be the same length. Found xs.length=" + xs.length + " ys.length=" + ys.length;
            throw new IllegalArgumentException(msg);
        }
        double meanXs = Statistics.mean(xs);
        double meanYs = Statistics.mean(ys);
        double ssXX = Statistics.sumSquareDiffs(xs, meanXs);
        double ssYY = Statistics.sumSquareDiffs(ys, meanYs);
        double ssXY = Statistics.sumSquareDiffs(xs, ys, meanXs, meanYs);
        return java.lang.Math.sqrt(ssXY * ssXY / (ssXX * ssYY));
    }

    public static int sample(double[] cumulativeProbRatios, Random random) {
        int low = 0;
        int high = cumulativeProbRatios.length - 1;
        double x = random.nextDouble() * cumulativeProbRatios[high];
        while (low < high) {
            int mid = (high + low) / 2;
            if (x > cumulativeProbRatios[mid]) {
                low = mid + 1;
                continue;
            }
            if (high == mid) {
                return x > cumulativeProbRatios[low] ? mid : low;
            }
            high = mid;
        }
        return low;
    }

    public static double dirichletLog2Prob(double alpha, double[] xs) {
        Statistics.verifyAlpha(alpha);
        Statistics.verifyDistro(xs);
        int k = xs.length;
        double result = Math.log2Gamma((double)k * alpha) - (double)k * Math.log2Gamma(alpha);
        double alphaMinus1 = alpha - 1.0;
        int i = 0;
        while (i < k) {
            result += alphaMinus1 * Math.log2(xs[i]);
            ++i;
        }
        return result;
    }

    public static double dirichletLog2Prob(double[] alphas, double[] xs) {
        if (alphas.length != xs.length) {
            String msg = "Dirichlet prior alphas and distribution xs must be the same length. Found alphas.length=" + alphas.length + " xs.length=" + xs.length;
            throw new IllegalArgumentException(msg);
        }
        int i = 0;
        while (i < alphas.length) {
            Statistics.verifyAlpha(alphas[i]);
            ++i;
        }
        Statistics.verifyDistro(xs);
        int k = xs.length;
        double result = 0.0;
        double alphaSum = 0.0;
        int i2 = 0;
        while (i2 < alphas.length) {
            alphaSum += alphas[i2];
            result -= Math.log2Gamma(alphas[i2]);
            ++i2;
        }
        result += Math.log2Gamma(alphaSum);
        i2 = 0;
        while (i2 < k) {
            result += (alphas[i2] - 1.0) * Math.log2(xs[i2]);
            ++i2;
        }
        return result;
    }

    static void verifyAlpha(double alpha) {
        if (Double.isNaN(alpha) || Double.isInfinite(alpha) || alpha <= 0.0) {
            String msg = "Concentration parameter must be positive and finite. Found alpha=" + alpha;
            throw new IllegalArgumentException(msg);
        }
    }

    static void verifyDistro(double[] xs) {
        int i = 0;
        while (i < xs.length) {
            if (xs[i] < 0.0 || xs[i] > 1.0 || Double.isNaN(xs[i]) || Double.isInfinite(xs[i])) {
                String msg = "All xs must be betwee 0.0 and 1.0 inclusive. Found xs[" + i + "]=" + xs[i];
                throw new IllegalArgumentException(msg);
            }
            ++i;
        }
    }

    static double sumSquareDiffs(double[] xs, double mean) {
        double sum = 0.0;
        int i = 0;
        while (i < xs.length) {
            double diff = xs[i] - mean;
            sum += diff * diff;
            ++i;
        }
        return sum;
    }

    static double sumSquareDiffs(double[] xs, double[] ys, double meanXs, double meanYs) {
        double sum = 0.0;
        int i = 0;
        while (i < xs.length) {
            sum += (xs[i] - meanXs) * (ys[i] - meanYs);
            ++i;
        }
        return sum;
    }

    static double variance(double[] xs, double mean) {
        return Statistics.sumSquareDiffs(xs, mean) / (double)xs.length;
    }

    static void assertNonNegative(String variableName, double value) {
        if (Double.isInfinite(value) || Double.isNaN(value) || value < 0.0) {
            String msg = "Require finite non-negative value. Found " + variableName + " =" + value;
            throw new IllegalArgumentException(msg);
        }
    }

    private static double csTerm(double found, double expected) {
        double diff = found - expected;
        return diff * diff / expected;
    }
}

