/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.mi;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.SingleClassifierEnhancer;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Optimization;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Discretize;
import weka.filters.unsupervised.attribute.MultiInstanceToPropositional;

public class MIBoost
extends SingleClassifierEnhancer
implements OptionHandler,
MultiInstanceCapabilitiesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = -3808427225599279539L;
    protected Classifier[] m_Models;
    protected int m_NumClasses;
    protected int[] m_Classes;
    protected Instances m_Attributes;
    private int m_NumIterations = 100;
    protected double[] m_Beta;
    protected int m_MaxIterations = 10;
    protected int m_DiscretizeBin = 0;
    protected Discretize m_Filter = null;
    protected MultiInstanceToPropositional m_ConvertToSI = new MultiInstanceToPropositional();

    public String globalInfo() {
        return "MI AdaBoost method, considers the geometric mean of posterior of instances inside a bag (arithmatic mean of log-posterior) and the expectation for a bag is taken inside the loss function.\n\nFor more information about Adaboost, see:\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Yoav Freund and Robert E. Schapire");
        result.setValue(TechnicalInformation.Field.TITLE, "Experiments with a new boosting algorithm");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Thirteenth International Conference on Machine Learning");
        result.setValue(TechnicalInformation.Field.YEAR, "1996");
        result.setValue(TechnicalInformation.Field.PAGES, "148-156");
        result.setValue(TechnicalInformation.Field.PUBLISHER, "Morgan Kaufmann");
        result.setValue(TechnicalInformation.Field.ADDRESS, "San Francisco");
        return result;
    }

    public Enumeration<Option> listOptions() {
        Vector<Object> result = new Vector<Object>();
        result.addElement(new Option("\tThe number of bins in discretization\n\t(default 0, no discretization)", "B", 1, "-B <num>"));
        result.addElement(new Option("\tMaximum number of boost iterations.\n\t(default 10)", "R", 1, "-R <num>"));
        result.addAll(Collections.list(super.listOptions()));
        return result.elements();
    }

    public void setOptions(String[] options) throws Exception {
        this.setDebug(Utils.getFlag((char)'D', (String[])options));
        String bin = Utils.getOption((char)'B', (String[])options);
        if (bin.length() != 0) {
            this.setDiscretizeBin(Integer.parseInt(bin));
        } else {
            this.setDiscretizeBin(0);
        }
        String boostIterations = Utils.getOption((char)'R', (String[])options);
        if (boostIterations.length() != 0) {
            this.setMaxIterations(Integer.parseInt(boostIterations));
        } else {
            this.setMaxIterations(10);
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions((String[])options);
    }

    public String[] getOptions() {
        Vector<String> result = new Vector<String>(4);
        result.add("-R");
        result.add("" + this.getMaxIterations());
        result.add("-B");
        result.add("" + this.getDiscretizeBin());
        Collections.addAll(result, super.getOptions());
        return result.toArray(new String[result.size()]);
    }

    public String maxIterationsTipText() {
        return "The maximum number of boost iterations.";
    }

    public void setMaxIterations(int maxIterations) {
        this.m_MaxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.m_MaxIterations;
    }

    public String discretizeBinTipText() {
        return "The number of bins in discretization.";
    }

    public void setDiscretizeBin(int bin) {
        this.m_DiscretizeBin = bin;
    }

    public int getDiscretizeBin() {
        return this.m_DiscretizeBin;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        result.disable(Capabilities.Capability.MISSING_VALUES);
        result.disableAllClasses();
        result.disableAllClassDependencies();
        if (super.getCapabilities().handles(Capabilities.Capability.BINARY_CLASS)) {
            result.enable(Capabilities.Capability.BINARY_CLASS);
        }
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return result;
    }

    public Capabilities getMultiInstanceCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAllClasses();
        result.enable(Capabilities.Capability.NO_CLASS);
        return result;
    }

    public void buildClassifier(Instances exps) throws Exception {
        this.getCapabilities().testWithFail(exps);
        Instances train = new Instances(exps);
        train.deleteWithMissingClass();
        this.m_NumClasses = train.numClasses();
        this.m_NumIterations = this.m_MaxIterations;
        if (this.m_Classifier == null) {
            throw new Exception("A base classifier has not been specified!");
        }
        if (!(this.m_Classifier instanceof WeightedInstancesHandler)) {
            throw new Exception("Base classifier cannot handle weighted instances!");
        }
        this.m_Models = AbstractClassifier.makeCopies((Classifier)this.m_Classifier, (int)this.getMaxIterations());
        if (this.m_Debug) {
            System.err.println("Base classifier: " + this.m_Classifier.getClass().getName());
        }
        this.m_Beta = new double[this.m_NumIterations];
        double N = train.numInstances();
        double sumNi = 0.0;
        int i = 0;
        while ((double)i < N) {
            int nn = train.instance(i).relationalValue(1).numInstances();
            sumNi += (double)nn;
            ++i;
        }
        i = 0;
        while ((double)i < N) {
            train.instance(i).setWeight(sumNi / N);
            ++i;
        }
        this.m_ConvertToSI.setInputFormat(train);
        Instances data = Filter.useFilter((Instances)train, (Filter)this.m_ConvertToSI);
        data.deleteAttributeAt(0);
        if (this.m_DiscretizeBin > 0) {
            this.m_Filter = new Discretize();
            this.m_Filter.setInputFormat(new Instances(data, 0));
            this.m_Filter.setBins(this.m_DiscretizeBin);
            data = Filter.useFilter((Instances)data, (Filter)this.m_Filter);
        }
        for (int m = 0; m < this.m_MaxIterations; ++m) {
            Instance exr;
            this.m_Models[m].buildClassifier(data);
            double[] err = new double[(int)N];
            double[] weights = new double[(int)N];
            boolean perfect = true;
            boolean tooWrong = true;
            int dataIdx = 0;
            int n = 0;
            while ((double)n < N) {
                Instance exn = train.instance(n);
                double nn = exn.relationalValue(1).numInstances();
                int p = 0;
                while ((double)p < nn) {
                    Instance testIns;
                    if ((int)this.m_Models[m].classifyInstance(testIns = data.instance(dataIdx++)) != (int)exn.classValue()) {
                        int n2 = n;
                        err[n2] = err[n2] + 1.0;
                    }
                    ++p;
                }
                weights[n] = exn.weight();
                int n3 = n;
                err[n3] = err[n3] / nn;
                if (err[n] > 0.5) {
                    perfect = false;
                }
                if (err[n] < 0.5) {
                    tooWrong = false;
                }
                ++n;
            }
            if (perfect || tooWrong) {
                this.m_Beta[m] = m == 0 ? 1.0 : 0.0;
                this.m_NumIterations = m + 1;
                if (!this.m_Debug) break;
                System.err.println("No errors");
                break;
            }
            double[] x = new double[]{0.0};
            double[][] b = new double[2][x.length];
            b[0][0] = Double.NaN;
            b[1][0] = Double.NaN;
            OptEng opt = new OptEng();
            opt.setWeights(weights);
            opt.setErrs(err);
            if (this.m_Debug) {
                System.out.println("Start searching for c... ");
            }
            x = opt.findArgmin(x, b);
            while (x == null) {
                x = opt.getVarbValues();
                if (this.m_Debug) {
                    System.out.println("200 iterations finished, not enough!");
                }
                x = opt.findArgmin(x, b);
            }
            if (this.m_Debug) {
                System.out.println("Finished.");
            }
            this.m_Beta[m] = x[0];
            if (this.m_Debug) {
                System.err.println("c = " + this.m_Beta[m]);
            }
            if (Double.isInfinite(this.m_Beta[m]) || Utils.smOrEq((double)this.m_Beta[m], (double)0.0)) {
                this.m_Beta[m] = m == 0 ? 1.0 : 0.0;
                this.m_NumIterations = m + 1;
                if (!this.m_Debug) break;
                System.err.println("Errors out of range!");
                break;
            }
            dataIdx = 0;
            double totWeights = 0.0;
            int r = 0;
            while ((double)r < N) {
                exr = train.instance(r);
                exr.setWeight(weights[r] * Math.exp(this.m_Beta[m] * (2.0 * err[r] - 1.0)));
                totWeights += exr.weight();
                ++r;
            }
            if (this.m_Debug) {
                System.err.println("Total weights = " + totWeights);
            }
            r = 0;
            while ((double)r < N) {
                exr = train.instance(r);
                double num = exr.relationalValue(1).numInstances();
                exr.setWeight(sumNi * exr.weight() / totWeights);
                int s = 0;
                while ((double)s < num) {
                    Instance inss = data.instance(dataIdx);
                    inss.setWeight(exr.weight() / num);
                    if (Double.isNaN(inss.weight())) {
                        throw new Exception("instance " + s + " in bag " + r + " has weight NaN!");
                    }
                    ++dataIdx;
                    ++s;
                }
                ++r;
            }
        }
    }

    public double[] distributionForInstance(Instance exmp) throws Exception {
        double[] rt = new double[this.m_NumClasses];
        Instances insts = new Instances(exmp.dataset(), 0);
        insts.add(exmp);
        insts = Filter.useFilter((Instances)insts, (Filter)this.m_ConvertToSI);
        insts.deleteAttributeAt(0);
        double n = insts.numInstances();
        if (this.m_DiscretizeBin > 0) {
            insts = Filter.useFilter((Instances)insts, (Filter)this.m_Filter);
        }
        int y = 0;
        while ((double)y < n) {
            Instance ins = insts.instance(y);
            for (int x = 0; x < this.m_NumIterations; ++x) {
                int n2 = (int)this.m_Models[x].classifyInstance(ins);
                rt[n2] = rt[n2] + this.m_Beta[x] / n;
            }
            ++y;
        }
        for (int i = 0; i < rt.length; ++i) {
            rt[i] = Math.exp(rt[i]);
        }
        Utils.normalize((double[])rt);
        return rt;
    }

    public String toString() {
        if (this.m_Models == null) {
            return "No model built yet!";
        }
        StringBuffer text = new StringBuffer();
        text.append("MIBoost: number of bins in discretization = " + this.m_DiscretizeBin + "\n");
        if (this.m_NumIterations == 0) {
            text.append("No model built yet.\n");
        } else if (this.m_NumIterations == 1) {
            text.append("No boosting possible, one classifier used: Weight = " + Utils.roundDouble((double)this.m_Beta[0], (int)2) + "\n");
            text.append("Base classifiers:\n" + this.m_Models[0].toString());
        } else {
            text.append("Base classifiers and their weights: \n");
            for (int i = 0; i < this.m_NumIterations; ++i) {
                text.append("\n\n" + i + ": Weight = " + Utils.roundDouble((double)this.m_Beta[i], (int)2) + "\nBase classifier:\n" + this.m_Models[i].toString());
            }
        }
        text.append("\n\nNumber of performed Iterations: " + this.m_NumIterations + "\n");
        return text.toString();
    }

    public String getRevision() {
        return RevisionUtils.extract((String)"$Revision: 10369 $");
    }

    public static void main(String[] argv) {
        MIBoost.runClassifier((Classifier)new MIBoost(), (String[])argv);
    }

    private class OptEng
    extends Optimization {
        private double[] weights;
        private double[] errs;

        private OptEng() {
        }

        public void setWeights(double[] w) {
            this.weights = w;
        }

        public void setErrs(double[] e) {
            this.errs = e;
        }

        protected double objectiveFunction(double[] x) throws Exception {
            double obj = 0.0;
            for (int i = 0; i < this.weights.length; ++i) {
                if (!Double.isNaN(obj += this.weights[i] * Math.exp(x[0] * (2.0 * this.errs[i] - 1.0)))) continue;
                throw new Exception("Objective function value is NaN!");
            }
            return obj;
        }

        protected double[] evaluateGradient(double[] x) throws Exception {
            double[] grad = new double[1];
            for (int i = 0; i < this.weights.length; ++i) {
                grad[0] = grad[0] + this.weights[i] * (2.0 * this.errs[i] - 1.0) * Math.exp(x[0] * (2.0 * this.errs[i] - 1.0));
                if (!Double.isNaN(grad[0])) continue;
                throw new Exception("Gradient is NaN!");
            }
            return grad;
        }

        public String getRevision() {
            return RevisionUtils.extract((String)"$Revision: 10369 $");
        }
    }
}

