/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    SufficientEmissionStatistics.java
 *    Copyright (C) 2010 Stefan Mutter
 *
 */
package weka.classifiers.sequence.core;

/**
<!-- globalinfo-start -->
* class calculating the sufficient emission statistics as described in
 * Jaakkola et al. A Discriminative Framework for Detecting Remote Protein Homologies.
 * Journal of Computational Biology, 7 (1-2):95-114,2000 doi: 10.1089/10665270050081405
 * 
 * However, no dirichlet mixture distribution is used
<!-- globalinfo-end -->
* 
* @author Stefan Mutter (pHMM4weka@gmail.com)
* @version $Revision: 4 $
*/
public class SufficientEmissionStatistics extends ProfileHMMAlgorithms{

  private static final long serialVersionUID = 360896614085679006L;

  private boolean includeInserts;

  protected double[][] emissionCountMatrix;

  private double[] sufficientEmissionStats;

  public SufficientEmissionStatistics(ProfileHMM net, boolean includeInserts) {
    super(net);
    this.includeInserts = includeInserts;
    if(includeInserts){
      this.emissionCountMatrix = new double[(2*this.net.getNumberMatchStates())-1][(this.net.getAlphabet()).alphabetSize()];
    }
    else{
      this.emissionCountMatrix = new double[this.net.getNumberMatchStates()][(this.net.getAlphabet()).alphabetSize()];
    }
    for(int i = 0; i < emissionCountMatrix.length; i++){
      for(int j = 0; j < emissionCountMatrix[0].length; j++){
	emissionCountMatrix[i][j] = Double.NEGATIVE_INFINITY;
      }
    }
    sufficientEmissionStats = new double[this.net.getNumberMatchStates() * (this.net.getAlphabet()).alphabetSize()];
  }

  public double[] getStats(String sequence) throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException, NumericStabilityException{
    ForwardAlgorithm fwd; 
    BackwardAlgorithm bwd;
    fwd = new ForwardAlgorithm(net, sequence);
    fwd.calculateForward();
    double probability = fwd.getScore();
    bwd = new BackwardAlgorithm(net, sequence);
    bwd.calculateBackward();

    //    if(includeInserts){
    //      this.emissionCountMatrix = new double[(2*this.net.getNumberMatchStates())-1][(this.net.getAlphabet()).alphabetSize()];
    //    }
    //    else{
    this.emissionCountMatrix = new double[this.net.getNumberMatchStates()][(this.net.getAlphabet()).alphabetSize()];
    // }
    for(int i = 0; i < emissionCountMatrix.length; i++){
      for(int j = 0; j < emissionCountMatrix[0].length; j++){
	emissionCountMatrix[i][j] = Double.NEGATIVE_INFINITY;
      }
    }

    for(int i = 0; i < sequence.length(); i++){ //why sequence.length-2?0
      calculateEmissionFrequency(sequence, i, fwd, bwd, probability);
    }
    double sum = Double.NEGATIVE_INFINITY;
    for(int stateIndex = 0; stateIndex<emissionCountMatrix.length; stateIndex++){
      double[] allEmissions = emissionCountMatrix[stateIndex];

      for (int i = 0; i < allEmissions.length; i++) {
	sum = logplus(sum,allEmissions[i]);
      }
      if (sum == Double.NEGATIVE_INFINITY) {
	throw new NumericStabilityException("Can't normalize array. Sum is negative infinity.");
      }
    }
    sum = sum*(-1);

    int numberSymbols = (this.net.getAlphabet()).alphabetSize();
    for(int column = 0; column < emissionCountMatrix.length; column ++){
      for(int symbol = 0; symbol <numberSymbols; symbol++){
	EmissionState actual = net.getMatchState(column);
	double probSymbolState = actual.getEmissionProbability((net.getAlphabet()).getSymbolAtIndex(symbol));
	emissionCountMatrix[column][symbol] = emissionCountMatrix[column][symbol] - probSymbolState;
      }
    }

    int index = 0;
    for(int column = 0; column < emissionCountMatrix.length; column ++){
      for(int symbol = 0; symbol <numberSymbols; symbol++){
	sufficientEmissionStats[index] = Math.exp(logplus(emissionCountMatrix[column][symbol],sum));
	index++;
      }
    }

    return sufficientEmissionStats;
  }

  private void calculateEmissionFrequency(String sequence, int sequencePosition, ForwardAlgorithm fwd, BackwardAlgorithm bwd, double probabilityForSequence) throws IllegalSymbolException {
    int columns = net.getNumberMatchStates();
    char charAtSeq = sequence.charAt(sequencePosition);
    Alphabet alphabet = net.getAlphabet();

    int alphabetIndex = alphabet.indexOfAlphabetSymbol(charAtSeq+"");
    for(int k = 0; k < columns; k++){
      double forwardContribution = fwd.getMatchMatrixPos(k, sequencePosition);
      double backwardContribution = bwd.getMatchMatrixPos(k, sequencePosition);
      //System.out.println(k+" "+alphabetIndex);
      emissionCountMatrix[k][alphabetIndex] = logplus(emissionCountMatrix[k][alphabetIndex], forwardContribution + backwardContribution - probabilityForSequence );

    }

  }



}
