/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    FisherKernel.java (slight derivation of weka.classifiers.functions.supportVector.RBFKernel.java version 1.11)
 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
 *    Copyright (C) 2005 J. Lindgren
 *    Copyright (C) 2010 Stefan Mutter
 *
 */

package weka.classifiers.functions.supportVector;

import java.util.Arrays;
import java.util.Enumeration;
import java.util.Vector;

import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.instance.RemoveWithValues;

/**
 <!-- globalinfo-start -->
 * The Fisher kernel. K(x, y) = e^-((1/(2*sigma^2)) * <x-y, x-y>^2) as described in
 * Jaakkola et al. A Discriminative Framework for Detecting Remote Protein Homologies.
 * Journal of Computational Biology, 7 (1-2):95-114,2000 doi: 10.1089/10665270050081405
 *
 * <p/>
 <!-- globalinfo-end -->
 *
 <!-- options-start -->
 * Valid options are: <p/>
 *
 * <pre> -G
 *  Use standard sigma (sigma = 7).
 *  (default: false, that means a data dependent sigma is calculated)</pre>
 *

 *
 <!-- options-end -->
 *
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @author Shane Legg (shane@intelligenesis.net) (sparse vector code)
 * @author Stuart Inglis (stuart@reeltwo.com) (sparse vector code)
 * @author J. Lindgren (jtlindgr{at}cs.helsinki.fi) (RBF kernel)
 * @author Stefan Mutter (pHMM4weka@gmail.com) (Fisher Kernel extension)
 * @version $Revision: 6 $
 */
