/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    IterativeProfileHMMClassifierSingleHMM.java
 *    Copyright (C) 2010 Stefan Mutter
 *
 */
package weka.classifiers.sequence;

import java.util.Enumeration;
import java.util.List;
import java.util.Random;
import java.util.Vector;

import weka.classifiers.Evaluation;
import weka.classifiers.functions.Logistic;
import weka.classifiers.sequence.core.Alphabet;
import weka.classifiers.sequence.core.BaumWelchLearner;
import weka.classifiers.sequence.core.ForwardAlgorithm;
import weka.classifiers.sequence.core.IllegalSymbolException;
import weka.classifiers.sequence.core.ImpossibleStateProbabilityException;
import weka.classifiers.sequence.core.InvalidStructureException;
import weka.classifiers.sequence.core.InvalidViterbiPathException;
import weka.classifiers.sequence.core.NumericStabilityException;
import weka.classifiers.sequence.core.ProbabilityPerStateCalculator;
import weka.classifiers.sequence.core.ProfileHMM;
import weka.classifiers.sequence.core.SufficientEmissionStatistics;
import weka.classifiers.sequence.core.ViterbiAlgorithm;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.instance.RemoveWithValues;

/**
<!-- globalinfo-start -->
* a one-class PHMM classifier for the positive class. The dataset needs to binary and the positive class needs to have index 0.
* <p/>
<!-- globalinfo-end -->
*
*
* @author Stefan Mutter (pHMM4weka@gmail.com)
* @version $Revision: 6 $
*/
public class IterativeProfileHMMClassifierSingleHMM extends IterativeProfileHMMClassifier{

	private static final long serialVersionUID = 3134763907535654753L;

	private ProfileHMM profileHMM;

	private double logLikelihoodOfPHMM;

	private double initiallogLikelihoodOfPHMM;

	private String[] allTrainingSequences;

	private double oldLogLikelihoodOfPHMM;

	private int numIterationPerClass;

	private Instances filtered;

	private Logistic logistic;

	private Instances preFiltered;

	private double[] trainingClassDistribution;

	private List<Double> viterbiScore;

	private List<Double> fwdScore;

	private List<List<Double>> allScore;

	private int iteration;


	public IterativeProfileHMMClassifierSingleHMM() {
		super();
		profileHMM = null;
		this.logLikelihoodThreshold = 0.0001; //personal communication Prof A Krogh
		this.useNullModel = false;
		this.transitionsEmissionsNotInLog = false;
		this.learnInsertEmissions = false;
		setBaumWelchOption(2);
		this.viterbiProb = false;
		this.fwdProb = false;
		this.allProb = false;
		this.allProbOnly = false;
		this.noBasic = false;
		this. noPathLogScores = false;
		this.logLikelihoodOfPHMM = Double.NEGATIVE_INFINITY;
		this.initiallogLikelihoodOfPHMM = Double.NEGATIVE_INFINITY;
		this.converged = new Vector<Boolean>();
		this.oldLogLikelihoodOfPHMM = Double.NEGATIVE_INFINITY;
		this.numIterationPerClass = 0;
		this.positiveClassIndex = 0;
		this.filtered = null;
	}


	public Instances propositionalise(Instances oldInstances) throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException, ImpossibleStateProbabilityException {

		Instances transformed = createPropositionalisedInstancesFormat(oldInstances);

		for(int j = 0;j < oldInstances.numInstances();j++){
			viterbiScore = new Vector<Double>();
			fwdScore = new Vector<Double>();
			allScore = new Vector<List<Double>>();
			Instance oldInstance = oldInstances.instance(j);
			String[] alignment = doAlignmentForPropositionalisation(oldInstance);
			transformed = doPropositionalisation(alignment, transformed, oldInstance.classValue());
		}
		return transformed;
	}

