/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.stats;

import com.aliasi.stats.OnlineNormalEstimator;

public class PotentialScaleReduction {
    private final OnlineNormalEstimator mGlobalEstimator;
    private final OnlineNormalEstimator[] mChainEstimators;

    public PotentialScaleReduction(int numChains) {
        if (numChains < 2) {
            String msg = "Need at least two chains. Found numChains=" + numChains;
            throw new IllegalStateException(msg);
        }
        this.mChainEstimators = new OnlineNormalEstimator[numChains];
        int m = 0;
        while (m < numChains) {
            this.mChainEstimators[m] = new OnlineNormalEstimator();
            ++m;
        }
        this.mGlobalEstimator = new OnlineNormalEstimator();
    }

    public PotentialScaleReduction(double[][] yss) {
        this(yss.length);
        int m = 0;
        while (m < yss.length) {
            int n = 0;
            while (n < yss[m].length) {
                this.update(m, yss[m][n]);
                ++n;
            }
            ++m;
        }
    }

    public int numChains() {
        return this.mChainEstimators.length;
    }

    public OnlineNormalEstimator estimator(int chain) {
        return this.mChainEstimators[chain];
    }

    public OnlineNormalEstimator globalEstimator() {
        return this.mGlobalEstimator;
    }

    public void update(int chain, double y) {
        this.mChainEstimators[chain].handle(y);
        this.mGlobalEstimator.handle(y);
    }

    public double rHat() {
        long minSamples = Long.MAX_VALUE;
        OnlineNormalEstimator[] onlineNormalEstimatorArray = this.mChainEstimators;
        int n = this.mChainEstimators.length;
        int n2 = 0;
        while (n2 < n) {
            OnlineNormalEstimator estimator = onlineNormalEstimatorArray[n2];
            if (minSamples > estimator.numSamples()) {
                minSamples = estimator.numSamples();
            }
            ++n2;
        }
        double w = 0.0;
        OnlineNormalEstimator[] onlineNormalEstimatorArray2 = this.mChainEstimators;
        int n3 = this.mChainEstimators.length;
        int n4 = 0;
        while (n4 < n3) {
            OnlineNormalEstimator estimator = onlineNormalEstimatorArray2[n4];
            w += estimator.varianceUnbiased();
            ++n4;
        }
        w /= (double)this.numChains();
        double crossChainMean = 0.0;
        OnlineNormalEstimator[] onlineNormalEstimatorArray3 = this.mChainEstimators;
        int n5 = this.mChainEstimators.length;
        int n6 = 0;
        while (n6 < n5) {
            OnlineNormalEstimator estimator = onlineNormalEstimatorArray3[n6];
            crossChainMean += estimator.mean();
            ++n6;
        }
        crossChainMean /= (double)this.numChains();
        double b = 0.0;
        OnlineNormalEstimator[] onlineNormalEstimatorArray4 = this.mChainEstimators;
        int n7 = this.mChainEstimators.length;
        int n8 = 0;
        while (n8 < n7) {
            OnlineNormalEstimator estimator = onlineNormalEstimatorArray4[n8];
            b += PotentialScaleReduction.square(estimator.mean() - crossChainMean);
            ++n8;
        }
        double varPlus = (double)(minSamples - 1L) * w / (double)minSamples + (b /= (double)this.numChains() - 1.0);
        return Math.sqrt(varPlus / w);
    }

    static double square(double x) {
        return x * x;
    }
}

