/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.knn.approximate;

import cern.jet.random.Uniform;
import cern.jet.random.engine.MersenneTwister;
import cern.jet.random.engine.RandomEngine;
import jal.objects.BinaryPredicate;
import jal.objects.Sorting;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import org.openimaj.knn.DoubleNearestNeighbours;
import org.openimaj.util.array.IntArrayView;
import org.openimaj.util.pair.DoubleIntPair;
import org.openimaj.util.pair.DoubleObjectPair;
import org.openimaj.util.pair.IntDoublePair;

public class DoubleKDTreeEnsemble {
    private static final int leaf_max_points = 14;
    private static final int varest_max_points = 128;
    private static final int varest_max_randsz = 5;
    Uniform rng;
    public final DoubleKDTreeNode[] trees;
    public final double[][] pnts;

    public DoubleKDTreeEnsemble(double[][] pnts) {
        this(pnts, 8, 42);
    }

    public DoubleKDTreeEnsemble(double[][] pnts, int ntrees) {
        this(pnts, ntrees, 42);
    }

    public DoubleKDTreeEnsemble(double[][] pnts, int ntrees, int seed) {
        int N = pnts.length;
        this.pnts = pnts;
        this.rng = new Uniform((RandomEngine)new MersenneTwister(seed));
        IntArrayView inds = new IntArrayView(N);
        for (int n = 0; n < N; ++n) {
            inds.setFast(n, n);
        }
        this.trees = new DoubleKDTreeNode[ntrees];
        for (int t = 0; t < ntrees; ++t) {
            this.trees[t] = new DoubleKDTreeNode(pnts, inds, this.rng);
        }
    }

    void search(double[] qu, int numnn, IntDoublePair[] ret_nns, int nchecks) {
        int N = this.pnts.length;
        if (nchecks < numnn) {
            nchecks = numnn;
        }
        if (nchecks > N) {
            nchecks = N;
        }
        PriorityQueue<DoubleObjectPair<DoubleKDTreeNode>> pri_branch = new PriorityQueue<DoubleObjectPair<DoubleKDTreeNode>>(11, new Comparator<DoubleObjectPair<DoubleKDTreeNode>>(){

            @Override
            public int compare(DoubleObjectPair<DoubleKDTreeNode> o1, DoubleObjectPair<DoubleKDTreeNode> o2) {
                if (o1.first > o2.first) {
                    return 1;
                }
                if (o2.first > o1.first) {
                    return -1;
                }
                return 0;
            }
        });
        ArrayList<IntDoublePair> nns = new ArrayList<IntDoublePair>(3 * nchecks / 2);
        boolean[] seen = new boolean[N];
        for (int t = 0; t < this.trees.length; ++t) {
            this.trees[t].search(qu, pri_branch, nns, seen, this.pnts, 0.0);
        }
        while (nns.size() < nchecks) {
            DoubleObjectPair<DoubleKDTreeNode> pr = pri_branch.poll();
            ((DoubleKDTreeNode)pr.second).search(qu, pri_branch, nns, seen, this.pnts, pr.first);
        }
        Object[] nns_arr = nns.toArray(new IntDoublePair[nns.size()]);
        Sorting.partial_sort((Object[])nns_arr, (int)0, (int)numnn, (int)nns_arr.length, (BinaryPredicate)new BinaryPredicate(){

            public boolean apply(Object lhs, Object rhs) {
                return ((IntDoublePair)lhs).second < ((IntDoublePair)rhs).second;
            }
        });
        System.arraycopy(nns_arr, 0, ret_nns, 0, Math.min(numnn, nchecks));
    }

    public static class DoubleKDTreeNode {
        DoubleKDTreeNode left;
        NodeData node_data;
        private Uniform rng;

        boolean is_leaf() {
            return this.left == null;
        }

