/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    BaumWelchLearner.java
 *    Copyright (C) 2010 Stefan Mutter
 *
 */
package weka.classifiers.sequence.core;

import java.util.List;
import java.util.Vector;

/**
<!-- globalinfo-start -->
* class to train a PHMM using the Baum-Welch training algorithm which is a derivate of EM (expectation maximization)
<!-- globalinfo-end -->
* 
* @author Stefan Mutter (pHMM4weka@gmail.com)
* @version $Revision: 4 $
*/
public class BaumWelchLearner extends ProfileHMMAlgorithms{

  
  private static final long serialVersionUID = -39317981676880357L;

  protected List<ForwardAlgorithm> forwardAllSequences;

  protected String[] sequences;

  protected double[][] emissionCountMatrix;

  protected double logLikelihoodThreshold;

  protected double logLikelihood;

  protected double transitionPseudoCount = Math.log(0.01); //pseudo count to stop transition to become prob 0

  protected double emissionPseudoCount = Math.log(0.01); //pseudo count to stop emissions to become prob 0

  protected ProfileHMM hmm;

  private boolean learnInsertEmissions;

  private boolean averageLikelihoodOverSequenceNumber;

  private boolean averageLikelihoodOverResidueNumber;

  private boolean memorySensitive;

  private int iteration;

  private double initialLogLikelihood;

  public double getInitialLogLikelihood() {
    return initialLogLikelihood;
  }

  /**
 * Constructor
 * @param sequences the sequence dataset
 * @param logLikelihoodThreshold training stop when the difference in loglikelihood is smaller than threshold
 * @param hmm the PHMM
 * @param inserts true if emission probabilities in insert states are learned, false otherwise
 * @param memorySens if true the training is executedn memory sensitive which is more time consuming as forward matrices need to re-calculated
 */
public BaumWelchLearner(String[] sequences, double logLikelihoodThreshold, ProfileHMM hmm, boolean inserts, boolean memorySens) {
    super(hmm);
    this.sequences = sequences;
    this.logLikelihoodThreshold = logLikelihoodThreshold;
    this.logLikelihood = Double.NEGATIVE_INFINITY;
    this.initialLogLikelihood = Double.NEGATIVE_INFINITY;
    this.hmm = hmm;
    this.iteration = 0;
    this.averageLikelihoodOverSequenceNumber = false;
    this.averageLikelihoodOverResidueNumber = false;
    this.memorySensitive = memorySens;
    this.learnInsertEmissions = inserts;
    if(learnInsertEmissions){
      this.emissionCountMatrix = new double[(2*this.hmm.getNumberMatchStates())-1][(this.hmm.getAlphabet()).alphabetSize()];
    }
    else{
      this.emissionCountMatrix = new double[this.hmm.getNumberMatchStates()][(this.hmm.getAlphabet()).alphabetSize()];
    }
    for(int i = 0; i < emissionCountMatrix.length; i++){
      for(int j = 0; j < emissionCountMatrix[0].length; j++){
	emissionCountMatrix[i][j] = Double.NEGATIVE_INFINITY;
      }
    }
    //System.out.println(emissionCountMatrix.length+" "+emissionCountMatrix[0].length);

  }

