/*
 * Decompiled with CFR 0.152.
 */
package mulan.evaluation;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.Evaluation;
import mulan.evaluation.MultipleEvaluation;
import mulan.evaluation.measure.AveragePrecision;
import mulan.evaluation.measure.Coverage;
import mulan.evaluation.measure.ErrorSetSize;
import mulan.evaluation.measure.ExampleBasedAccuracy;
import mulan.evaluation.measure.ExampleBasedFMeasure;
import mulan.evaluation.measure.ExampleBasedPrecision;
import mulan.evaluation.measure.ExampleBasedRecall;
import mulan.evaluation.measure.ExampleBasedSpecificity;
import mulan.evaluation.measure.GeometricMeanAverageInterpolatedPrecision;
import mulan.evaluation.measure.GeometricMeanAveragePrecision;
import mulan.evaluation.measure.HammingLoss;
import mulan.evaluation.measure.HierarchicalLoss;
import mulan.evaluation.measure.IsError;
import mulan.evaluation.measure.MacroAUC;
import mulan.evaluation.measure.MacroFMeasure;
import mulan.evaluation.measure.MacroPrecision;
import mulan.evaluation.measure.MacroRecall;
import mulan.evaluation.measure.MacroSpecificity;
import mulan.evaluation.measure.MeanAverageInterpolatedPrecision;
import mulan.evaluation.measure.MeanAveragePrecision;
import mulan.evaluation.measure.Measure;
import mulan.evaluation.measure.MicroAUC;
import mulan.evaluation.measure.MicroFMeasure;
import mulan.evaluation.measure.MicroPrecision;
import mulan.evaluation.measure.MicroRecall;
import mulan.evaluation.measure.MicroSpecificity;
import mulan.evaluation.measure.OneError;
import mulan.evaluation.measure.RankingLoss;
import mulan.evaluation.measure.SubsetAccuracy;
import weka.core.Instance;
import weka.core.Instances;

public class Evaluator {
    private int seed = 1;

    public void setSeed(int aSeed) {
        this.seed = aSeed;
    }

    public Evaluation evaluate(MultiLabelLearner learner, MultiLabelInstances data, List<Measure> measures) throws IllegalArgumentException, Exception {
        this.checkLearner(learner);
        this.checkData(data);
        this.checkMeasures(measures);
        for (Measure m : measures) {
            m.reset();
        }
        int numLabels = data.getNumLabels();
        int[] labelIndices = data.getLabelIndices();
        HashSet<Measure> failed = new HashSet<Measure>();
        Instances testData = data.getDataSet();
        int numInstances = testData.numInstances();
        for (int instanceIndex = 0; instanceIndex < numInstances; ++instanceIndex) {
            Instance instance = testData.instance(instanceIndex);
            if (data.hasMissingLabels(instance)) continue;
            Instance labelsMissing = (Instance)instance.copy();
            labelsMissing.setDataset(instance.dataset());
            for (int i = 0; i < data.getNumLabels(); ++i) {
                labelsMissing.setMissing(data.getLabelIndices()[i]);
            }
            MultiLabelOutput output = learner.makePrediction(labelsMissing);
            boolean[] trueLabels = this.getTrueLabels(instance, numLabels, labelIndices);
            for (Measure m : measures) {
                if (failed.contains(m)) continue;
                try {
                    m.update(output, trueLabels);
                }
                catch (Exception ex) {
                    failed.add(m);
                }
            }
        }
        return new Evaluation(measures, data);
    }

    private void checkLearner(MultiLabelLearner learner) {
        if (learner == null) {
            throw new IllegalArgumentException("Learner to be evaluated is null.");
        }
    }

    private void checkData(MultiLabelInstances data) {
        if (data == null) {
            throw new IllegalArgumentException("Evaluation data object is null.");
        }
    }

    private void checkMeasures(List<Measure> measures) {
        if (measures == null) {
            throw new IllegalArgumentException("List of evaluation measures to compute is null.");
        }
    }

    private void checkFolds(int someFolds) {
        if (someFolds < 2) {
            throw new IllegalArgumentException("Number of folds must be at least two or higher.");
        }
    }

    public Evaluation evaluate(MultiLabelLearner learner, MultiLabelInstances data) throws IllegalArgumentException, Exception {
        this.checkLearner(learner);
        this.checkData(data);
        List<Measure> measures = this.prepareMeasures(learner, data);
        return this.evaluate(learner, data, measures);
    }

