/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.meta.thresholding;

import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.InvalidDataException;
import mulan.classifier.MultiLabelLearner;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.meta.MultiLabelMetaLearner;
import mulan.classifier.transformation.BinaryRelevance;
import mulan.core.MulanRuntimeException;
import mulan.data.InvalidDataFormatException;
import mulan.data.LabelsMetaData;
import mulan.data.MultiLabelInstances;
import mulan.evaluation.measure.BipartitionMeasureBase;
import weka.classifiers.Classifier;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.Utils;

public class RCut
extends MultiLabelMetaLearner {
    private int t = 0;
    private BipartitionMeasureBase measure;
    private int folds;
    private MultiLabelLearner foldLearner;

    public RCut() {
        this(new BinaryRelevance((Classifier)new J48()));
    }

    public RCut(MultiLabelLearner baseLearner) {
        super(baseLearner);
    }

    public RCut(MultiLabelLearner baseLearner, BipartitionMeasureBase aMeasure, int someFolds) {
        super(baseLearner);
        this.measure = aMeasure;
        this.folds = someFolds;
        try {
            this.foldLearner = baseLearner.makeCopy();
        }
        catch (Exception ex) {
            Logger.getLogger(RCut.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    public RCut(MultiLabelLearner baseLearner, BipartitionMeasureBase aMeasure) {
        super(baseLearner);
        this.measure = aMeasure;
    }

    private void autoTuneThreshold(MultiLabelInstances trainingData, BipartitionMeasureBase measure, int folds) throws InvalidDataFormatException, Exception {
        if (folds < 2) {
            throw new IllegalArgumentException("folds should be more than 1");
        }
        double[] totalDiff = new double[this.numLabels + 1];
        LabelsMetaData labelsMetaData = trainingData.getLabelsMetaData();
        MultiLabelLearner tempLearner = this.foldLearner.makeCopy();
        for (int f = 0; f < folds; ++f) {
            Instances train = trainingData.getDataSet().trainCV(folds, f);
            MultiLabelInstances trainMulti = new MultiLabelInstances(train, labelsMetaData);
            Instances test = trainingData.getDataSet().testCV(folds, f);
            MultiLabelInstances testMulti = new MultiLabelInstances(test, labelsMetaData);
            tempLearner.build(trainMulti);
            double[] diff = this.computeThreshold(tempLearner, testMulti, measure);
            for (int k = 0; k < diff.length; ++k) {
                int n = k;
                totalDiff[n] = totalDiff[n] + diff[k];
            }
        }
        this.t = Utils.minIndex((double[])totalDiff);
    }

    private double[] computeThreshold(MultiLabelLearner learner, MultiLabelInstances data, BipartitionMeasureBase measure) throws Exception {
        double[] diff = new double[this.numLabels + 1];
        measure.reset();
        for (int j = 0; j < data.getNumInstances(); ++j) {
            Instance instance = data.getDataSet().instance(j);
            if (data.hasMissingLabels(instance)) continue;
            MultiLabelOutput mlo = learner.makePrediction(instance);
            boolean[] trueLabels = new boolean[this.numLabels];
            for (int counter = 0; counter < this.numLabels; ++counter) {
                int classIdx = this.labelIndices[counter];
                String classValue = instance.attribute(classIdx).value((int)instance.value(classIdx));
                trueLabels[counter] = classValue.equals("1");
            }
            int[] ranking = mlo.getRanking();
            for (int threshold = 0; threshold <= this.numLabels; ++threshold) {
                boolean[] bipartition = new boolean[this.numLabels];
                for (int k = 0; k < this.numLabels; ++k) {
                    if (ranking[k] > threshold) continue;
                    bipartition[k] = true;
                }
            }
        }
        return diff;
    }

    @Override
    protected void buildInternal(MultiLabelInstances trainingData) throws Exception {
        this.baseLearner.build(trainingData);
        MultiLabelOutput mlo = this.baseLearner.makePrediction(trainingData.getDataSet().firstInstance());
        if (!mlo.hasRanking()) {
            throw new MulanRuntimeException("Learner is not a ranker");
        }
        if (this.measure == null) {
            this.t = (int)Math.round(trainingData.getCardinality());
            this.t = 2;
        } else if (this.folds == 0) {
            double[] diff = this.computeThreshold(this.baseLearner, trainingData, this.measure);
            this.t = Utils.minIndex((double[])diff);
        } else {
            this.autoTuneThreshold(trainingData, this.measure, this.folds);
        }
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Yiming Yang");
        result.setValue(TechnicalInformation.Field.TITLE, "A study of thresholding strategies for text categorization");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Proceedings of the 24th annual international ACM SIGIR conference on Research and development in information retrieval");
        result.setValue(TechnicalInformation.Field.PAGES, "137 - 145");
        result.setValue(TechnicalInformation.Field.LOCATION, "New Orleans, Louisiana, United States");
        result.setValue(TechnicalInformation.Field.YEAR, "2001");
        return result;
    }

    @Override
    protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception, InvalidDataException {
        MultiLabelOutput mlo = this.baseLearner.makePrediction(instance);
        int[] ranking = mlo.getRanking();
        boolean[] predictedLabels = new boolean[this.numLabels];
        for (int i = 0; i < this.numLabels; ++i) {
            predictedLabels[i] = ranking[i] <= this.t;
        }
        MultiLabelOutput newOutput = new MultiLabelOutput(predictedLabels);
        return newOutput;
    }

    @Override
    public void setDebug(boolean debug) {
        super.setDebug(debug);
        this.baseLearner.setDebug(debug);
    }

    @Override
    public String globalInfo() {
        return "Classs that implements RCut(Rank-based cut). It selects the k top ranked labels for each instance, where k is a parameter provided by the user or automatically tuned." + this.getTechnicalInformation().toString();
    }
}