  /**
 * learns a PHMM from a set of sequences
 * @param stopAfterIteration training can be stopped after a specified amount of iteration instead of difference in loglikelihood (standard)
 * @return a trained PHMM
 * @throws NumericStabilityException
 * @throws IllegalSymbolException
 * @throws InvalidStructureException
 * @throws InvalidViterbiPathException
 */
public ProfileHMM learn(int stopAfterIteration) throws NumericStabilityException, IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException{
    double oldLogLikelihood;
    iteration = 0;
    computeLoglikelihood();
    initialLogLikelihood = logLikelihood;
    String infoOutput;
    if(useNullModel){
      infoOutput = "Log-odd score "; 
    }
    else{
      infoOutput = "Loglikelihood ";
    }
    if(averageLikelihoodOverSequenceNumber){
      infoOutput += "averaged over number of sequences ("+sequences.length+") ";
    }
    if(averageLikelihoodOverResidueNumber){
      infoOutput += "averaged over number of all residues ";
    }
    //System.out.println(infoOutput+"in iteration "+iteration+": "+logLikelihood);
    do {
      iteration++;
      oldLogLikelihood = logLikelihood;

      //E-Step
      calculateExpectations();

      //M-Step
      maximumLikelihoodEstimator();

      //compare new loglikelihood to old one
      computeLoglikelihood();
      //System.out.println(infoOutput+"in iteration "+iteration+": "+logLikelihood);
      if(oldLogLikelihood > logLikelihood){
	throw new NumericStabilityException("LogLikelihood problem. Loglikelihood of new model ("+ logLikelihood+") smaller than the old one ("+oldLogLikelihood+") in iteration "+iteration+".");
      }
    } while ((Math.abs(oldLogLikelihood - logLikelihood) > logLikelihoodThreshold) && (iteration < stopAfterIteration));
    return hmm;
  }

  public ProfileHMM learnFast() throws NumericStabilityException, IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException{
    double oldLogLikelihood;
    computeLoglikelihood();
    initialLogLikelihood = logLikelihood;
    String infoOutput;
    if(useNullModel){
      infoOutput = "Log-odd score "; 
    }
    else{
      infoOutput = "Loglikelihood ";
    }
    if(averageLikelihoodOverSequenceNumber){
      infoOutput += "averaged over number of sequences ("+sequences.length+") ";
    }
    if(averageLikelihoodOverResidueNumber){
      infoOutput += "averaged over number of all residues ";
    }
    //System.out.println(infoOutput+"in iteration "+iteration+": "+logLikelihood);
    oldLogLikelihood = logLikelihood;

    //E-Step
    calculateExpectations();

    //M-Step
    maximumLikelihoodEstimator();

    //compare new loglikelihood to old one
    computeLoglikelihood();
    //System.out.println(infoOutput+"in iteration "+iteration+": "+logLikelihood);
    if(oldLogLikelihood > logLikelihood){
      throw new NumericStabilityException("LogLikelihood problem. Loglikelihood of new model ("+ logLikelihood+") smaller than the old one ("+oldLogLikelihood+") in iteration "+iteration+".");
    }
    return hmm;
  }

  //use computeLogLikelihood before using this methods
  private void calculateExpectations() throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException, NumericStabilityException {
    for(int j = 0; j < sequences.length; j++){
      String sequence = sequences[j];
      ForwardAlgorithm fwd; 
      BackwardAlgorithm bwd; 
      if(!memorySensitive){
	fwd = forwardAllSequences.get(j);
      }
      else{
	fwd = new ForwardAlgorithm(hmm, sequence);
	fwd.calculateForward();
      }
      double probability = fwd.getScore();

      bwd = new BackwardAlgorithm(hmm, sequence);
      bwd.calculateBackward();

      if(probability - bwd.getScore() > 0.0001){
	throw new NumericStabilityException("Forward score "+probability+" and backward score"+bwd.getScore()+" aren't identical due to numerical instability.");
      }


      EmissionState end = hmm.getEndState();
      SpecialAlphabet endAlphabet = (SpecialAlphabet)end.getAlphabet();
      sequence = sequence+endAlphabet.toString();

      for(int i = 0; i < sequence.length()-2; i++){
	calculateTransitionExpectations(sequence, i, fwd, bwd, probability);
	calculateEmissionExpectations(sequence, i, fwd, bwd, probability);
	//System.out.println(i);
      }
      calculateEmissionExpectations(sequence, sequence.length()-2, fwd, bwd, probability);
      fwd = null;
      bwd = null;
    }
    forwardAllSequences = null;
    //backwardAllSequences = null;
    System.gc();
    for(int i = 0; i < emissionCountMatrix.length; i++){
      for(int j = 0; j < emissionCountMatrix[0].length; j++){
	emissionCountMatrix[i][j] = logplus(emissionCountMatrix[i][j],emissionPseudoCount);
      }
    }

  }