public class FisherKernel
extends CachedKernel {

  /** for serialization */
  static final long serialVersionUID = 5247117544316387852L;

  /** The precalculated dotproducts of &lt;inst_i,inst_i&gt; */
  protected double m_kernelPrecalc[];

  /** Gamma for the kernel. */
  protected double sigma = 7;

  /** flag to indicate if data dependent sigma is calculated*/
  protected boolean useStandardSigma = false;

  private Instances dataToFilter;

  /** calculates data dependent sigma at first sight of the data, is set to false afterwards*/
  private boolean first = true;


  /**
   * default constructor - does nothing.
   */
  public FisherKernel() {
    super();
  }

  /**
   * Constructor. Initializes m_kernelPrecalc[].
   *
   * @param data	the data to use
   * @param cacheSize	the size of the cache
   * @param gamma	the bandwidth
   * @throws Exception	if something goes wrong
   */
  public FisherKernel(Instances data, int cacheSize)
  throws Exception {

    super();
    //dataToFilter = data;
    setCacheSize(cacheSize);

    buildKernel(data);
  }

  private void calculateSigma() throws Exception {

    Instances preFiltered = new Instances(m_data);

    RemoveWithValues filter  = new RemoveWithValues();
    String[] options = new String [5];
    options[0] = "-L";
    options[1] = "1";
    options[2] = "-C";
    options[3] = (preFiltered.classIndex()+1)+"";
    options[4] = "-V";
    filter.setInputFormat(preFiltered);
    filter.setOptions(options);
    //      filter.setInvertSelection(true);
    //      filter.setAttributeIndex(data.classIndex()+"");
    //      System.out.println(filter.getAttributeIndex());
    //      filter.setNominalIndices(classIndexToKeep+"");
    //      System.out.println(filter.getNominalIndices());
    Instances filtered = Filter.useFilter(preFiltered, filter);
    Instances class1 = new Instances(filtered);
    class1.setRelationName(preFiltered.relationName()+"_class1");

    filter  = new RemoveWithValues();
    options = new String [4];
    options[0] = "-L";
    options[1] = "1";
    options[2] = "-C";
    options[3] = (preFiltered.classIndex()+1)+"";
    filter.setInputFormat(preFiltered);
    filter.setOptions(options);
    //      filter.setInvertSelection(true);
    //      filter.setAttributeIndex(data.classIndex()+"");
    //      System.out.println(filter.getAttributeIndex());
    //      filter.setNominalIndices(classIndexToKeep+"");
    //      System.out.println(filter.getNominalIndices());
    filtered = Filter.useFilter(preFiltered, filter);
    Instances class2 = new Instances(filtered);
    class1.setRelationName(preFiltered.relationName()+"_class2");

    double[] saveMinDist = new double[class1.numInstances()];


    //calculate sigma

    for(int i = 0; i < class1.numInstances(); i++){
      double[] findMinDist = new double[class2.numInstances()];
      Instance inst1 = class1.instance(i);
      for(int j = 0; j < class2.numInstances(); j++ ){
	Instance inst2 = class2.instance(j);
	findMinDist[j] = Math.sqrt((((-2.) * dotProd(inst1, inst2) )+ dotProd(inst1, inst1) + dotProd(inst2, inst2)));
      }
      saveMinDist[i] = findMinDist[Utils.minIndex(findMinDist)];
    }

    Arrays.sort(saveMinDist);

    if(saveMinDist.length % 2 == 1){
      sigma = saveMinDist[saveMinDist.length/2];
    }
    else{
      sigma = 0.5 *(saveMinDist[saveMinDist.length/2]+saveMinDist[(saveMinDist.length/2)-1]);
    }
    System.out.println("sigma: "+sigma);
  }


  /**
   * Returns a string describing the kernel
   *
   * @return a description suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String globalInfo() {
    return
    "The Fisher kernel. K(x, y) = e^-((1/(2*sigma^2)) * <x-y, x-y>^2)";
  }

  /**
   * Returns an enumeration describing the available options.
   *
   * @return 		an enumeration of all the available options.
   */
  public Enumeration listOptions() {
    Vector		result;
    Enumeration		en;

    result = new Vector();

    en = super.listOptions();
    while (en.hasMoreElements())
      result.addElement(en.nextElement());

    if(useStandardSigma){
      result.add("-G");
    }

    return result.elements();
  }

  /**
   * Parses a given list of options. <p/>
   *
   <!-- options-start -->
   * Valid options are: <p/>
   *
   * <pre> -G
   *  Use standard sigma (sigma = 7).
   *  (default: false, that means a data dependent sigma is calculated)</pre>
   *
   *
   <!-- options-end -->
   *
   * @param options 	the list of options as an array of strings
   * @throws Exception 	if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {

    useStandardSigma = Utils.getFlag('G', options);

    super.setOptions(options);
  }

  /**
   * Gets the current settings of the Kernel.
   *
   * @return an array of strings suitable for passing to setOptions
   */
  public String[] getOptions() {
    int       i;
    Vector    result;
    String[]  options;

    result = new Vector();
    options = super.getOptions();
    for (i = 0; i < options.length; i++)
      result.add(options[i]);

    if(useStandardSigma){
      result.add("-G");
    }

    return (String[]) result.toArray(new String[result.size()]);
  }

  /**
   *
   * @param id1   	the index of instance 1
   * @param id2		the index of instance 2
   * @param inst1	the instance 1 object
   * @return 		the dot product
   * @throws Exception 	if something goes wrong
   */
  protected double evaluate(int id1, int id2, Instance inst1)
  throws Exception {

    if (id1 == id2) {
      return 1.0;
    } else {
      double precalc1;
      if (id1 == -1)
	precalc1 = dotProd(inst1, inst1);
      else
	precalc1 = m_kernelPrecalc[id1];
      Instance inst2 = m_data.instance(id2);
      double result = Math.exp((1/(2*sigma*sigma))
	  * (2. * dotProd(inst1, inst2) - precalc1 - m_kernelPrecalc[id2]));
      //System.out.println(sigma);
      return result;
    }
  }





  /**
   * initializes variables etc.
   *
   * @param data	the data to use
   */
  protected void initVars(Instances data) {
    super.initVars(data);

    m_kernelPrecalc = new double[data.numInstances()];
  }

  /**
   * Returns the Capabilities of this kernel.
   *
   * @return            the capabilities of this object
   * @see               Capabilities
   */
  public Capabilities getCapabilities() {
    Capabilities result = super.getCapabilities();

    result.enable(Capability.NUMERIC_ATTRIBUTES);
    result.enableAllClasses();
    result.enable(Capability.MISSING_CLASS_VALUES);

    return result;
  }

  /**
   * builds the kernel with the given data. Initializes the kernel cache.
   * The actual size of the cache in bytes is (64 * cacheSize).
   *
   * @param data	the data to base the kernel on
   * @throws Exception	if something goes wrong
   */
  public void buildKernel(Instances data) throws Exception {
    // does kernel handle the data?
    if (!getChecksTurnedOff())
      getCapabilities().testWithFail(data);

    initVars(data);
    if(first && !useStandardSigma){
      calculateSigma();
      first = false;
    }

    for (int i = 0; i < data.numInstances(); i++)
      m_kernelPrecalc[i] = dotProd(data.instance(i), data.instance(i));
  }

  /**
   * returns a string representation for the Kernel
   *
   * @return 		a string representaiton of the kernel
   */
  public String toString() {

    return "Fisher kernel: K(x,y) = e^-((1/(2*sigma*sigma)) * <x-y,x-y>^2) ";
  }

  /**
   * Returns the revision string.
   *
   * @return		the revision
   */
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 6 $");
  }
  /**
   * returns true if standard sigma is used, false otherwise
   *
   * @return a flag to indicate whether standard sigma is used
   */
  public boolean isUseStandardSigma() {
    return useStandardSigma;
  }

  /**
   * sets the use of a data independent standard sigma (default false)
   *
   * @param useStandardSigma	boolean flag to indicate the use of standard sigma
   */
  public void setUseStandardSigma(boolean useStandardSigma) {
    this.useStandardSigma = useStandardSigma;
  }
}

