/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    ViterbiAlgorithm.java
 *    Copyright (C) 2010 Stefan Mutter
 *
 */
package weka.classifiers.sequence.core;

import java.util.List;
import java.util.Vector;

/**
<!-- globalinfo-start -->
* calculates the Viterbi Path as a list of states and its score
<!-- globalinfo-end -->
*
* @author Stefan Mutter (pHMM4weka@gmail.com)
* @version $Revision: 6 $
*/
public class ViterbiAlgorithm extends DynamicProgAlgorithms{

  private static final long serialVersionUID = 5491459703945080087L;

  //keeping pointer from any state and sequence position. The Viterbi path is the one that ends in the last match state and end of sequence length. The pointer then re-trace the path.
  private ViterbiBacktrackObject[][] matchBacktrackPointer;
  private ViterbiBacktrackObject[][] insertBacktrackPointer;
  private ViterbiBacktrackObject[][] deleteBacktrackPointer;

  private List<State> viterbiPath;

  private boolean verbose;


  public ViterbiAlgorithm(ProfileHMM net, String sequence){
    super(net,sequence);
    this.viterbiPath = new Vector<State>();
    this.verbose = false;
  }

  /**
 * calculates the Viterbi path for a sequence
 * @return the aligned sequence (capital sequence symbol for match, lower case for insertion and - for delete)
 * @throws IllegalSymbolException
 * @throws InvalidStructureException
 * @throws InvalidViterbiPathException
 */
public String calculateViterbiPath() throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException{

    matchBacktrackPointer = new ViterbiBacktrackObject[net.getNumberMatchStates()+1][sequence.length()];
    insertBacktrackPointer = new ViterbiBacktrackObject[net.getNumberMatchStates()+1][sequence.length()];
    deleteBacktrackPointer = new ViterbiBacktrackObject[net.getNumberMatchStates()+1][sequence.length()];

    matchMatrix = new double[net.getNumberMatchStates()+1][sequence.length()];
    insertMatrix = new double[net.getNumberMatchStates()+1][sequence.length()];
    deleteMatrix = new double[net.getNumberMatchStates()+1][sequence.length()];


    //fill Matrices
    initialiseMatrices();
    recurse();
    terminate();


    //re-construct path

    return reconstructViterbiPath();

  }

  /**
 * @return the viterbi path as an aligned sequence to the model.
 * @throws InvalidViterbiPathException
 */
private String reconstructViterbiPath() throws InvalidViterbiPathException {
    String alignment = "";
    List<State> viterbiPathReverseOrder = new Vector<State>();
    viterbiPathReverseOrder.add(net.getEndState());
    viterbiPathReverseOrder.add(net.getMatchState(net.getNumberMatchStates()-1));

    int sequencePointer = sequence.length()-2;
    ViterbiBacktrackObject actualVitPos = matchBacktrackPointer[net.getNumberMatchStates()-1][sequencePointer];
    sequencePointer--;
    for(int i = net.getNumberMatchStates()-1; i > 0; i--){
      String type = actualVitPos.getType();
      if(type.equals("D")){
	actualVitPos = deleteBacktrackPointer[i-2][sequencePointer];
	viterbiPathReverseOrder.add(net.getDeleteState(i-2));
      }
      else{
	if(type.equals("M")){
	  actualVitPos = matchBacktrackPointer[i-1][sequencePointer];
	  viterbiPathReverseOrder.add(net.getMatchState(i-1));
	}
	else{
	  //type.equals("I")
	  actualVitPos = insertBacktrackPointer[i-1][sequencePointer];
	  viterbiPathReverseOrder.add(net.getInsertState(i-1));
	  i++;
	}
	sequencePointer--;
      }

    }
    if(actualVitPos.getState().equals(net.getBeginState())){
      viterbiPathReverseOrder.add(net.getBeginState());
    }
    else{
      throw new InvalidViterbiPathException("problem because of:"+actualVitPos.getState().getFullNameId()+" "+sequencePointer+". We should be in begin state.");
    }

    for (int i = 0; i < viterbiPathReverseOrder.size(); i++) {
      viterbiPath.add(i, viterbiPathReverseOrder.get(viterbiPathReverseOrder.size()-1-i));
    }

    if(verbose){
      for (int i = 0; i < viterbiPath.size(); i++) {
	State actual = viterbiPath.get(i);
	System.out.print(actual.getFullNameId()+" ");
      }
      System.out.print("\t");
    }
    sequencePointer = 0;
    for (int i = 1; i < viterbiPath.size()-1; i++) {
      State actual = viterbiPath.get(i);
      String type = actual.getName();
      if(type.equals("M")){
	if(verbose){
	  System.out.print((""+sequence.charAt(sequencePointer)).toUpperCase());
	}
	alignment += (""+sequence.charAt(sequencePointer)).toUpperCase();
	sequencePointer++;
      }
      if(type.equals("D")){
	if(verbose){
	  System.out.print("-");
	}
	alignment += "-";
      }
      if(type.equals("I")){
	if(verbose){
	  System.out.print((""+sequence.charAt(sequencePointer)).toLowerCase());
	}
	alignment += (""+sequence.charAt(sequencePointer)).toLowerCase();
	sequencePointer++;
      }
    }
    if(verbose){
      System.out.println("\t"+Math.exp(score));
    }

    return alignment;

  }