        IntDoublePair choose_split(double[][] pnts, IntArrayView inds) {
            int d;
            int D = pnts[0].length;
            double[] sum_x = new double[D];
            double[] sum_xx = new double[D];
            int count = Math.min(inds.size(), 128);
            for (int n = 0; n < count; ++n) {
                for (d = 0; d < D; ++d) {
                    int n2 = d;
                    sum_x[n2] = sum_x[n2] + pnts[inds.getFast(n)][d];
                    int n3 = d;
                    sum_xx[n3] = sum_xx[n3] + pnts[inds.getFast(n)][d] * pnts[inds.getFast(n)][d];
                }
            }
            Object[] var_dim = new DoubleIntPair[D];
            for (d = 0; d < D; ++d) {
                var_dim[d] = new DoubleIntPair();
                var_dim[d].first = count <= 1 ? 0.0 : (sum_xx[d] - 1.0 / (double)count * sum_x[d] * sum_x[d]) / (double)(count - 1);
                ((DoubleIntPair)var_dim[d]).second = d;
            }
            int nrand = Math.min(5, D);
            Sorting.partial_sort((Object[])var_dim, (int)0, (int)nrand, (int)var_dim.length, (BinaryPredicate)new BinaryPredicate(){

                public boolean apply(Object arg0, Object arg1) {
                    DoubleIntPair p1 = (DoubleIntPair)arg0;
                    DoubleIntPair p2 = (DoubleIntPair)arg1;
                    if (p1.first > p2.first) {
                        return true;
                    }
                    if (p2.first > p1.first) {
                        return false;
                    }
                    return p1.second > p2.second;
                }
            });
            int randd = ((DoubleIntPair)var_dim[this.rng.nextIntFromTo((int)0, (int)(nrand - 1))]).second;
            return new IntDoublePair(randd, sum_x[randd] / (double)count);
        }

        void split_points(double[][] pnts, IntArrayView inds) {
            IntDoublePair spl = this.choose_split(pnts, inds);
            ((InternalNodeData)this.node_data).disc_dim = spl.first;
            ((InternalNodeData)this.node_data).disc = spl.second;
            int N = inds.size();
            int l = 0;
            int r = N;
            while (l != r) {
                if (pnts[inds.getFast(l)][((InternalNodeData)this.node_data).disc_dim] < ((InternalNodeData)this.node_data).disc) {
                    ++l;
                    continue;
                }
                int t = inds.getFast(l);
                inds.setFast(l, inds.getFast(--r));
                inds.setFast(r, t);
            }
            if (l == 0 || l == N) {
                l = N / 2;
            }
            this.left = new DoubleKDTreeNode(pnts, inds.subView(0, l), this.rng);
            ((InternalNodeData)this.node_data).right = new DoubleKDTreeNode(pnts, inds.subView(l, N), this.rng);
        }

        public DoubleKDTreeNode() {
        }

        public DoubleKDTreeNode(double[][] pnts, IntArrayView inds, Uniform rng) {
            this.rng = rng;
            if (inds.size() > 14) {
                this.node_data = new InternalNodeData();
                this.split_points(pnts, inds);
            } else {
                this.node_data = new LeafNodeData();
                ((LeafNodeData)this.node_data).indices = inds.toArray();
            }
        }

        void search(double[] qu, PriorityQueue<DoubleObjectPair<DoubleKDTreeNode>> pri_branch, List<IntDoublePair> nns, boolean[] seen, double[][] pnts, double mindsq) {
            DoubleKDTreeNode cur = this;
            DoubleKDTreeNode other = null;
            while (!cur.is_leaf()) {
                double diff = qu[((InternalNodeData)cur.node_data).disc_dim] - ((InternalNodeData)cur.node_data).disc;
                if (diff < 0.0) {
                    other = ((InternalNodeData)cur.node_data).right;
                    cur = cur.left;
                } else {
                    other = cur.left;
                    cur = ((InternalNodeData)cur.node_data).right;
                }
                pri_branch.add((DoubleObjectPair<DoubleKDTreeNode>)new DoubleObjectPair(mindsq + diff * diff, (Object)other));
            }
            int[] cur_inds = ((LeafNodeData)cur.node_data).indices;
            int ncur_inds = cur_inds.length;
            double[] dsq = new double[1];
            for (int i = 0; i < ncur_inds; ++i) {
                int ci = cur_inds[i];
                if (seen[ci]) continue;
                DoubleNearestNeighbours.distanceFunc(qu, new double[][]{pnts[ci]}, dsq);
                nns.add(new IntDoublePair(ci, dsq[0]));
                seen[ci] = true;
            }
        }

        class LeafNodeData
        extends NodeData {
            int[] indices;

            LeafNodeData() {
            }
        }

        class InternalNodeData
        extends NodeData {
            DoubleKDTreeNode right;
            double disc;
            int disc_dim;

            InternalNodeData() {
            }
        }

        class NodeData {
            NodeData() {
            }
        }
    }
}