  private void calculateEmissionExpectations(String sequence, int sequencePosition, ForwardAlgorithm fwd, BackwardAlgorithm bwd, double probabilityForSequence) throws IllegalSymbolException {
    int columns = hmm.getNumberMatchStates();
    char charAtSeq = sequence.charAt(sequencePosition);
    Alphabet alphabet = hmm.getAlphabet();
    //System.out.println(charAtSeq+" "+sequencePosition+" "+sequence+" "+hmm.getAlphabet().alphabetSize());
    int alphabetIndex = alphabet.indexOfAlphabetSymbol(charAtSeq+"");
    for(int k = 0; k < columns; k++){
      double forwardContribution = fwd.getMatchMatrixPos(k, sequencePosition);
      double backwardContribution = bwd.getMatchMatrixPos(k, sequencePosition);
      //System.out.println(k+" "+alphabetIndex);
      emissionCountMatrix[k][alphabetIndex] = logplus(emissionCountMatrix[k][alphabetIndex], forwardContribution + backwardContribution - probabilityForSequence );

      if(learnInsertEmissions && k < columns -1){
	int position = k + columns;
	forwardContribution = fwd.getInsertMatrixPos(k, sequencePosition);
	backwardContribution = bwd.getInsertMatrixPos(k, sequencePosition);
	emissionCountMatrix[position][alphabetIndex] = logplus(emissionCountMatrix[position][alphabetIndex], forwardContribution + backwardContribution - probabilityForSequence );

      }
    }

  }

  private void calculateTransitionExpectations(String sequence, int sequencePosition, ForwardAlgorithm fwd, BackwardAlgorithm bwd, double probabilityForSequence) throws IllegalSymbolException {
    int columns = hmm.getNumberMatchStates();
    for(int k = 0; k < columns-1; k++){
      EmissionState matchK = hmm.getMatchState(k);
      for(int l = 0; l < (matchK.getAllOutgoing()).size(); l++){
	BaumWelchTransition actualTransition = (BaumWelchTransition) (matchK.getAllOutgoing()).get(l);
	State endStateofTransition = actualTransition.getEnd();
	double count = actualTransition.getExpectationLogCount();
	double forwardContribution = fwd.getMatchMatrixPos(k, sequencePosition);
	double backwardContribution;
	double emissionContribution = 0;
	if (endStateofTransition instanceof EmissionState) {
	  String typ = endStateofTransition.getName();
	  //System.out.println(typ+endStateofTransition.getId()+" "+sequence+" "+(sequencePosition+1));
	  emissionContribution = calculateLogEmissionProbability((EmissionState)endStateofTransition, sequence, sequencePosition+1);
	  if(typ.equalsIgnoreCase("M")){
	    backwardContribution = bwd.getMatchMatrixPos(k+1, sequencePosition+1);
	  }
	  else{
	    backwardContribution = bwd.getInsertMatrixPos(k, sequencePosition+1);
	  }
	}
	else{
	  backwardContribution = bwd.getDeleteMatrixPos(k, sequencePosition);//seq+1 ?
	}
	double pathProbability = forwardContribution + getTransitionProbability(actualTransition) + emissionContribution + backwardContribution - probabilityForSequence;
	count = logplus(count, pathProbability);
	actualTransition.setExpectationLogCount(count);

      }

      EmissionState insertK = hmm.getInsertState(k);
      for(int l = 0; l < (insertK.getAllOutgoing()).size(); l++){
	BaumWelchTransition actualTransition = (BaumWelchTransition) (insertK.getAllOutgoing()).get(l);
	State endStateofTransition = actualTransition.getEnd();
	double count = actualTransition.getExpectationLogCount();
	double forwardContribution = fwd.getInsertMatrixPos(k, sequencePosition);
	double backwardContribution;
	double emissionContribution = calculateLogEmissionProbability((EmissionState)endStateofTransition, sequence, sequencePosition+1);
	String typ = endStateofTransition.getName();
	if(typ.equalsIgnoreCase("M")){
	  backwardContribution = bwd.getMatchMatrixPos(k+1, sequencePosition+1);
	}
	else{
	  backwardContribution = bwd.getInsertMatrixPos(k, sequencePosition+1);
	}
	double pathProbability = forwardContribution + getTransitionProbability(actualTransition) + emissionContribution + backwardContribution - probabilityForSequence;
	count = logplus(count, pathProbability);
	actualTransition.setExpectationLogCount(count);

      }

      if(k < columns-3){
	DeleteState deleteK = hmm.getDeleteState(k);
	for(int l = 0; l < (deleteK.getAllOutgoing()).size(); l++){
	  BaumWelchTransition actualTransition = (BaumWelchTransition) (deleteK.getAllOutgoing()).get(l);
	  State endStateofTransition = actualTransition.getEnd();
	  double count = actualTransition.getExpectationLogCount();
	  double forwardContribution = fwd.getDeleteMatrixPos(k, sequencePosition);
	  double backwardContribution;
	  double emissionContribution = 0;
	  if (endStateofTransition instanceof EmissionState) {
	    emissionContribution = calculateLogEmissionProbability((EmissionState)endStateofTransition, sequence, sequencePosition+1);
	    backwardContribution = bwd.getMatchMatrixPos(k+2, sequencePosition+1);
	  }
	  else{
	    backwardContribution = bwd.getDeleteMatrixPos(k+1, sequencePosition);//seq+1 ?
	  }
	  double pathProbability = forwardContribution + getTransitionProbability(actualTransition) + emissionContribution + backwardContribution - probabilityForSequence;
	  count = logplus(count, pathProbability);
	  actualTransition.setExpectationLogCount(count);
	}
      }
    }

  }