	public Instances extractSufficientStatistics(Instances oldInstances) throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException, NumericStabilityException{
		String[] sequences = new String[oldInstances.numInstances()];
		for(int k =0;k<oldInstances.numInstances();k++){
			String sequence = oldInstances.instance(k).stringValue(oldInstances.instance(k).attribute(sequenceIndex));
			if(this.getRestrictSequenceLength() != -1 && sequence.length()>this.getRestrictSequenceLength()){
				sequence = sequence.substring(0,this.getRestrictSequenceLength());
			}
			sequences[k]=sequence;
		}

		double[][] allAttributeValues = getSufficientStats(false,sequences);
		Instances transformed = createSufficientStatsFormat(oldInstances, allAttributeValues[1].length);
		for(int i = 0; i < oldInstances.numInstances(); i++){
			Instance newInst = new DenseInstance(transformed.numAttributes());
			newInst.setDataset(transformed);
			newInst.setClassValue(oldInstances.instance(i).classValue());
			for(int j = 0 ; j < allAttributeValues[i].length; j++){
				newInst.setValue(j, allAttributeValues[i][j]);
			}
			transformed.add(newInst);
		}

		//System.out.println(transformed);

		return transformed;
	}


	private Instances createSufficientStatsFormat(Instances oldInstances, int numNonClassAttributes) {

		FastVector attInfo = new FastVector(numNonClassAttributes+1);
		FastVector my_nominal_class_values = new FastVector();
		for (int i = 0; i < oldInstances.numClasses(); i++) {
			my_nominal_class_values.addElement(oldInstances.classAttribute().value(i));
		}

		Attribute tempAttribute;
		for(int j = 0; j < numNonClassAttributes; j++){
			tempAttribute = new Attribute("SuffStats_"+j);
			attInfo.addElement(tempAttribute);
		}

		tempAttribute = new Attribute("class",my_nominal_class_values);
		attInfo.addElement(tempAttribute);
		Instances transformed = new Instances(oldInstances.relationName()+"_sufficientStatsSingleHMM",attInfo,1);
		transformed.setClassIndex(numNonClassAttributes);

		return transformed;
	}


	public Instance propositionaliseTestInstance(Instance oldInstance) throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException, ImpossibleStateProbabilityException {

		Instances oldInstances = oldInstance.dataset();
		Instances transformed = createPropositionalisedInstancesFormat(oldInstances);
		viterbiScore = new Vector<Double>();
		fwdScore = new Vector<Double>();
		allScore = new Vector<List<Double>>();
		String[] alignment = doAlignmentForPropositionalisation(oldInstance);

		transformed = doPropositionalisation(alignment, transformed, oldInstance.classValue());
		//System.out.println(transformed.firstInstance());
		return transformed.firstInstance();
	}

	public Instance extractSufficientStatisticsTestInstance(Instance oldInstance) throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException, NumericStabilityException{
		Instances oldInstances = oldInstance.dataset();

		String[] sequences = new String[1];
		String sequence = oldInstance.stringValue(oldInstance.attribute(sequenceIndex));
		if(this.getRestrictSequenceLength() != -1 && sequence.length()>this.getRestrictSequenceLength()){
			sequence = sequence.substring(0,this.getRestrictSequenceLength());
		}
		sequences[0]=sequence;
		double[][] allAttributeValues = getSufficientStats(false,sequences);

		Instances transformed = createSufficientStatsFormat(oldInstances, allAttributeValues[0].length);
		Instance newInst = new DenseInstance(transformed.numAttributes());
		newInst.setDataset(transformed);
		newInst.setClassValue(oldInstance.classValue());
		for(int j = 0 ; j < allAttributeValues[0].length; j++){
			newInst.setValue(j, allAttributeValues[0][j]);
		}
		transformed.add(newInst);
		return transformed.firstInstance();
	}


