/*
 * WekaClassifierErrors.java
 * Copyright (C) 2009-2011 University of Waikato, Hamilton, New Zealand
 */

package adams.flow.sink;

import java.awt.BorderLayout;
import java.awt.Color;
import java.util.ArrayList;

import javax.swing.JComponent;

import org.math.plot.Plot3DPanel;
import org.math.plot.plots.BarPlot;
import org.math.plot.plots.Plot;
import org.math.plot.utils.Array;

import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.Prediction;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.gui.visualize.Plot2D;
import weka.gui.visualize.PlotData2D;
import weka.gui.visualize.VisualizePanel;
import adams.data.weka.predictions.AbstractErrorScaler;
import adams.data.weka.predictions.AutoScaler;
import adams.flow.core.Token;
import adams.gui.core.BasePanel;

/**
 <!-- globalinfo-start -->
 * Actor for displaying classifier errors.
 * <p/>
 <!-- globalinfo-end -->
 *
 <!-- flow-summary-start -->
 * Input/output:<br/>
 * - accepts:<br/>
 * &nbsp;&nbsp;&nbsp;weka.classifiers.Evaluation<br/>
 * <p/>
 <!-- flow-summary-end -->
 *
 <!-- options-start -->
 * Valid options are: <p/>
 *
 * <pre>-D &lt;int&gt; (property: debugLevel)
 * &nbsp;&nbsp;&nbsp;The greater the number the more additional info the scheme may output to
 * &nbsp;&nbsp;&nbsp;the console (0 = off).
 * &nbsp;&nbsp;&nbsp;default: 0
 * &nbsp;&nbsp;&nbsp;minimum: 0
 * </pre>
 *
 * <pre>-name &lt;java.lang.String&gt; (property: name)
 * &nbsp;&nbsp;&nbsp;The name of the actor.
 * &nbsp;&nbsp;&nbsp;default: ClassifierErrors
 * </pre>
 *
 * <pre>-annotation &lt;adams.core.base.BaseText&gt; (property: annotations)
 * &nbsp;&nbsp;&nbsp;The annotations to attach to this actor.
 * &nbsp;&nbsp;&nbsp;default:
 * </pre>
 *
 * <pre>-skip (property: skip)
 * &nbsp;&nbsp;&nbsp;If set to true, transformation is skipped and the input token is just forwarded
 * &nbsp;&nbsp;&nbsp;as it is.
 * </pre>
 *
 * <pre>-width &lt;int&gt; (property: width)
 * &nbsp;&nbsp;&nbsp;The width of the dialog.
 * &nbsp;&nbsp;&nbsp;default: 640
 * &nbsp;&nbsp;&nbsp;minimum: 1
 * </pre>
 *
 * <pre>-height &lt;int&gt; (property: height)
 * &nbsp;&nbsp;&nbsp;The height of the dialog.
 * &nbsp;&nbsp;&nbsp;default: 480
 * &nbsp;&nbsp;&nbsp;minimum: 1
 * </pre>
 *
 * <pre>-x &lt;int&gt; (property: x)
 * &nbsp;&nbsp;&nbsp;The X position of the dialog (&gt;=0: absolute, -1: left, -2: center, -3: right
 * &nbsp;&nbsp;&nbsp;).
 * &nbsp;&nbsp;&nbsp;default: -1
 * &nbsp;&nbsp;&nbsp;minimum: -3
 * </pre>
 *
 * <pre>-y &lt;int&gt; (property: y)
 * &nbsp;&nbsp;&nbsp;The Y position of the dialog (&gt;=0: absolute, -1: top, -2: center, -3: bottom
 * &nbsp;&nbsp;&nbsp;).
 * &nbsp;&nbsp;&nbsp;default: -1
 * &nbsp;&nbsp;&nbsp;minimum: -3
 * </pre>
 *
 * <pre>-writer &lt;adams.gui.print.JComponentWriter [options]&gt; (property: writer)
 * &nbsp;&nbsp;&nbsp;The writer to use for generating the graphics output.
 * &nbsp;&nbsp;&nbsp;default: adams.gui.print.NullWriter
 * </pre>
 *
 * <pre>-scaler &lt;adams.data.weka.predictions.AbstractErrorScaler [options]&gt; (property: errorScaler)
 * &nbsp;&nbsp;&nbsp;The scaler to use for scaling the errors.
 * &nbsp;&nbsp;&nbsp;default: adams.data.weka.predictions.FixedSizeErrorScaler
 * </pre>
 *
 <!-- options-end -->
 *
 * @author  fracpete (fracpete at waikato dot ac dot nz)
 * @version $Revision: 3948 $
 */
