/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.knn.lsh;

import gnu.trove.TIntCollection;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.set.hash.TIntHashSet;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.openimaj.knn.IncrementalNearestNeighbours;
import org.openimaj.util.comparator.DistanceComparator;
import org.openimaj.util.hash.HashFunction;
import org.openimaj.util.hash.HashFunctionFactory;
import org.openimaj.util.pair.IntFloatPair;
import org.openimaj.util.queue.BoundedPriorityQueue;

public class LSHNearestNeighbours<OBJECT>
implements IncrementalNearestNeighbours<OBJECT, float[], IntFloatPair> {
    protected DistanceComparator<OBJECT> distanceFcn;
    protected List<Table<OBJECT>> tables;
    protected List<OBJECT> data = new ArrayList<OBJECT>();

    public LSHNearestNeighbours(List<HashFunction<OBJECT>> tableHashes, DistanceComparator<OBJECT> distanceFcn) {
        int numTables = tableHashes.size();
        this.distanceFcn = distanceFcn;
        this.tables = new ArrayList<Table<OBJECT>>(numTables);
        for (int i = 0; i < numTables; ++i) {
            this.tables.add(new Table<OBJECT>(tableHashes.get(i)));
        }
    }

    public LSHNearestNeighbours(HashFunctionFactory<OBJECT> factory, int numTables, DistanceComparator<OBJECT> distanceFcn) {
        this.distanceFcn = distanceFcn;
        this.tables = new ArrayList<Table<OBJECT>>(numTables);
        for (int i = 0; i < numTables; ++i) {
            this.tables.add(new Table(factory.create()));
        }
    }

    public int numTables() {
        return this.tables.size();
    }

    public void addAll(Collection<OBJECT> d) {
        int i = this.data.size();
        for (OBJECT point : d) {
            this.data.add(point);
            for (Table<OBJECT> table : this.tables) {
                table.insertPoint(point, i);
            }
            ++i;
        }
    }

    public void addAll(OBJECT[] d) {
        int i = this.data.size();
        for (OBJECT point : d) {
            this.data.add(point);
            for (Table<OBJECT> table : this.tables) {
                table.insertPoint(point, i);
            }
            ++i;
        }
    }

    @Override
    public int add(OBJECT o) {
        int index = this.data.size();
        this.data.add(o);
        for (Table<OBJECT> table : this.tables) {
            table.insertPoint(o, index);
        }
        return index;
    }

    public TIntHashSet[] search(OBJECT[] data) {
        TIntHashSet[] pls = new TIntHashSet[data.length];
        for (int i = 0; i < data.length; ++i) {
            pls[i] = this.search(data[i]);
        }
        return pls;
    }

    public TIntHashSet search(OBJECT data) {
        TIntHashSet pl = new TIntHashSet();
        for (Table<OBJECT> table : this.tables) {
            TIntArrayList result = table.searchPoint(data);
            if (result == null) continue;
            pl.addAll((TIntCollection)result);
        }
        return pl;
    }

    public int[][] getBucketId(OBJECT[] data) {
        int[][] ids = new int[data.length][];
        for (int i = 0; i < data.length; ++i) {
            ids[i] = this.getBucketId(data[i]);
        }
        return ids;
    }

    public int[] getBucketId(OBJECT point) {
        int[] ids = new int[this.tables.size()];
        for (int j = 0; j < this.tables.size(); ++j) {
            ids[j] = this.tables.get((int)j).function.computeHashCode(point);
        }
        return ids;
    }

    @Override
    public void searchNN(OBJECT[] qus, int[] argmins, float[] mins) {
        int[][] argminsWrapper = new int[][]{argmins};
        float[][] minsWrapper = new float[][]{mins};
        this.searchKNN(qus, 1, (int[][])argminsWrapper, (float[][])minsWrapper);
    }

    public void searchKNN(OBJECT[] qus, int K, int[][] argmins, float[][] mins) {
        for (int i = 0; i < qus.length; ++i) {
            TIntHashSet pl = this.search(qus[i]);
            int[] ids = pl.toArray();
            ArrayList<OBJECT> vectors = new ArrayList<OBJECT>(ids.length);
            for (int j = 0; j < ids.length; ++j) {
                vectors.add(this.data.get(ids[j]));
            }
            this.exactNN(vectors, ids, qus[i], K, argmins[i], mins[i]);
        }
    }

    @Override
    public void searchNN(List<OBJECT> qus, int[] argmins, float[] mins) {
        int[][] argminsWrapper = new int[][]{argmins};
        float[][] minsWrapper = new float[][]{mins};
        this.searchKNN(qus, 1, (int[][])argminsWrapper, (float[][])minsWrapper);
    }

    public void searchKNN(List<OBJECT> qus, int K, int[][] argmins, float[][] mins) {
        int size = qus.size();
        for (int i = 0; i < size; ++i) {
            TIntHashSet pl = this.search(qus.get(i));
            int[] ids = pl.toArray();
            ArrayList<OBJECT> vectors = new ArrayList<OBJECT>(ids.length);
            for (int j = 0; j < ids.length; ++j) {
                vectors.add(this.data.get(ids[j]));
            }
            this.exactNN(vectors, ids, qus.get(i), K, argmins[i], mins[i]);
        }
    }

    private void exactNN(List<OBJECT> subset, int[] ids, OBJECT query, int K, int[] argmins, float[] mins) {
        int actualK;
        int size = subset.size();
        for (int k = actualK = Math.min(K, size); k < K; ++k) {
            argmins[k] = -1;
            mins[k] = Float.MAX_VALUE;
        }
        if (actualK == 0) {
            return;
        }
        BoundedPriorityQueue queue = new BoundedPriorityQueue(actualK, IntFloatPair.SECOND_ITEM_ASCENDING_COMPARATOR);
        ArrayList<IntFloatPair> list = new ArrayList<IntFloatPair>(actualK + 1);
        for (int i = 0; i < actualK + 1; ++i) {
            list.add(new IntFloatPair());
        }
        List<IntFloatPair> result = this.search(subset, query, (BoundedPriorityQueue<IntFloatPair>)queue, list);
        for (int k = 0; k < actualK; ++k) {
            IntFloatPair p = result.get(k);
            argmins[k] = ids[p.first];
            mins[k] = p.second;
        }
    }

    private List<IntFloatPair> search(List<OBJECT> subset, OBJECT query, BoundedPriorityQueue<IntFloatPair> queue, List<IntFloatPair> results) {
        int size = subset.size();
        IntFloatPair wp = null;
        for (IntFloatPair p : results) {
            p.second = Float.MAX_VALUE;
            p.first = -1;
            wp = (IntFloatPair)queue.offerItem((Object)p);
        }
        int i = 0;
        while (i < size) {
            wp.second = (float)this.distanceFcn.compare(query, subset.get(i));
            wp.first = i++;
            wp = (IntFloatPair)queue.offerItem((Object)wp);
        }
        return queue.toOrderedListDestructive();
    }

    @Override
    public int size() {
        return this.data.size();
    }

    public List<OBJECT> getData() {
        return new AbstractList<OBJECT>(){

            @Override
            public OBJECT get(int index) {
                return LSHNearestNeighbours.this.data.get(index);
            }

            @Override
            public int size() {
                return LSHNearestNeighbours.this.data.size();
            }
        };
    }

    public OBJECT get(int i) {
        return this.data.get(i);
    }

    @Override
    public int[] addAll(List<OBJECT> d) {
        int[] indexes = new int[d.size()];
        for (int i = 0; i < indexes.length; ++i) {
            indexes[i] = this.add(d.get(i));
        }
        return indexes;
    }

    @Override
    public List<IntFloatPair> searchKNN(OBJECT query, int K) {
        ArrayList<OBJECT> qus = new ArrayList<OBJECT>(1);
        qus.add(query);
        int[][] idx = new int[1][K];
        float[][] dst = new float[1][K];
        this.searchKNN(qus, K, idx, dst);
        ArrayList<IntFloatPair> res = new ArrayList<IntFloatPair>();
        for (int k = 0; k < K; ++k) {
            if (idx[0][k] == -1) continue;
            res.add(new IntFloatPair(idx[0][k], dst[0][k]));
        }
        return res;
    }

    @Override
    public IntFloatPair searchNN(OBJECT query) {
        ArrayList<OBJECT> qus = new ArrayList<OBJECT>(1);
        qus.add(query);
        int[] idx = new int[1];
        float[] dst = new float[1];
        this.searchNN((List<OBJECT>)qus, idx, dst);
        if (idx[0] == -1) {
            return null;
        }
        return new IntFloatPair(idx[0], dst[0]);
    }

    private static class Table<OBJECT> {
        private final TIntObjectHashMap<TIntArrayList> table;
        HashFunction<OBJECT> function;

        public Table(HashFunction<OBJECT> function) {
            this.function = function;
            this.table = new TIntObjectHashMap();
        }

        protected void insertPoint(OBJECT point, int pid) {
            int hash = this.function.computeHashCode(point);
            TIntArrayList bucket = (TIntArrayList)this.table.get(hash);
            if (bucket == null) {
                bucket = new TIntArrayList();
                this.table.put(hash, (Object)bucket);
            }
            bucket.add(pid);
        }

        protected TIntArrayList searchPoint(OBJECT point) {
            int hash = this.function.computeHashCode(point);
            return (TIntArrayList)this.table.get(hash);
        }
    }
}