	private String[] doAlignmentForPropositionalisation(Instance oldInstance) throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException, ImpossibleStateProbabilityException {
		ProfileHMM hmm;
		String[] alignment = new String[1];
		String test = "";
		test += oldInstance.stringValue(oldInstance.attribute(sequenceIndex));
		if(this.getRestrictSequenceLength() != -1 && test.length()>this.getRestrictSequenceLength())
			test = test.substring(0,this.getRestrictSequenceLength());
		//for (int i = 0; i < allHMMs.size(); i++) {
		//hmm = allHMMs.get(i);
		if(profileHMM != null){
			if(isViterbiProb() || !isNoBasic()){
				ViterbiAlgorithm vitAlg= new ViterbiAlgorithm(profileHMM,test);
				alignment[0] = vitAlg.calculateViterbiPath();
				if(noBasic){
					alignment[0] = "";
				}
				viterbiScore.add(0,vitAlg.getScore());
				vitAlg = null;
			}
			else{
				alignment[0] = "";
			}
			if(isFwdProb()){
				ForwardAlgorithm fwd = new ForwardAlgorithm(profileHMM,test);
				fwd.calculateForward();
				fwdScore.add(0,fwd.getScore());
				fwd = null;
			}
			if(isAllProb()){
				ProbabilityPerStateCalculator calc = new ProbabilityPerStateCalculator(profileHMM,test);
				allScore.add(0,calc.getScores());
				calc = null;
			}
		}
		else{
			alignment[0] = "";
			viterbiScore.add(0,null);
			fwdScore.add(0,null);
			allScore.add(0,null);
		}
		//}
		if(noPathLogScores){
			double[] path;
			int viterbiSize = viterbiScore.size();
			int fwdSize = fwdScore.size();

			if(viterbiSize > 0){
				path = new double [viterbiSize];
				for(int i = 0; i < viterbiSize; i++){
					path[i] = viterbiScore.get(i);
					//System.out.println("log viterbi for HMM "+i+" : "+viterbiScore.get(i));
				}
				path = Utils.logs2probs(path);
				Utils.normalize(path);
				for(int i = 0; i < viterbiSize; i++){
					viterbiScore.set(i, path[i]);
					//System.out.println("viterbi for HMM "+i+" : "+viterbiScore.get(i));
				}
				path = null;
			}
			if(fwdSize > 0){
				path = new double [fwdSize];
				for(int i = 0; i < fwdSize; i++){
					path[i] = fwdScore.get(i);
					//System.out.println("log fwd for HMM "+i+" : "+fwdScore.get(i));
				}
				path = Utils.logs2probs(path);
				Utils.normalize(path);
				for(int i = 0; i < fwdSize; i++){
					fwdScore.set(i, path[i]);
					//System.out.println("fwd for HMM "+i+" : "+fwdScore.get(i));
				}
				path = null;
				//}
			}
		}
		return alignment;
	}

	protected Instances doPropositionalisation(String[] allAlignment, Instances transformed, double classValue) throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException {

		Instance newInst = new DenseInstance(transformed.numAttributes());
		newInst.setDataset(transformed);
		newInst.setClassValue(classValue);

		int attributeCounter = 0;

		for(int k = 0; k < allAlignment.length; k++){
			String alignment = allAlignment[k];
			for(int i = 0; i < alignment.length(); i++){
				if(alignment.charAt(i)== '-'){
					newInst.setValue(attributeCounter, "-");
					attributeCounter++;
					newInst.setValue(attributeCounter,0);
					attributeCounter++;
				}
				else{
					if(Character.isUpperCase(alignment.charAt(i))){
						newInst.setValue(attributeCounter, alignment.charAt(i)+"");
						attributeCounter++;
					}
					if(i < alignment.length()-1 && Character.isLowerCase(alignment.charAt(i+1))){
						i++;
						int count = 0;
						while(i < alignment.length() && Character.isLowerCase(alignment.charAt(i))){
							i++;
							count++;
						}
						i--;
						newInst.setValue(attributeCounter,count);
						attributeCounter++;
					}
					else{
						if(i < alignment.length()-1){
							newInst.setValue(attributeCounter,0);
							attributeCounter++;
						}
					}
				}
			}
			if(isViterbiProb()){
				newInst.setValue(attributeCounter, viterbiScore.get(k));
				attributeCounter++;
			}
			if(isFwdProb()){
				newInst.setValue(attributeCounter, fwdScore.get(k));
				attributeCounter++;
			}
			if(isAllProb()){
				List<Double> scores = allScore.get(k);
				double[] probs = new double[scores.size()];
				if(isAllProbOnly()){
					for(int n = 0; n< scores.size(); n++){
						probs[n] = scores.get(n);
					}
					probs = Utils.logs2probs(probs);
					Utils.normalize(probs);
				}
				for(int j = 0; j < scores.size(); j++){
					if(!isAllProbOnly()){
						newInst.setValue(attributeCounter, scores.get(j));
					}
					else{
						newInst.setValue(attributeCounter, probs[j]);
					}
					attributeCounter++;
				}
			}
		}

		transformed.add(newInst);
		newInst = null;
		return transformed;
	}