  private void terminate() throws IllegalSymbolException, InvalidStructureException {
    for (int i = 0; i < sequence.length()-1; i++) {
      matchMatrix[net.getNumberMatchStates()][i] = Double.NEGATIVE_INFINITY;
    }

    EmissionState end = net.getEndState();
    EmissionState matchjMinus1 = net.getMatchState(net.getNumberMatchStates()-1);

    double emissionProb = calculateLogEmissionProbability(end,sequence.length()-1);
    double matchjMinus1ToEndTrans = getIncomingTransitionProbFrom(end, matchjMinus1);
    double matchScore = matchjMinus1ToEndTrans + matchMatrix[net.getNumberMatchStates()-1][sequence.length()-2];

    matchMatrix[net.getNumberMatchStates()][sequence.length()-1] = emissionProb + matchScore;
    matchBacktrackPointer[net.getNumberMatchStates()][sequence.length()-1] = new ViterbiBacktrackObject(net.getMatchState(net.getNumberMatchStates()-1),sequence.length()-2);
    score = matchMatrix[net.getNumberMatchStates()][sequence.length()-1];

  }

  private void recurse() throws IllegalSymbolException, InvalidStructureException {
    for (int j = 1; j < net.getNumberMatchStates(); j++) {
      for (int i = 1; i < sequence.length()-1; i++) {
	updateMatchMatrix(j, i);
	if(j < net.getNumberMatchStates()-1 && i < sequence.length()-2){
	  updateInsertMatrix(j, i);
	}
	if(j < net.getNumberMatchStates()-2 && i < sequence.length()-2){
	  updateDeleteMatrix(j, i);
	}
      }
    }
  }

  private void updateDeleteMatrix(int j, int i) throws InvalidStructureException {
    DeleteState deletej = net.getDeleteState(j);
    DeleteState deletejMinus1 = net.getDeleteState(j-1);
    EmissionState matchj = net.getMatchState(j);
    double deletejMinus1ToDeletejTrans = getIncomingTransitionProbFrom(deletej, deletejMinus1);
    double matchjToDeletejTrans = getIncomingTransitionProbFrom(deletej, matchj);
    double matchScore = matchjToDeletejTrans + matchMatrix[j][i];
    double deleteScore = deletejMinus1ToDeletejTrans + deleteMatrix[j-1][i];
    if(matchScore > deleteScore){
      deleteMatrix[j][i] = matchScore;
      deleteBacktrackPointer[j][i] = new ViterbiBacktrackObject(net.getMatchState(j),i);

    }
    else{
      deleteMatrix[j][i] = deleteScore;
      deleteBacktrackPointer[j][i] = new ViterbiBacktrackObject(net.getDeleteState(j-1),i);

    }
  }

  private void updateInsertMatrix(int j, int i) throws IllegalSymbolException, InvalidStructureException {
    EmissionState insertj = net.getInsertState(j);
    EmissionState matchj = net.getMatchState(j);

    double emissionProb = calculateLogEmissionProbability(insertj,i);
    double insertjToInsertjTrans = getIncomingTransitionProbFrom(insertj, insertj);
    double matchjToInsertjTrans = getIncomingTransitionProbFrom(insertj, matchj);
    double matchScore = matchjToInsertjTrans + matchMatrix[j][i-1];
    double insertScore = insertjToInsertjTrans + insertMatrix[j][i-1];
    if(matchScore > insertScore){
      insertMatrix[j][i] = emissionProb + matchScore;
      insertBacktrackPointer[j][i] = new ViterbiBacktrackObject(net.getMatchState(j),i-1);

    }
    else{
      insertMatrix[j][i] = emissionProb + insertScore;
      insertBacktrackPointer[j][i] = new ViterbiBacktrackObject(net.getInsertState(j),i-1);

    }
  }

  private void updateMatchMatrix(int j, int i) throws IllegalSymbolException, InvalidStructureException {
    EmissionState matchj = net.getMatchState(j);
    EmissionState matchjMinus1 = net.getMatchState(j-1);
    EmissionState insertjMinus1 = net.getInsertState(j-1);

    double emissionProb = calculateLogEmissionProbability(matchj,i);
    double matchjMinus1ToMatchjTrans = getIncomingTransitionProbFrom(matchj, matchjMinus1);
    double insertjMinus1ToMatchjTrans = getIncomingTransitionProbFrom(matchj, insertjMinus1);
    double matchScore = matchjMinus1ToMatchjTrans + matchMatrix[j-1][i-1];
    double insertScore = insertjMinus1ToMatchjTrans + insertMatrix[j-1][i-1];
    double deleteScore = Double.NEGATIVE_INFINITY;;
    if(j>1){
      DeleteState deletejMinus2 = net.getDeleteState(j-2);
      double deletejMinus2ToMatchjTrans = getIncomingTransitionProbFrom(matchj, deletejMinus2);
      deleteScore = deletejMinus2ToMatchjTrans + deleteMatrix[j-2][i-1];
    }


    if(matchScore > insertScore){
      if((matchScore > deleteScore) || j == 1){
	matchMatrix[j][i] = emissionProb + matchScore;
	matchBacktrackPointer[j][i] = new ViterbiBacktrackObject(net.getMatchState(j-1),i-1);

      }
      else{
	matchMatrix[j][i] = emissionProb + deleteScore;
	matchBacktrackPointer[j][i] = new ViterbiBacktrackObject(net.getDeleteState(j-2),i-1);

      }
    }
    else{
      if((insertScore > deleteScore) || j == 1){
	matchMatrix[j][i] = emissionProb + insertScore;
	matchBacktrackPointer[j][i] = new ViterbiBacktrackObject(net.getInsertState(j-1),i-1);

      }
      else{
	matchMatrix[j][i] = emissionProb + deleteScore;
	matchBacktrackPointer[j][i] = new ViterbiBacktrackObject(net.getDeleteState(j-2),i-1);

      }
    }
  }