public class WekaClassifierErrors
  extends AbstractGraphicalDisplay
  implements DisplayPanelProvider {

  /** for serialization. */
  private static final long serialVersionUID = 3247255046513744115L;

  /**
   * Helper class for generating visualization data.
   *
   * @author  fracpete (fracpete at waikato dot ac dot nz)
   * @version $Revision: 3948 $
   */
  public static class DataGenerator {

    /** the underlying Evaluation object. */
    protected Evaluation m_Evaluation;

    /** the underlying data. */
    protected Instances m_PlotInstances;

    /** for storing the plot shapes. */
    protected FastVector m_PlotShapes;

    /** for storing the plot sizes. */
    protected FastVector m_PlotSizes;

    /** the scaler scheme to use. */
    protected AbstractErrorScaler m_ErrorScaler;

    /** whether the data has already been processed. */
    protected boolean m_Processed;

    /**
     * Initializes the generator.
     *
     * @param eval	the Evaluation object to use
     * @param scaler	the scaler scheme to use for the errors
     */
    public DataGenerator(Evaluation eval, AbstractErrorScaler scaler) {
      super();

      m_Evaluation  = eval;
      m_ErrorScaler = scaler;
      m_Processed   = false;
    }

    /**
     * Processes the data if necessary.
     */
    protected void process() {
      Capabilities		cap;
      ArrayList<Integer>	scaled;

      if (m_Processed)
	return;

      m_Processed = true;

      createDataset(m_Evaluation);

      try {
	cap = m_ErrorScaler.getCapabilities();
	cap.testWithFail(m_PlotInstances.classAttribute(), true);
	scaled = m_ErrorScaler.scale(m_PlotSizes);
	m_PlotSizes = new FastVector();
	m_PlotSizes.addAll(scaled);
      }
      catch (Exception e) {
	e.printStackTrace();
	m_PlotInstances = new Instances(m_PlotInstances, 0);
	m_PlotSizes     = new FastVector();
	m_PlotShapes    = new FastVector();
      }
    }

    /**
     * Returns the underlying Evaluation object.
     *
     * @return		the Evaluation object
     */
    public Evaluation getEvaluation() {
      return m_Evaluation;
    }

    /**
     * Returns the scaling scheme.
     *
     * @return		the scaler
     */
    public AbstractErrorScaler getErrorScaler() {
      return m_ErrorScaler;
    }

    /**
     * Returns the generated dataset that is plotted.
     *
     * @return		the dataset
     */
    public Instances getPlotInstances() {
      process();

      return m_PlotInstances;
    }

    /**
     * Generates a dataset, containing the predicted vs actual values.
     *
     * @param eval	for obtaining the dataset information and predictions
     */
    protected void createDataset(Evaluation eval) {
      ArrayList<Attribute>	atts;
      Attribute			classAtt;
      FastVector		preds;
      int			i;
      double[]			values;
      Instance			inst;
      Prediction		pred;

      m_PlotShapes = new FastVector();
      m_PlotSizes  = new FastVector();
      classAtt     = eval.getHeader().classAttribute();
      preds        = eval.predictions();

      // generate header
      atts     = new ArrayList<Attribute>();
      atts.add(classAtt.copy("predicted" + classAtt.name()));
      atts.add((Attribute) classAtt.copy());
      m_PlotInstances = new Instances(
  	eval.getHeader().relationName() + "-classifier_errors", atts, preds.size());
      m_PlotInstances.setClassIndex(m_PlotInstances.numAttributes() - 1);

      // add data
      for (i = 0; i < preds.size(); i++) {
        pred   = (Prediction) preds.elementAt(i);
        values = new double[]{pred.predicted(), pred.actual()};
        inst   = new DenseInstance(pred.weight(), values);
        m_PlotInstances.add(inst);

        if (classAtt.isNominal()) {
          if (weka.core.Utils.isMissingValue(pred.actual()) || weka.core.Utils.isMissingValue(pred.predicted())) {
            m_PlotShapes.addElement(new Integer(Plot2D.MISSING_SHAPE));
          }
          else if (pred.predicted() != pred.actual()) {
            // set to default error point shape
            m_PlotShapes.addElement(new Integer(Plot2D.ERROR_SHAPE));
          }
          else {
            // otherwise set to constant (automatically assigned) point shape
            m_PlotShapes.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE));
          }
          m_PlotSizes.addElement(new Integer(Plot2D.DEFAULT_SHAPE_SIZE));
        }
        else {
          // store the error (to be converted to a point size later)
          Double errd = null;
          if (!weka.core.Utils.isMissingValue(pred.actual()) && !weka.core.Utils.isMissingValue(pred.predicted())) {
            errd = new Double(pred.predicted() - pred.actual());
            m_PlotShapes.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE));
          }
          else {
            // missing shape if actual class not present or prediction is missing
            m_PlotShapes.addElement(new Integer(Plot2D.MISSING_SHAPE));
          }
          m_PlotSizes.addElement(errd);
        }
      }
    }

    /**
     * Assembles and returns the plot. The relation name of the dataset gets
     * added automatically.
     *
     * @return			the plot
     * @throws Exception	if plot generation fails
     */
    public PlotData2D getPlotData() throws Exception {
      PlotData2D 	result;

      process();

      result = new PlotData2D(m_PlotInstances);
      result.setShapeSize(m_PlotSizes);
      result.setShapeType(m_PlotShapes);
      result.setPlotName("Classifier Errors" + " (" + m_PlotInstances.relationName() + ")");
      result.addInstanceNumberAttribute();

      return result;
    }

    /**
     * Assembles and returns the plot. The relation name of the dataset gets
     * added automatically. x=Actual, y=Predicted, z=Error.
     *
     * @return			the plot
     * @throws Exception	if plot generation fails
     */
    public Plot getJMathPlot() throws Exception {
      BarPlot	result;
      double[]	x;
      double[]	y;
      double[]	z;
      int	i;

      process();

      x = m_PlotInstances.attributeToDoubleArray(1);
      y = m_PlotInstances.attributeToDoubleArray(0);
      z = new double[m_PlotSizes.size()];
      for (i = 0; i < m_PlotSizes.size(); i++)
	z[i] = ((Number) m_PlotSizes.get(i)).doubleValue();

      result = new BarPlot(
	  "Classifier Errors" + " (" + m_PlotInstances.relationName() + ")",
	  Color.RED,
	  Array.mergeColumns(x, y, z));

      return result;
    }
  }

  /**
   * The type of plot to produce.
   *
   * @author  fracpete (fracpete at waikato dot ac dot nz)
   * @version $Revision: 3948 $
   */
  public enum PlotType {
    /** using the Weka visualization. */
    TWO_DIMENSIONAL,
    /** using the jmathplot visualization. */
    THREE_DIMENSIONAL
  }

  /** the Weka plot panel. */
  protected VisualizePanel m_VisualizePanel;

  /** the JMathPlot panel. */
  protected Plot3DPanel m_JMathPlotPanel;

  /** The scheme for scaling the errors. */
  protected AbstractErrorScaler m_ErrorScaler;

  /** the type of plot to produce. */
  protected PlotType m_PlotType;

  /**
   * Returns a string describing the object.
   *
   * @return 			a description suitable for displaying in the gui
   */
  public String globalInfo() {
    return "Actor for displaying classifier errors.";
  }

  /**
   * Adds options to the internal list of options.
   */
  public void defineOptions() {
    super.defineOptions();

    m_OptionManager.add(
	    "scaler", "errorScaler",
	    new AutoScaler());

    m_OptionManager.add(
	    "plot-type", "plotType",
	    PlotType.TWO_DIMENSIONAL);
  }

  /**
   * Returns the default width for the dialog.
   *
   * @return		the default width
   */
  protected int getDefaultWidth() {
    return 640;
  }

  /**
   * Returns the default height for the dialog.
   *
   * @return		the default height
   */
  protected int getDefaultHeight() {
    return 480;
  }

  /**
   * Sets the scheme for scaling the errors.
   *
   * @param value 	the scheme
   */
  public void setErrorScaler(AbstractErrorScaler value) {
    m_ErrorScaler = value;
    reset();
  }

  /**
   * Returns the scheme to use for scaling the errors.
   *
   * @return 		the scheme
   */
  public AbstractErrorScaler getErrorScaler() {
    return m_ErrorScaler;
  }

  /**
   * Returns the tip text for this property.
   *
   * @return 		tip text for this property suitable for
   * 			displaying in the GUI or for listing the options.
   */
  public String errorScalerTipText() {
    return "The scaler to use for scaling the errors.";
  }

  /**
   * Sets the type of plot to produce.
   *
   * @param value 	the type
   */
  public void setPlotType(PlotType value) {
    m_PlotType = value;
    reset();
  }

  /**
   * Returns the type of plot to produce.
   *
   * @return 		the type
   */
  public PlotType getPlotType() {
    return m_PlotType;
  }

  /**
   * Returns the tip text for this property.
   *
   * @return 		tip text for this property suitable for
   * 			displaying in the GUI or for listing the options.
   */
  public String plotTypeTipText() {
    return "The type of plot to produce.";
  }

  /**
   * Clears the content of the panel.
   */
  public void clearPanel() {
    m_VisualizePanel.removeAllPlots();
  }

  /**
   * Creates the panel to display in the dialog.
   *
   * @return		the panel
   */
  protected BasePanel newPanel() {
    BasePanel	result;

    result = new BasePanel(new BorderLayout());
    switch (m_PlotType) {
      case TWO_DIMENSIONAL:
	m_VisualizePanel = new VisualizePanel();
	result.add(m_VisualizePanel, BorderLayout.CENTER);
	break;

      case THREE_DIMENSIONAL:
	m_JMathPlotPanel = new Plot3DPanel();
	m_JMathPlotPanel.setAxisLabel(0, "Actual");
	m_JMathPlotPanel.setAxisLabel(1, "Predicted");
	m_JMathPlotPanel.setAxisLabel(2, "Error");
	result.add(m_JMathPlotPanel, BorderLayout.CENTER);
	break;

      default:
	throw new IllegalStateException("Unhandled plot type: " + m_PlotType);
    }

    return result;
  }

  /**
   * Returns the class that the consumer accepts.
   *
   * @return		<!-- flow-accepts-start -->weka.classifiers.Evaluation.class<!-- flow-accepts-end -->
   */
  public Class[] accepts() {
    return new Class[]{Evaluation.class};
  }

  /**
   * ClassifierErrorss the token (the panel and dialog have already been created at
   * this stage).
   *
   * @param token	the token to display
   */
  protected void display(Token token) {
    DataGenerator	generator;
    Evaluation		eval;

    try {
      eval = (Evaluation) token.getPayload();
      if (eval.predictions() == null) {
	getSystemErr().println("No predictions available from Evaluation object!");
	return;
      }
      generator = new DataGenerator(eval, m_ErrorScaler);
      switch (m_PlotType) {
	case TWO_DIMENSIONAL:
	  PlotData2D plotdata = generator.getPlotData();
	  plotdata.setPlotName(generator.getPlotInstances().relationName());
	  m_VisualizePanel.addPlot(plotdata);
	  m_VisualizePanel.setColourIndex(plotdata.getPlotInstances().classIndex());
	  if ((m_VisualizePanel.getXIndex() == 0) && (m_VisualizePanel.getYIndex() == 1)) {
	    try {
	      m_VisualizePanel.setXIndex(m_VisualizePanel.getInstances().classIndex());  // class
	      m_VisualizePanel.setYIndex(m_VisualizePanel.getInstances().classIndex() - 1);  // predicted class
	    }
	    catch (Exception e) {
	      // ignored
	    }
	  }
	  break;

	case THREE_DIMENSIONAL:
	  Plot plot = generator.getJMathPlot();
	  m_JMathPlotPanel.addPlot(plot);
	  double min;
	  double max;
	  if (generator.getPlotInstances().attribute(1).isNumeric()) {
	    min = Math.min(generator.getPlotInstances().attributeStats(1).numericStats.min, generator.getPlotInstances().attributeStats(0).numericStats.min);
	    max = Math.min(generator.getPlotInstances().attributeStats(1).numericStats.max, generator.getPlotInstances().attributeStats(0).numericStats.max);
	  }
	  else {
	    min = 0;
	    max = generator.getPlotInstances().attribute(1).numValues() - 1;
	  }
	  m_JMathPlotPanel.addLinePlot(
	      "Diagonal",
	      Color.DARK_GRAY,
	      new double[]{min, max},
	      new double[]{min, max},
	      new double[]{0.0, 0.0});
	  break;

	default:
	  throw new IllegalStateException("Unhandled plot type: " + m_PlotType);
      }
    }
    catch (Exception e) {
      getSystemErr().printStackTrace(e);
    }
  }

  /**
   * Removes all graphical components.
   */
  protected void cleanUpGUI() {
    if (m_VisualizePanel != null) {
      m_VisualizePanel.removeAllPlots();
      m_VisualizePanel = null;
    }
    if (m_JMathPlotPanel != null) {
      m_JMathPlotPanel.removeAllPlots();
      m_JMathPlotPanel = null;
    }

    super.cleanUpGUI();
  }

  /**
   * Creates a new panel for the token.
   *
   * @param token	the token to display in a new panel
   * @return		the generated panel
   */
  public AbstractDisplayPanel createDisplayPanel(Token token) {
    AbstractDisplayPanel	result;
    String			name;

    name = "Classifier errors (" + ((Evaluation) token.getPayload()).getHeader().relationName() + ")";

    result = new AbstractComponentDisplayPanel(name) {
      private static final long serialVersionUID = -7362768698548152899L;
      protected VisualizePanel m_VisualizePanel;
      protected Plot3DPanel m_JMathPlotPanel;
      protected void initGUI() {
	super.initGUI();
	setLayout(new BorderLayout());
	switch (m_PlotType) {
	  case TWO_DIMENSIONAL:
	    m_VisualizePanel = new VisualizePanel();
	    add(m_VisualizePanel, BorderLayout.CENTER);
	    break;

	  case THREE_DIMENSIONAL:
	    m_JMathPlotPanel = new Plot3DPanel();
	    m_JMathPlotPanel.setAxisLabel(0, "Actual");
	    m_JMathPlotPanel.setAxisLabel(1, "Predicted");
	    m_JMathPlotPanel.setAxisLabel(2, "Error");
	    add(m_JMathPlotPanel, BorderLayout.CENTER);
	    break;

	  default:
	    throw new IllegalStateException("Unhandled plot type: " + m_PlotType);
	}
      }
      public void display(Token token) {
	try {
	  DataGenerator generator = new DataGenerator((Evaluation) token.getPayload(), m_ErrorScaler);
	  switch (m_PlotType) {
	    case TWO_DIMENSIONAL:
	      PlotData2D plotdata = generator.getPlotData();
	      plotdata.setPlotName(generator.getPlotInstances().relationName());
	      m_VisualizePanel.addPlot(plotdata);
	      m_VisualizePanel.setColourIndex(plotdata.getPlotInstances().classIndex());
	      if ((m_VisualizePanel.getXIndex() == 0) && (m_VisualizePanel.getYIndex() == 1)) {
		try {
		  m_VisualizePanel.setXIndex(m_VisualizePanel.getInstances().classIndex());  // class
		  m_VisualizePanel.setYIndex(m_VisualizePanel.getInstances().classIndex() - 1);  // predicted class
		}
		catch (Exception e) {
		  // ignored
		}
	      }
	      break;

	    case THREE_DIMENSIONAL:
	      Plot plot = generator.getJMathPlot();
	      m_JMathPlotPanel.addPlot(plot);
	      double min;
	      double max;
	      if (generator.getPlotInstances().attribute(1).isNumeric()) {
		min = Math.min(generator.getPlotInstances().attributeStats(1).numericStats.min, generator.getPlotInstances().attributeStats(0).numericStats.min);
		max = Math.min(generator.getPlotInstances().attributeStats(1).numericStats.max, generator.getPlotInstances().attributeStats(0).numericStats.max);
	      }
	      else {
		min = 0;
		max = generator.getPlotInstances().attribute(1).numValues() - 1;
	      }
	      m_JMathPlotPanel.addLinePlot(
		  "Diagonal",
		  Color.DARK_GRAY,
		  new double[]{min, max},
		  new double[]{min, max},
		  new double[]{0.0, 0.0});
	      break;

	    default:
	      throw new IllegalStateException("Unhandled plot type: " + m_PlotType);
	  }
	}
	catch (Exception e) {
	  getSystemErr().printStackTrace(e);
	}
      }
      public JComponent supplyComponent() {
	return m_VisualizePanel;
      }
      public void cleanUp() {
	m_VisualizePanel.removeAllPlots();
      }
    };
    result.display(token);

    return result;
  }

  /**
   * Returns whether the created display panel requires a scroll pane or not.
   *
   * @return		true if the display panel requires a scroll pane
   */
  public boolean displayPanelRequiresScrollPane() {
    return true;
  }
}
