/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.lm;

import com.aliasi.corpus.ObjectHandler;
import com.aliasi.lm.IntNode;
import com.aliasi.lm.IntSeqCounter;
import com.aliasi.util.ObjectToCounterMap;

public class TrieIntSeqCounter
implements IntSeqCounter {
    private final int mMaxLength;
    final IntNode mRootNode;

    public TrieIntSeqCounter(int maxLength) {
        if (maxLength < 0) {
            String msg = "Max length must be >= 0. Found maxLength=" + maxLength;
            throw new IllegalArgumentException(msg);
        }
        this.mMaxLength = maxLength;
        this.mRootNode = new IntNode();
    }

    public void prune(int minCount) {
        this.mRootNode.prune(minCount);
    }

    public void rescale(double countMultiplier) {
        this.mRootNode.rescale(countMultiplier);
    }

    public int maxLength() {
        return this.mMaxLength;
    }

    public void incrementSubsequences(int[] is, int start, int end) {
        TrieIntSeqCounter.checkBoundaries(is, start, end);
        int i = start;
        while (i < end) {
            this.mRootNode.increment(is, i, Math.min(i + this.maxLength(), end));
            ++i;
        }
    }

    public void incrementSubsequences(int[] is, int start, int end, int count) {
        TrieIntSeqCounter.checkBoundaries(is, start, end);
        TrieIntSeqCounter.checkCount(count);
        if (count == 0) {
            return;
        }
        int i = start;
        while (i < end) {
            this.mRootNode.increment(is, i, Math.min(i + this.maxLength(), end), count);
            ++i;
        }
    }

    static void checkCount(int count) {
        if (count >= 0) {
            return;
        }
        String msg = "Counts must be non-negative. Found count=" + count;
        throw new IllegalArgumentException(msg);
    }

    public void incrementSequence(int[] is, int start, int end, int count) {
        TrieIntSeqCounter.checkBoundaries(is, start, end);
        TrieIntSeqCounter.checkCount(count);
        if (count == 0) {
            return;
        }
        this.mRootNode.incrementSequence(is, Math.max(start, end - this.maxLength()), end, count);
    }

    public ObjectToCounterMap<int[]> nGramCounts(int nGram, int minCount) {
        if (nGram < 1) {
            String msg = "Ngrams must be positive. Found n-gram=" + nGram;
            throw new IllegalArgumentException(msg);
        }
        ObjectToCounterMap<int[]> result = new ObjectToCounterMap<int[]>();
        int[] nGramBuffer = new int[nGram];
        this.addNGramCounts(minCount, 0, nGram, nGramBuffer, result);
        return result;
    }

    public int trieSize() {
        return this.mRootNode.trieSize();
    }

    public void handleNGrams(int nGram, int minCount, ObjectHandler<int[]> handler) {
        if (nGram < 1) {
            String msg = "Ngrams must be positive. Found n-gram=" + nGram;
            throw new IllegalArgumentException(msg);
        }
        int[] nGramBuffer = new int[nGram];
        this.handleNGrams(minCount, 0, nGram, nGramBuffer, handler);
    }

    @Override
    public int count(int[] is, int start, int end) {
        TrieIntSeqCounter.checkBoundaries(is, start, end);
        IntNode dtr = this.mRootNode.getDtr(is, start, end);
        return dtr == null ? 0 : dtr.count();
    }

    @Override
    public long extensionCount(int[] is, int start, int end) {
        TrieIntSeqCounter.checkBoundaries(is, start, end);
        IntNode dtr = this.mRootNode.getDtr(is, start, end);
        return dtr == null ? 0L : dtr.extensionCount();
    }

    @Override
    public int numExtensions(int[] is, int start, int end) {
        TrieIntSeqCounter.checkBoundaries(is, start, end);
        IntNode dtr = this.mRootNode.getDtr(is, start, end);
        return dtr == null ? 0 : dtr.numExtensions();
    }

    @Override
    public int[] observedIntegers() {
        return this.mRootNode.observedIntegers();
    }

    @Override
    public int[] integersFollowing(int[] is, int start, int end) {
        return this.mRootNode.integersFollowing(is, start, end);
    }

    public String toString() {
        return this.mRootNode.toString(null);
    }

    void decrementUnigram(int symbol) {
        this.mRootNode.decrement(symbol);
    }

    void decrementUnigram(int symbol, int count) {
        this.mRootNode.decrement(symbol, count);
    }

    void handleNGrams(int minCount, int pos, int nGram, int[] buf, ObjectHandler<int[]> handler) {
        int[] integersFollowing = this.integersFollowing(buf, 0, pos);
        if (pos == nGram) {
            int count = this.count(buf, 0, nGram);
            if (count < minCount) {
                return;
            }
            handler.handle(buf);
            return;
        }
        int i = 0;
        while (i < integersFollowing.length) {
            buf[pos] = integersFollowing[i];
            this.handleNGrams(minCount, pos + 1, nGram, buf, handler);
            ++i;
        }
    }

    void addNGramCounts(int minCount, int pos, int nGram, int[] buf, ObjectToCounterMap<int[]> counter) {
        int[] integersFollowing = this.integersFollowing(buf, 0, pos);
        if (pos == nGram) {
            int count = this.count(buf, 0, nGram);
            if (count < minCount) {
                return;
            }
            counter.set((int[])buf.clone(), count);
            return;
        }
        int i = 0;
        while (i < integersFollowing.length) {
            buf[pos] = integersFollowing[i];
            this.addNGramCounts(minCount, pos + 1, nGram, buf, counter);
            ++i;
        }
    }

    static void checkBoundaries(int[] is, int start, int end) {
        if (start < 0) {
            String msg = "Start must be in array range. Found start=" + start;
            throw new IndexOutOfBoundsException(msg);
        }
        if (end > is.length) {
            String msg = "End must be in array range. Found end=" + end + " Length=" + is.length;
            throw new IndexOutOfBoundsException(msg);
        }
        if (end < start) {
            String msg = "End must be at or after start. Found start=" + start + " Found end=" + end;
            throw new IndexOutOfBoundsException(msg);
        }
    }
}

