/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.sequence.core;

import java.util.List;
import java.util.Vector;
import weka.classifiers.sequence.core.Alphabet;
import weka.classifiers.sequence.core.BackwardAlgorithm;
import weka.classifiers.sequence.core.BaumWelchTransition;
import weka.classifiers.sequence.core.DeleteState;
import weka.classifiers.sequence.core.EmissionState;
import weka.classifiers.sequence.core.ForwardAlgorithm;
import weka.classifiers.sequence.core.IllegalSymbolException;
import weka.classifiers.sequence.core.InvalidStructureException;
import weka.classifiers.sequence.core.InvalidViterbiPathException;
import weka.classifiers.sequence.core.NumericStabilityException;
import weka.classifiers.sequence.core.ProfileHMM;
import weka.classifiers.sequence.core.ProfileHMMAlgorithms;
import weka.classifiers.sequence.core.SimpleDistribution;
import weka.classifiers.sequence.core.SpecialAlphabet;
import weka.classifiers.sequence.core.State;

public class BaumWelchLearner
extends ProfileHMMAlgorithms {
    private static final long serialVersionUID = -39317981676880357L;
    protected List<ForwardAlgorithm> forwardAllSequences;
    protected String[] sequences;
    protected double[][] emissionCountMatrix;
    protected double logLikelihoodThreshold;
    protected double logLikelihood;
    protected double transitionPseudoCount = Math.log(0.01);
    protected double emissionPseudoCount = Math.log(0.01);
    protected ProfileHMM hmm;
    private boolean learnInsertEmissions;
    private boolean averageLikelihoodOverSequenceNumber;
    private boolean averageLikelihoodOverResidueNumber;
    private boolean memorySensitive;
    private int iteration;
    private double initialLogLikelihood;

    public double getInitialLogLikelihood() {
        return this.initialLogLikelihood;
    }

    public BaumWelchLearner(String[] sequences, double logLikelihoodThreshold, ProfileHMM hmm, boolean inserts, boolean memorySens) {
        super(hmm);
        this.sequences = sequences;
        this.logLikelihoodThreshold = logLikelihoodThreshold;
        this.logLikelihood = Double.NEGATIVE_INFINITY;
        this.initialLogLikelihood = Double.NEGATIVE_INFINITY;
        this.hmm = hmm;
        this.iteration = 0;
        this.averageLikelihoodOverSequenceNumber = false;
        this.averageLikelihoodOverResidueNumber = false;
        this.memorySensitive = memorySens;
        this.learnInsertEmissions = inserts;
        this.emissionCountMatrix = this.learnInsertEmissions ? new double[2 * this.hmm.getNumberMatchStates() - 1][this.hmm.getAlphabet().alphabetSize()] : new double[this.hmm.getNumberMatchStates()][this.hmm.getAlphabet().alphabetSize()];
        for (int i = 0; i < this.emissionCountMatrix.length; ++i) {
            for (int j = 0; j < this.emissionCountMatrix[0].length; ++j) {
                this.emissionCountMatrix[i][j] = Double.NEGATIVE_INFINITY;
            }
        }
    }

    public ProfileHMM learn(int stopAfterIteration) throws NumericStabilityException, IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException {
        double oldLogLikelihood;
        this.iteration = 0;
        this.computeLoglikelihood();
        this.initialLogLikelihood = this.logLikelihood;
        String infoOutput = this.useNullModel ? "Log-odd score " : "Loglikelihood ";
        if (this.averageLikelihoodOverSequenceNumber) {
            infoOutput = infoOutput + "averaged over number of sequences (" + this.sequences.length + ") ";
        }
        if (this.averageLikelihoodOverResidueNumber) {
            infoOutput = infoOutput + "averaged over number of all residues ";
        }
        do {
            ++this.iteration;
            oldLogLikelihood = this.logLikelihood;
            this.calculateExpectations();
            this.maximumLikelihoodEstimator();
            this.computeLoglikelihood();
            if (!(oldLogLikelihood > this.logLikelihood)) continue;
            throw new NumericStabilityException("LogLikelihood problem. Loglikelihood of new model (" + this.logLikelihood + ") smaller than the old one (" + oldLogLikelihood + ") in iteration " + this.iteration + ".");
        } while (Math.abs(oldLogLikelihood - this.logLikelihood) > this.logLikelihoodThreshold && this.iteration < stopAfterIteration);
        return this.hmm;
    }

    public ProfileHMM learnFast() throws NumericStabilityException, IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException {
        this.computeLoglikelihood();
        this.initialLogLikelihood = this.logLikelihood;
        String infoOutput = this.useNullModel ? "Log-odd score " : "Loglikelihood ";
        if (this.averageLikelihoodOverSequenceNumber) {
            infoOutput = infoOutput + "averaged over number of sequences (" + this.sequences.length + ") ";
        }
        if (this.averageLikelihoodOverResidueNumber) {
            infoOutput = infoOutput + "averaged over number of all residues ";
        }
        double oldLogLikelihood = this.logLikelihood;
        this.calculateExpectations();
        this.maximumLikelihoodEstimator();
        this.computeLoglikelihood();
        if (oldLogLikelihood > this.logLikelihood) {
            throw new NumericStabilityException("LogLikelihood problem. Loglikelihood of new model (" + this.logLikelihood + ") smaller than the old one (" + oldLogLikelihood + ") in iteration " + this.iteration + ".");
        }
        return this.hmm;
    }

    private void calculateExpectations() throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException, NumericStabilityException {
        for (int j = 0; j < this.sequences.length; ++j) {
            ForwardAlgorithm fwd;
            String sequence = this.sequences[j];
            if (!this.memorySensitive) {
                fwd = this.forwardAllSequences.get(j);
            } else {
                fwd = new ForwardAlgorithm(this.hmm, sequence);
                fwd.calculateForward();
            }
            double probability = fwd.getScore();
            BackwardAlgorithm bwd = new BackwardAlgorithm(this.hmm, sequence);
            bwd.calculateBackward();
            if (probability - bwd.getScore() > 1.0E-4) {
                throw new NumericStabilityException("Forward score " + probability + " and backward score" + bwd.getScore() + " aren't identical due to numerical instability.");
            }
            EmissionState end = this.hmm.getEndState();
            SpecialAlphabet endAlphabet = (SpecialAlphabet)end.getAlphabet();
            sequence = sequence + endAlphabet.toString();
            for (int i = 0; i < sequence.length() - 2; ++i) {
                this.calculateTransitionExpectations(sequence, i, fwd, bwd, probability);
                this.calculateEmissionExpectations(sequence, i, fwd, bwd, probability);
            }
            this.calculateEmissionExpectations(sequence, sequence.length() - 2, fwd, bwd, probability);
            fwd = null;
            bwd = null;
        }
        this.forwardAllSequences = null;
        System.gc();
        for (int i = 0; i < this.emissionCountMatrix.length; ++i) {
            for (int j = 0; j < this.emissionCountMatrix[0].length; ++j) {
                this.emissionCountMatrix[i][j] = BaumWelchLearner.logplus(this.emissionCountMatrix[i][j], this.emissionPseudoCount);
            }
        }
    }

    private void calculateEmissionExpectations(String sequence, int sequencePosition, ForwardAlgorithm fwd, BackwardAlgorithm bwd, double probabilityForSequence) throws IllegalSymbolException {
        int columns = this.hmm.getNumberMatchStates();
        char charAtSeq = sequence.charAt(sequencePosition);
        Alphabet alphabet = this.hmm.getAlphabet();
        int alphabetIndex = alphabet.indexOfAlphabetSymbol(charAtSeq + "");
        for (int k = 0; k < columns; ++k) {
            double forwardContribution = fwd.getMatchMatrixPos(k, sequencePosition);
            double backwardContribution = bwd.getMatchMatrixPos(k, sequencePosition);
            this.emissionCountMatrix[k][alphabetIndex] = BaumWelchLearner.logplus(this.emissionCountMatrix[k][alphabetIndex], forwardContribution + backwardContribution - probabilityForSequence);
            if (!this.learnInsertEmissions || k >= columns - 1) continue;
            int position = k + columns;
            forwardContribution = fwd.getInsertMatrixPos(k, sequencePosition);
            backwardContribution = bwd.getInsertMatrixPos(k, sequencePosition);
            this.emissionCountMatrix[position][alphabetIndex] = BaumWelchLearner.logplus(this.emissionCountMatrix[position][alphabetIndex], forwardContribution + backwardContribution - probabilityForSequence);
        }
    }

    private void calculateTransitionExpectations(String sequence, int sequencePosition, ForwardAlgorithm fwd, BackwardAlgorithm bwd, double probabilityForSequence) throws IllegalSymbolException {
        int columns = this.hmm.getNumberMatchStates();
        for (int k = 0; k < columns - 1; ++k) {
            double pathProbability;
            EmissionState matchK = this.hmm.getMatchState(k);
            for (int l = 0; l < matchK.getAllOutgoing().size(); ++l) {
                double backwardContribution;
                BaumWelchTransition actualTransition = (BaumWelchTransition)matchK.getAllOutgoing().get(l);
                State endStateofTransition = actualTransition.getEnd();
                double count = actualTransition.getExpectationLogCount();
                double forwardContribution = fwd.getMatchMatrixPos(k, sequencePosition);
                double emissionContribution = 0.0;
                if (endStateofTransition instanceof EmissionState) {
                    String typ = endStateofTransition.getName();
                    emissionContribution = this.calculateLogEmissionProbability((EmissionState)endStateofTransition, sequence, sequencePosition + 1);
                    backwardContribution = typ.equalsIgnoreCase("M") ? bwd.getMatchMatrixPos(k + 1, sequencePosition + 1) : bwd.getInsertMatrixPos(k, sequencePosition + 1);
                } else {
                    backwardContribution = bwd.getDeleteMatrixPos(k, sequencePosition);
                }
                double pathProbability2 = forwardContribution + this.getTransitionProbability(actualTransition) + emissionContribution + backwardContribution - probabilityForSequence;
                count = BaumWelchLearner.logplus(count, pathProbability2);
                actualTransition.setExpectationLogCount(count);
            }
            EmissionState insertK = this.hmm.getInsertState(k);
            for (int l = 0; l < insertK.getAllOutgoing().size(); ++l) {
                BaumWelchTransition actualTransition = (BaumWelchTransition)insertK.getAllOutgoing().get(l);
                State endStateofTransition = actualTransition.getEnd();
                double count = actualTransition.getExpectationLogCount();
                double forwardContribution = fwd.getInsertMatrixPos(k, sequencePosition);
                double emissionContribution = this.calculateLogEmissionProbability((EmissionState)endStateofTransition, sequence, sequencePosition + 1);
                String typ = endStateofTransition.getName();
                double backwardContribution = typ.equalsIgnoreCase("M") ? bwd.getMatchMatrixPos(k + 1, sequencePosition + 1) : bwd.getInsertMatrixPos(k, sequencePosition + 1);
                pathProbability = forwardContribution + this.getTransitionProbability(actualTransition) + emissionContribution + backwardContribution - probabilityForSequence;
                count = BaumWelchLearner.logplus(count, pathProbability);
                actualTransition.setExpectationLogCount(count);
            }
            if (k >= columns - 3) continue;
            DeleteState deleteK = this.hmm.getDeleteState(k);
            for (int l = 0; l < deleteK.getAllOutgoing().size(); ++l) {
                double backwardContribution;
                BaumWelchTransition actualTransition = (BaumWelchTransition)deleteK.getAllOutgoing().get(l);
                State endStateofTransition = actualTransition.getEnd();
                double count = actualTransition.getExpectationLogCount();
                double forwardContribution = fwd.getDeleteMatrixPos(k, sequencePosition);
                double emissionContribution = 0.0;
                if (endStateofTransition instanceof EmissionState) {
                    emissionContribution = this.calculateLogEmissionProbability((EmissionState)endStateofTransition, sequence, sequencePosition + 1);
                    backwardContribution = bwd.getMatchMatrixPos(k + 2, sequencePosition + 1);
                } else {
                    backwardContribution = bwd.getDeleteMatrixPos(k + 1, sequencePosition);
                }
                pathProbability = forwardContribution + this.getTransitionProbability(actualTransition) + emissionContribution + backwardContribution - probabilityForSequence;
                count = BaumWelchLearner.logplus(count, pathProbability);
                actualTransition.setExpectationLogCount(count);
            }
        }
    }

    private void maximumLikelihoodEstimator() throws NumericStabilityException {
        int columns = this.hmm.getNumberMatchStates();
        for (int k = 0; k < columns - 1; ++k) {
            EmissionState matchK = this.hmm.getMatchState(k);
            EmissionState insertK = this.hmm.getInsertState(k);
            this.doMaximumLikelihoodEstimatesTransitions(matchK);
            this.doMaximumLikelihoodEstimatesTransitions(insertK);
            if (k < columns - 3) {
                DeleteState deleteK = this.hmm.getDeleteState(k);
                this.doMaximumLikelihoodEstimatesTransitions(deleteK);
            }
            this.doMaximumLikelihoodEstimatesEmissions(k);
            if (!this.learnInsertEmissions) continue;
            this.doMaximumLikelihoodEstimatesEmissions(k + columns);
        }
        this.doMaximumLikelihoodEstimatesEmissions(columns - 1);
        System.gc();
        this.emissionCountMatrix = this.learnInsertEmissions ? new double[2 * this.hmm.getNumberMatchStates() - 1][this.hmm.getAlphabet().alphabetSize()] : new double[this.hmm.getNumberMatchStates()][this.hmm.getAlphabet().alphabetSize()];
        for (int i = 0; i < this.emissionCountMatrix.length; ++i) {
            for (int j = 0; j < this.emissionCountMatrix[0].length; ++j) {
                this.emissionCountMatrix[i][j] = Double.NEGATIVE_INFINITY;
            }
        }
    }

    private void doMaximumLikelihoodEstimatesEmissions(int stateIndex) throws NumericStabilityException {
        double[] allEmissions = this.emissionCountMatrix[stateIndex];
        int columns = this.hmm.getNumberMatchStates();
        EmissionState state = stateIndex < columns ? this.hmm.getMatchState(stateIndex) : this.hmm.getInsertState(stateIndex - columns);
        state.setDistribution(null);
        SimpleDistribution simpleDist = new SimpleDistribution(this.hmm.getAlphabet(), this.hmm.isUseLogSpace());
        simpleDist.setProbWithArray(allEmissions);
        state.setDistribution(simpleDist);
    }

    protected void doMaximumLikelihoodEstimatesTransitions(State state) {
        BaumWelchTransition actualTransition;
        int l;
        double sum = Double.NEGATIVE_INFINITY;
        for (l = 0; l < state.getAllOutgoing().size(); ++l) {
            actualTransition = (BaumWelchTransition)state.getAllOutgoing().get(l);
            sum = BaumWelchLearner.logplus(sum, actualTransition.getExpectationLogCount());
            sum = BaumWelchLearner.logplus(sum, this.transitionPseudoCount);
        }
        for (l = 0; l < state.getAllOutgoing().size(); ++l) {
            actualTransition = (BaumWelchTransition)state.getAllOutgoing().get(l);
            double withPseudoCount = BaumWelchLearner.logplus(actualTransition.getExpectationLogCount(), this.transitionPseudoCount);
            this.setTransitionProbability(actualTransition, withPseudoCount - sum);
            actualTransition.setExpectationLogCount(Double.NEGATIVE_INFINITY);
        }
    }

    private void computeLoglikelihood() throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException, NumericStabilityException {
        double tempLogLikelihood = 0.0;
        this.forwardAllSequences = new Vector<ForwardAlgorithm>();
        int sumOfAllResidues = 0;
        for (int i = 0; i < this.sequences.length; ++i) {
            String sequence = this.sequences[i];
            sumOfAllResidues += sequence.length();
            ForwardAlgorithm fwd = new ForwardAlgorithm(this.hmm, sequence);
            fwd.calculateForward();
            if (!this.memorySensitive) {
                this.forwardAllSequences.add(i, fwd);
            }
            tempLogLikelihood += fwd.getScore();
            fwd = null;
        }
        this.logLikelihood = !this.averageLikelihoodOverSequenceNumber && !this.averageLikelihoodOverResidueNumber ? tempLogLikelihood : (this.averageLikelihoodOverSequenceNumber ? tempLogLikelihood / (double)this.sequences.length : tempLogLikelihood / (double)sumOfAllResidues);
    }

    private double calculateLogEmissionProbability(EmissionState actual, String sequence, int sequencePosition) throws IllegalSymbolException {
        double prob = actual.getEmissionProbability("" + sequence.charAt(sequencePosition));
        if (!this.logForTransitionsEmissions) {
            if (this.useNullModel) {
                prob = actual.equals(this.net.getEndState()) ? 1.0 : (prob /= this.net.getNullModel().getNullModelEmissionProbability("" + sequence.charAt(sequencePosition)));
            }
            return Math.log(prob);
        }
        if (this.useNullModel) {
            prob = actual.equals(this.net.getEndState()) ? 0.0 : (prob -= this.net.getNullModel().getNullModelEmissionProbability("" + sequence.charAt(sequencePosition)));
        }
        return prob;
    }

    public boolean isLearnInsertEmissions() {
        return this.learnInsertEmissions;
    }

    public ProfileHMM getHmm() {
        return this.hmm;
    }

    public boolean isAverageLikelihoodOverSequenceNumber() {
        return this.averageLikelihoodOverSequenceNumber;
    }

    public void setAverageLikelihoodOverSequenceNumber(boolean averageLikelihoodOverSequenceNumber) {
        this.averageLikelihoodOverResidueNumber = false;
        this.averageLikelihoodOverSequenceNumber = averageLikelihoodOverSequenceNumber;
    }

    public boolean isAverageLikelihoodOverResidueNumber() {
        return this.averageLikelihoodOverResidueNumber;
    }

    public void setAverageLikelihoodOverResidueNumber(boolean averageLikelihoodOverResidueNumber) {
        this.averageLikelihoodOverSequenceNumber = false;
        this.averageLikelihoodOverResidueNumber = averageLikelihoodOverResidueNumber;
    }

    public boolean isMemorySensitive() {
        return this.memorySensitive;
    }

    public void setMemorySensitive(boolean memorySensitive) {
        this.memorySensitive = memorySensitive;
    }

    public int getIteration() {
        return this.iteration;
    }

    public double getLogLikelihood() {
        return this.logLikelihood;
    }
}

