/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.transformation;

import java.util.Arrays;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.transformation.TransformationBasedMultiLabelLearner;
import mulan.core.Util;
import mulan.data.LabelSet;
import mulan.data.MultiLabelInstances;
import mulan.transformations.LabelPowersetTransformation;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

public class LabelPowerset
extends TransformationBasedMultiLabelLearner {
    private int confidenceCalculationMethod = 1;
    protected boolean makePredictionsBasedOnConfidences = false;
    protected double threshold = 0.5;
    protected LabelPowersetTransformation transformation;
    protected Random Rand = new Random(1L);

    public LabelPowerset(Classifier classifier) {
        super(classifier);
    }

    public void setMakePredictionsBasedOnConfidences(boolean value) {
        this.makePredictionsBasedOnConfidences = value;
    }

    public void setSeed(int s) {
        this.Rand = new Random(s);
    }

    public void setThreshold(double t) {
        this.threshold = t;
    }

    public void setConfidenceCalculationMethod(int method) {
        if (method == 0 || method == 1 || method == 2) {
            this.confidenceCalculationMethod = method;
        }
    }

    @Override
    protected void buildInternal(MultiLabelInstances mlData) throws Exception {
        this.transformation = new LabelPowersetTransformation();
        this.debug("Transforming the training set.");
        Instances transformedData = this.transformation.transformInstances(mlData);
        this.debug("Building single-label classifier.");
        if (transformedData.attribute(transformedData.numAttributes() - 1).numValues() > 1) {
            this.baseClassifier.buildClassifier(transformedData);
        }
    }

    @Override
    protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
        boolean[] bipartition = null;
        double[] confidences = null;
        if (this.transformation.getTransformedFormat().classAttribute().numValues() == 1) {
            String strClass = this.transformation.getTransformedFormat().classAttribute().value(0);
            LabelSet labelSet = null;
            try {
                labelSet = LabelSet.fromBitString(strClass);
            }
            catch (Exception ex) {
                Logger.getLogger(LabelPowerset.class.getName()).log(Level.SEVERE, null, ex);
            }
            bipartition = labelSet.toBooleanArray();
            confidences = labelSet.toDoubleArray();
        } else {
            double[] distribution = null;
            try {
                Instance transformedInstance = this.transformation.transformInstance(instance, this.labelIndices);
                distribution = this.baseClassifier.distributionForInstance(transformedInstance);
            }
            catch (Exception ex) {
                Logger.getLogger(LabelPowerset.class.getName()).log(Level.SEVERE, null, ex);
            }
            int classIndex = Util.RandomIndexOfMax(distribution, this.Rand);
            String strClass = this.transformation.getTransformedFormat().classAttribute().value(classIndex);
            LabelSet labelSet = null;
            try {
                labelSet = LabelSet.fromBitString(strClass);
            }
            catch (Exception ex) {
                Logger.getLogger(LabelPowerset.class.getName()).log(Level.SEVERE, null, ex);
            }
            bipartition = labelSet.toBooleanArray();
            switch (this.confidenceCalculationMethod) {
                case 0: {
                    confidences = Arrays.copyOf(labelSet.toDoubleArray(), labelSet.size());
                    break;
                }
                case 1: {
                    int i;
                    confidences = new double[this.numLabels];
                    double prob = distribution[classIndex];
                    for (i = 0; i < this.numLabels; ++i) {
                        confidences[i] = bipartition[i] ? prob : 1.0 - prob;
                    }
                    break;
                }
                case 2: {
                    int i;
                    confidences = new double[this.numLabels];
                    for (i = 0; i < distribution.length; ++i) {
                        strClass = this.transformation.getTransformedFormat().classAttribute().value(i);
                        try {
                            labelSet = LabelSet.fromBitString(strClass);
                        }
                        catch (Exception ex) {
                            Logger.getLogger(LabelPowerset.class.getName()).log(Level.SEVERE, null, ex);
                        }
                        double[] predictionsTemp = labelSet.toDoubleArray();
                        double confidence = distribution[i];
                        for (int j = 0; j < this.numLabels; ++j) {
                            if (predictionsTemp[j] != 1.0) continue;
                            int n = j;
                            confidences[n] = confidences[n] + confidence;
                        }
                    }
                    break;
                }
            }
            if (this.makePredictionsBasedOnConfidences) {
                for (int i = 0; i < confidences.length; ++i) {
                    bipartition[i] = confidences[i] > this.threshold;
                }
            }
        }
        MultiLabelOutput mlo = new MultiLabelOutput(bipartition, confidences);
        return mlo;
    }
}