	public Instances createPropositionalisedInstancesFormat(Instances oldInstances){

		//ProfileHMM hmm;
		Alphabet usedAlphabet = null;
		int numNonClassAttributes = 0;

		//for (int i = 0; i < allHMMs.size(); i++) {
		//hmm = allHMMs.get(i);
		if(profileHMM != null){
			if(usedAlphabet == null){
				usedAlphabet = profileHMM.getAlphabet();
			}
			if(!noBasic){
				numNonClassAttributes += (profileHMM.getNumberMatchStates()*2)-1;
			}
			if(isViterbiProb())
				numNonClassAttributes++;
			if(isFwdProb())
				numNonClassAttributes++;
			if(isAllProb())
				numNonClassAttributes += (profileHMM.getNumberMatchStates()*3)-5;
		}
		//}
		FastVector attInfo = new FastVector(numNonClassAttributes+1);
		FastVector my_nominal_values = new FastVector();
		if(!noBasic){
			for(int i = 0; i < usedAlphabet.alphabetSize();i++){
				my_nominal_values.addElement(usedAlphabet.getSymbolAtIndex(i));
			}
			my_nominal_values.addElement("-");
		}
		FastVector my_nominal_class_values = new FastVector();
		for (int i = 0; i < oldInstances.numClasses(); i++) {
			my_nominal_class_values.addElement(oldInstances.classAttribute().value(i));
		}

		int counter = 0;
		//int index = 0;
		String tempAttributeName = "";
		Attribute tempAttribute;
		//for (int i = 0; i < allHMMs.size(); i++) {
		//hmm = allHMMs.get(i);
		if(!noBasic){
			int numberOfAttributes = 0;
			if(profileHMM != null){
				numberOfAttributes = (profileHMM.getNumberMatchStates()*2)-1;
			}
			for(int j = 0; j < numberOfAttributes; j++){
				if(j%2==0){
					if(j == 0 || j == numberOfAttributes -1){
						tempAttributeName = "Match"+counter;
					}
					else{
						tempAttributeName = "Match_Delete"+counter;
					}
					tempAttribute = new Attribute(tempAttributeName,my_nominal_values);
					attInfo.addElement(tempAttribute);
				}
				else{
					tempAttributeName = "Insert"+counter;
					tempAttribute = new Attribute(tempAttributeName);
					attInfo.addElement(tempAttribute);
					counter++;
				}
			}
		}
		if(isViterbiProb()){
			tempAttribute = new Attribute("ViterbiScore");
			attInfo.addElement(tempAttribute);
		}
		if(isFwdProb()){
			tempAttribute = new Attribute("ForwardScore");
			attInfo.addElement(tempAttribute);
		}
		if(isAllProb()){
			for(int j = 1; j < profileHMM.getNumberMatchStates()-1; j++){
				tempAttribute = new Attribute("Score4Match"+j);
				attInfo.addElement(tempAttribute);
			}
			for(int j = 0; j < profileHMM.getNumberMatchStates()-1; j++){
				tempAttribute = new Attribute("Score4Insert"+j);
				attInfo.addElement(tempAttribute);
			}
			for(int j = 0; j < profileHMM.getNumberMatchStates()-2; j++){
				tempAttribute = new Attribute("Score4Delete"+j);
				attInfo.addElement(tempAttribute);
			}
		}


		counter = 0;


		tempAttribute = new Attribute("class",my_nominal_class_values);
		attInfo.addElement(tempAttribute);
		Instances transformed = new Instances(oldInstances.relationName()+"_propositionalizedSingleHMM",attInfo,1);
		transformed.setClassIndex(numNonClassAttributes);

		return transformed;
	}

