/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    MarginCurve.java
 *    Copyright (C) 2002-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.evaluation;

import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;

import java.util.ArrayList;

/**
 * Generates points illustrating the prediction margin. The margin is defined as
 * the difference between the probability predicted for the actual class and the
 * highest probability predicted for the other classes. One hypothesis as to the
 * good performance of boosting algorithms is that they increaes the margins on
 * the training data and this gives better performance on test data.
 * 
 * @author Len Trigg (len@reeltwo.com)
 * @version $Revision: 15751 $
 */
public class MarginCurve implements RevisionHandler {

  /**
   * Calculates the cumulative margin distribution for the set of predictions,
   * returning the result as a set of Instances. The structure of these
   * Instances is as follows:
   * <p>
   * <ul>
   * <li><b>Margin</b> contains the margin value (which should be plotted as an
   * x-coordinate)
   * <li><b>Current</b> contains the count of instances with the current margin
   * (plot as y axis)
   * <li><b>Cumulative</b> contains the count of instances with margin less than
   * or equal to the current margin (plot as y axis)
   * </ul>
   * <p>
   * 
   * @return datapoints as a set of instances, null if no predictions have been
   *         made.
   */
  public Instances getCurve(ArrayList<Prediction> predictions) {

    if (predictions.size() == 0) {
      return null;
    }

    Instances insts = makeHeader();
    double[] margins = getMargins(predictions);
    int[] sorted = Utils.sort(margins);
    int binMargin = 0;
    int totalMargin = 0;
    insts.add(makeInstance(-1, binMargin, totalMargin));
    for (int element : sorted) {
      double current = margins[element];
      double weight = ((NominalPrediction) predictions.get(element)).weight();
      totalMargin += weight;
      binMargin += weight;
      if (true) {
        insts.add(makeInstance(current, binMargin, totalMargin));
        binMargin = 0;
      }
    }
    return insts;
  }

  /**
   * Pulls all the margin values out of a vector of NominalPredictions.
   * 
   * @param predictions a FastVector containing NominalPredictions
   * @return an array of margin values.
   */
  private double[] getMargins(ArrayList<Prediction> predictions) {

    // sort by predicted probability of the desired class.
    double[] margins = new double[predictions.size()];
    for (int i = 0; i < margins.length; i++) {
      NominalPrediction pred = (NominalPrediction) predictions.get(i);
      margins[i] = pred.margin();
    }
    return margins;
  }

  /**
   * Creates an Instances object with the attributes we will be calculating.
   * 
   * @return the Instances structure.
   */
  private Instances makeHeader() {

    ArrayList<Attribute> fv = new ArrayList<Attribute>();
    fv.add(new Attribute("Margin"));
    fv.add(new Attribute("Current"));
    fv.add(new Attribute("Cumulative"));
    return new Instances("MarginCurve", fv, 100);
  }

  /**
   * Creates an Instance object with the attributes calculated.
   * 
   * @param margin the margin for this data point.
   * @param current the number of instances with this margin.
   * @param cumulative the number of instances with margin less than or equal to
   *          this margin.
   * @return the Instance object.
   */
  private Instance makeInstance(double margin, int current, int cumulative) {

    int count = 0;
    double[] vals = new double[3];
    vals[count++] = margin;
    vals[count++] = current;
    vals[count++] = cumulative;
    return new DenseInstance(1.0, vals);
  }

  /**
   * Returns the revision string.
   * 
   * @return the revision
   */
  @Override
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 15751 $");
  }
}
