/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    ProfileHMM.java
 *    Copyright (C) 2010 Stefan Mutter
 *
 */
package weka.classifiers.sequence.core;

import java.io.Serializable;

import weka.core.SelectedTag;
import weka.core.Tag;

/**
<!-- globalinfo-start -->
* class representing a PHMM with uniform tranisition probabilities initially. Note support finished for uniform initialisation of emission probabilities, not for random initialisation.
<!-- globalinfo-end -->
*
* @author Stefan Mutter (pHMM4weka@gmail.com)
* @version $Revision: 6 $
*/
public class ProfileHMM implements Serializable{

  private static final long serialVersionUID = -8530352277246148378L;

  //uniform tranisitio probabilities
  private static final double MATCH_TO_MATCH_PROB = 1.0 / 3.0;

  private static final double DELETE_TO_MATCH_PROB = 0.5;

  private static final double INSERT_TO_MATCH_PROB = 0.5;


  protected int numberMatchStates;


  private Alphabet alphabet;

  protected EmissionState[] matchStates;

  protected EmissionState[] insertStates;

  protected DeleteState[] deleteStates;

  protected EmissionState end;

  protected EmissionState begin;

  protected Transition[][] transitions;

  protected NullModel nullModel;

  protected boolean useLogSpace = false;

  //initialisation of emission probabilities
  private static int INIT_UNIFORM = 0;

  private static int INIT_RANDOM = 1;


  /** The filter to apply to the training data */
  public static final Tag [] TAGS_INIT = {
    new Tag(INIT_UNIFORM, "Initialise PHMM uniformly"),
    new Tag(INIT_RANDOM, "Initiliase PHMM randmoly"),
  };

  protected int initType = INIT_UNIFORM;

  public int getNumberMatchStates() {
    return numberMatchStates;
  }

  public void setNumberMatchStates(int numberMatchStates)throws SequenceLengthException {
    if(numberMatchStates < 4){
      throw new SequenceLengthException("Provide at least 4 (match) columns.");
    }
    this.numberMatchStates = numberMatchStates;
  }

  public Alphabet getAlphabet() {
    return alphabet;
  }

  public void setAlphabet(Alphabet alpha) {
    this.alphabet = alpha;
  }

  public boolean isUseLogSpace() {
    return useLogSpace;
  }

  public ProfileHMM(){

  }

