/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.meta;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.meta.MultiLabelMetaLearner;
import mulan.classifier.transformation.BinaryRelevance;
import mulan.data.MultiLabelInstances;
import weka.classifiers.Classifier;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

public class RAkEL
extends MultiLabelMetaLearner {
    private int seed = 0;
    private Random rnd;
    double[][] sumVotesIncremental;
    double[][] lengthVotesIncremental;
    double[] sumVotes;
    double[] lengthVotes;
    int numOfModels;
    double threshold = 0.5;
    int sizeOfSubset = 3;
    int[][] classIndicesPerSubset;
    int[][] absoluteIndicesToRemove;
    MultiLabelLearner[] subsetClassifiers;
    private Remove[] remove;
    HashSet<String> combinations;

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Grigorios Tsoumakas and Ioannis Katakis and Ioannis Vlahavas");
        result.setValue(TechnicalInformation.Field.TITLE, "Random k-Labelsets for Multi-Label Classification");
        result.setValue(TechnicalInformation.Field.JOURNAL, "IEEE Transactions on Knowledge and Data Engineering");
        result.setValue(TechnicalInformation.Field.PAGES, "1079-1089");
        result.setValue(TechnicalInformation.Field.VOLUME, "23");
        result.setValue(TechnicalInformation.Field.NUMBER, "7");
        result.setValue(TechnicalInformation.Field.YEAR, "2011");
        return result;
    }

    public RAkEL() {
        this(new BinaryRelevance((Classifier)new J48()));
    }

    public RAkEL(MultiLabelLearner baseLearner) {
        super(baseLearner);
    }

    public RAkEL(MultiLabelLearner baseLearner, int models, int subset) {
        super(baseLearner);
        this.sizeOfSubset = subset;
        this.numOfModels = models;
    }

    public RAkEL(MultiLabelLearner baseLearner, int models, int subset, double threshold) {
        super(baseLearner);
        this.sizeOfSubset = subset;
        this.numOfModels = models;
        this.threshold = threshold;
    }

    public void setSeed(int x) {
        this.seed = x;
    }

    public void setSizeOfSubset(int size) {
        this.sizeOfSubset = size;
        this.classIndicesPerSubset = new int[this.numOfModels][this.sizeOfSubset];
    }

    public int getSizeOfSubset() {
        return this.sizeOfSubset;
    }

    public void setNumModels(int models) {
        this.numOfModels = models;
    }

    public int getNumModels() {
        return this.numOfModels;
    }

    public static int binomial(int n, int m) {
        int[] b = new int[n + 1];
        b[0] = 1;
        for (int i = 1; i <= n; ++i) {
            b[i] = 1;
            for (int j = i - 1; j > 0; --j) {
                int n2 = j;
                b[n2] = b[n2] + b[j - 1];
            }
        }
        return b[m];
    }

    @Override
    protected void buildInternal(MultiLabelInstances trainingData) throws Exception {
        this.rnd = new Random(this.seed);
        this.combinations = new HashSet();
        if (this.sizeOfSubset >= this.numLabels) {
            throw new IllegalArgumentException("Size of subsets should be less than the number of labels");
        }
        if (this.numOfModels == 0) {
            this.numOfModels = Math.min(2 * this.numLabels, RAkEL.binomial(this.numLabels, this.sizeOfSubset));
        }
        this.classIndicesPerSubset = new int[this.numOfModels][this.sizeOfSubset];
        this.absoluteIndicesToRemove = new int[this.numOfModels][this.sizeOfSubset];
        this.subsetClassifiers = new MultiLabelLearner[this.numOfModels];
        this.remove = new Remove[this.numOfModels];
        for (int i = 0; i < this.numOfModels; ++i) {
            this.updateClassifier(trainingData, i);
        }
    }

    private void updateClassifier(MultiLabelInstances mlTrainData, int model) throws Exception {
        boolean[] selected;
        if (this.combinations == null) {
            this.combinations = new HashSet();
        }
        Instances trainData = mlTrainData.getDataSet();
        do {
            selected = new boolean[this.numLabels];
            for (int j = 0; j < this.sizeOfSubset; ++j) {
                int randomLabel = this.rnd.nextInt(this.numLabels);
                while (selected[randomLabel]) {
                    randomLabel = this.rnd.nextInt(this.numLabels);
                }
                selected[randomLabel] = true;
                this.classIndicesPerSubset[model][j] = randomLabel;
            }
            Arrays.sort(this.classIndicesPerSubset[model]);
        } while (!this.combinations.add(Arrays.toString(this.classIndicesPerSubset[model])));
        this.debug("Building model " + (model + 1) + "/" + this.numOfModels + ", subset: " + Arrays.toString(this.classIndicesPerSubset[model]));
        this.absoluteIndicesToRemove[model] = new int[this.numLabels - this.sizeOfSubset];
        int k = 0;
        for (int j = 0; j < this.numLabels; ++j) {
            if (selected[j]) continue;
            this.absoluteIndicesToRemove[model][k] = this.labelIndices[j];
            ++k;
        }
        this.remove[model] = new Remove();
        this.remove[model].setAttributeIndicesArray(this.absoluteIndicesToRemove[model]);
        this.remove[model].setInputFormat(trainData);
        this.remove[model].setInvertSelection(false);
        Instances trainSubset = Filter.useFilter((Instances)trainData, (Filter)this.remove[model]);
        this.subsetClassifiers[model] = this.getBaseLearner().makeCopy();
        this.subsetClassifiers[model].build(mlTrainData.reintegrateModifiedDataSet(trainSubset));
    }

    @Override
    protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
        double[] sumConf = new double[this.numLabels];
        this.sumVotes = new double[this.numLabels];
        this.lengthVotes = new double[this.numLabels];
        for (int i = 0; i < this.numOfModels; ++i) {
            this.remove[i].input(instance);
            this.remove[i].batchFinished();
            Instance newInstance = this.remove[i].output();
            MultiLabelOutput subsetMLO = this.subsetClassifiers[i].makePrediction(newInstance);
            for (int j = 0; j < this.sizeOfSubset; ++j) {
                int n = this.classIndicesPerSubset[i][j];
                sumConf[n] = sumConf[n] + subsetMLO.getConfidences()[j];
                int n2 = this.classIndicesPerSubset[i][j];
                this.sumVotes[n2] = this.sumVotes[n2] + (subsetMLO.getBipartition()[j] ? 1.0 : 0.0);
                int n3 = this.classIndicesPerSubset[i][j];
                this.lengthVotes[n3] = this.lengthVotes[n3] + 1.0;
            }
        }
        double[] confidence1 = new double[this.numLabels];
        double[] confidence2 = new double[this.numLabels];
        boolean[] bipartition = new boolean[this.numLabels];
        for (int i = 0; i < this.numLabels; ++i) {
            if (this.lengthVotes[i] != 0.0) {
                confidence1[i] = this.sumVotes[i] / this.lengthVotes[i];
                confidence2[i] = sumConf[i] / this.lengthVotes[i];
            } else {
                confidence1[i] = 0.0;
                confidence2[i] = 0.0;
            }
            bipartition[i] = confidence1[i] >= this.threshold;
        }
        MultiLabelOutput mlo = new MultiLabelOutput(bipartition, confidence1);
        return mlo;
    }

    @Override
    public String globalInfo() {
        return "Class implementing a generalized version of the RAkEL (RAndom k-labELsets) algorithm. For more information, see\n\n" + this.getTechnicalInformation().toString();
    }
}

