/*
 * Decompiled with CFR 0.152.
 */
package meka.experiment;

import java.io.ByteArrayOutputStream;
import java.io.ObjectOutputStream;
import java.util.HashMap;
import meka.classifiers.multilabel.BR;
import meka.classifiers.multilabel.Evaluation;
import meka.classifiers.multilabel.MultilabelClassifier;
import meka.core.Result;
import meka.experiment.MekaSplitEvaluator;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.core.AdditionalMeasureProducer;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Summarizable;
import weka.core.Utils;
import weka.experiment.ClassifierSplitEvaluator;

public class MekaClassifierSplitEvaluator
extends ClassifierSplitEvaluator
implements MekaSplitEvaluator {
    static final long serialVersionUID = -8511241602760467265L;
    private static final int KEY_SIZE = 3;
    private static final int RESULT_SIZE = 27;
    protected int m_TotalNumClasses = 0;

    public MekaClassifierSplitEvaluator() {
        this.m_Template = new BR();
        this.updateOptions();
    }

    @Override
    public void setTotalNumClasses(int value) {
        this.m_TotalNumClasses = value;
    }

    @Override
    public int getTotalNumClasses() {
        return this.m_TotalNumClasses;
    }

    public String globalInfo() {
        return " A SplitEvaluator that produces results for a MEKA classification scheme on a nominal class attribute.";
    }

    public Object[] getKeyTypes() {
        Object[] keyTypes = new Object[]{"", "", ""};
        return keyTypes;
    }

    public String[] getKeyNames() {
        String[] keyNames = new String[]{"Scheme", "Scheme_options", "Scheme_version_ID"};
        return keyNames;
    }

    public Object[] getKey() {
        Object[] key = new Object[]{this.m_Template.getClass().getName(), this.m_ClassifierOptions, this.m_ClassifierVersion};
        return key;
    }

    public Object[] getResultTypes() {
        int i;
        int addm = this.m_AdditionalMeasures != null ? this.m_AdditionalMeasures.length : 0;
        int overall_length = 27 + addm + this.m_TotalNumClasses * 2;
        if (this.getAttributeID() >= 0) {
            ++overall_length;
        }
        Object[] resultTypes = new Object[overall_length];
        Double doub = new Double(0.0);
        int current = 0;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        for (i = 0; i < this.m_TotalNumClasses; ++i) {
            resultTypes[current++] = doub;
            resultTypes[current++] = doub;
        }
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        resultTypes[current++] = doub;
        if (this.getAttributeID() >= 0) {
            resultTypes[current++] = "";
        }
        resultTypes[current++] = "";
        for (i = 0; i < addm; ++i) {
            resultTypes[current++] = doub;
        }
        if (current != overall_length) {
            throw new Error("ResultTypes didn't fit RESULT_SIZE: " + current + " != " + 27);
        }
        return resultTypes;
    }

    public String[] getResultNames() {
        int i;
        int addm = this.m_AdditionalMeasures != null ? this.m_AdditionalMeasures.length : 0;
        int overall_length = 27 + addm + this.m_TotalNumClasses * 2;
        if (this.getAttributeID() >= 0) {
            ++overall_length;
        }
        String[] resultNames = new String[overall_length];
        int current = 0;
        resultNames[current++] = "N_train";
        resultNames[current++] = "N_test";
        resultNames[current++] = "LCard_train";
        resultNames[current++] = "LCard_test";
        resultNames[current++] = "Build_time";
        resultNames[current++] = "Test_time";
        resultNames[current++] = "Total_time";
        resultNames[current++] = "Accuracy";
        resultNames[current++] = "Hamming_score";
        resultNames[current++] = "Exact_match";
        resultNames[current++] = "Jaccard_dist";
        resultNames[current++] = "Hamming_loss";
        resultNames[current++] = "ZeroOne_loss";
        resultNames[current++] = "Harmonic_score";
        resultNames[current++] = "One_error";
        resultNames[current++] = "Rank_loss";
        resultNames[current++] = "Avg_precision";
        resultNames[current++] = "Log_Loss_D";
        resultNames[current++] = "Log_Loss_L";
        resultNames[current++] = "F-micro";
        resultNames[current++] = "F-macro_D";
        resultNames[current++] = "F-macro_L";
        resultNames[current++] = "N_empty";
        for (i = 0; i < this.m_TotalNumClasses; ++i) {
            resultNames[current++] = "Accuracy[" + i + "]";
            resultNames[current++] = "Harmonic[" + i + "]";
        }
        resultNames[current++] = "Serialized_Model_Size";
        resultNames[current++] = "Serialized_Train_Set_Size";
        resultNames[current++] = "Serialized_Test_Set_Size";
        if (this.getAttributeID() >= 0) {
            resultNames[current++] = "Instance_ID";
        }
        resultNames[current++] = "Summary";
        for (i = 0; i < addm; ++i) {
            resultNames[current++] = this.m_AdditionalMeasures[i];
        }
        if (current != overall_length) {
            throw new Error("ResultNames didn't fit RESULT_SIZE");
        }
        return resultNames;
    }

    protected Double getEvaluationMetric(HashMap<String, Double> values, String name) {
        if (values.containsKey(name)) {
            return values.get(name);
        }
        return Utils.missingValue();
    }

    public Object[] getResult(Instances train, Instances test) throws Exception {
        int i;
        if (this.m_Template == null) {
            throw new Exception("No classifier has been specified");
        }
        int addm = this.m_AdditionalMeasures != null ? this.m_AdditionalMeasures.length : 0;
        int overall_length = 27 + addm + this.m_TotalNumClasses * 2;
        if (this.getAttributeID() >= 0) {
            ++overall_length;
        }
        Object[] result = new Object[overall_length];
        this.m_Classifier = AbstractClassifier.makeCopy((Classifier)this.m_Template);
        Result res = Evaluation.evaluateModel((MultilabelClassifier)this.m_Classifier, train, test, "PCut1", "3");
        HashMap<String, Double> map = Result.getStats(res, "3");
        this.m_result = res.toString();
        int current = 0;
        result[current++] = res.getValue("N_train");
        result[current++] = res.getValue("N_test");
        result[current++] = res.getValue("LCard_train");
        result[current++] = res.getValue("LCard_test");
        result[current++] = res.getValue("Build_time");
        result[current++] = res.getValue("Test_time");
        result[current++] = res.getValue("Total_time");
        result[current++] = this.getEvaluationMetric(map, "Accuracy");
        result[current++] = this.getEvaluationMetric(map, "Hamming score");
        result[current++] = this.getEvaluationMetric(map, "Exact match");
        result[current++] = this.getEvaluationMetric(map, "Jaccard dist");
        result[current++] = this.getEvaluationMetric(map, "Hamming loss");
        result[current++] = this.getEvaluationMetric(map, "ZeroOne loss");
        result[current++] = this.getEvaluationMetric(map, "Harmonic score");
        result[current++] = this.getEvaluationMetric(map, "One error");
        result[current++] = this.getEvaluationMetric(map, "Rank loss");
        result[current++] = this.getEvaluationMetric(map, "Avg precision");
        result[current++] = this.getEvaluationMetric(map, "Log Loss D");
        result[current++] = this.getEvaluationMetric(map, "Log Loss L");
        result[current++] = this.getEvaluationMetric(map, "F-micro");
        result[current++] = this.getEvaluationMetric(map, "F-macro_D");
        result[current++] = this.getEvaluationMetric(map, "F-macro_L");
        result[current++] = this.getEvaluationMetric(map, "N_empty");
        for (i = 0; i < this.m_TotalNumClasses; ++i) {
            result[current++] = this.getEvaluationMetric(map, "Accuracy[" + i + "]");
            result[current++] = this.getEvaluationMetric(map, "Harmonic[" + i + "]");
        }
        if (this.getNoSizeDetermination()) {
            result[current++] = -1.0;
            result[current++] = -1.0;
            result[current++] = -1.0;
        } else {
            ByteArrayOutputStream bastream = new ByteArrayOutputStream();
            ObjectOutputStream oostream = new ObjectOutputStream(bastream);
            oostream.writeObject(this.m_Classifier);
            result[current++] = new Double(bastream.size());
            bastream = new ByteArrayOutputStream();
            oostream = new ObjectOutputStream(bastream);
            oostream.writeObject(train);
            result[current++] = new Double(bastream.size());
            bastream = new ByteArrayOutputStream();
            oostream = new ObjectOutputStream(bastream);
            oostream.writeObject(test);
            result[current++] = new Double(bastream.size());
        }
        if (this.getAttributeID() >= 0) {
            int i2;
            String idsString = "";
            if (test.attribute(this.getAttributeID()).isNumeric()) {
                if (test.numInstances() > 0) {
                    idsString = idsString + test.instance(0).value(this.getAttributeID());
                }
                for (i2 = 1; i2 < test.numInstances(); ++i2) {
                    idsString = idsString + "|" + test.instance(i2).value(this.getAttributeID());
                }
            } else {
                if (test.numInstances() > 0) {
                    idsString = idsString + test.instance(0).stringValue(this.getAttributeID());
                }
                for (i2 = 1; i2 < test.numInstances(); ++i2) {
                    idsString = idsString + "|" + test.instance(i2).stringValue(this.getAttributeID());
                }
            }
            result[current++] = idsString;
        }
        result[current++] = this.m_Classifier instanceof Summarizable ? ((Summarizable)this.m_Classifier).toSummaryString() : null;
        for (i = 0; i < addm; ++i) {
            if (this.m_doesProduce[i]) {
                try {
                    double dv = ((AdditionalMeasureProducer)this.m_Classifier).getMeasure(this.m_AdditionalMeasures[i]);
                    if (!Utils.isMissingValue((double)dv)) {
                        Double value = new Double(dv);
                        result[current++] = value;
                        continue;
                    }
                    result[current++] = null;
                }
                catch (Exception ex) {
                    System.err.println(ex);
                }
                continue;
            }
            result[current++] = null;
        }
        if (current != overall_length) {
            throw new Error("Results didn't fit RESULT_SIZE");
        }
        return result;
    }

    public void setClassifier(Classifier newClassifier) {
        if (!(newClassifier instanceof MultilabelClassifier)) {
            throw new IllegalArgumentException("Classifier must be a " + MultilabelClassifier.class.getName() + ", provided: " + newClassifier.getClass().getName());
        }
        super.setClassifier(newClassifier);
    }

    public String classifierTipText() {
        return super.classifierTipText() + ", must be a " + MultilabelClassifier.class.getName();
    }

    public String getRevision() {
        return RevisionUtils.extract((String)"$Revision: 10376 $");
    }
}