	public void buildClassifier(Instances data) throws Exception {
		initClassifier(data);
		int iteration = 1;
		do{
			next(iteration);
			iteration++;
		}while(!this.fullyConverged());

	}


	public double classifyInstance(Instance instance) throws Exception {
		double result;
		double [] dist;
		int index;
		dist = distributionForInstance(instance);
		index = Utils.maxIndex(dist);
		if (dist[index] == 0){
			result = Utils.missingValue();
		}
		else{
			result = index;
		}
		return result;
	}



	public double [] distributionForInstance(Instance instance) throws Exception {
		if(positiveClassIndex < 0){
			throw new Exception("no pure PHMM classification possible");
		}
		double[] distribution = new double[2];
		if(profileHMM == null){
			for(int i = 0; i < distribution.length; i++){
				distribution[i] = 0.0;
			}
			return distribution;
		}


		String testSequence = instance.stringValue(instance.attribute(sequenceIndex));
		if(this.getRestrictSequenceLength() != -1 && testSequence.length() > this.getRestrictSequenceLength()){
			testSequence = testSequence.substring(0,this.getRestrictSequenceLength());
		}
		ForwardAlgorithm fwd = new ForwardAlgorithm(profileHMM,testSequence);
		fwd.calculateForward();


		FastVector attInfo = new FastVector(2);
		FastVector classInfo = new FastVector(2);
		for (int i = 0; i < preFiltered.numClasses(); i++) {
			classInfo.addElement(preFiltered.classAttribute().value(i));
			if(i == 2){
				throw new Exception("Only on binary datasets");
			}
		}
		Attribute tempAttribute1 = new Attribute("score");
		attInfo.addElement(tempAttribute1);
		Attribute tempAttribute2 = new Attribute("class",classInfo);
		attInfo.addElement(tempAttribute2);
		Instances instancesLogistic = new Instances("logistic", attInfo, 1);
		instancesLogistic.setClassIndex(1);
		Instance newInst = new DenseInstance(instancesLogistic.numAttributes());
		newInst.setDataset(instancesLogistic);
		newInst.setClassValue(Utils.missingValue());
		newInst.setValue(0, fwd.getScore());

		return logistic.distributionForInstance(newInst);

	}

	private void calibrateClassificationClassifier(Instances data) throws Exception{
		if(positiveClassIndex < 0)
			return;
		logistic = new Logistic();
		FastVector attInfo = new FastVector(2);
		FastVector classInfo = new FastVector(2);
		for (int i = 0; i < preFiltered.numClasses(); i++) {
			classInfo.addElement(preFiltered.classAttribute().value(i));
			if(i == 2){
				throw new Exception("Calibration method only works on binary datasets");
			}
		}
		Attribute tempAttribute1 = new Attribute("score");
		attInfo.addElement(tempAttribute1);
		Attribute tempAttribute2 = new Attribute("class",classInfo);
		attInfo.addElement(tempAttribute2);
		Instances instancesLogistic = new Instances("logistic", attInfo, 1);
		instancesLogistic.setClassIndex(1);


		for(int i = 0; i < data.numInstances(); i++){
			String sequence = (data.instance(i)).stringValue((data.instance(i)).attribute(sequenceIndex));
			ForwardAlgorithm fwd = new ForwardAlgorithm(profileHMM,sequence);
			fwd.calculateForward();
			Instance newInst = new DenseInstance(instancesLogistic.numAttributes());
			newInst.setDataset(instancesLogistic);
			newInst.setClassValue(positiveClassIndex);//newInst.setClassValue(data.instance(i).classValue());
			newInst.setValue(0, fwd.getScore());
			instancesLogistic.add(newInst);
			newInst = null;
			fwd = null;
		}
		for(int i = 0; i < data.numInstances(); i++){
			String sequence = pertubate((data.instance(i)).stringValue((data.instance(i)).attribute(sequenceIndex)), 1);
			ForwardAlgorithm fwd = new ForwardAlgorithm(profileHMM,sequence);
			fwd.calculateForward();
			Instance newInst = new DenseInstance(instancesLogistic.numAttributes());
			newInst.setDataset(instancesLogistic);
			double negativeClassIndex = 0.0;
			if(positiveClassIndex == 0){
				negativeClassIndex = 1.0;
			}
			newInst.setClassValue(negativeClassIndex);
			newInst.setValue(0, fwd.getScore());
			instancesLogistic.add(newInst);
			newInst = null;
			fwd = null;
		}

		logistic.buildClassifier(instancesLogistic);
	}

