/*
 * Decompiled with CFR 0.152.
 */
package mulan.data;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import mulan.data.LabelPairsDependenceIdentifier;
import mulan.data.LabelsPair;
import mulan.data.MultiLabelInstances;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.meta.FilteredClassifier;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

public class ConditionalDependenceIdentifier
implements LabelPairsDependenceIdentifier,
Serializable {
    private double criticalValue = 3.25;
    private Classifier baseLearner;
    private int numFolds = 10;
    protected int seed;
    private static HashMap<String, FilteredClassifier> existingModels = null;

    public ConditionalDependenceIdentifier(Classifier classifier) {
        this.baseLearner = classifier;
        if (existingModels == null) {
            existingModels = new HashMap();
        }
    }

    @Override
    public LabelsPair[] calculateDependence(MultiLabelInstances mlInstances) {
        int numLabels = mlInstances.getNumLabels();
        int numPairs = numLabels * (numLabels - 1) / 2;
        LabelsPair[] pairs = new LabelsPair[numPairs];
        int ind = 0;
        for (int i = 0; i < numLabels - 1; ++i) {
            for (int j = i + 1; j < numLabels; ++j) {
                double val2;
                int[] comb1 = new int[2];
                int[] comb2 = new int[2];
                comb1[0] = i;
                comb1[1] = j;
                comb2[0] = j;
                comb2[1] = i;
                double val1 = this.testDependence(comb1, mlInstances, this.numFolds);
                pairs[ind++] = val1 >= (val2 = this.testDependence(comb2, mlInstances, this.numFolds)) ? new LabelsPair(comb1, val1) : new LabelsPair(comb2, val2);
            }
        }
        Arrays.sort(pairs, Collections.reverseOrder());
        return pairs;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private double testDependence(int[] comb, MultiLabelInstances mlData, int numFolds) {
        double val;
        double[] acc1 = null;
        double[] acc2 = null;
        try {
            int numLabels = mlData.getNumLabels();
            int[] labelIndices = mlData.getLabelIndices();
            Instances[] trainSets = new Instances[numFolds];
            Instances[] testSets = new Instances[numFolds];
            Evaluation[] eval = new Evaluation[numFolds];
            Evaluation[] eval2 = new Evaluation[numFolds];
            acc1 = new double[numFolds];
            acc2 = new double[numFolds];
            Instances workingSet = new Instances(mlData.getDataSet());
            Random random = new Random(this.seed);
            workingSet.randomize(random);
            for (int i = 0; i < numFolds; ++i) {
                trainSets[i] = workingSet.trainCV(numFolds, i, random);
                testSets[i] = workingSet.testCV(numFolds, i);
                int classIndex = labelIndices[comb[0]];
                int[] indecesToRemove1 = new int[numLabels - 1];
                int counter2 = 0;
                for (int counter1 = 0; counter1 < numLabels; ++counter1) {
                    if (counter1 == comb[0]) continue;
                    indecesToRemove1[counter2] = labelIndices[counter1];
                    ++counter2;
                }
                int foldHash = trainSets[i].toString().hashCode();
                String modelKey = this.createKey(indecesToRemove1, foldHash);
                FilteredClassifier indepModel = existingModels.containsKey(modelKey) ? existingModels.get(modelKey) : this.buildModel(indecesToRemove1, classIndex, trainSets[i]);
                int[] indecesToRemove2 = new int[numLabels - 2];
                counter2 = 0;
                for (int counter1 = 0; counter1 < numLabels; ++counter1) {
                    if (counter1 == comb[0] || counter1 == comb[1]) continue;
                    indecesToRemove2[counter2] = labelIndices[counter1];
                    ++counter2;
                }
                FilteredClassifier depModel = this.buildModel(indecesToRemove2, classIndex, trainSets[i]);
                Instances filteredTrainData = this.prepareDatSet(indecesToRemove1, classIndex, trainSets[i]);
                Instances filteredTestData = this.prepareDatSet(indecesToRemove1, classIndex, testSets[i]);
                eval[i] = new Evaluation(filteredTrainData);
                eval[i].evaluateModel((Classifier)indepModel, filteredTestData, new Object[0]);
                acc1[i] = eval[i].pctCorrect();
                Instances filteredTrainData2 = this.prepareDatSet(indecesToRemove2, classIndex, trainSets[i]);
                Instances filteredTestData2 = this.prepareDatSet(indecesToRemove2, classIndex, testSets[i]);
                eval2[i] = new Evaluation(filteredTrainData2);
                eval2[i].evaluateModel((Classifier)depModel, filteredTestData2, new Object[0]);
                acc2[i] = eval2[i].pctCorrect();
            }
            val = acc1 == null || acc2 == null ? -1.0 : this.applyTtest(acc1, acc2);
        }
        catch (Exception e) {
            Logger.getLogger(ConditionalDependenceIdentifier.class.getSimpleName()).log(Level.SEVERE, null, e);
        }
        finally {
            val = acc1 == null || acc2 == null ? -1.0 : this.applyTtest(acc1, acc2);
        }
        return val;
    }

    private double applyTtest(double[] val1, double[] val2) {
        double tValue;
        double sum1 = 0.0;
        double sum2 = 0.0;
        int count = val1.length;
        for (int i = 0; i < count; ++i) {
            sum1 += val1[i];
            sum2 += val2[i];
        }
        double avg1 = sum1 / (double)count;
        double avg2 = sum2 / (double)count;
        if (avg1 > avg2) {
            return -1.0;
        }
        double varDiff = 0.0;
        for (int i = 0; i < count; ++i) {
            double var1 = val1[i] - avg1;
            double var2 = val2[i] - avg2;
            varDiff += Math.pow(var1 - var2, 2.0);
        }
        double m = 0.0;
        if (varDiff != 0.0) {
            m = Math.sqrt((double)(count * (count - 1)) / varDiff);
        }
        if ((tValue = (avg1 - avg2) * m) < 0.0) {
            tValue *= -1.0;
        }
        return tValue;
    }

    private FilteredClassifier buildModel(int[] indicesToRemove, int classIndex, Instances trainDataset) throws Exception {
        FilteredClassifier model = new FilteredClassifier();
        model.setClassifier(AbstractClassifier.makeCopy((Classifier)this.baseLearner));
        Remove remove = new Remove();
        remove.setAttributeIndicesArray(indicesToRemove);
        remove.setInputFormat(trainDataset);
        remove.setInvertSelection(false);
        model.setFilter((Filter)remove);
        trainDataset.setClassIndex(classIndex);
        model.buildClassifier(trainDataset);
        int foldHash = trainDataset.toString().hashCode();
        String modelKey = this.createKey(indicesToRemove, foldHash);
        existingModels.put(modelKey, model);
        return model;
    }

    private String createKey(int[] set, int fold) {
        StringBuilder sb = new StringBuilder("_");
        for (int i : set) {
            sb.append(i);
            sb.append("_");
        }
        sb.append(fold);
        return sb.toString();
    }

    private Instances prepareDatSet(int[] indicesToRemove, int classIndex, Instances dataset) throws Exception {
        Remove remove = new Remove();
        remove.setAttributeIndicesArray(indicesToRemove);
        remove.setInputFormat(dataset);
        remove.setInvertSelection(false);
        dataset.setClassIndex(classIndex);
        return dataset;
    }

    public void setCriticalValue(double criticalValue) {
        this.criticalValue = criticalValue;
    }

    @Override
    public double getCriticalValue() {
        return this.criticalValue;
    }

    public int getSeed() {
        return this.seed;
    }

    public void setSeed(int seed) {
        this.seed = seed;
    }

    public int getNumFolds() {
        return this.numFolds;
    }

    public void setNumFolds(int numFolds) {
        this.numFolds = numFolds;
    }
}