  private void maximumLikelihoodEstimator() throws NumericStabilityException {
    int columns = hmm.getNumberMatchStates();
    for (int k = 0; k < columns - 1; k++) {
      EmissionState matchK = hmm.getMatchState(k);
      EmissionState insertK = hmm.getInsertState(k);

      //transitions
      doMaximumLikelihoodEstimatesTransitions(matchK);
      doMaximumLikelihoodEstimatesTransitions(insertK);
      if(k < columns-3){
	DeleteState deleteK = hmm.getDeleteState(k);
	doMaximumLikelihoodEstimatesTransitions(deleteK);
      }

      //emissions
      doMaximumLikelihoodEstimatesEmissions(k);
      if(learnInsertEmissions){
	doMaximumLikelihoodEstimatesEmissions(k + columns);
      }
    }
    doMaximumLikelihoodEstimatesEmissions(columns-1);
    System.gc();
    if(learnInsertEmissions){
      this.emissionCountMatrix = new double[(2*this.hmm.getNumberMatchStates())-1][(this.hmm.getAlphabet()).alphabetSize()];
    }
    else{
      this.emissionCountMatrix = new double[this.hmm.getNumberMatchStates()][(this.hmm.getAlphabet()).alphabetSize()];
    }
    for(int i = 0; i < emissionCountMatrix.length; i++){
      for(int j = 0; j < emissionCountMatrix[0].length; j++){
	emissionCountMatrix[i][j] = Double.NEGATIVE_INFINITY;
      }
    }

  }

  private void doMaximumLikelihoodEstimatesEmissions(int stateIndex) throws NumericStabilityException {
    double[] allEmissions = emissionCountMatrix[stateIndex];
    int columns = hmm.getNumberMatchStates();
    EmissionState state;
    if(stateIndex< columns){
      state = hmm.getMatchState(stateIndex);
    }
    else{
      state = hmm.getInsertState(stateIndex - columns);
    }
    state.setDistribution(null);
    SimpleDistribution simpleDist = new SimpleDistribution(hmm.getAlphabet(), hmm.isUseLogSpace());
    simpleDist.setProbWithArray(allEmissions);
    state.setDistribution(simpleDist);

  }

