/*
 * Decompiled with CFR 0.152.
 */
package org.ddogleg.nn.alg;

import java.util.List;
import org.ddogleg.nn.alg.AxisSplitRule;
import org.ddogleg.nn.alg.AxisSplitRuleMax;
import org.ddogleg.nn.alg.AxisSplitter;
import org.ddogleg.sorting.QuickSelectArray;

public class AxisSplitterMedian<D>
implements AxisSplitter<D> {
    private int N;
    private double[] mean;
    private double[] var;
    private double[] tmp = new double[1];
    private int[] indexes = new int[1];
    AxisSplitRule splitRule;
    int splitAxis;
    double[] splitPoint;
    D splitData;

    public AxisSplitterMedian(AxisSplitRule splitRule) {
        this.splitRule = splitRule;
    }

    public AxisSplitterMedian() {
        this.splitRule = new AxisSplitRuleMax();
    }

    @Override
    public void setDimension(int N) {
        this.N = N;
        this.mean = new double[N];
        this.var = new double[N];
        if (this.splitRule == null) {
            throw new RuntimeException("You must call setRule() before setDimension()");
        }
        this.splitRule.setDimension(N);
    }

    @Override
    public void splitData(List<double[]> points, List<D> data, List<double[]> left, List<D> leftData, List<double[]> right, List<D> rightData) {
        this.computeAxisVariance(points);
        this.splitAxis = this.splitRule.select(this.var);
        int medianNum = points.size() / 2;
        this.quickSelect(points, this.splitAxis, medianNum);
        this.splitPoint = points.get(this.indexes[medianNum]);
        if (data == null) {
            int i;
            for (i = 0; i < medianNum; ++i) {
                left.add(points.get(this.indexes[i]));
            }
            for (i = medianNum + 1; i < points.size(); ++i) {
                right.add(points.get(this.indexes[i]));
            }
            this.splitData = null;
        } else {
            int index;
            int i;
            for (i = 0; i < medianNum; ++i) {
                index = this.indexes[i];
                left.add(points.get(index));
                leftData.add(data.get(index));
            }
            for (i = medianNum + 1; i < points.size(); ++i) {
                index = this.indexes[i];
                right.add(points.get(index));
                rightData.add(data.get(index));
            }
            this.splitData = data.get(this.indexes[medianNum]);
        }
    }

    @Override
    public double[] getSplitPoint() {
        return this.splitPoint;
    }

    @Override
    public D getSplitData() {
        return this.splitData;
    }

    @Override
    public int getSplitAxis() {
        return this.splitAxis;
    }

    private void computeAxisVariance(List<double[]> points) {
        int j;
        double[] p;
        int i;
        int numPoints = points.size();
        for (i = 0; i < this.N; ++i) {
            this.mean[i] = 0.0;
            this.var[i] = 0.0;
        }
        for (i = 0; i < numPoints; ++i) {
            p = points.get(i);
            for (j = 0; j < this.N; ++j) {
                int n = j;
                this.mean[n] = this.mean[n] + p[j];
            }
        }
        i = 0;
        while (i < this.N) {
            int n = i++;
            this.mean[n] = this.mean[n] / (double)numPoints;
        }
        for (i = 0; i < numPoints; ++i) {
            p = points.get(i);
            j = 0;
            while (j < this.N) {
                double d = this.mean[j] - p[j];
                int n = j++;
                this.var[n] = this.var[n] + d * d;
            }
        }
    }

    private void quickSelect(List<double[]> points, int splitAxis, int medianNum) {
        int numPoints = points.size();
        if (this.tmp.length < numPoints) {
            this.tmp = new double[numPoints];
            this.indexes = new int[numPoints];
        }
        for (int i = 0; i < numPoints; ++i) {
            this.tmp[i] = points.get(i)[splitAxis];
        }
        QuickSelectArray.selectIndex(this.tmp, medianNum, numPoints, this.indexes);
    }
}

