/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer;

import adams.core.QuickInfoHelper;
import adams.core.base.BaseObject;
import adams.core.base.BaseString;
import adams.core.option.OptionHandler;
import adams.data.spreadsheet.DefaultSpreadSheet;
import adams.data.spreadsheet.HeaderRow;
import adams.data.spreadsheet.Row;
import adams.data.spreadsheet.SpreadSheet;
import adams.data.spreadsheet.SpreadSheetColumnIndex;
import adams.flow.core.Token;
import adams.flow.transformer.AbstractSpreadSheetTransformer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;

public class ConfusionMatrix
extends AbstractSpreadSheetTransformer {
    private static final long serialVersionUID = 6499246835313298302L;
    protected SpreadSheetColumnIndex m_ActualColumn;
    protected String m_ActualPrefix;
    protected SpreadSheetColumnIndex m_PredictedColumn;
    protected String m_PredictedPrefix;
    protected SpreadSheetColumnIndex m_ProbabilityColumn;
    protected MatrixValues m_MatrixValues;
    protected BaseString[] m_ClassLabels;

    public String globalInfo() {
        return "Generates a confusion matrix from the specified actual and predicted columns containing class labels.\nCan take a probability column (of prediction) into account for generating weighted counts.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("actual-column", "actualColumn", (Object)new SpreadSheetColumnIndex("1"));
        this.m_OptionManager.add("actual-prefix", "actualPrefix", (Object)"a: ");
        this.m_OptionManager.add("predicted-column", "predictedColumn", (Object)new SpreadSheetColumnIndex("2"));
        this.m_OptionManager.add("predicted-prefix", "predictedPrefix", (Object)"p: ");
        this.m_OptionManager.add("probability-column", "probabilityColumn", (Object)new SpreadSheetColumnIndex(""));
        this.m_OptionManager.add("matrix-values", "matrixValues", (Object)MatrixValues.COUNTS);
        this.m_OptionManager.add("class-labels", "classLabels", (Object)new BaseString[0]);
    }

    public String getQuickInfo() {
        String result = QuickInfoHelper.toString((OptionHandler)this, (String)"actualColumn", (Object)this.m_ActualColumn, (String)"actual: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"predictedColumn", (Object)this.m_PredictedColumn, (String)", predicted: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"probabilityColumn", (Object)(this.m_ProbabilityColumn.isEmpty() ? "-none-" : this.m_ProbabilityColumn.getIndex()), (String)", probability: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"matrixValues", (Object)((Object)this.m_MatrixValues), (String)", values: ");
        return result;
    }

    public void setActualColumn(SpreadSheetColumnIndex value) {
        this.m_ActualColumn = value;
        this.reset();
    }

    public SpreadSheetColumnIndex getActualColumn() {
        return this.m_ActualColumn;
    }

    public String actualColumnTipText() {
        return "The column with the actual labels.";
    }

    public void setActualPrefix(String value) {
        this.m_ActualPrefix = value;
        this.reset();
    }

    public String getActualPrefix() {
        return this.m_ActualPrefix;
    }

    public String actualPrefixTipText() {
        return "The prefix for the actual labels.";
    }

    public void setPredictedColumn(SpreadSheetColumnIndex value) {
        this.m_PredictedColumn = value;
        this.reset();
    }

    public SpreadSheetColumnIndex getPredictedColumn() {
        return this.m_PredictedColumn;
    }

    public String predictedColumnTipText() {
        return "The column with the predicted labels.";
    }

    public void setPredictedPrefix(String value) {
        this.m_PredictedPrefix = value;
        this.reset();
    }

    public String getPredictedPrefix() {
        return this.m_PredictedPrefix;
    }

    public String predictedPrefixTipText() {
        return "The prefix for the predicted labels.";
    }

    public void setProbabilityColumn(SpreadSheetColumnIndex value) {
        this.m_ProbabilityColumn = value;
        this.reset();
    }

    public SpreadSheetColumnIndex getProbabilityColumn() {
        return this.m_ProbabilityColumn;
    }

    public String probabilityColumnTipText() {
        return "The (optional) column with the probabilities; if not available probability of 1 is assumed.";
    }

    public void setMatrixValues(MatrixValues value) {
        this.m_MatrixValues = value;
        this.reset();
    }

    public MatrixValues getMatrixValues() {
        return this.m_MatrixValues;
    }

    public String matrixValuesTipText() {
        return "The type of values to generate.";
    }

    public void setClassLabels(BaseString[] value) {
        this.m_ClassLabels = value;
        this.reset();
    }

    public BaseString[] getClassLabels() {
        return this.m_ClassLabels;
    }

    public String classLabelsTipText() {
        return "The class labels to use for enforcing order other than alphabetical.";
    }

    protected String doExecute() {
        String result = null;
        SpreadSheet sheet = (SpreadSheet)this.m_InputToken.getPayload();
        this.m_ActualColumn.setData((Object)sheet);
        this.m_PredictedColumn.setData((Object)sheet);
        this.m_ProbabilityColumn.setData((Object)sheet);
        int actCol = this.m_ActualColumn.getIntIndex();
        int predCol = this.m_PredictedColumn.getIntIndex();
        int probCol = this.m_ProbabilityColumn.getIntIndex();
        if (actCol == -1) {
            result = "Actual column not found: " + this.m_ActualColumn;
        } else if (predCol == -1) {
            result = "Predicted column not found: " + this.m_PredictedColumn;
        }
        if (result == null) {
            int n;
            int i;
            ArrayList<String> predLabels;
            ArrayList<String> actLabels;
            HashMap actIndices = new HashMap();
            HashMap predIndices = new HashMap();
            if (this.m_ClassLabels.length > 0) {
                actLabels = new ArrayList<String>(Arrays.asList(BaseObject.toStringArray((BaseObject[])this.m_ClassLabels)));
                predLabels = new ArrayList<String>(Arrays.asList(BaseObject.toStringArray((BaseObject[])this.m_ClassLabels)));
                List labels = sheet.getCellValues(actCol);
                for (String label : labels) {
                    if (actLabels.indexOf(label) == -1) {
                        actLabels.add(label);
                    }
                    if (predLabels.indexOf(label) != -1) continue;
                    predLabels.add(label);
                }
                labels = sheet.getCellValues(predCol);
                for (String label : labels) {
                    if (actLabels.indexOf(label) == -1) {
                        actLabels.add(label);
                    }
                    if (predLabels.indexOf(label) != -1) continue;
                    predLabels.add(label);
                }
            } else {
                actLabels = sheet.getCellValues(actCol);
                predLabels = sheet.getCellValues(predCol);
                for (String label : actLabels) {
                    if (predLabels.contains(label)) continue;
                    predLabels.add(label);
                }
                for (String label : predLabels) {
                    if (actLabels.contains(label)) continue;
                    actLabels.add(label);
                }
                Collections.sort(predLabels);
                Collections.sort(actLabels);
            }
            for (Row r : sheet.rows()) {
                if ((!r.hasCell(actCol) || !r.getCell(actCol).isMissing()) && (!r.hasCell(predCol) || !r.getCell(predCol).isMissing())) continue;
                actLabels.add(0, "?");
                predLabels.add(0, "?");
                break;
            }
            DefaultSpreadSheet matrix = new DefaultSpreadSheet();
            HeaderRow row = matrix.getHeaderRow();
            row.addCell("0").setContentAsString("x");
            for (i = 0; i < predLabels.size(); ++i) {
                row.addCell("" + (i + 1)).setContentAsString(this.m_PredictedPrefix + (String)predLabels.get(i));
                predIndices.put(predLabels.get(i), i + 1);
            }
            for (i = 0; i < actLabels.size(); ++i) {
                row = matrix.addRow();
                for (n = 0; n < matrix.getColumnCount(); ++n) {
                    row.getCell(n).setContent(Integer.valueOf(0));
                }
                row.addCell(0).setContentAsString(this.m_ActualPrefix + (String)actLabels.get(i));
                actIndices.put(actLabels.get(i), i);
            }
            for (i = 0; i < sheet.getRowCount(); ++i) {
                row = sheet.getRow(i);
                if (!row.hasCell(actCol) || !row.hasCell(predCol)) continue;
                String actLabel = row.getCell(actCol).isMissing() ? "?" : row.getCell(actCol).getContent();
                String predLabel = row.getCell(predCol).isMissing() ? "?" : row.getCell(predCol).getContent();
                int actIndex = (Integer)actIndices.get(actLabel);
                int predIndex = (Integer)predIndices.get(predLabel);
                if (probCol == -1) {
                    matrix.getCell(actIndex, predIndex).setContent(Long.valueOf(matrix.getCell(actIndex, predIndex).toLong() + 1L));
                    continue;
                }
                matrix.getCell(actIndex, predIndex).setContent(Double.valueOf(matrix.getCell(actIndex, predIndex).toDouble() + row.getCell(probCol).toDouble()));
            }
            switch (this.m_MatrixValues) {
                case COUNTS: {
                    break;
                }
                case PERCENTAGES: {
                    int sum = 0;
                    for (i = 0; i < matrix.getRowCount(); ++i) {
                        for (n = 1; n < matrix.getColumnCount(); ++n) {
                            sum = (int)((long)sum + matrix.getCell(i, n).toLong());
                        }
                    }
                    if (sum <= 0) break;
                    for (i = 0; i < matrix.getRowCount(); ++i) {
                        for (n = 1; n < matrix.getColumnCount(); ++n) {
                            matrix.getCell(i, n).setContent(Double.valueOf(matrix.getCell(i, n).toDouble() / (double)sum));
                        }
                    }
                    break;
                }
                case PERCENTAGES_PER_ROW: {
                    for (i = 0; i < matrix.getRowCount(); ++i) {
                        int sum = 0;
                        for (n = 1; n < matrix.getColumnCount(); ++n) {
                            sum = (int)((long)sum + matrix.getCell(i, n).toLong());
                        }
                        if (sum <= 0) continue;
                        for (n = 1; n < matrix.getColumnCount(); ++n) {
                            matrix.getCell(i, n).setContent(Double.valueOf(matrix.getCell(i, n).toDouble() / (double)sum));
                        }
                    }
                    break;
                }
                default: {
                    throw new IllegalStateException("Unhandled matrix values: " + (Object)((Object)this.m_MatrixValues));
                }
            }
            this.m_OutputToken = new Token((Object)matrix);
        }
        return result;
    }

    public static enum MatrixValues {
        COUNTS,
        PERCENTAGES,
        PERCENTAGES_PER_ROW;

    }
}

