/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    ForwardAlgorithm.java
 *    Copyright (C) 2010 Stefan Mutter
 *
 */
package weka.classifiers.sequence.core;


/**
<!-- globalinfo-start -->
* calculates the forward score for a PHMM and a sequence
<!-- globalinfo-end -->
*
* @author Stefan Mutter (pHMM4weka@gmail.com)
* @version $Revision: 6 $
*/
public class ForwardAlgorithm extends DynamicProgAlgorithms {

  private static final long serialVersionUID = -8804666261111359271L;

  public ForwardAlgorithm(ProfileHMM net, String sequence){
    super(net,sequence);
  }

  /**
 * calculates the forward matrices for match, insert and delete states. The forward score can be retrieved by calling getScore() from the parent class after this methods execution.
 * @throws IllegalSymbolException
 * @throws InvalidStructureException
 * @throws InvalidViterbiPathException
 */
public void calculateForward() throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException{
//  create Matrices one for match, delete and insert
    matchMatrix = new double[net.getNumberMatchStates()+1][sequence.length()];
    insertMatrix = new double[net.getNumberMatchStates()+1][sequence.length()];
    deleteMatrix = new double[net.getNumberMatchStates()+1][sequence.length()];

    //fill Matrices
    initialiseMatrices();
    recurse();
    terminate();
  }


  private void terminate() throws IllegalSymbolException, InvalidStructureException {

    EmissionState end = net.getEndState();
    EmissionState matchjMinus1 = net.getMatchState(net.getNumberMatchStates()-1);

    double emissionProb = calculateLogEmissionProbability(end,sequence.length()-1);
    double matchjMinus1ToEndTrans = getIncomingTransitionProbFrom(end, matchjMinus1);
    double matchScore = matchjMinus1ToEndTrans  + matchMatrix[net.getNumberMatchStates()-1][sequence.length()-2];

    matchMatrix[net.getNumberMatchStates()][sequence.length()-1] = emissionProb + matchScore;
    score = matchMatrix[net.getNumberMatchStates()][sequence.length()-1];

  }

  private void recurse() throws IllegalSymbolException, InvalidStructureException {
    for (int j = 1; j < net.getNumberMatchStates(); j++) {
      for (int i = 0; i < sequence.length()-1; i++) {
	if(i > 0){
	  updateMatchMatrix(j, i);
	}
	if(j < net.getNumberMatchStates()-1 && i < sequence.length()-2 && i > 0){
	  updateInsertMatrix(j, i);
	}
	if(j < net.getNumberMatchStates()-2 && i < sequence.length()-2){
	  updateDeleteMatrix(j, i);
	}
      }
    }
  }

  private void updateDeleteMatrix(int j, int i) throws InvalidStructureException {
    DeleteState deletej = net.getDeleteState(j);
    DeleteState deletejMinus1 = net.getDeleteState(j-1);
    EmissionState matchj = net.getMatchState(j);

    double deletejMinus1ToDeletejTrans = getIncomingTransitionProbFrom(deletej, deletejMinus1);
    double matchjToDeletejTrans = getIncomingTransitionProbFrom(deletej,matchj);
    double matchScore = matchjToDeletejTrans + matchMatrix[j][i];
    double deleteScore = deletejMinus1ToDeletejTrans + deleteMatrix[j-1][i];

    deleteMatrix[j][i] = logplus(matchScore, deleteScore);
  }

  private void updateInsertMatrix(int j, int i) throws IllegalSymbolException, InvalidStructureException {
    EmissionState insertj = net.getInsertState(j);
    EmissionState matchj = net.getMatchState(j);

    double emissionProb = calculateLogEmissionProbability(insertj,i);
    double insertjToInsertjTrans = getIncomingTransitionProbFrom(insertj, insertj);
    double matchjToInsertjTrans = getIncomingTransitionProbFrom(insertj, matchj);
    double matchScore = matchjToInsertjTrans + matchMatrix[j][i-1];
    double insertScore = insertjToInsertjTrans + insertMatrix[j][i-1];

    insertMatrix[j][i] = emissionProb + logplus(matchScore, insertScore);
  }

  private void updateMatchMatrix(int j, int i) throws IllegalSymbolException, InvalidStructureException {
    EmissionState matchj = net.getMatchState(j);
    EmissionState matchjMinus1 = net.getMatchState(j-1);
    EmissionState insertjMinus1 = net.getInsertState(j-1);

    double emissionProb = calculateLogEmissionProbability(matchj,i);
    double matchjMinus1ToMatchjTrans = getIncomingTransitionProbFrom(matchj, matchjMinus1);
    double insertjMinus1ToMatchjTrans = getIncomingTransitionProbFrom(matchj, insertjMinus1);
    double matchScore = matchjMinus1ToMatchjTrans + matchMatrix[j-1][i-1];
    double insertScore = insertjMinus1ToMatchjTrans + insertMatrix[j-1][i-1];
    double deleteScore = Double.NEGATIVE_INFINITY;;
    if(j>1){
      DeleteState deletejMinus2 = net.getDeleteState(j-2);
      double deletejMinus2ToMatchjTrans = getIncomingTransitionProbFrom(matchj, deletejMinus2);
      deleteScore = deletejMinus2ToMatchjTrans + deleteMatrix[j-2][i-1];

      double logApproxmitionHelper = Double.NEGATIVE_INFINITY;
      logApproxmitionHelper = logplus(deleteScore,matchScore);
      logApproxmitionHelper = logplus(logApproxmitionHelper,insertScore);
      matchMatrix[j][i] = emissionProb + logApproxmitionHelper;
    }
    else{
      matchMatrix[j][i] = emissionProb + logplus(matchScore, insertScore);
    }
  }

  private void initialiseMatrices() throws IllegalSymbolException, InvalidStructureException {
    for(int i = 0; i < net.getNumberMatchStates()+1; i++){
      for(int j = 0; j < sequence.length(); j++){
	matchMatrix[i][j]= Double.NEGATIVE_INFINITY;
	deleteMatrix[i][j]= Double.NEGATIVE_INFINITY;
	insertMatrix[i][j]= Double.NEGATIVE_INFINITY;
      }
    }

    //match
    EmissionState match0 = net.getMatchState(0);
    double emissionProb = calculateLogEmissionProbability(match0,0);
    matchMatrix[0][0] = emissionProb;

    //insert
    EmissionState insert0 = net.getInsertState(0);
    double matchToInsertTrans = getIncomingTransitionProbFrom(insert0, match0);
    double insertToInsertTrans = getIncomingTransitionProbFrom(insert0, insert0);
    for (int i = 1; i < sequence.length()-1; i++) {
      emissionProb = calculateLogEmissionProbability(insert0,i);
      double fromMatch = matchToInsertTrans + matchMatrix[0][i-1];
      double fromInsert = insertToInsertTrans + insertMatrix[0][i-1];

      insertMatrix[0][i] = emissionProb + logplus(fromMatch,fromInsert);
    }


    //delete
    DeleteState delete0 = net.getDeleteState(0);
    double matchToDeleteTrans = getIncomingTransitionProbFrom(delete0, match0);
    deleteMatrix[0][0] = matchToDeleteTrans + matchMatrix[0][0];

  }
}
