/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    BackwardAlgorithm.java
 *    Copyright (C) 2010 Stefan Mutter
 *
 */
package weka.classifiers.sequence.core;


/**
<!-- globalinfo-start -->
* class implementing the backward algorithm for PHMMs
<!-- globalinfo-end -->
*
* @author Stefan Mutter (pHMM4weka@gmail.com)
* @version $Revision: 6 $
*/
public class BackwardAlgorithm extends DynamicProgAlgorithms{

  private static final long serialVersionUID = 7601518046923307693L;

  /**
 * Constructor
 * @param net the PHMM under consideration
 * @param sequence the sequence to be evaluated
 */
public BackwardAlgorithm(ProfileHMM net, String sequence){
    super(net,sequence);
  }

  /**
 * calculates the matrices for match, insert and delete states according to backward algorithm
 * @throws IllegalSymbolException
 * @throws InvalidStructureException
 * @throws InvalidViterbiPathException
 */
public void calculateBackward() throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException{
//  create Matrices one for match, delete and insert
    matchMatrix = new double[net.getNumberMatchStates()+1][sequence.length()];
    insertMatrix = new double[net.getNumberMatchStates()+1][sequence.length()];
    deleteMatrix = new double[net.getNumberMatchStates()+1][sequence.length()];

    //fill Matrices
    initialiseMatrices();
    recurse();
    terminate();
  }


  /**
 * calculates the last entry of the match matrix (entry at position 0,0). This entry is the probability (or log likelihood) of the sequence given the PHMM
 * This value will be stored in the score variable
 * @throws IllegalSymbolException
 * @throws InvalidStructureException
 */
private void terminate() throws IllegalSymbolException, InvalidStructureException {

    double emissionProb = calculateLogEmissionProbability(net.getMatchState(0), 0);
    score = matchMatrix[0][0] + emissionProb;

  }

  /**
 * recursively fills the backward matrices
 * @throws IllegalSymbolException
 * @throws InvalidStructureException
 */
private void recurse() throws IllegalSymbolException, InvalidStructureException {

    for (int j = net.getNumberMatchStates()-2; j >= 0; j--) {
      for (int i = sequence.length()-3; i >=0; i--) {
	if(j <= net.getNumberMatchStates()-4){
	  updateDeleteMatrix(j, i);
	}
	updateInsertMatrix(j, i);
	updateMatchMatrix(j, i);
      }
    }
  }

  private void updateDeleteMatrix(int j, int i) throws InvalidStructureException, IllegalSymbolException {
    DeleteState deletej = net.getDeleteState(j);
    DeleteState deletejPlus1 = net.getDeleteState(j+1);
    EmissionState matchjPlus2 = net.getMatchState(j+2);

    double deletejToMatchjPlus2 = getIncomingTransitionProbTo(deletej, matchjPlus2);
    double deletejToDeletejPlus1 = getIncomingTransitionProbTo(deletej, deletejPlus1);
    double emissionProb = calculateLogEmissionProbability(matchjPlus2,i+1);
    double matchScore = matchMatrix[j+2][i+1];
    double deleteScore = deleteMatrix[j+1][i];

    deleteMatrix[j][i] = logplus((deletejToMatchjPlus2 + matchScore + emissionProb), (deletejToDeletejPlus1 + deleteScore));
  }

  private void updateInsertMatrix(int j, int i) throws InvalidStructureException, IllegalSymbolException {
    EmissionState insertj = net.getInsertState(j);
    EmissionState matchjPlus1 = net.getMatchState(j+1);

    double insertjToMatchPlus1 = getIncomingTransitionProbTo(insertj, matchjPlus1);
    double insertjToInsertj = getIncomingTransitionProbTo(insertj, insertj);

    double matchEmissionProb = calculateLogEmissionProbability(matchjPlus1, i+1);
    double insertEmissionProb = calculateLogEmissionProbability(insertj, i+1);

    double matchScore = matchMatrix[j+1][i+1];
    double insertScore = insertMatrix[j][i+1];

    insertMatrix[j][i] = logplus((insertjToMatchPlus1 + matchScore + matchEmissionProb), (insertjToInsertj + insertScore + insertEmissionProb));
  }

  private void updateMatchMatrix(int j, int i) throws InvalidStructureException, IllegalSymbolException {
    EmissionState matchj = net.getMatchState(j);
    EmissionState matchjPlus1 = net.getMatchState(j+1);
    EmissionState insertj = net.getInsertState(j);

    double matchjToMatchjPlus1 = getIncomingTransitionProbTo(matchj, matchjPlus1);
    double matchjToInsertj = getIncomingTransitionProbTo(matchj, insertj);

    double matchEmissionProb = calculateLogEmissionProbability(matchjPlus1, i+1);
    double insertEmissionProb = calculateLogEmissionProbability(insertj, i+1);

    double matchScore = matchMatrix[j+1][i+1];
    double insertScore = insertMatrix[j][i+1];


    if(j <= net.getNumberMatchStates()-3){
      DeleteState deletej = net.getDeleteState(j);
      double matchjToDeletej = getIncomingTransitionProbTo(matchj, deletej);
      double deleteScore = deleteMatrix[j][i];

      double logApproxmitionHelper = Double.NEGATIVE_INFINITY;
      logApproxmitionHelper = logplus(logApproxmitionHelper,(matchjToMatchjPlus1 + matchScore + matchEmissionProb));
      logApproxmitionHelper = logplus(logApproxmitionHelper,(matchjToInsertj + insertScore + insertEmissionProb));
      logApproxmitionHelper = logplus(logApproxmitionHelper,(matchjToDeletej + deleteScore));
      matchMatrix[j][i] = logApproxmitionHelper;
    }
    else{
      matchMatrix[j][i] = logplus((matchjToMatchjPlus1 + matchScore + matchEmissionProb), (matchjToInsertj + insertScore + insertEmissionProb));
    }
  }

  /**
 * initialisation of the backward matrices for match, insert and delete states (one matrix per type of state).
 * @throws IllegalSymbolException
 * @throws InvalidStructureException
 */
private void initialiseMatrices() throws IllegalSymbolException, InvalidStructureException {
    for(int i = 0; i < net.getNumberMatchStates()+1; i++){
      for(int j = 0; j < sequence.length(); j++){
	matchMatrix[i][j]= Double.NEGATIVE_INFINITY;
	deleteMatrix[i][j]= Double.NEGATIVE_INFINITY;
	insertMatrix[i][j]= Double.NEGATIVE_INFINITY;
      }
    }

    //match
    matchMatrix[net.getNumberMatchStates()][sequence.length()-1] = 0.0;//Math.log(matchjMinus1ToEndTrans);
    matchMatrix[net.getNumberMatchStates()-1][sequence.length()-2] = 0.0;//Math.log(matchjMinus1ToEndTrans * matchScore);


    //delete
    EmissionState matchjPlus2 = net.getMatchState(net.getNumberMatchStates()-1);
    DeleteState deletej = net.getDeleteState(net.getNumberMatchStates()-3);
    double deletejToMatchjPlus2 = getIncomingTransitionProbTo(deletej,matchjPlus2);
    double emissionProb = calculateLogEmissionProbability(matchjPlus2,sequence.length()-2);
    double matchScore = matchMatrix[net.getNumberMatchStates()-1][sequence.length()-2];

    deleteMatrix[net.getNumberMatchStates()-3][sequence.length()-3] = deletejToMatchjPlus2 + emissionProb + matchScore;

  }

}
