/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.transformation;

import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.transformation.TransformationBasedMultiLabelLearner;
import mulan.data.DataUtils;
import mulan.data.MultiLabelInstances;
import mulan.data.Statistics;
import mulan.transformations.BinaryRelevanceTransformation;
import weka.attributeSelection.ASEvaluation;
import weka.attributeSelection.ASSearch;
import weka.attributeSelection.AttributeSelection;
import weka.attributeSelection.Ranker;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.lazy.IBk;
import weka.classifiers.meta.FilteredClassifier;
import weka.classifiers.trees.J48;
import weka.core.Attribute;
import weka.core.DistanceFunction;
import weka.core.EuclideanDistance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.Utils;
import weka.core.neighboursearch.LinearNNSearch;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

public class MultiLabelStacking
extends TransformationBasedMultiLabelLearner
implements Serializable {
    private static final long serialVersionUID = 1L;
    private Classifier metaClassifier;
    private Instances[] baseLevelData;
    private Instances[] metaLevelData;
    private Classifier[] baseLevelEnsemble;
    private Classifier[] metaLevelEnsemble;
    private FilteredClassifier[] metaLevelFilteredEnsemble;
    private int numFolds;
    protected Instances train;
    private double[][] baseLevelPredictions;
    private boolean normalize;
    private double[] maxProb;
    private double[] minProb;
    private boolean includeAttrs;
    private double metaPercentage;
    private int topkCorrelated;
    private int[][] selectedAttributes;
    private ASEvaluation eval;
    private LinearNNSearch lnn = null;
    private boolean partialBuild;

    @Override
    public String globalInfo() {
        return "This class is an implementation of the (BR)^2 or Multi-Label stacking method.\n\nFor more information, see\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Grigorios Tsoumakas, Anastasios Dimou, Eleftherios Spyromitros, Vasileios Mezaris, Ioannis Kompatsiaris, Ioannis Vlahavas");
        result.setValue(TechnicalInformation.Field.TITLE, "Correlation-Based Pruning of Stacked Binary Relevance Models for Multi-Label Learning");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Proc. ECML/PKDD 2009 Workshop on Learning from Multi-Label Data (MLD'09)");
        result.setValue(TechnicalInformation.Field.YEAR, "2009");
        result.setValue(TechnicalInformation.Field.PAGES, "101-116");
        result.setValue(TechnicalInformation.Field.LOCATION, "Bled, Slovenia");
        return result;
    }

    public MultiLabelStacking() {
        this((Classifier)new J48(), (Classifier)new J48());
    }

    public MultiLabelStacking(Classifier baseClassifier, Classifier metaClassifier) {
        super(baseClassifier);
        this.metaClassifier = metaClassifier;
        this.numFolds = 10;
        this.metaPercentage = 1.0;
        this.eval = null;
        this.normalize = false;
        this.includeAttrs = false;
        this.partialBuild = false;
    }

    @Override
    protected void buildInternal(MultiLabelInstances dataSet) throws Exception {
        if (this.partialBuild) {
            return;
        }
        if (this.baseClassifier instanceof IBk) {
            this.buildBaseLevelKNN(dataSet);
        } else {
            this.buildBaseLevel(dataSet);
        }
        this.initializeMetaLevel(dataSet, this.metaClassifier, this.includeAttrs, this.metaPercentage, this.eval);
        this.buildMetaLevel();
    }

    public void initializeMetaLevel(MultiLabelInstances dataSet, Classifier metaClassifier, boolean includeAttrs, double metaPercentage, ASEvaluation eval) throws Exception {
        block8: {
            this.metaClassifier = metaClassifier;
            this.metaLevelEnsemble = AbstractClassifier.makeCopies((Classifier)metaClassifier, (int)this.numLabels);
            this.metaLevelData = new Instances[this.numLabels];
            this.metaLevelFilteredEnsemble = new FilteredClassifier[this.numLabels];
            this.includeAttrs = includeAttrs;
            this.topkCorrelated = (int)Math.floor(metaPercentage * (double)this.numLabels);
            if (this.topkCorrelated < 1) {
                this.debug("Too small percentage, selecting k=1");
                this.topkCorrelated = 1;
            }
            if (this.topkCorrelated >= this.numLabels) break block8;
            this.selectedAttributes = new int[this.numLabels][];
            if (eval == null) {
                Statistics phi = new Statistics();
                phi.calculatePhi(dataSet);
                for (int i = 0; i < this.numLabels; ++i) {
                    this.selectedAttributes[i] = phi.topPhiCorrelatedLabels(i, this.topkCorrelated);
                }
            } else {
                AttributeSelection attsel = new AttributeSelection();
                Ranker rankingMethod = new Ranker();
                rankingMethod.setNumToSelect(this.topkCorrelated);
                attsel.setEvaluator(eval);
                attsel.setSearch((ASSearch)rankingMethod);
                for (int i = 0; i < this.numLabels; ++i) {
                    ArrayList<Attribute> attributes = new ArrayList<Attribute>();
                    for (int j = 0; j < this.numLabels; ++j) {
                        attributes.add(this.train.attribute(this.labelIndices[j]));
                    }
                    attributes.add(this.train.attribute(this.labelIndices[i]).copy("meta"));
                    Instances iporesult = new Instances("Meta format", attributes, 0);
                    iporesult.setClassIndex(this.numLabels);
                    for (int k = 0; k < this.train.numInstances(); ++k) {
                        double[] values = new double[this.numLabels + 1];
                        for (int m = 0; m < this.numLabels; ++m) {
                            values[m] = Double.parseDouble(this.train.attribute(this.labelIndices[m]).value((int)this.train.instance(k).value(this.labelIndices[m])));
                        }
                        values[this.numLabels] = Double.parseDouble(this.train.attribute(this.labelIndices[i]).value((int)this.train.instance(k).value(this.labelIndices[i])));
                        Instance metaInstance = DataUtils.createInstance(this.train.instance(k), 1.0, values);
                        metaInstance.setDataset(iporesult);
                        iporesult.add(metaInstance);
                    }
                    attsel.SelectAttributes(iporesult);
                    this.selectedAttributes[i] = attsel.selectedAttributes();
                    iporesult.delete();
                }
            }
        }
    }

    public void buildBaseLevel(MultiLabelInstances trainingSet) throws Exception {
        this.train = new Instances(trainingSet.getDataSet());
        this.baseLevelData = new Instances[this.numLabels];
        this.baseLevelEnsemble = AbstractClassifier.makeCopies((Classifier)this.baseClassifier, (int)this.numLabels);
        if (this.normalize) {
            this.maxProb = new double[this.numLabels];
            this.minProb = new double[this.numLabels];
            Arrays.fill(this.minProb, 1.0);
        }
        this.baseLevelPredictions = new double[this.train.numInstances()][this.numLabels];
        for (int labelIndex = 0; labelIndex < this.numLabels; ++labelIndex) {
            this.debug("Label: " + labelIndex);
            this.baseLevelData[labelIndex] = BinaryRelevanceTransformation.transformInstances(this.train, this.labelIndices, this.labelIndices[labelIndex]);
            this.baseLevelData[labelIndex] = new Instances(this.attachIndexes(this.baseLevelData[labelIndex]));
            Random random = new Random(1L);
            this.baseLevelData[labelIndex].randomize(random);
            this.baseLevelData[labelIndex].stratify(this.numFolds);
            this.debug("Creating meta-data");
            for (int j = 0; j < this.numFolds; ++j) {
                this.debug("Label=" + labelIndex + ", Fold=" + j);
                Instances subtrain = this.baseLevelData[labelIndex].trainCV(this.numFolds, j, random);
                FilteredClassifier fil = new FilteredClassifier();
                fil.setClassifier(this.baseLevelEnsemble[labelIndex]);
                Remove remove = new Remove();
                remove.setAttributeIndices("first");
                remove.setInputFormat(subtrain);
                fil.setFilter((Filter)remove);
                fil.buildClassifier(subtrain);
                Instances subtest = this.baseLevelData[labelIndex].testCV(this.numFolds, j);
                for (int i = 0; i < subtest.numInstances(); ++i) {
                    double[] distribution = fil.distributionForInstance(subtest.instance(i));
                    Attribute classAttribute = this.baseLevelData[labelIndex].classAttribute();
                    this.baseLevelPredictions[(int)subtest.instance((int)i).value((int)0)][labelIndex] = distribution[classAttribute.indexOfValue("1")];
                    if (!this.normalize) continue;
                    if (distribution[classAttribute.indexOfValue("1")] > this.maxProb[labelIndex]) {
                        this.maxProb[labelIndex] = distribution[classAttribute.indexOfValue("1")];
                    }
                    if (!(distribution[classAttribute.indexOfValue("1")] < this.minProb[labelIndex])) continue;
                    this.minProb[labelIndex] = distribution[classAttribute.indexOfValue("1")];
                }
            }
            this.baseLevelData[labelIndex] = this.detachIndexes(this.baseLevelData[labelIndex]);
            this.debug("Building base classifier on full data");
            this.baseLevelEnsemble[labelIndex].buildClassifier(this.baseLevelData[labelIndex]);
            this.baseLevelData[labelIndex].delete();
        }
        if (this.normalize) {
            this.normalizePredictions();
        }
    }

    public void buildMetaLevel() throws Exception {
        this.debug("Building the ensemle of the meta level classifiers");
        for (int i = 0; i < this.numLabels; ++i) {
            int j;
            ArrayList<Attribute> attributes = new ArrayList<Attribute>();
            if (this.includeAttrs) {
                for (j = 0; j < this.train.numAttributes(); ++j) {
                    attributes.add(this.train.attribute(j));
                }
            } else {
                for (j = 0; j < this.numLabels; ++j) {
                    attributes.add(this.train.attribute(this.labelIndices[j]));
                }
            }
            attributes.add(this.train.attribute(this.labelIndices[i]).copy("meta"));
            this.metaLevelData[i] = new Instances("Meta format", attributes, 0);
            this.metaLevelData[i].setClassIndex(this.metaLevelData[i].numAttributes() - 1);
            for (int l = 0; l < this.train.numInstances(); ++l) {
                double[] values = new double[this.metaLevelData[i].numAttributes()];
                if (this.includeAttrs) {
                    for (int m = 0; m < this.featureIndices.length; ++m) {
                        values[m] = this.train.instance(l).value(this.featureIndices[m]);
                    }
                    System.arraycopy(this.baseLevelPredictions[l], 0, values, this.train.numAttributes() - this.numLabels, this.numLabels);
                } else {
                    System.arraycopy(this.baseLevelPredictions[l], 0, values, 0, this.numLabels);
                }
                values[values.length - 1] = Double.parseDouble(this.train.attribute(this.labelIndices[i]).value((int)this.train.instance(l).value(this.labelIndices[i])));
                Instance metaInstance = DataUtils.createInstance(this.train.instance(l), 1.0, values);
                metaInstance.setDataset(this.metaLevelData[i]);
                if (values[values.length - 1] > 0.5) {
                    metaInstance.setClassValue("1");
                } else {
                    metaInstance.setClassValue("0");
                }
                this.metaLevelData[i].add(metaInstance);
            }
            this.metaLevelFilteredEnsemble[i] = new FilteredClassifier();
            this.metaLevelFilteredEnsemble[i].setClassifier(this.metaLevelEnsemble[i]);
            Remove remove = new Remove();
            if (this.topkCorrelated < this.numLabels) {
                remove.setAttributeIndicesArray(this.selectedAttributes[i]);
            } else {
                remove.setAttributeIndices("first-last");
            }
            remove.setInvertSelection(true);
            remove.setInputFormat(this.metaLevelData[i]);
            this.metaLevelFilteredEnsemble[i].setFilter((Filter)remove);
            this.debug("Building classifier for meta training set" + i);
            this.metaLevelFilteredEnsemble[i].buildClassifier(this.metaLevelData[i]);
            this.metaLevelData[i].delete();
        }
    }

    public void buildBaseLevelKNN(MultiLabelInstances trainingSet) throws Exception {
        this.train = new Instances(trainingSet.getDataSet());
        EuclideanDistance dfunc = new EuclideanDistance();
        dfunc.setDontNormalize(false);
        String labelIndicesString = "";
        for (int i = 0; i < this.numLabels - 1; ++i) {
            labelIndicesString = labelIndicesString + (this.labelIndices[i] + 1) + ",";
        }
        labelIndicesString = labelIndicesString + (this.labelIndices[this.numLabels - 1] + 1);
        dfunc.setAttributeIndices(labelIndicesString);
        dfunc.setInvertSelection(true);
        this.lnn = new LinearNNSearch();
        this.lnn.setSkipIdentical(true);
        this.lnn.setDistanceFunction((DistanceFunction)dfunc);
        this.lnn.setInstances(this.train);
        this.lnn.setMeasurePerformance(false);
        this.baseLevelPredictions = new double[this.train.numInstances()][this.numLabels];
        int numOfNeighbors = ((IBk)this.baseClassifier).getKNN();
        for (int i = 0; i < this.train.numInstances(); ++i) {
            Instances knn = new Instances(this.lnn.kNearestNeighbours(this.train.instance(i), numOfNeighbors));
            for (int j = 0; j < this.numLabels; ++j) {
                double count_for_label_j = 0.0;
                for (int k = 0; k < numOfNeighbors; ++k) {
                    String value = this.train.attribute(this.labelIndices[j]).value((int)knn.instance(k).value(this.labelIndices[j]));
                    if (!value.equals("1")) continue;
                    count_for_label_j += 1.0;
                }
                this.baseLevelPredictions[i][j] = count_for_label_j / (double)numOfNeighbors;
            }
        }
    }

    private void normalizePredictions() {
        for (int i = 0; i < this.baseLevelPredictions.length; ++i) {
            for (int j = 0; j < this.numLabels; ++j) {
                this.baseLevelPredictions[i][j] = this.baseLevelPredictions[i][j] - this.minProb[j] / this.maxProb[j] - this.minProb[j];
            }
        }
    }

    @Override
    protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
        boolean[] bipartition = new boolean[this.numLabels];
        double[] metaconfidences = new double[this.numLabels];
        double[] confidences = new double[this.numLabels];
        if (!(this.baseClassifier instanceof IBk)) {
            for (int labelIndex = 0; labelIndex < this.numLabels; ++labelIndex) {
                Instance newInstance = BinaryRelevanceTransformation.transformInstance(instance, this.labelIndices, this.labelIndices[labelIndex]);
                newInstance.setDataset(this.baseLevelData[labelIndex]);
                double[] distribution = this.baseLevelEnsemble[labelIndex].distributionForInstance(newInstance);
                Attribute classAttribute = this.baseLevelData[labelIndex].classAttribute();
                confidences[labelIndex] = distribution[classAttribute.indexOfValue("1")];
            }
        } else {
            int numOfNeighbors = ((IBk)this.baseClassifier).getKNN();
            Instances knn = new Instances(this.lnn.kNearestNeighbours(instance, numOfNeighbors));
            for (int i = 0; i < this.numLabels; ++i) {
                double count_for_label_i = 0.0;
                for (int k = 0; k < numOfNeighbors; ++k) {
                    double value = Double.parseDouble(this.train.attribute(this.labelIndices[i]).value((int)knn.instance(k).value(this.labelIndices[i])));
                    if (!Utils.eq((double)value, (double)1.0)) continue;
                    count_for_label_i += 1.0;
                }
                confidences[i] = count_for_label_i / (double)numOfNeighbors;
            }
        }
        double[] values = new double[this.numLabels + 1];
        if (this.includeAttrs) {
            values = new double[instance.numAttributes() + 1];
            for (int m = 0; m < this.featureIndices.length; ++m) {
                values[m] = instance.value(this.featureIndices[m]);
            }
            System.arraycopy(confidences, 0, values, instance.numAttributes() - this.numLabels, confidences.length);
        } else {
            System.arraycopy(confidences, 0, values, 0, confidences.length);
        }
        for (int labelIndex = 0; labelIndex < this.numLabels; ++labelIndex) {
            double[] distribution;
            values[values.length - 1] = 0.0;
            Instance newmetaInstance = DataUtils.createInstance(instance, 1.0, values);
            try {
                distribution = this.metaLevelFilteredEnsemble[labelIndex].distributionForInstance(newmetaInstance);
            }
            catch (Exception e) {
                System.out.println(e);
                return null;
            }
            int maxIndex = distribution[0] > distribution[1] ? 0 : 1;
            Attribute classAttribute = this.metaLevelData[labelIndex].classAttribute();
            bipartition[labelIndex] = classAttribute.value(maxIndex).equals("1");
            metaconfidences[labelIndex] = distribution[classAttribute.indexOfValue("1")];
        }
        MultiLabelOutput mlo = new MultiLabelOutput(bipartition, metaconfidences);
        return mlo;
    }

    protected Instances attachIndexes(Instances original) {
        ArrayList<Attribute> attributes = new ArrayList<Attribute>(original.numAttributes() + 1);
        for (int i = 0; i < original.numAttributes(); ++i) {
            attributes.add(original.attribute(i));
        }
        attributes.add(0, new Attribute("Index"));
        Instances transformed = new Instances("Meta format", attributes, 0);
        for (int i = 0; i < original.numInstances(); ++i) {
            Instance newInstance = (Instance)original.instance(i).copy();
            newInstance.setDataset(null);
            newInstance.insertAttributeAt(0);
            newInstance.setValue(0, (double)i);
            transformed.add(newInstance);
        }
        transformed.setClassIndex(original.classIndex() + 1);
        return transformed;
    }

    protected Instances detachIndexes(Instances original) throws Exception {
        Remove remove = new Remove();
        remove.setAttributeIndices("first");
        remove.setInputFormat(original);
        Instances result = Filter.useFilter((Instances)original, (Filter)remove);
        return result;
    }

    public void saveObject(String filename) {
        try {
            ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(filename));
            out.writeObject(this);
        }
        catch (IOException ex) {
            Logger.getLogger(MultiLabelStacking.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    public void setNormalize(boolean normalize) {
        this.normalize = normalize;
    }

    public void setIncludeAttrs(boolean includeAttrs) {
        this.includeAttrs = includeAttrs;
    }

    public void setMetaPercentage(double metaPercentage) {
        this.metaPercentage = metaPercentage;
    }

    public void setEval(ASEvaluation eval) {
        this.eval = eval;
    }

    public void setMetaAlgorithm(Classifier metaClassifier) throws Exception {
        this.metaClassifier = metaClassifier;
        this.metaLevelEnsemble = AbstractClassifier.makeCopies((Classifier)metaClassifier, (int)this.numLabels);
    }

    public void setPartialBuild(boolean partialBuild) {
        this.partialBuild = partialBuild;
    }

    public int getTopkCorrelated() {
        return this.topkCorrelated;
    }
}