	private static String pertubate(String realSequence, int seed){
		String mixedUpSequence = "";
		int index;
		Random randGen = new Random(seed);
		boolean letter[] = new boolean[realSequence.length()];
		for(int i = 0; i < letter.length; i++){
			letter[i] = false;
		}
		do {
			index = randGen.nextInt(realSequence.length());
			if (letter[index] == false) {
				mixedUpSequence = mixedUpSequence + realSequence.charAt(index);
				letter[index] = true;
			}
		} while (mixedUpSequence.length() < realSequence.length());
		return mixedUpSequence;
	}



	private int determineColumns(Instances inst) {
		int average = 0;
		for(int i = 0;i< inst.numInstances();i++){

			average += (((inst.instance(i).stringValue(inst.instance(i).attribute(sequenceIndex))).length()));
		}

		return Math.round(average/inst.numInstances());
	}

	/**
	 * Returns an enumeration describing the available options.
	 *
	 * @return an enumeration of all the available options.
	 */
	public Enumeration listOptions() {
		Enumeration 	enm;
		Vector		result;

		result = new Vector();

		enm = super.listOptions();
		while (enm.hasMoreElements())
			result.addElement(enm.nextElement());

		return result.elements();
	}

	/**
	 * Gets the current settings.
	 *
	 * @return an array of strings suitable for passing to setOptions()
	 */
	public String [] getOptions() {
		int       	i;
		Vector    	result;
		String[]  	options;

		result = new Vector();

		options = super.getOptions();
		for (i = 0; i < options.length; i++)
			result.add(options[i]);

		return (String[]) result.toArray(new String[result.size()]);
	}

	/**
	 * Parses a given list of options. <p/>
	 *
	 * @param options the list of options as an array of strings
	 * @throws Exception if an option is not supported
	 */
	public void setOptions(String[] options) throws Exception {

		super.setOptions(options);

	}

	public String toString() {

		if(profileHMM == null){
			return "No model built yet.";
		}

		String result = "";

		result += "\nOne class classifier ProfileHMM:\n"
			+"  model for class:    \t"+positiveClassIndex
			+"\n  match states:    \t"+profileHMM.getNumberMatchStates()
			+"\n  iterations in BW:\t"+numIterationPerClass
			+"\n  final score:    \t"+logLikelihoodOfPHMM
			+"\n  initial score:\t"+initiallogLikelihoodOfPHMM+"\n";


		return result;
	}

	/**
	 * Main method for testing this class.
	 *
	 * @param argv the options
	 */
	public static void main(String [] argv) {

		try {
			System.out.println(Evaluation.evaluateModel(new IterativeProfileHMMClassifierSingleHMM(), argv));
		} catch (Exception e) {
			//System.err.println(e.getMessage());
			e.printStackTrace();
		}
	}


	public String getRevision() {
		return "1.0";
	}


