/*
 * Decompiled with CFR 0.152.
 */
package meka.classifiers.multitarget;

import java.util.Arrays;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import meka.classifiers.multilabel.Evaluation;
import meka.classifiers.multilabel.MultilabelClassifier;
import meka.classifiers.multitarget.CR;
import meka.classifiers.multitarget.MultiTargetClassifier;
import meka.core.A;
import meka.core.M;
import meka.core.MLEvalUtils;
import meka.core.MLUtils;
import meka.core.Result;
import meka.core.StatUtils;
import meka.core.SuperLabelUtils;
import meka.filters.multilabel.SuperNodeFilter;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Randomizable;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

public class SCC
extends MultilabelClassifier
implements Randomizable,
MultiTargetClassifier,
TechnicalInformationHandler {
    private SuperNodeFilter f = new SuperNodeFilter();
    private int m_P = 1;
    private int m_N = 0;
    private int m_L = 2;
    private int m_I = 1000;
    private int m_O = 0;
    private static final int i_SPLIT = 67;
    private static final String i_ErrFn = "Exact match";
    private Random rand = null;
    protected int m_S = 0;

    @Override
    public String globalInfo() {
        return "Super Class Classifier (SCC).\nLike a multi-target-capable PS. Removes examples with P-infrequent labelsets from the training data, then makes super classes out of what's left; and then trains a standard ML classifier on them.\nFor more information see:\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Jesse Read, Concha Blieza, Pedro Larranaga");
        result.setValue(TechnicalInformation.Field.TITLE, "Multi-Dimensional Classification with Super-Classes");
        result.setValue(TechnicalInformation.Field.JOURNAL, "IEEE Transactions on Knowledge and Data Engineering");
        result.setValue(TechnicalInformation.Field.YEAR, "2013");
        return result;
    }

    private double rating(int[][] partition, double[][] M2) {
        return this.rating(partition, M2, 0.0);
    }

    private double rating(int[][] partition, double[][] M2, double CRITICAL) {
        int L = M2.length;
        double[][] S = new double[L][L];
        boolean[][] T = new boolean[L][L];
        double sumTogether = 0.0;
        double sumApart = 0.0;
        for (int i = 0; i < partition.length; ++i) {
            Arrays.sort(partition[i]);
            double n = partition[i].length;
            int j = 0;
            while ((double)j < n) {
                int k = j + 1;
                while ((double)k < n) {
                    T[partition[i][j]][partition[i][k]] = true;
                    ++k;
                }
                ++j;
            }
        }
        for (int j = 0; j < L; ++j) {
            for (int k = j + 1; k < L; ++k) {
                if (T[j][k]) {
                    sumTogether += M2[j][k] - CRITICAL;
                    continue;
                }
                sumApart += M2[j][k] - CRITICAL;
            }
        }
        return sumTogether - sumApart;
    }

    private int[][] mutateCombinations(int[][] partition, Random r) {
        int from = r.nextInt(partition.length);
        int i = r.nextInt(partition[from].length);
        int to = r.nextInt(partition.length);
        if (to == from) {
            partition = (int[][])Arrays.copyOf(partition, partition.length + 1);
            partition[partition.length - 1] = new int[]{partition[from][i]};
            to = partition.length + 1;
            partition[from] = A.delete(partition[from], i);
        } else {
            partition[to] = A.append(partition[to], partition[from][i]);
            partition[from] = A.delete(partition[from], i);
        }
        if (partition[from].length <= 0) {
            partition[from] = partition[partition.length - 1];
            partition = (int[][])Arrays.copyOf(partition, partition.length - 1);
        }
        return partition;
    }

    public void trainClassifier(Classifier h, Instances D, int[][] partition) throws Exception {
        this.f = new SuperNodeFilter();
        this.f.setIndices(partition);
        this.f.setP(this.m_P >= 0 ? this.m_P : this.rand.nextInt(Math.abs(this.m_P)));
        this.f.setN(this.m_L >= 0 ? this.m_L : this.rand.nextInt(Math.abs(this.m_L)));
        Instances D_ = this.f.process(D);
        if (this.getDebug()) {
            int N = D.numInstances();
            int U = MLUtils.numberOfUniqueCombinations(D);
            System.out.println("PS(" + this.f.getP() + "," + this.m_L + ") reduced: " + N + " -> " + D_.numInstances() + " / " + U + " -> " + MLUtils.numberOfUniqueCombinations(D_));
        }
        this.m_InstancesTemplate = D_;
        this.m_Classifier.buildClassifier(D_);
    }

    public Result testClassifier(Classifier h, Instances D_train, Instances D_test, int[][] partition) throws Exception {
        this.trainClassifier(this.m_Classifier, D_train, partition);
        Result result = Evaluation.testClassifier((MultilabelClassifier)h, D_test);
        if (h instanceof MultiTargetClassifier || Evaluation.isMT(D_test)) {
            result.setInfo("Type", "MT");
        } else if (h instanceof MultilabelClassifier) {
            result.setInfo("Threshold", MLEvalUtils.getThreshold(result.predictions, D_train, "PCut1"));
            result.setInfo("Type", "ML");
        }
        result.setValue("N_train", D_train.numInstances());
        result.setValue("N_test", D_test.numInstances());
        result.setValue("LCard_train", MLUtils.labelCardinality(D_train));
        result.setValue("LCard_test", MLUtils.labelCardinality(D_test));
        result.setInfo("Classifier_name", h.getClass().getName());
        result.setInfo("Classifier_info", h.toString());
        result.setInfo("Dataset_name", MLUtils.getDatasetName(D_test));
        result.output = Result.getStats(result, "1");
        return result;
    }

    @Override
    public void buildClassifier(Instances D) throws Exception {
        int[][] partition_;
        int i;
        int N = D.numInstances();
        int U = MLUtils.numberOfUniqueCombinations(D);
        int L = D.classIndex();
        this.rand = new Random(this.m_S);
        Instances D_r = new Instances(D);
        D_r.randomize(this.rand);
        Instances D_train = new Instances(D_r, 0, D_r.numInstances() * 67 / 100);
        Instances D_test = new Instances(D_r, D_train.numInstances(), D_r.numInstances() - D_train.numInstances());
        if (this.getDebug()) {
            System.out.print("1. BUILD & Evaluate BR: ");
        }
        CR cr = new CR();
        cr.setClassifier(((MultilabelClassifier)this.m_Classifier).getClassifier());
        Result result_1 = Evaluation.evaluateModel(cr, D_train, D_test, "PCut1", "5");
        double acc1 = result_1.output.get(i_ErrFn);
        if (this.getDebug()) {
            System.out.println(" " + acc1);
        }
        int[][] partition = SuperLabelUtils.generatePartition(MLUtils.gen_indices(L), this.rand);
        if (this.getDebug()) {
            System.out.println("2. GET ERR-CHI-SQUARED MATRIX: ");
        }
        double[][] MER = StatUtils.condDepMatrix(D_test, result_1);
        if (this.getDebug()) {
            System.out.println(M.toString(MER));
        }
        if (this.getDebug()) {
            System.out.println("3. COMBINE NODES TO FIND THE BEST COMBINATION ACCORDING TO CHI");
        }
        double w = this.rating(partition, MER);
        if (this.getDebug()) {
            System.out.println("@0 : " + SuperLabelUtils.toString(partition) + "\t(" + w + ")");
        }
        for (i = 0; i < this.m_I; ++i) {
            partition_ = this.mutateCombinations(M.deep_copy(partition), this.rand);
            double w_ = this.rating(partition_, MER);
            if (w_ > w) {
                partition = partition_;
                w = w_;
                if (!this.getDebug()) continue;
                System.out.println("@" + i + " : " + SuperLabelUtils.toString(partition) + "\t(" + w + ")");
                continue;
            }
            double diff = Math.abs(w_ - w);
            double p = 2.0 * (1.0 - SCC.sigma(diff * (double)i / 1000.0));
            if (!(p > this.rand.nextDouble())) continue;
            if (this.getDebug()) {
                System.out.println("@" + i + " : " + SuperLabelUtils.toString(partition_) + "\t(" + w_ + ")*");
            }
            partition = partition_;
            w = w_;
        }
        if (this.m_N > 0) {
            if (this.getDebug()) {
                System.out.println("4. REFINING THE INITIAL SET WITH SOME OLD-FASHIONED INTERNAL EVAL");
            }
            result_1 = this.testClassifier((Classifier)((MultilabelClassifier)this.m_Classifier), D_train, D_test, partition);
            w = result_1.output.get(i_ErrFn);
            if (this.getDebug()) {
                System.out.println("@0 : " + SuperLabelUtils.toString(partition) + "\t(" + w + ")");
            }
            for (i = 0; i < this.m_N; ++i) {
                partition_ = this.mutateCombinations(M.deep_copy(partition), this.rand);
                this.trainClassifier(this.m_Classifier, D_train, partition);
                Result result_2 = this.testClassifier((Classifier)((MultilabelClassifier)this.m_Classifier), D_train, D_test, partition_);
                double w_ = result_2.output.get(i_ErrFn);
                if (!(w_ > w)) continue;
                w = w_;
                partition = partition_;
                if (!this.getDebug()) continue;
                System.out.println("@" + (i + 1) + "' : " + SuperLabelUtils.toString(partition) + "\t(" + w + ")");
            }
        }
        if (this.getDebug()) {
            System.out.println("4. TRAIN " + SuperLabelUtils.toString(partition));
        }
        this.trainClassifier(this.m_Classifier, D, partition);
        if (this.getDebug()) {
            // empty if block
        }
    }

    @Override
    public double[] distributionForInstance(Instance x) throws Exception {
        int L = x.classIndex();
        double[] y = new double[L * 2];
        int L_ = this.m_InstancesTemplate.classIndex();
        Instance x_ = MLUtils.setTemplate(x, this.f.getTemplate(), this.m_InstancesTemplate);
        double[] y_ = null;
        try {
            y_ = ((MultilabelClassifier)this.m_Classifier).distributionForInstance(x_);
        }
        catch (Exception e) {
            System.err.println("EXCEPTION !!! setting to " + Arrays.toString(y_));
            return y;
        }
        for (int j = 0; j < L_; ++j) {
            int[] idxs = SuperNodeFilter.decodeClasses(this.m_InstancesTemplate.attribute(j).name());
            String[] vals = SuperNodeFilter.decodeValue(this.m_InstancesTemplate.attribute(j).value((int)Math.round(y_[j])));
            for (int i = 0; i < idxs.length; ++i) {
                y[idxs[i]] = x.dataset().attribute(idxs[i]).indexOfValue(vals[i]);
                y[idxs[i] + L] = y_[j + L_];
            }
        }
        return y;
    }

    public void setSeed(int s) {
        this.m_S = s;
    }

    public int getSeed() {
        return this.m_S;
    }

    public static void main(String[] args) {
        MultilabelClassifier.evaluation(new SCC(), args);
    }

    public static final double sigma(double a) {
        return 1.0 / (1.0 + Math.exp(-a));
    }

    public Enumeration listOptions() {
        Vector<Object> newVector = new Vector<Object>();
        newVector.addElement(new Option("\tSets the number of simulated annealing iterations\n\tdefault: " + this.m_I, "I", 1, "-I <value>"));
        newVector.addElement(new Option("\tSets the number of connections\n\tdefault: " + this.m_N, "N", 1, "-N <value>"));
        newVector.addElement(new Option("\tSets the pruning number for PS\n\tdefault: " + this.m_P, "P", 1, "-P <value>"));
        newVector.addElement(new Option("\tSets the limit for PS (was N) \n\tdefault: " + this.m_L, "L", 1, "-L <value>"));
        newVector.addElement(new Option("\tAnother random open option.\n\tdefault: " + this.m_O, "O", 1, "-O <value>"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement(enu.nextElement());
        }
        return newVector.elements();
    }

    public void setOptions(String[] options) throws Exception {
        this.m_I = Utils.getOptionPos((char)'I', (String[])options) >= 0 ? Integer.parseInt(Utils.getOption((char)'I', (String[])options)) : this.m_I;
        this.m_L = Utils.getOptionPos((char)'L', (String[])options) >= 0 ? Integer.parseInt(Utils.getOption((char)'L', (String[])options)) : this.m_L;
        this.m_N = Utils.getOptionPos((char)'N', (String[])options) >= 0 ? Integer.parseInt(Utils.getOption((char)'N', (String[])options)) : this.m_N;
        this.m_P = Utils.getOptionPos((char)'P', (String[])options) >= 0 ? Integer.parseInt(Utils.getOption((char)'P', (String[])options)) : this.m_P;
        this.m_O = Utils.getOptionPos((char)'O', (String[])options) >= 0 ? Integer.parseInt(Utils.getOption((char)'O', (String[])options)) : this.m_O;
        super.setOptions(options);
    }

    public String[] getOptions() {
        String[] superOptions = super.getOptions();
        String[] options = new String[superOptions.length + 10];
        int current = 0;
        options[current++] = "-I";
        options[current++] = String.valueOf(this.m_I);
        options[current++] = "-N";
        options[current++] = String.valueOf(this.m_N);
        options[current++] = "-P";
        options[current++] = String.valueOf(this.m_P);
        options[current++] = "-L";
        options[current++] = String.valueOf(this.m_L);
        options[current++] = "-O";
        options[current++] = String.valueOf(this.m_O);
        System.arraycopy(superOptions, 0, options, current, superOptions.length);
        return options;
    }
}

