/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    ProbabilityPerStateCalculator.java
 *    Copyright (C) 2010 Stefan Mutter
 *
 */
package weka.classifiers.sequence.core;

import java.util.Vector;

/**
<!-- globalinfo-start -->
* calculates the probability of each state in the PHMM given a sequence using the forward and backward algorithm.
<!-- globalinfo-end -->
* 
* @author Stefan Mutter (pHMM4weka@gmail.com)
* @version $Revision: 4 $
*/
public class ProbabilityPerStateCalculator extends ProfileHMMAlgorithms {

  private static final long serialVersionUID = -2163580005173691767L;

  private int numberOfMatchStates;

  private boolean isLogSpace;

  private String testSequence;

  private ForwardAlgorithm fwd;

  private BackwardAlgorithm bwd;

  private double scoreOfTestSequence;

  private Vector<Double> scores4AllStates;

  
  /**
 * Constructor
 * @param net the PHMM
 * @param sequence
 * @throws IllegalSymbolException
 * @throws InvalidStructureException
 * @throws InvalidViterbiPathException
 */
public ProbabilityPerStateCalculator(ProfileHMM net, String sequence) throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException {
    super(net);
    this.numberOfMatchStates = net.getNumberMatchStates();
    this.isLogSpace = net.isUseLogSpace();
    this.testSequence = sequence;
    fwd = new ForwardAlgorithm(net,testSequence);
    bwd = new BackwardAlgorithm(net,testSequence);
    fwd.calculateForward();
    bwd.calculateBackward();
    this.scoreOfTestSequence = fwd.getScore();
    scores4AllStates = new Vector<Double>();
  }

  /**
 * returns a vector containing the probabilities or their log space equivalent of each state given the sequence from the constructor
 * @return a probability for each state or its log space equivalent
 * @throws ImpossibleStateProbabilityException
 */
public Vector<Double> getScores() throws ImpossibleStateProbabilityException{
    testScores();
    //match (probability of first and last match state is always 1 in PHMM (0 in logSpace))
    for (int k = 1; k < numberOfMatchStates-1; k++) {  
      double sumPerState = Double.NEGATIVE_INFINITY;
      for (int i = 0; i < testSequence.length(); i++) {
	sumPerState = logplus(sumPerState,logsum(fwd.getMatchMatrixPos(k, i), bwd.getMatchMatrixPos(k, i)));
      }
      sumPerState = logsum(sumPerState,(scoreOfTestSequence*(-1.0)));
      if(!isLogSpace){
	sumPerState = Math.exp(sumPerState);
      }
      scores4AllStates.add(sumPerState);
    }
    //insert
    for (int k = 0; k < numberOfMatchStates-1; k++) {  
      double sumPerState = Double.NEGATIVE_INFINITY;
      for (int i = 0; i < testSequence.length(); i++) {
	sumPerState = logplus(sumPerState,logsum(fwd.getInsertMatrixPos(k, i), bwd.getInsertMatrixPos(k, i)));
      }
      sumPerState = logsum(sumPerState,(scoreOfTestSequence*(-1.0)));
      if(!isLogSpace){
	sumPerState = Math.exp(sumPerState);
      }
      scores4AllStates.add(sumPerState);
    }
    //delete
    for (int k = 0; k < numberOfMatchStates-2; k++) {  
      double sumPerState = Double.NEGATIVE_INFINITY;
      for (int i = 0; i < testSequence.length(); i++) {
	sumPerState = logplus(sumPerState,logsum(fwd.getDeleteMatrixPos(k, i), bwd.getDeleteMatrixPos(k, i)));
      }
      sumPerState = logsum(sumPerState,(scoreOfTestSequence*(-1.0)));
      if(!isLogSpace){
	sumPerState = Math.exp(sumPerState);
      }
      scores4AllStates.add(sumPerState);
    }
    /*for(int i = 0; i < scores4AllStates.size(); i++){
      System.out.print(scores4AllStates.get(i)+" ");
    }
    System.out.println();*/
    return scores4AllStates;
  }

  //for testing purposes
  private void testScores() throws ImpossibleStateProbabilityException { 
    double sumPerState = Double.NEGATIVE_INFINITY;
    for (int i = 0; i < testSequence.length(); i++) {
      sumPerState = logplus(sumPerState,logsum(fwd.getMatchMatrixPos(0, i), bwd.getMatchMatrixPos(0, i)));
    }
    sumPerState = logsum(sumPerState,(scoreOfTestSequence*(-1.0)));
    if(sumPerState > 0.00000001)
      throw new ImpossibleStateProbabilityException("Probablity of first match state M0 should be 1, but it is"+Math.exp(sumPerState));
  
    sumPerState = Double.NEGATIVE_INFINITY;
    int lastState = numberOfMatchStates-1;
    for (int i = 0; i < testSequence.length(); i++) {
      sumPerState = logplus(sumPerState,logsum(fwd.getMatchMatrixPos(lastState, i), bwd.getMatchMatrixPos(lastState, i)));
    }
    sumPerState = logsum(sumPerState,(scoreOfTestSequence*(-1.0)));
    if(sumPerState > 0.00000001)
      throw new ImpossibleStateProbabilityException("Probablity of last match state M"+lastState+" should be 1, but it is"+Math.exp(sumPerState));
  }

}
