/*
 * Decompiled with CFR 0.152.
 */
package adams.flow.transformer;

import adams.core.QuickInfoHelper;
import adams.core.Range;
import adams.core.option.OptionHandler;
import adams.data.spreadsheet.DataRow;
import adams.data.spreadsheet.DefaultSpreadSheet;
import adams.data.spreadsheet.SpreadSheet;
import adams.flow.container.DL4JEvaluationContainer;
import adams.flow.core.Token;
import adams.flow.transformer.AbstractTransformer;
import adams.ml.dl4j.EvaluationHelper;
import adams.ml.dl4j.EvaluationStatistic;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;

public class DL4JEvaluationValues
extends AbstractTransformer {
    private static final long serialVersionUID = -1977976026411517458L;
    protected EvaluationStatistic[] m_StatisticValues;
    protected Range m_ClassIndex;
    protected Range m_RegressionColumns;

    public String globalInfo() {
        return "Generates a spreadsheet from statistics of an Evaluation object.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("statistic", "statisticValues", (Object)new EvaluationStatistic[]{EvaluationStatistic.ACCURACY, EvaluationStatistic.F1});
        this.m_OptionManager.add("index", "classIndex", (Object)new Range("first"));
        this.m_OptionManager.add("regression-columns", "regressionColumns", (Object)new Range("last"));
    }

    public void setStatisticValues(EvaluationStatistic[] value) {
        this.m_StatisticValues = value;
        this.reset();
    }

    public EvaluationStatistic[] getStatisticValues() {
        return this.m_StatisticValues;
    }

    public String statisticValuesTipText() {
        return "The evaluation values to extract and turn into a spreadsheet.";
    }

    public void setClassIndex(Range value) {
        this.m_ClassIndex = value;
        this.reset();
    }

    public Range getClassIndex() {
        return this.m_ClassIndex;
    }

    public String classIndexTipText() {
        return "The range of class label indices (eg used for AUC).";
    }

    public void setRegressionColumns(Range value) {
        this.m_RegressionColumns = value;
        this.reset();
    }

    public Range getRegressionColumns() {
        return this.m_RegressionColumns;
    }

    public String regressionColumnsTipText() {
        return "The range of columns to get regression statistics for.";
    }

    public String getQuickInfo() {
        String result = QuickInfoHelper.toString((OptionHandler)this, (String)"classIndex", (Object)this.m_ClassIndex, (String)"class labels: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"regressionColumns", (Object)this.m_RegressionColumns, (String)", reg cols: ");
        result = result + QuickInfoHelper.toString((OptionHandler)this, (String)"statisticValues", (Object)(this.m_StatisticValues.length + " value" + (this.m_StatisticValues.length != 1 ? "s" : "")), (String)", ");
        return result;
    }

    public Class[] accepts() {
        return new Class[]{Evaluation.class, DL4JEvaluationContainer.class};
    }

    protected String addStatistic(Evaluation eval, SpreadSheet sheet, EvaluationStatistic statistic, int classIndex, boolean useIndex) {
        String result = null;
        try {
            double value = EvaluationHelper.getValue(eval, statistic, classIndex);
            DataRow row = sheet.addRow("" + sheet.getRowCount());
            String name = statistic.toDisplay();
            if (useIndex && statistic.isPerClass()) {
                name = name + " (label #" + (classIndex + 1) + ")";
            }
            row.addCell(0).setContent(name);
            row.addCell(1).setContent(Double.toString(value));
        }
        catch (Exception e) {
            result = this.handleException("Error retrieving value for '" + (Object)((Object)statistic) + "':\n", e);
        }
        return result;
    }

    protected String addStatistic(RegressionEvaluation eval, SpreadSheet sheet, EvaluationStatistic statistic, int column, boolean useCol) {
        String result = null;
        try {
            double value = EvaluationHelper.getValue(eval, statistic, column);
            DataRow row = sheet.addRow("" + sheet.getRowCount());
            String name = statistic.toDisplay();
            if (useCol) {
                name = name + " (col #" + (column + 1) + ")";
            }
            row.addCell(0).setContent(name);
            row.addCell(1).setContent(Double.toString(value));
        }
        catch (Exception e) {
            result = this.handleException("Error retrieving value for '" + (Object)((Object)statistic) + "':\n", e);
        }
        return result;
    }

    protected String doExecute() {
        String result = null;
        Evaluation evalCls = null;
        RegressionEvaluation evalReg = null;
        Object evalObj = this.m_InputToken.getPayload() instanceof DL4JEvaluationContainer ? ((DL4JEvaluationContainer)((Object)this.m_InputToken.getPayload())).getValue("Evaluation") : this.m_InputToken.getPayload();
        if (evalObj instanceof Evaluation) {
            evalCls = (Evaluation)evalObj;
        } else if (evalObj instanceof RegressionEvaluation) {
            evalReg = (RegressionEvaluation)evalObj;
        }
        DefaultSpreadSheet sheet = new DefaultSpreadSheet();
        sheet.getHeaderRow().addCell("0").setContent("Statistic");
        sheet.getHeaderRow().addCell("1").setContent("Value");
        if (evalCls != null) {
            this.m_ClassIndex.setMax(evalCls.falseNegatives().size());
            int[] indices = this.m_ClassIndex.getIntIndices();
            if (indices.length == 1) {
                for (EvaluationStatistic statistic : this.m_StatisticValues) {
                    String msg;
                    if (!statistic.isClassification() || (msg = this.addStatistic(evalCls, (SpreadSheet)sheet, statistic, indices[0], true)) == null) continue;
                    result = result == null ? "" : result + "\n";
                    result = result + msg;
                }
            } else if (indices.length > 1) {
                String msg;
                for (EvaluationStatistic statistic : this.m_StatisticValues) {
                    if (!statistic.isClassification() || statistic.isPerClass() || (msg = this.addStatistic(evalCls, (SpreadSheet)sheet, statistic, 0, false)) == null) continue;
                    result = result == null ? "" : result + "\n";
                    result = result + msg;
                }
                for (int index : indices) {
                    for (EvaluationStatistic statistic : this.m_StatisticValues) {
                        if (!statistic.isClassification() || !statistic.isPerClass() || (msg = this.addStatistic(evalCls, (SpreadSheet)sheet, statistic, index, true)) == null) continue;
                        result = result == null ? "" : result + "\n";
                        result = result + msg;
                    }
                }
            }
        } else if (evalReg != null) {
            this.m_RegressionColumns.setMax(evalReg.numColumns());
            int[] indices = this.m_RegressionColumns.getIntIndices();
            for (EvaluationStatistic statistic : this.m_StatisticValues) {
                if (!statistic.isRegression()) continue;
                for (int index : indices) {
                    String msg = this.addStatistic(evalReg, (SpreadSheet)sheet, statistic, index, true);
                    if (msg == null) continue;
                    result = result == null ? "" : result + "\n";
                    result = result + msg;
                }
            }
        }
        this.m_OutputToken = new Token((Object)sheet);
        return result;
    }

    public Class[] generates() {
        return new Class[]{SpreadSheet.class};
    }
}

