/*
 * Decompiled with CFR 0.152.
 */
package mulan.experiments;

import java.util.HashMap;
import java.util.LinkedList;
import java.util.Random;
import mulan.classifier.meta.thresholding.OneThreshold;
import mulan.classifier.transformation.EnsembleOfPrunedSets;
import mulan.classifier.transformation.PrunedSets;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.Evaluation;
import mulan.evaluation.Evaluator;
import mulan.evaluation.MultipleEvaluation;
import mulan.evaluation.measure.BipartitionMeasureBase;
import mulan.evaluation.measure.ExampleBasedAccuracy;
import mulan.evaluation.measure.ExampleBasedFMeasure;
import mulan.evaluation.measure.HammingLoss;
import mulan.evaluation.measure.Measure;
import mulan.experiments.Experiment;
import weka.classifiers.Classifier;
import weka.classifiers.functions.SMO;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.Utils;

public class ICDM08EnsembleOfPrunedSets
extends Experiment {
    public static void main(String[] args) {
        try {
            String path = Utils.getOption((String)"path", (String[])args);
            String filestem = Utils.getOption((String)"filestem", (String[])args);
            System.out.println("Loading the data set");
            MultiLabelInstances dataSet = new MultiLabelInstances(path + filestem + ".arff", path + filestem + ".xml");
            Measure[] evaluationMeasures = new Measure[]{new ExampleBasedAccuracy(), new HammingLoss()};
            evaluationMeasures[2] = new ExampleBasedFMeasure();
            HashMap<String, MultipleEvaluation> result = new HashMap<String, MultipleEvaluation>();
            for (Measure m : evaluationMeasures) {
                MultipleEvaluation me = new MultipleEvaluation(dataSet);
                result.put(m.getName(), me);
            }
            Random random = new Random(1L);
            for (int repetition = 0; repetition < 5; ++repetition) {
                dataSet.getDataSet().randomize(random);
                for (int fold = 0; fold < 2; ++fold) {
                    Evaluator evaluator;
                    System.out.println("Experiment " + (repetition * 2 + fold + 1));
                    Instances train = dataSet.getDataSet().trainCV(2, fold);
                    MultiLabelInstances multiTrain = new MultiLabelInstances(train, dataSet.getLabelsMetaData());
                    Instances test = dataSet.getDataSet().testCV(2, fold);
                    MultiLabelInstances multiTest = new MultiLabelInstances(test, dataSet.getLabelsMetaData());
                    HashMap<String, Integer> bestP = new HashMap<String, Integer>();
                    HashMap<String, Integer> bestB = new HashMap<String, Integer>();
                    HashMap<String, PrunedSets.Strategy> bestStrategy = new HashMap<String, PrunedSets.Strategy>();
                    HashMap<String, Double> bestDiff = new HashMap<String, Double>();
                    for (Measure m : evaluationMeasures) {
                        bestDiff.put(m.getName(), (Double)Double.MAX_VALUE);
                    }
                    System.out.println("Searching parameters");
                    for (int p = 5; p > 1; --p) {
                        for (int b = 1; b < 4; ++b) {
                            double diff;
                            evaluator = new Evaluator();
                            PrunedSets ps = new PrunedSets((Classifier)new SMO(), p, PrunedSets.Strategy.A, b);
                            LinkedList<Measure> measures = new LinkedList<Measure>();
                            for (Measure m : evaluationMeasures) {
                                measures.add(m.makeCopy());
                            }
                            System.out.print("p=" + p + " b=" + b + " strategy=A ");
                            MultipleEvaluation innerResult = evaluator.crossValidate(ps, multiTrain, measures, 5);
                            for (Measure m : evaluationMeasures) {
                                System.out.print(m.getName() + ": " + innerResult.getMean(m.getName()) + " ");
                                diff = Math.abs(m.getIdealValue() - innerResult.getMean(m.getName()));
                                if (!(diff <= (Double)bestDiff.get(m.getName()))) continue;
                                bestDiff.put(m.getName(), diff);
                                bestP.put(m.getName(), p);
                                bestB.put(m.getName(), b);
                                bestStrategy.put(m.getName(), PrunedSets.Strategy.A);
                            }
                            System.out.println();
                            evaluator = new Evaluator();
                            ps = new PrunedSets((Classifier)new SMO(), p, PrunedSets.Strategy.B, b);
                            measures = new LinkedList();
                            for (Measure m : evaluationMeasures) {
                                measures.add(m.makeCopy());
                            }
                            System.out.print("p=" + p + " b=" + b + " strategy=B ");
                            innerResult = evaluator.crossValidate(ps, multiTrain, measures, 5);
                            for (Measure m : evaluationMeasures) {
                                System.out.print(m.getName() + ": " + innerResult.getMean(m.getName()) + " ");
                                diff = Math.abs(m.getIdealValue() - innerResult.getMean(m.getName()));
                                if (!(diff <= (Double)bestDiff.get(m.getName()))) continue;
                                bestDiff.put(m.getName(), diff);
                                bestP.put(m.getName(), p);
                                bestB.put(m.getName(), b);
                                bestStrategy.put(m.getName(), PrunedSets.Strategy.B);
                            }
                            System.out.println();
                        }
                    }
                    for (Measure m : evaluationMeasures) {
                        System.out.println(m.getName());
                        System.out.println("Best p: " + bestP.get(m.getName()));
                        System.out.println("Best strategy: " + bestStrategy.get(m.getName()));
                        System.out.println("Best b: " + bestB.get(m.getName()));
                        EnsembleOfPrunedSets eps = new EnsembleOfPrunedSets(63.0, 10, 0.5, (Integer)bestP.get(m.getName()), (PrunedSets.Strategy)((Object)bestStrategy.get(m.getName())), (Integer)bestB.get(m.getName()), (Classifier)new SMO());
                        OneThreshold ot = new OneThreshold(eps, (BipartitionMeasureBase)m.makeCopy(), 5);
                        ot.build(multiTrain);
                        System.out.println("Best threshold: " + ot.getThreshold());
                        evaluator = new Evaluator();
                        Evaluation e = evaluator.evaluate(ot, multiTest);
                        System.out.println(e.toCSV());
                        ((MultipleEvaluation)result.get(m.getName())).addEvaluation(e);
                    }
                }
            }
            for (Measure m : evaluationMeasures) {
                System.out.println(m.getName());
                ((MultipleEvaluation)result.get(m.getName())).calculateStatistics();
                System.out.println(result.get(m.getName()));
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.CONFERENCE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Read, Jesse");
        result.setValue(TechnicalInformation.Field.TITLE, "Multi-label Classification using Ensembles of Pruned Sets");
        result.setValue(TechnicalInformation.Field.PAGES, "995-1000");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "ICDM'08: Eighth IEEE International Conference on Data Mining");
        result.setValue(TechnicalInformation.Field.YEAR, "2008");
        return result;
    }
}