  protected void doMaximumLikelihoodEstimatesTransitions(State state) {
    double sum = Double.NEGATIVE_INFINITY;
    for(int l = 0; l < (state.getAllOutgoing()).size(); l++){
      BaumWelchTransition actualTransition = (BaumWelchTransition) (state.getAllOutgoing()).get(l);
      sum = logplus(sum,actualTransition.getExpectationLogCount());
      sum = logplus(sum,transitionPseudoCount);
    }
    for(int l = 0; l < (state.getAllOutgoing()).size(); l++){
      BaumWelchTransition actualTransition = (BaumWelchTransition) (state.getAllOutgoing()).get(l);
      double withPseudoCount = logplus(actualTransition.getExpectationLogCount(),transitionPseudoCount);
      setTransitionProbability(actualTransition, withPseudoCount - sum);
      actualTransition.setExpectationLogCount(Double.NEGATIVE_INFINITY);
    }
  }


  private void computeLoglikelihood() throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException, NumericStabilityException{
    double tempLogLikelihood = 0;
    //probabilityForEachSequence = new double[sequences.length];
    forwardAllSequences = new Vector<ForwardAlgorithm>();
    //backwardAllSequences = new Vector<BackwardAlgorithm>();
    int sumOfAllResidues = 0;
    for(int i = 0; i < sequences.length; i++){
      String sequence = sequences[i];
      sumOfAllResidues += sequence.length();
      ForwardAlgorithm fwd = new ForwardAlgorithm(hmm, sequence);
      fwd.calculateForward();
      if(!memorySensitive){
	forwardAllSequences.add(i, fwd);
      }
      tempLogLikelihood += fwd.getScore();
      fwd = null;

    }
    if(!averageLikelihoodOverSequenceNumber && !averageLikelihoodOverResidueNumber){
      logLikelihood = tempLogLikelihood;
    }
    else{
      if(averageLikelihoodOverSequenceNumber){
	logLikelihood = tempLogLikelihood / sequences.length;
      }
      else{
	logLikelihood = tempLogLikelihood / sumOfAllResidues;
      }
    }
  }

  private double calculateLogEmissionProbability(EmissionState actual, String sequence, int sequencePosition) throws IllegalSymbolException {

    double prob = actual.getEmissionProbability(""+sequence.charAt(sequencePosition));
    if(!logForTransitionsEmissions){
      if(useNullModel){
	if(actual.equals(net.getEndState())){
	  prob = 1.0;
	}
	else{
	  prob = prob / ((net.getNullModel()).getNullModelEmissionProbability(""+sequence.charAt(sequencePosition)));
	}
      }
      return Math.log(prob);
    }
    else{
      if(useNullModel){
	if(actual.equals(net.getEndState())){
	  prob = 0.0;
	}
	else{
	  prob = prob - ((net.getNullModel()).getNullModelEmissionProbability(""+sequence.charAt(sequencePosition)));
	}
      }
      return prob;
    }
  }

  public boolean isLearnInsertEmissions() {
    return learnInsertEmissions;
  }

  public ProfileHMM getHmm() {
    return hmm;
  }

  public boolean isAverageLikelihoodOverSequenceNumber() {
    return averageLikelihoodOverSequenceNumber;
  }

  public void setAverageLikelihoodOverSequenceNumber(
      boolean averageLikelihoodOverSequenceNumber) {
    averageLikelihoodOverResidueNumber = false;
    this.averageLikelihoodOverSequenceNumber = averageLikelihoodOverSequenceNumber;
  }

  public boolean isAverageLikelihoodOverResidueNumber() {
    return averageLikelihoodOverResidueNumber;
  }

  public void setAverageLikelihoodOverResidueNumber(
      boolean averageLikelihoodOverResidueNumber) {
    averageLikelihoodOverSequenceNumber = false;
    this.averageLikelihoodOverResidueNumber = averageLikelihoodOverResidueNumber;
  }

  public boolean isMemorySensitive() {
    return memorySensitive;
  }

  public void setMemorySensitive(boolean memorySensitive) {
    this.memorySensitive = memorySensitive;
  }

  public int getIteration() {
    return iteration;
  }

  public double getLogLikelihood() {
    return logLikelihood;
  }



}