  private void initialiseMatrices() throws IllegalSymbolException, InvalidStructureException {
    for(int i = 0; i < net.getNumberMatchStates()+1; i++){
      for(int j = 0; j < sequence.length(); j++){
	matchMatrix[i][j]= Double.NEGATIVE_INFINITY;
	deleteMatrix[i][j]= Double.NEGATIVE_INFINITY;
	insertMatrix[i][j]= Double.NEGATIVE_INFINITY;
      }
    }

    EmissionState match0 = net.getMatchState(0);
    double emissionProb = calculateLogEmissionProbability(match0,0);
    matchMatrix[0][0] = emissionProb;
    matchBacktrackPointer[0][0] = new ViterbiBacktrackObject(net.getBeginState(),0);
    for (int i = 1; i < sequence.length(); i++) {
      matchMatrix[0][i] = Double.NEGATIVE_INFINITY;
      matchBacktrackPointer[0][i] = new ViterbiBacktrackObject(null,0);
    }
    for (int j = 1; j < net.getNumberMatchStates(); j++) {
      matchMatrix[j][0] = Double.NEGATIVE_INFINITY;
      matchBacktrackPointer[j][0] = new ViterbiBacktrackObject(null,0);
    }

    insertMatrix[0][0] = Double.NEGATIVE_INFINITY;
    insertBacktrackPointer[0][0] = new ViterbiBacktrackObject(null,0);
    EmissionState insert0 = net.getInsertState(0);
    double matchToInsertTrans = getIncomingTransitionProbFrom(insert0, match0);
    double insertToInsertTrans = getIncomingTransitionProbFrom(insert0, insert0);
    for (int i = 1; i < sequence.length()-1; i++) {
      emissionProb = calculateLogEmissionProbability(insert0,i);
      double fromMatch = matchToInsertTrans + matchMatrix[0][i-1];
      double fromInsert = insertToInsertTrans + insertMatrix[0][i-1];

      if(i == 1){
	insertMatrix[0][i] = emissionProb+fromMatch;
	insertBacktrackPointer[0][i] = new ViterbiBacktrackObject(match0,i-1);
      }
      else{
	insertMatrix[0][i] = emissionProb+fromInsert;
	insertBacktrackPointer[0][i] = new ViterbiBacktrackObject(insert0,i-1);
      }
    }
    insertMatrix[0][sequence.length()-1] = Double.NEGATIVE_INFINITY;
    insertBacktrackPointer[0][sequence.length()-1] = new ViterbiBacktrackObject(null,0);
    for(int j = 1; j < net.getNumberMatchStates(); j++){
      insertMatrix[j][0] = Double.NEGATIVE_INFINITY;
      insertBacktrackPointer[j][0] = new ViterbiBacktrackObject(null,0);
    }


    DeleteState delete0 = net.getDeleteState(0);
    double matchToDeleteTrans = getIncomingTransitionProbFrom(delete0, match0);
    deleteMatrix[0][0] = matchMatrix[0][0]+matchToDeleteTrans;
    deleteBacktrackPointer[0][0] = new ViterbiBacktrackObject(match0,0);

    for (int i = 1; i < sequence.length(); i++) {
      deleteMatrix[0][i] = Double.NEGATIVE_INFINITY;
      deleteBacktrackPointer[0][i] = new ViterbiBacktrackObject(null,0);
    }
    for (int j = 1; j < net.getNumberMatchStates(); j++) {
      deleteMatrix[j][0] = Double.NEGATIVE_INFINITY;
      deleteBacktrackPointer[j][0] = new ViterbiBacktrackObject(null,0);
    }
  }

  /**
 * @return the score of the Viterbi path. Needs calculateViterbiPath() first
 * @see weka.classifiers.sequence.core.DynamicProgAlgorithms#getScore()
 */
public double getScore(){
    return score;
  }

  public boolean isVerbose() {
    return verbose;
  }

  public void setVerbose(boolean verbose) {
    this.verbose = verbose;
  }

  /**
 * @return the Viterbi path as list of states. Needs calculateViterbiPath() first
 */
public List<State> getViterbiPath() {
    return viterbiPath;
  }

}