    private List<Measure> prepareMeasures(MultiLabelLearner learner, MultiLabelInstances data) {
        ArrayList<Measure> measures = new ArrayList<Measure>();
        try {
            int numOfLabels;
            MultiLabelLearner copyOfLearner = learner.makeCopy();
            MultiLabelOutput prediction = copyOfLearner.makePrediction(data.getDataSet().instance(0));
            if (prediction.hasBipartition()) {
                measures.add(new HammingLoss());
                measures.add(new SubsetAccuracy());
                measures.add(new ExampleBasedPrecision());
                measures.add(new ExampleBasedRecall());
                measures.add(new ExampleBasedFMeasure());
                measures.add(new ExampleBasedAccuracy());
                measures.add(new ExampleBasedSpecificity());
                numOfLabels = data.getNumLabels();
                measures.add(new MicroPrecision(numOfLabels));
                measures.add(new MicroRecall(numOfLabels));
                measures.add(new MicroFMeasure(numOfLabels));
                measures.add(new MicroSpecificity(numOfLabels));
                measures.add(new MacroPrecision(numOfLabels));
                measures.add(new MacroRecall(numOfLabels));
                measures.add(new MacroFMeasure(numOfLabels));
                measures.add(new MacroSpecificity(numOfLabels));
            }
            if (prediction.hasRanking()) {
                measures.add(new AveragePrecision());
                measures.add(new Coverage());
                measures.add(new OneError());
                measures.add(new IsError());
                measures.add(new ErrorSetSize());
                measures.add(new RankingLoss());
            }
            if (prediction.hasConfidences()) {
                numOfLabels = data.getNumLabels();
                measures.add(new MeanAveragePrecision(numOfLabels));
                measures.add(new GeometricMeanAveragePrecision(numOfLabels));
                measures.add(new MeanAverageInterpolatedPrecision(numOfLabels, 10));
                measures.add(new GeometricMeanAverageInterpolatedPrecision(numOfLabels, 10));
                measures.add(new MicroAUC(numOfLabels));
                measures.add(new MacroAUC(numOfLabels));
            }
            if (data.getLabelsMetaData().isHierarchy()) {
                measures.add(new HierarchicalLoss(data));
            }
        }
        catch (Exception ex) {
            Logger.getLogger(Evaluator.class.getName()).log(Level.SEVERE, null, ex);
        }
        return measures;
    }

    private boolean[] getTrueLabels(Instance instance, int numLabels, int[] labelIndices) {
        boolean[] trueLabels = new boolean[numLabels];
        for (int counter = 0; counter < numLabels; ++counter) {
            int classIdx = labelIndices[counter];
            String classValue = instance.attribute(classIdx).value((int)instance.value(classIdx));
            trueLabels[counter] = classValue.equals("1");
        }
        return trueLabels;
    }

    public MultipleEvaluation crossValidate(MultiLabelLearner learner, MultiLabelInstances data, int someFolds) {
        this.checkLearner(learner);
        this.checkData(data);
        this.checkFolds(someFolds);
        return this.innerCrossValidate(learner, data, false, null, someFolds);
    }

    public MultipleEvaluation crossValidate(MultiLabelLearner learner, MultiLabelInstances data, List<Measure> measures, int someFolds) {
        this.checkLearner(learner);
        this.checkData(data);
        this.checkMeasures(measures);
        return this.innerCrossValidate(learner, data, true, measures, someFolds);
    }

    private MultipleEvaluation innerCrossValidate(MultiLabelLearner learner, MultiLabelInstances data, boolean hasMeasures, List<Measure> measures, int someFolds) {
        Evaluation[] evaluation = new Evaluation[someFolds];
        Instances workingSet = new Instances(data.getDataSet());
        workingSet.randomize(new Random(this.seed));
        for (int i = 0; i < someFolds; ++i) {
            System.out.println("Fold " + (i + 1) + "/" + someFolds);
            try {
                Instances train = workingSet.trainCV(someFolds, i);
                Instances test = workingSet.testCV(someFolds, i);
                MultiLabelInstances mlTrain = new MultiLabelInstances(train, data.getLabelsMetaData());
                MultiLabelInstances mlTest = new MultiLabelInstances(test, data.getLabelsMetaData());
                MultiLabelLearner clone = learner.makeCopy();
                clone.build(mlTrain);
                if (hasMeasures) {
                    evaluation[i] = this.evaluate(clone, mlTest, measures);
                    continue;
                }
                evaluation[i] = this.evaluate(clone, mlTest);
                continue;
            }
            catch (Exception ex) {
                Logger.getLogger(Evaluator.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        MultipleEvaluation me = new MultipleEvaluation(evaluation, data);
        me.calculateStatistics();
        return me;
    }
}

