/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.evaluation;

import java.util.List;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.evaluation.ClusterEvaluation;

public class AdjustedRandIndex
implements ClusterEvaluation {
    @Override
    public double evaluate(int[] designations, DataSet dataSet) {
        if (!(dataSet instanceof ClassificationDataSet)) {
            throw new RuntimeException("NMI can only be calcuate for classification data sets");
        }
        ClassificationDataSet cds = (ClassificationDataSet)dataSet;
        int clusters = 0;
        for (int clusterID : designations) {
            clusters = Math.max(clusterID + 1, clusters);
        }
        double[] truthSums = new double[cds.getClassSize()];
        double[] clusterSums = new double[clusters];
        double[][] table = new double[clusterSums.length][truthSums.length];
        double n = 0.0;
        for (int i = 0; i < designations.length; ++i) {
            int cluster = designations[i];
            if (cluster < 0) continue;
            int label = cds.getDataPointCategory(i);
            double weight = cds.getDataPoint(i).getWeight();
            double[] dArray = table[cluster];
            int n2 = label;
            dArray[n2] = dArray[n2] + weight;
            int n3 = label;
            truthSums[n3] = truthSums[n3] + weight;
            int n4 = cluster;
            clusterSums[n4] = clusterSums[n4] + weight;
            n += weight;
        }
        double sumAllTable = 0.0;
        double addCTerm = 0.0;
        double addLTerm = 0.0;
        for (int i = 0; i < table.length; ++i) {
            double a_i = clusterSums[i];
            addCTerm += a_i * (a_i - 1.0) / 2.0;
            for (int j = 0; j < table[i].length; ++j) {
                if (i == 0) {
                    double b_j = truthSums[j];
                    addLTerm += b_j * (b_j - 1.0) / 2.0;
                }
                double n_ij = table[i][j];
                double n_ij_c2 = n_ij * (n_ij - 1.0) / 2.0;
                sumAllTable += n_ij_c2;
            }
        }
        double longMultTerm = Math.exp(Math.log(addCTerm) + Math.log(addLTerm) - (Math.log(n) + Math.log(n - 1.0) - Math.log(2.0)));
        return 1.0 - (sumAllTable - longMultTerm) / (addCTerm / 2.0 + addLTerm / 2.0 - longMultTerm);
    }

    @Override
    public double evaluate(List<List<DataPoint>> dataSets) {
        throw new UnsupportedOperationException("Adjusted Rand Index requires the true data set labels, call evaluate(int[] designations, DataSet dataSet) instead");
    }

    @Override
    public ClusterEvaluation clone() {
        return new AdjustedRandIndex();
    }
}