  public ProfileHMM (int matchColumns, Alphabet alpha, boolean useNullModel, boolean useLogSpace){
    this.setAlphabet(alpha);
    this.useLogSpace = useLogSpace;
    try {
      this.setNumberMatchStates(matchColumns);
    } catch (SequenceLengthException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
    if(useNullModel){
      this.nullModel = new NullModel(useLogSpace, this.alphabet);
    }
    else{
      this.nullModel = null;
    }
    this.initType = INIT_UNIFORM;

    try {
      buildStates();
      if(!useLogSpace){
	buildTransitions();
      }
      else{
	buildTransitionsLogSpace();
      }
    } catch (InvalidStructureException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
  }

  public ProfileHMM (int matchColumns, Alphabet alpha, boolean useNullModel, boolean useBackgroundNull, String[] sequences, boolean useLogSpace){
    this.setAlphabet(alpha);
    this.useLogSpace = useLogSpace;
    try {
      this.setNumberMatchStates(matchColumns);
    } catch (SequenceLengthException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
    if(useNullModel){
      if(!useBackgroundNull){
      this.nullModel = new NullModel(useLogSpace, this.alphabet);
      }
      else{
	try {
	  this.nullModel = new BackgroundDistNullModel(useLogSpace, this.alphabet, sequences);
	  //System.out.println(nullModel.distribution.toString());
	} catch (IllegalSymbolException e) {
	  System.out.println(e);
	  e.printStackTrace();
	  System.exit(1);
	} catch (NumericStabilityException e) {
	  System.out.println(e);
	  e.printStackTrace();
	  System.exit(1);
	}
      }
    }
    else{
      this.nullModel = null;
    }
    this.initType = INIT_UNIFORM;

    try {
      buildStates();
      if(!useLogSpace){
	buildTransitions();
      }
      else{
	buildTransitionsLogSpace();
      }
    } catch (InvalidStructureException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
  }

  public ProfileHMM (int matchColumns, Alphabet alpha, boolean useNullModel, boolean useLogSpace, int initUniformly){
    this.setAlphabet(alpha);
    this.useLogSpace = useLogSpace;
    try {
      this.setNumberMatchStates(matchColumns);
    } catch (SequenceLengthException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
    if(useNullModel){
      this.nullModel = new NullModel(useLogSpace, this.alphabet);
    }
    else{
      this.nullModel = null;
    }
    this.initType = initUniformly;

    try {
      buildStates();
      if(!useLogSpace){
	buildTransitions();
      }
      else{
	buildTransitionsLogSpace();
      }
    } catch (InvalidStructureException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
  }



  private void buildMatchState(int column){
    if(initType == INIT_UNIFORM){
      matchStates[column] = new EmissionState("M",column,alphabet, new UniformDistribution(alphabet,useLogSpace));
    }
    if(initType == INIT_RANDOM){
      matchStates[column] = new EmissionState("M",column,alphabet, new RandomDistribution(alphabet,useLogSpace,column));
    }
  }

  private void buildInsertState(int column){
    if(initType == INIT_UNIFORM){
      insertStates[column] = new EmissionState("I",column,alphabet, new UniformDistribution(alphabet,useLogSpace));
    }
    if(initType == INIT_RANDOM){
      insertStates[column] = new EmissionState("I",column,alphabet, new RandomDistribution(alphabet,useLogSpace,column));
    }
  }

  protected void buildStates() {
    matchStates = new EmissionState[numberMatchStates];
    insertStates = new EmissionState[numberMatchStates-1];
    deleteStates = new DeleteState[numberMatchStates-2];

    for(int i = 0; i < deleteStates.length; i++){
      deleteStates[i] = new DeleteState("D",i);
    }
    for(int i = 0; i < matchStates.length; i++){
      buildMatchState(i);
      if(i < insertStates.length){
	buildInsertState(i);
      }
    }
    SpecialAlphabet endAlphabet = new SpecialAlphabet();
    end = new EmissionState("E",0,endAlphabet, new UniformDistribution(endAlphabet,useLogSpace));

    BeginAlphabet beginAlphabet = new BeginAlphabet();
    begin = new EmissionState("B",0,beginAlphabet, new UniformDistribution(beginAlphabet,useLogSpace));
  }

  protected void buildTransitions() throws InvalidStructureException{
    connect(begin, matchStates[0], 1.0);
    for (int i = 0; i < matchStates.length-2; i++) {
      connect(matchStates[i], matchStates[i+1], MATCH_TO_MATCH_PROB);
      double transitionProb = (1.0-MATCH_TO_MATCH_PROB)/2.0;
      connect(matchStates[i], insertStates[i], transitionProb);
      connect(matchStates[i], deleteStates[i], transitionProb);
    }

    connect(matchStates[matchStates.length-2], matchStates[matchStates.length-1], MATCH_TO_MATCH_PROB);
    connect(matchStates[matchStates.length-2], insertStates[insertStates.length-1], 1.0-MATCH_TO_MATCH_PROB);

    connect(matchStates[matchStates.length-1], end, 1.0);

    for (int i = 0; i < insertStates.length; i++) {
      connect(insertStates[i], matchStates[i+1], INSERT_TO_MATCH_PROB);
      connect(insertStates[i], insertStates[i], 1.0-INSERT_TO_MATCH_PROB);
    }

    for (int i = 0; i < deleteStates.length-1; i++) {
      connect(deleteStates[i], matchStates[i+2], DELETE_TO_MATCH_PROB);
      connect(deleteStates[i], deleteStates[i+1], 1.0-DELETE_TO_MATCH_PROB);
    }
    connect(deleteStates[deleteStates.length-1], matchStates[matchStates.length-1], 1.0);
  }

  protected void buildTransitionsLogSpace() throws InvalidStructureException {
    connect(begin, matchStates[0], 0.0);
    for (int i = 0; i < matchStates.length-2; i++) {
      connect(matchStates[i], matchStates[i+1], Math.log(MATCH_TO_MATCH_PROB));
      double transitionProb = Math.log((1.0-MATCH_TO_MATCH_PROB)/2.0);
      connect(matchStates[i], insertStates[i], transitionProb);
      connect(matchStates[i], deleteStates[i], transitionProb);
    }

    connect(matchStates[matchStates.length-2], matchStates[matchStates.length-1], Math.log(MATCH_TO_MATCH_PROB));
    connect(matchStates[matchStates.length-2], insertStates[insertStates.length-1], Math.log(1.0-MATCH_TO_MATCH_PROB));

    connect(matchStates[matchStates.length-1], end, 0.0);

    for (int i = 0; i < insertStates.length; i++) {
      connect(insertStates[i], matchStates[i+1], Math.log(INSERT_TO_MATCH_PROB));
      connect(insertStates[i], insertStates[i], Math.log(1.0-INSERT_TO_MATCH_PROB));
    }

    for (int i = 0; i < deleteStates.length-1; i++) {
      connect(deleteStates[i], matchStates[i+2], Math.log(DELETE_TO_MATCH_PROB));
      connect(deleteStates[i], deleteStates[i+1], Math.log(1.0-DELETE_TO_MATCH_PROB));
    }
    connect(deleteStates[deleteStates.length-1], matchStates[matchStates.length-1], 0.0);

  }

  private void connect(State start, State end, double probability) throws InvalidStructureException{
    Transition transition = new BaumWelchTransition(probability, start, end);
    start.addOutgoingTransition(transition);
    end.addIncomingTransition(transition);
  }

  public EmissionState getMatchState(int i) {
    return matchStates[i];
  }

  public EmissionState getEndState() {
    return end;
  }

  public EmissionState getBeginState() {
    return begin;
  }

  public EmissionState getInsertState(int i) {
    return insertStates[i];
  }

  public DeleteState getDeleteState(int i) {
    return deleteStates[i];
  }

  public String toString(){
    StringBuffer output = new StringBuffer();
    output.append(matchStates[0].toString());
    output.append(insertStates[0].toString());
    for(int i = 1;i < numberMatchStates-1;i++){
      output.append(matchStates[i].toString());
      output.append(insertStates[i].toString());
      output.append(deleteStates[i-1].toString());
    }
    output.append(matchStates[numberMatchStates-1].toString());
    output.append(end.toString());
    //output.append(emmissionsToString());
    return output.toString();
  }

  public String emmissionsToString(){
    StringBuffer output = new StringBuffer();
    for(int i = 0;i < numberMatchStates-1;i++){
      output.append(matchStates[i].outputEmissionInformation());
      output.append(insertStates[i].outputEmissionInformation());
    }
    output.append(matchStates[numberMatchStates-1].outputEmissionInformation());
    output.append(end.outputEmissionInformation());
    return output.toString();
  }

  public NullModel getNullModel(){
    return nullModel;
  }

  public void setNullModel(NullModel nullModel) {
    this.nullModel = nullModel;
  }

  public SelectedTag getInitType() {
    return new SelectedTag(initType, TAGS_INIT);
  }

  public void setInitType(SelectedTag newType) {
    if (newType.getTags() == TAGS_INIT) {
      initType = newType.getSelectedTag().getID();
    }
  }

  public static long getSerialVersionUID() {
    return serialVersionUID;
  }

}