	public void initClassifier(Instances data) throws Exception {
		//  can classifier handle the data?
		getCapabilities().testWithFail(data);
		//data = new Instances(data);
		data.deleteWithMissingClass();


		if(positiveClassIndex >= 0){

			preFiltered = new Instances(data);
			RemoveWithValues filter  = new RemoveWithValues();
			String[] options = new String [5];
			options[0] = "-L";
			options[1] = (positiveClassIndex+1)+"";
			options[2] = "-C";
			options[3] = (data.classIndex()+1)+"";
			options[4] = "-V";
			filter.setInputFormat(data);
			filter.setOptions(options);
			filtered = Filter.useFilter(data, filter);
			data = new Instances(filtered);
			data.setRelationName(preFiltered.relationName());
			//System.out.println(data);
		}

		//System.out.println(data);

		if(data.numAttributes() > 2)
			throw new Exception ("Dataset has to consist of exactly one string attribute and a class attribute");
		if(data.classIndex() == 1)
			sequenceIndex = 0;
		else
			sequenceIndex = 1;



		iteration = 0;

		int numInstances = data.numInstances();
		trainingClassDistribution = new double[data.numClasses()];
		for(int i = 0; i < data.numClasses(); i ++){
			trainingClassDistribution[i] = 0.0;
		}
		for(int i = 0; i < numInstances; i++){
			trainingClassDistribution[(int)data.instance(i).classValue()]++;
		}
		Utils.normalize(trainingClassDistribution);


		//System.out.println(backDist);

		if(numInstances == 0){
			profileHMM = null;
			converged.add(0,true);
			numIterationPerClass = -1;
		}
		else{
			allTrainingSequences = new String[data.numInstances()];
			for(int k =0;k<data.numInstances();k++){
				String sequence = data.instance(k).stringValue(data.instance(k).attribute(sequenceIndex));
				if(this.getRestrictSequenceLength() != -1 && sequence.length()>this.getRestrictSequenceLength()){
					sequence = sequence.substring(0,this.getRestrictSequenceLength());
				}
				allTrainingSequences[k]=sequence;
			}
			if(backDist == UniformBackgroundDist){
				if(this.getRestrictSequenceLength() != -1){
					profileHMM = new ProfileHMM(this.getRestrictSequenceLength(),this.getAlphabet(),useNullModel,!transitionsEmissionsNotInLog);
				}
				else{
					if(this.getRestrictMatchColumns() != -1){
						profileHMM = new ProfileHMM(this.getRestrictMatchColumns(),this.getAlphabet(),useNullModel, !transitionsEmissionsNotInLog);
					}
					else{
						int matchColumns = determineColumns(data);
						profileHMM = new ProfileHMM(matchColumns,this.getAlphabet(),useNullModel, !transitionsEmissionsNotInLog);
					}
				}
			}
			else{
				String[] sequencesForBackgroundDist = null;
				if(backDist == Pos4BackgroundDist){
					sequencesForBackgroundDist = allTrainingSequences;
				}
				else{
					if(backDist == All4BackgroundDist){
						sequencesForBackgroundDist = new String[preFiltered.numInstances()];
						for(int k =0;k<preFiltered.numInstances();k++){
							String sequence = preFiltered.instance(k).stringValue(preFiltered.instance(k).attribute(sequenceIndex));
							if(this.getRestrictSequenceLength() != -1 && sequence.length()>this.getRestrictSequenceLength()){
								sequence = sequence.substring(0,this.getRestrictSequenceLength());
							}
							sequencesForBackgroundDist[k]=sequence;
						}
					}
					if(backDist == Neg4BackgroundDist){
						if(positiveClassIndex >= 0){
							RemoveWithValues filter  = new RemoveWithValues();
							String[] options = new String [4];
							options[0] = "-L";
							options[1] = (positiveClassIndex+1)+"";
							options[2] = "-C";
							options[3] = (data.classIndex()+1)+"";
							filter.setInputFormat(preFiltered);
							filter.setOptions(options);
							Instances filteredNeg = Filter.useFilter(preFiltered, filter);
							Instances negOnly = new Instances(filteredNeg);
							negOnly.setRelationName(preFiltered.relationName());
							//System.out.println(negOnly);
							sequencesForBackgroundDist = new String[negOnly.numInstances()];
							for(int k =0;k<negOnly.numInstances();k++){
								String sequence = negOnly.instance(k).stringValue(negOnly.instance(k).attribute(sequenceIndex));
								if(this.getRestrictSequenceLength() != -1 && sequence.length()>this.getRestrictSequenceLength()){
									sequence = sequence.substring(0,this.getRestrictSequenceLength());
								}
								sequencesForBackgroundDist[k]=sequence;
							}
						}
					}
				}
				//ProfileHMM hmm;
				if(this.getRestrictSequenceLength() != -1){
					profileHMM = new ProfileHMM(this.getRestrictSequenceLength(),this.getAlphabet(),useNullModel,true, sequencesForBackgroundDist,!transitionsEmissionsNotInLog);
				}
				else{
					if(this.getRestrictMatchColumns() != -1){
						profileHMM = new ProfileHMM(this.getRestrictMatchColumns(),this.getAlphabet(),useNullModel, true,sequencesForBackgroundDist, !transitionsEmissionsNotInLog);
					}
					else{
						int matchColumns = determineColumns(data);
						profileHMM = new ProfileHMM(matchColumns,this.getAlphabet(),useNullModel, true, sequencesForBackgroundDist, !transitionsEmissionsNotInLog);
					}
				}
			}

			converged.add(0,false);
			logLikelihoodOfPHMM = Double.NEGATIVE_INFINITY;
			initiallogLikelihoodOfPHMM = Double.NEGATIVE_INFINITY;
			oldLogLikelihoodOfPHMM = Double.NEGATIVE_INFINITY;
			numIterationPerClass = -1;
		}


	}


