/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.hmm;

import com.aliasi.symbol.SymbolTable;
import com.aliasi.tag.TagLattice;
import com.aliasi.util.Math;
import com.aliasi.util.ScoredObject;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

class TagWordLattice
extends TagLattice<String> {
    final double[][][] mTransitions;
    final double[][] mForwards;
    final double[] mForwardExps;
    final double[][] mBacks;
    final double[] mBackExps;
    final double[] mStarts;
    final double[] mEnds;
    final String[] mTokens;
    final SymbolTable mTagSymbolTable;
    double mTotal = Double.NaN;
    double mLog2Total = Double.NaN;

    public TagWordLattice(String[] tokens, SymbolTable tagSymbolTable, double[] startProbs, double[] endProbs, double[][][] transitProbs) {
        int i = 0;
        while (i < startProbs.length) {
            if (startProbs[i] < 0.0 || startProbs[i] > 1.0) {
                String msg = "startProbs[" + i + "]=" + startProbs[i];
                throw new IllegalArgumentException(msg);
            }
            ++i;
        }
        i = 0;
        while (i < endProbs.length) {
            if (endProbs[i] < 0.0 || endProbs[i] > 1.0) {
                String msg = "endProbs[" + i + "]=" + endProbs[i];
                throw new IllegalArgumentException(msg);
            }
            ++i;
        }
        i = 1;
        while (i < transitProbs.length) {
            int j = 0;
            while (j < transitProbs[i].length) {
                int k = 0;
                while (k < transitProbs[i][j].length) {
                    if (transitProbs[i][j][k] < 0.0 || transitProbs[i][j][k] > 1.0) {
                        String msg = "transitProbs[" + i + "][" + j + "][" + k + "]=" + transitProbs[i][j][k];
                        throw new IllegalArgumentException(msg);
                    }
                    ++k;
                }
                ++j;
            }
            ++i;
        }
        int numTags = tagSymbolTable.numSymbols();
        int numTokens = tokens.length;
        this.mStarts = startProbs;
        this.mEnds = endProbs;
        this.mTransitions = transitProbs;
        this.mTokens = tokens;
        this.mTagSymbolTable = tagSymbolTable;
        this.mForwards = new double[numTokens][numTags];
        this.mForwardExps = new double[numTokens];
        this.mBacks = new double[numTokens][numTags];
        this.mBackExps = new double[numTokens];
        this.computeAll();
    }

    public String[] tokens() {
        return this.mTokens;
    }

    @Override
    public SymbolTable tagSymbolTable() {
        return this.mTagSymbolTable;
    }

    public List<ScoredObject<String>> log2ConditionalTagList(int tokenIndex) {
        double log2Total = this.log2Total();
        SymbolTable st = this.mTagSymbolTable;
        int numTags = st.numSymbols();
        ArrayList<ScoredObject<String>> scoredTagList = new ArrayList<ScoredObject<String>>();
        int tagId = 0;
        while (tagId < numTags) {
            String tag = st.idToSymbol(tagId);
            double log2P = this.log2ForwardBackward(tokenIndex, tagId);
            double condLog2P = log2P - log2Total;
            if (condLog2P > 0.0) {
                condLog2P = 0.0;
            } else if (Double.isNaN(condLog2P) || Double.isInfinite(condLog2P)) {
                condLog2P = Math.log2(Double.MIN_VALUE);
            }
            scoredTagList.add(new ScoredObject<String>(tag, condLog2P));
            ++tagId;
        }
        Collections.sort(scoredTagList, ScoredObject.reverseComparator());
        return scoredTagList;
    }

    public ScoredObject<String>[] log2ConditionalTags(int tokenIndex) {
        double log2Total = this.log2Total();
        SymbolTable st = this.mTagSymbolTable;
        int numTags = st.numSymbols();
        ScoredObject[] scoredTags = new ScoredObject[numTags];
        int tagId = 0;
        while (tagId < numTags) {
            String tag = st.idToSymbol(tagId);
            double log2P = this.log2ForwardBackward(tokenIndex, tagId);
            double condLog2P = log2P - log2Total;
            if (condLog2P > 0.0) {
                condLog2P = 0.0;
            } else if (Double.isNaN(condLog2P) || Double.isInfinite(condLog2P)) {
                condLog2P = Math.log2(Double.MIN_VALUE);
            }
            scoredTags[tagId] = new ScoredObject<String>(tag, condLog2P);
            ++tagId;
        }
        Arrays.sort(scoredTags, ScoredObject.reverseComparator());
        return scoredTags;
    }

    public String[] bestForwardBackward() {
        String[] bestTags = new String[this.mTokens.length];
        int numTags = this.mTagSymbolTable.numSymbols();
        int i = 0;
        while (i < bestTags.length) {
            int bestTagId = 0;
            double bestFB = this.forwardBackward(i, 0);
            int tagId = 1;
            while (tagId < numTags) {
                double fb = this.forwardBackward(i, tagId);
                if (fb > bestFB) {
                    bestFB = fb;
                    bestTagId = tagId;
                }
                ++tagId;
            }
            bestTags[i] = this.mTagSymbolTable.idToSymbol(bestTagId);
            ++i;
        }
        return bestTags;
    }

    public double start(int tagId) {
        return this.mStarts[tagId];
    }

    public double log2Start(int tagId) {
        return Math.log2(this.start(tagId));
    }

    public double end(int tagId) {
        return this.mEnds[tagId];
    }

    public double log2End(int tagId) {
        return Math.log2(this.end(tagId));
    }

    public double transition(int tokenIndex, int sourceTagId, int targetTagId) {
        if (tokenIndex == 0) {
            String msg = "Token index must be > 0.";
            throw new IndexOutOfBoundsException(msg);
        }
        return this.mTransitions[tokenIndex][sourceTagId][targetTagId];
    }

    public double log2Transitions(int tokenIndex, int sourceTagId, int targetTagId) {
        return Math.log2(this.transition(tokenIndex, sourceTagId, targetTagId));
    }

    public double forward(int tokenIndex, int tagId) {
        return this.mForwards[tokenIndex][tagId] * java.lang.Math.pow(2.0, this.mForwardExps[tokenIndex]);
    }

    public double log2Forward(int tokenIndex, int tagId) {
        return Math.log2(this.mForwards[tokenIndex][tagId]) + this.mForwardExps[tokenIndex];
    }

    public double backward(int tokenIndex, int tagId) {
        return this.mBacks[tokenIndex][tagId] * java.lang.Math.pow(2.0, this.mBackExps[tokenIndex]);
    }

    public double log2Backward(int tokenIndex, int tagId) {
        return Math.log2(this.mBacks[tokenIndex][tagId]) + this.mBackExps[tokenIndex];
    }

    public double forwardBackward(int tokenIndex, int tagId) {
        return this.forward(tokenIndex, tagId) * this.backward(tokenIndex, tagId);
    }

    public double log2ForwardBackward(int tokenIndex, int tagId) {
        return this.log2Forward(tokenIndex, tagId) + this.log2Backward(tokenIndex, tagId);
    }

    public double total() {
        return this.mTotal;
    }

    public double log2Total() {
        return this.mLog2Total;
    }

    @Override
    public double logForward(int token, int tag) {
        return Math.logBase2ToNaturalLog(this.log2Forward(token, tag));
    }

    @Override
    public double logBackward(int token, int tag) {
        return Math.logBase2ToNaturalLog(this.log2Backward(token, tag));
    }

    @Override
    public double logZ() {
        return Math.logBase2ToNaturalLog(this.log2Total());
    }

    @Override
    public double logTransition(int tokenFrom, int tagFrom, int tagTo) {
        return Math.logBase2ToNaturalLog(this.log2Transitions(tokenFrom + 1, tagFrom, tagTo));
    }

    @Override
    public double logProbability(int tokenIndex, int tagId) {
        return Math.logBase2ToNaturalLog(this.log2ForwardBackward(tokenIndex, tagId));
    }

    @Override
    public double logProbability(int tokenTo, int tagFrom, int tagTo) {
        return this.logProbability(tokenTo - 1, new int[]{tagFrom, tagTo});
    }

    @Override
    public double logProbability(int tokenFrom, int[] tags) {
        int startTag = tags[0];
        int endTag = tags[tags.length - 1];
        int tokenTo = tokenFrom + tags.length - 1;
        double logProb = this.logForward(tokenFrom, startTag) + this.logBackward(tokenTo, endTag) - this.logZ();
        int n = 1;
        while (n < tags.length) {
            logProb += this.logTransition(tokenFrom + n - 1, tags[n - 1], tags[n]);
            ++n;
        }
        return logProb;
    }

    @Override
    public int numTokens() {
        return this.mTokens.length;
    }

    @Override
    public List<String> tokenList() {
        return Arrays.asList(this.mTokens);
    }

    @Override
    public String token(int n) {
        return this.mTokens[n];
    }

    @Override
    public int numTags() {
        return this.mTagSymbolTable.numSymbols();
    }

    @Override
    public String tag(int n) {
        return this.mTagSymbolTable.idToSymbol(n);
    }

    @Override
    public List<String> tagList() {
        ArrayList<String> result = new ArrayList<String>(this.numTags());
        int i = 0;
        while (i < this.numTags()) {
            result.add(this.tag(i));
            ++i;
        }
        return result;
    }

    final void computeAll() {
        this.computeForward();
        this.computeBackward();
        this.computeTotal();
    }

    private void computeTotal() {
        if (this.mForwards.length == 0) {
            this.mTotal = 1.0;
            this.mLog2Total = 0.0;
            return;
        }
        double total = 0.0;
        int numSymbols = this.tagSymbolTable().numSymbols();
        int tagId = 0;
        while (tagId < numSymbols) {
            total += this.mForwards[0][tagId] * this.mBacks[0][tagId];
            ++tagId;
        }
        double exp = this.mForwardExps[0] + this.mBackExps[0];
        this.mLog2Total = Math.log2(total) + exp;
        this.mTotal = total * java.lang.Math.pow(2.0, exp);
    }

    private void computeForward() {
        if (this.mForwards.length == 0) {
            return;
        }
        int numSymbols = this.tagSymbolTable().numSymbols();
        double[] forwards = this.mForwards[0];
        int tagId = 0;
        while (tagId < numSymbols) {
            if (this.mStarts[tagId] < 0.0) {
                this.mStarts[tagId] = 0.0;
            }
            forwards[tagId] = this.mStarts[tagId];
            ++tagId;
        }
        this.mForwardExps[0] = TagWordLattice.log2ScaleExp(forwards);
        int numToks = this.mTokens.length;
        int tokenId = 1;
        while (tokenId < numToks) {
            forwards = this.mForwards[tokenId - 1];
            double[][] transits = this.mTransitions[tokenId];
            int tagId2 = 0;
            while (tagId2 < numSymbols) {
                double f = 0.0;
                int prevTagId = 0;
                while (prevTagId < numSymbols) {
                    f += forwards[prevTagId] * transits[prevTagId][tagId2];
                    ++prevTagId;
                }
                this.mForwards[tokenId][tagId2] = f;
                ++tagId2;
            }
            this.mForwardExps[tokenId] = TagWordLattice.log2ScaleExp(this.mForwards[tokenId]) + this.mForwardExps[tokenId - 1];
            ++tokenId;
        }
    }

    private void computeBackward() {
        if (this.mBacks.length == 0) {
            return;
        }
        int numSymbols = this.tagSymbolTable().numSymbols();
        int lastTok = this.mTokens.length - 1;
        double[] backs = this.mBacks[lastTok];
        int tagId = 0;
        while (tagId < numSymbols) {
            backs[tagId] = this.mEnds[tagId];
            ++tagId;
        }
        this.mBackExps[lastTok] = TagWordLattice.log2ScaleExp(backs);
        int tokenId = lastTok;
        while (--tokenId >= 0) {
            backs = this.mBacks[tokenId + 1];
            double[][] transits = this.mTransitions[tokenId + 1];
            int tagId2 = 0;
            while (tagId2 < numSymbols) {
                double b = 0.0;
                int nextTagId = 0;
                while (nextTagId < numSymbols) {
                    b += backs[nextTagId] * transits[tagId2][nextTagId];
                    ++nextTagId;
                }
                this.mBacks[tokenId][tagId2] = b;
                ++tagId2;
            }
            this.mBackExps[tokenId] = TagWordLattice.log2ScaleExp(this.mBacks[tokenId]) + this.mBackExps[tokenId + 1];
        }
    }

    static double log2ScaleExp(double[] xs) {
        if (xs.length == 0) {
            return 0.0;
        }
        double max = xs[0];
        int i = 1;
        while (i < xs.length) {
            if (max < xs[i]) {
                max = xs[i];
            }
            ++i;
        }
        if (max < 0.0 || max > 1.0) {
            String msg = "Max must be >= 0 and <= 1. Found max=" + max;
            throw new IllegalArgumentException(msg);
        }
        if (max == 0.0) {
            return 0.0;
        }
        double exp = 0.0;
        double mult = 1.0;
        while (max < 0.5) {
            exp -= 1.0;
            mult *= 2.0;
            max *= 2.0;
        }
        int j = 0;
        while (j < xs.length) {
            xs[j] = xs[j] * mult;
            ++j;
        }
        if (exp > 0.0) {
            String msg = "Exponent must be <= 0. Found exp=" + exp;
            throw new IllegalArgumentException(msg);
        }
        return exp;
    }
}

