/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.sink;

import adams.data.weka.predictions.AbstractErrorScaler;
import adams.data.weka.predictions.AutoScaler;
import adams.flow.core.Token;
import adams.flow.sink.AbstractComponentDisplayPanel;
import adams.flow.sink.AbstractDisplayPanel;
import adams.flow.sink.AbstractGraphicalDisplay;
import adams.flow.sink.DisplayPanelProvider;
import adams.gui.core.BasePanel;
import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Component;
import java.awt.LayoutManager;
import java.util.ArrayList;
import javax.swing.JComponent;
import org.math.plot.Plot3DPanel;
import org.math.plot.plots.BarPlot;
import org.math.plot.plots.Plot;
import org.math.plot.utils.Array;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.Prediction;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.gui.visualize.PlotData2D;
import weka.gui.visualize.VisualizePanel;

public class WekaClassifierErrors
extends AbstractGraphicalDisplay
implements DisplayPanelProvider {
    private static final long serialVersionUID = 3247255046513744115L;
    protected VisualizePanel m_VisualizePanel;
    protected Plot3DPanel m_JMathPlotPanel;
    protected AbstractErrorScaler m_ErrorScaler;
    protected PlotType m_PlotType;

    public String globalInfo() {
        return "Actor for displaying classifier errors.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("scaler", "errorScaler", (Object)new AutoScaler());
        this.m_OptionManager.add("plot-type", "plotType", (Object)PlotType.TWO_DIMENSIONAL);
    }

    protected int getDefaultWidth() {
        return 640;
    }

    protected int getDefaultHeight() {
        return 480;
    }

    public void setErrorScaler(AbstractErrorScaler value) {
        this.m_ErrorScaler = value;
        this.reset();
    }

    public AbstractErrorScaler getErrorScaler() {
        return this.m_ErrorScaler;
    }

    public String errorScalerTipText() {
        return "The scaler to use for scaling the errors.";
    }

    public void setPlotType(PlotType value) {
        this.m_PlotType = value;
        this.reset();
    }

    public PlotType getPlotType() {
        return this.m_PlotType;
    }

    public String plotTypeTipText() {
        return "The type of plot to produce.";
    }

    public void clearPanel() {
        this.m_VisualizePanel.removeAllPlots();
    }

    protected BasePanel newPanel() {
        BasePanel result = new BasePanel((LayoutManager)new BorderLayout());
        switch (this.m_PlotType) {
            case TWO_DIMENSIONAL: {
                this.m_VisualizePanel = new VisualizePanel();
                result.add((Component)this.m_VisualizePanel, (Object)"Center");
                break;
            }
            case THREE_DIMENSIONAL: {
                this.m_JMathPlotPanel = new Plot3DPanel();
                this.m_JMathPlotPanel.setAxisLabel(0, "Actual");
                this.m_JMathPlotPanel.setAxisLabel(1, "Predicted");
                this.m_JMathPlotPanel.setAxisLabel(2, "Error");
                result.add((Component)this.m_JMathPlotPanel, (Object)"Center");
                break;
            }
            default: {
                throw new IllegalStateException("Unhandled plot type: " + (Object)((Object)this.m_PlotType));
            }
        }
        return result;
    }

    public Class[] accepts() {
        return new Class[]{Evaluation.class};
    }

    protected void display(Token token) {
        try {
            Evaluation eval = (Evaluation)token.getPayload();
            if (eval.predictions() == null) {
                this.getSystemErr().println("No predictions available from Evaluation object!");
                return;
            }
            DataGenerator generator = new DataGenerator(eval, this.m_ErrorScaler);
            switch (this.m_PlotType) {
                case TWO_DIMENSIONAL: {
                    PlotData2D plotdata = generator.getPlotData();
                    plotdata.setPlotName(generator.getPlotInstances().relationName());
                    this.m_VisualizePanel.addPlot(plotdata);
                    this.m_VisualizePanel.setColourIndex(plotdata.getPlotInstances().classIndex());
                    if (this.m_VisualizePanel.getXIndex() == 0 && this.m_VisualizePanel.getYIndex() == 1) {
                        try {
                            this.m_VisualizePanel.setXIndex(this.m_VisualizePanel.getInstances().classIndex());
                            this.m_VisualizePanel.setYIndex(this.m_VisualizePanel.getInstances().classIndex() - 1);
                        }
                        catch (Exception e) {}
                    }
                    break;
                }
                case THREE_DIMENSIONAL: {
                    double max;
                    double min;
                    Plot plot = generator.getJMathPlot();
                    this.m_JMathPlotPanel.addPlot(plot);
                    if (generator.getPlotInstances().attribute(1).isNumeric()) {
                        min = Math.min(generator.getPlotInstances().attributeStats((int)1).numericStats.min, generator.getPlotInstances().attributeStats((int)0).numericStats.min);
                        max = Math.min(generator.getPlotInstances().attributeStats((int)1).numericStats.max, generator.getPlotInstances().attributeStats((int)0).numericStats.max);
                    } else {
                        min = 0.0;
                        max = generator.getPlotInstances().attribute(1).numValues() - 1;
                    }
                    this.m_JMathPlotPanel.addLinePlot("Diagonal", Color.DARK_GRAY, new double[]{min, max}, new double[]{min, max}, new double[]{0.0, 0.0});
                    break;
                }
                default: {
                    throw new IllegalStateException("Unhandled plot type: " + (Object)((Object)this.m_PlotType));
                }
            }
        }
        catch (Exception e) {
            this.getSystemErr().printStackTrace((Throwable)e);
        }
    }

    protected void cleanUpGUI() {
        if (this.m_VisualizePanel != null) {
            this.m_VisualizePanel.removeAllPlots();
            this.m_VisualizePanel = null;
        }
        if (this.m_JMathPlotPanel != null) {
            this.m_JMathPlotPanel.removeAllPlots();
            this.m_JMathPlotPanel = null;
        }
        super.cleanUpGUI();
    }

    public AbstractDisplayPanel createDisplayPanel(Token token) {
        String name = "Classifier errors (" + ((Evaluation)token.getPayload()).getHeader().relationName() + ")";
        AbstractComponentDisplayPanel result = new AbstractComponentDisplayPanel(name){
            private static final long serialVersionUID = -7362768698548152899L;
            protected VisualizePanel m_VisualizePanel;
            protected Plot3DPanel m_JMathPlotPanel;

            protected void initGUI() {
                super.initGUI();
                this.setLayout(new BorderLayout());
                switch (WekaClassifierErrors.this.m_PlotType) {
                    case TWO_DIMENSIONAL: {
                        this.m_VisualizePanel = new VisualizePanel();
                        this.add((Component)this.m_VisualizePanel, "Center");
                        break;
                    }
                    case THREE_DIMENSIONAL: {
                        this.m_JMathPlotPanel = new Plot3DPanel();
                        this.m_JMathPlotPanel.setAxisLabel(0, "Actual");
                        this.m_JMathPlotPanel.setAxisLabel(1, "Predicted");
                        this.m_JMathPlotPanel.setAxisLabel(2, "Error");
                        this.add((Component)this.m_JMathPlotPanel, "Center");
                        break;
                    }
                    default: {
                        throw new IllegalStateException("Unhandled plot type: " + (Object)((Object)WekaClassifierErrors.this.m_PlotType));
                    }
                }
            }

            public void display(Token token) {
                try {
                    DataGenerator generator = new DataGenerator((Evaluation)token.getPayload(), WekaClassifierErrors.this.m_ErrorScaler);
                    switch (WekaClassifierErrors.this.m_PlotType) {
                        case TWO_DIMENSIONAL: {
                            PlotData2D plotdata = generator.getPlotData();
                            plotdata.setPlotName(generator.getPlotInstances().relationName());
                            this.m_VisualizePanel.addPlot(plotdata);
                            this.m_VisualizePanel.setColourIndex(plotdata.getPlotInstances().classIndex());
                            if (this.m_VisualizePanel.getXIndex() == 0 && this.m_VisualizePanel.getYIndex() == 1) {
                                try {
                                    this.m_VisualizePanel.setXIndex(this.m_VisualizePanel.getInstances().classIndex());
                                    this.m_VisualizePanel.setYIndex(this.m_VisualizePanel.getInstances().classIndex() - 1);
                                }
                                catch (Exception e) {}
                            }
                            break;
                        }
                        case THREE_DIMENSIONAL: {
                            double max;
                            double min;
                            Plot plot = generator.getJMathPlot();
                            this.m_JMathPlotPanel.addPlot(plot);
                            if (generator.getPlotInstances().attribute(1).isNumeric()) {
                                min = Math.min(generator.getPlotInstances().attributeStats((int)1).numericStats.min, generator.getPlotInstances().attributeStats((int)0).numericStats.min);
                                max = Math.min(generator.getPlotInstances().attributeStats((int)1).numericStats.max, generator.getPlotInstances().attributeStats((int)0).numericStats.max);
                            } else {
                                min = 0.0;
                                max = generator.getPlotInstances().attribute(1).numValues() - 1;
                            }
                            this.m_JMathPlotPanel.addLinePlot("Diagonal", Color.DARK_GRAY, new double[]{min, max}, new double[]{min, max}, new double[]{0.0, 0.0});
                            break;
                        }
                        default: {
                            throw new IllegalStateException("Unhandled plot type: " + (Object)((Object)WekaClassifierErrors.this.m_PlotType));
                        }
                    }
                }
                catch (Exception e) {
                    WekaClassifierErrors.this.getSystemErr().printStackTrace((Throwable)e);
                }
            }

            public JComponent supplyComponent() {
                return this.m_VisualizePanel;
            }

            public void cleanUp() {
                this.m_VisualizePanel.removeAllPlots();
            }
        };
        result.display(token);
        return result;
    }

    public boolean displayPanelRequiresScrollPane() {
        return true;
    }

    public static enum PlotType {
        TWO_DIMENSIONAL,
        THREE_DIMENSIONAL;

    }

    public static class DataGenerator {
        protected Evaluation m_Evaluation;
        protected Instances m_PlotInstances;
        protected FastVector m_PlotShapes;
        protected FastVector m_PlotSizes;
        protected AbstractErrorScaler m_ErrorScaler;
        protected boolean m_Processed;

        public DataGenerator(Evaluation eval, AbstractErrorScaler scaler) {
            this.m_Evaluation = eval;
            this.m_ErrorScaler = scaler;
            this.m_Processed = false;
        }

        protected void process() {
            if (this.m_Processed) {
                return;
            }
            this.m_Processed = true;
            this.createDataset(this.m_Evaluation);
            try {
                Capabilities cap = this.m_ErrorScaler.getCapabilities();
                cap.testWithFail(this.m_PlotInstances.classAttribute(), true);
                ArrayList<Integer> scaled = this.m_ErrorScaler.scale((ArrayList)this.m_PlotSizes);
                this.m_PlotSizes = new FastVector();
                this.m_PlotSizes.addAll(scaled);
            }
            catch (Exception e) {
                e.printStackTrace();
                this.m_PlotInstances = new Instances(this.m_PlotInstances, 0);
                this.m_PlotSizes = new FastVector();
                this.m_PlotShapes = new FastVector();
            }
        }

        public Evaluation getEvaluation() {
            return this.m_Evaluation;
        }

        public AbstractErrorScaler getErrorScaler() {
            return this.m_ErrorScaler;
        }

        public Instances getPlotInstances() {
            this.process();
            return this.m_PlotInstances;
        }

        protected void createDataset(Evaluation eval) {
            this.m_PlotShapes = new FastVector();
            this.m_PlotSizes = new FastVector();
            Attribute classAtt = eval.getHeader().classAttribute();
            FastVector preds = eval.predictions();
            ArrayList<Attribute> atts = new ArrayList<Attribute>();
            atts.add(classAtt.copy("predicted" + classAtt.name()));
            atts.add((Attribute)classAtt.copy());
            this.m_PlotInstances = new Instances(eval.getHeader().relationName() + "-classifier_errors", atts, preds.size());
            this.m_PlotInstances.setClassIndex(this.m_PlotInstances.numAttributes() - 1);
            for (int i = 0; i < preds.size(); ++i) {
                Prediction pred = (Prediction)preds.elementAt(i);
                double[] values = new double[]{pred.predicted(), pred.actual()};
                DenseInstance inst = new DenseInstance(pred.weight(), values);
                this.m_PlotInstances.add((Instance)inst);
                if (classAtt.isNominal()) {
                    if (Utils.isMissingValue((double)pred.actual()) || Utils.isMissingValue((double)pred.predicted())) {
                        this.m_PlotShapes.addElement((Object)new Integer(2000));
                    } else if (pred.predicted() != pred.actual()) {
                        this.m_PlotShapes.addElement((Object)new Integer(1000));
                    } else {
                        this.m_PlotShapes.addElement((Object)new Integer(-1));
                    }
                    this.m_PlotSizes.addElement((Object)new Integer(2));
                    continue;
                }
                Double errd = null;
                if (!Utils.isMissingValue((double)pred.actual()) && !Utils.isMissingValue((double)pred.predicted())) {
                    errd = new Double(pred.predicted() - pred.actual());
                    this.m_PlotShapes.addElement((Object)new Integer(-1));
                } else {
                    this.m_PlotShapes.addElement((Object)new Integer(2000));
                }
                this.m_PlotSizes.addElement(errd);
            }
        }

        public PlotData2D getPlotData() throws Exception {
            this.process();
            PlotData2D result = new PlotData2D(this.m_PlotInstances);
            result.setShapeSize(this.m_PlotSizes);
            result.setShapeType(this.m_PlotShapes);
            result.setPlotName("Classifier Errors (" + this.m_PlotInstances.relationName() + ")");
            result.addInstanceNumberAttribute();
            return result;
        }

        public Plot getJMathPlot() throws Exception {
            this.process();
            double[] x = this.m_PlotInstances.attributeToDoubleArray(1);
            double[] y = this.m_PlotInstances.attributeToDoubleArray(0);
            double[] z = new double[this.m_PlotSizes.size()];
            for (int i = 0; i < this.m_PlotSizes.size(); ++i) {
                z[i] = ((Number)this.m_PlotSizes.get(i)).doubleValue();
            }
            BarPlot result = new BarPlot("Classifier Errors (" + this.m_PlotInstances.relationName() + ")", Color.RED, Array.mergeColumns(x, y, z));
            return result;
        }
    }
}