	public void next(int iteration) throws Exception {

		double logLikelihood;
		BaumWelchLearner bwl = null;

		this.iteration = iteration;
		//for (int i = 0; i < allTrainingSequences.size(); i++) {

		if(allTrainingSequences.length != 0 && converged.get(0)== false){

			//learn

			bwl = new BaumWelchLearner(allTrainingSequences, logLikelihoodThreshold, profileHMM, learnInsertEmissions, isMemorySensitive());
			if(getBaumWelchOption() == 1){
				bwl.setAverageLikelihoodOverSequenceNumber(true);
			}
			if(getBaumWelchOption() == 2){
				bwl.setAverageLikelihoodOverResidueNumber(true);
			}
			ProfileHMM learnt;

			//System.out.println("start training HMM");
			learnt = bwl.learnFast();
			//System.out.println("finished training HMM");

			logLikelihood = bwl.getLogLikelihood();
			logLikelihoodOfPHMM = logLikelihood;
			if(iteration == 1){
				initiallogLikelihoodOfPHMM = bwl.getInitialLogLikelihood();
			}

			profileHMM = learnt;
			//System.out.println(learnt);
			numIterationPerClass = iteration;
			if(Math.abs(oldLogLikelihoodOfPHMM - logLikelihood) <= logLikelihoodThreshold){
				converged.set(0, true);
				numIterationPerClass = iteration;
			}
			oldLogLikelihoodOfPHMM = bwl.getLogLikelihood();
		}
		//}
		calibrateClassificationClassifier(filtered);

	}

	public Object clone() throws CloneNotSupportedException{
		return this.clone();
	}

	public boolean fullyConverged(){
		if(converged.isEmpty()){
			return false;
		}
		else{
			for(int i = 0; i < converged.size(); i++){
				if(converged.get(i) == false){
					return false;
				}
			}
			return true;
		}

	}


	public int getIteration() {
		return iteration;
	}


	public void resetAllTrainingSequences() {
		this.allTrainingSequences = null;
	}


	public int getClassIndexToKeep() {
		return positiveClassIndex;
	}


	public void setClassIndexToKeep(int classIndexToKeep) {
		this.positiveClassIndex = classIndexToKeep;
	}


	public Instances getOneClassClassificationArff(){
		return filtered;
	}



	public ProfileHMM getProfileHMM() {
		return profileHMM;
	}


	public void setProfileHMM(ProfileHMM profileHMM) {
		this.profileHMM = profileHMM;
	}

	public ProfileHMM getProfileHMM(int i) {
		return profileHMM;
	}


	public void setProfileHMM(ProfileHMM profileHMM, int i) {
		this.profileHMM = profileHMM;
	}


	public static long getSerialVersionUID() {
		return serialVersionUID;
	}

	public double[][] getSufficientStats(boolean includeInserts, String[] sequences) throws IllegalSymbolException, InvalidStructureException, InvalidViterbiPathException, NumericStabilityException{
		double[][] allStats = new double[sequences.length][profileHMM.getNumberMatchStates()*(profileHMM.getAlphabet()).alphabetSize()];
		for(int i = 0; i < sequences.length;i++){
			SufficientEmissionStatistics stats = new SufficientEmissionStatistics(profileHMM,includeInserts);
			allStats[i] = stats.getStats(sequences[i]);
		}
		return allStats;
	}

}
