/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.mi;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.mi.TLDSimple_Optm;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

public class TLDSimple
extends RandomizableClassifier
implements OptionHandler,
MultiInstanceCapabilitiesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = 9040995947243286591L;
    protected double[][] m_MeanP = null;
    protected double[][] m_MeanN = null;
    protected double[][] m_SumP = null;
    protected double[][] m_SumN = null;
    protected double[] m_SgmSqP;
    protected double[] m_SgmSqN;
    protected double[] m_ParamsP = null;
    protected double[] m_ParamsN = null;
    protected int m_Dimension = 0;
    protected double[] m_Class = null;
    protected int m_NumClasses = 2;
    public static double ZERO = 1.0E-12;
    protected int m_Run = 1;
    protected double m_Cutoff;
    protected boolean m_UseEmpiricalCutOff = false;
    private double[] m_LkRatio;
    private Instances m_Attribute = null;

    public String globalInfo() {
        return "A simpler version of TLD, mu random but sigma^2 fixed and estimated via data.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.MASTERSTHESIS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Xin Xu");
        result.setValue(TechnicalInformation.Field.YEAR, "2003");
        result.setValue(TechnicalInformation.Field.TITLE, "Statistical learning in multiple instance problem");
        result.setValue(TechnicalInformation.Field.SCHOOL, "University of Waikato");
        result.setValue(TechnicalInformation.Field.ADDRESS, "Hamilton, NZ");
        result.setValue(TechnicalInformation.Field.NOTE, "0657.594");
        return result;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.BINARY_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return result;
    }

    public Capabilities getMultiInstanceCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.disableAllClasses();
        result.enable(Capabilities.Capability.NO_CLASS);
        return result;
    }

    public void buildClassifier(Instances exs) throws Exception {
        int u;
        int t;
        int w;
        int k;
        int v;
        this.getCapabilities().testWithFail(exs);
        exs = new Instances(exs);
        exs.deleteWithMissingClass();
        int numegs = exs.numInstances();
        this.m_Dimension = exs.attribute(1).relation().numAttributes();
        this.m_Attribute = exs.attribute(1).relation().stringFreeStructure();
        Instances pos = new Instances(exs, 0);
        Instances neg = new Instances(exs, 0);
        for (int u2 = 0; u2 < numegs; ++u2) {
            Instance example = exs.instance(u2);
            if (example.classValue() == 1.0) {
                pos.add(example);
                continue;
            }
            neg.add(example);
        }
        int pnum = pos.numInstances();
        int nnum = neg.numInstances();
        this.m_MeanP = new double[pnum][this.m_Dimension];
        this.m_SumP = new double[pnum][this.m_Dimension];
        this.m_MeanN = new double[nnum][this.m_Dimension];
        this.m_SumN = new double[nnum][this.m_Dimension];
        this.m_ParamsP = new double[2 * this.m_Dimension];
        this.m_ParamsN = new double[2 * this.m_Dimension];
        this.m_SgmSqP = new double[this.m_Dimension];
        this.m_SgmSqN = new double[this.m_Dimension];
        double[][] varP = new double[pnum][this.m_Dimension];
        double[][] varN = new double[nnum][this.m_Dimension];
        double[] effNumExP = new double[this.m_Dimension];
        double[] effNumExN = new double[this.m_Dimension];
        double[] pMM = new double[this.m_Dimension];
        double[] nMM = new double[this.m_Dimension];
        double[] pVM = new double[this.m_Dimension];
        double[] nVM = new double[this.m_Dimension];
        double[] numOneInsExsP = new double[this.m_Dimension];
        double[] numOneInsExsN = new double[this.m_Dimension];
        double[] pInvN = new double[this.m_Dimension];
        double[] nInvN = new double[this.m_Dimension];
        for (v = 0; v < pnum; ++v) {
            Instances pxi = pos.instance(v).relationalValue(1);
            for (k = 0; k < pxi.numAttributes(); ++k) {
                this.m_MeanP[v][k] = pxi.meanOrMode(k);
                varP[v][k] = pxi.variance(k);
            }
            w = 0;
            t = 0;
            while (w < this.m_Dimension) {
                if (varP[v][w] <= 0.0) {
                    varP[v][w] = 0.0;
                }
                if (!Double.isNaN(this.m_MeanP[v][w])) {
                    for (u = 0; u < pxi.numInstances(); ++u) {
                        if (pxi.instance(u).isMissing(t)) continue;
                        double[] dArray = this.m_SumP[v];
                        int n = w;
                        dArray[n] = dArray[n] + pxi.instance(u).weight();
                    }
                    int n = w;
                    pMM[n] = pMM[n] + this.m_MeanP[v][w];
                    int n2 = w;
                    pVM[n2] = pVM[n2] + this.m_MeanP[v][w] * this.m_MeanP[v][w];
                    if (this.m_SumP[v][w] > 1.0 && varP[v][w] > ZERO) {
                        int n3 = w;
                        this.m_SgmSqP[n3] = this.m_SgmSqP[n3] + varP[v][w] * (this.m_SumP[v][w] - 1.0) / this.m_SumP[v][w];
                        int n4 = w;
                        effNumExP[n4] = effNumExP[n4] + 1.0;
                        int n5 = w;
                        pInvN[n5] = pInvN[n5] + 1.0 / this.m_SumP[v][w];
                    } else {
                        int n6 = w;
                        numOneInsExsP[n6] = numOneInsExsP[n6] + 1.0;
                    }
                }
                ++w;
                ++t;
            }
        }
        for (v = 0; v < nnum; ++v) {
            Instances nxi = neg.instance(v).relationalValue(1);
            for (k = 0; k < nxi.numAttributes(); ++k) {
                this.m_MeanN[v][k] = nxi.meanOrMode(k);
                varN[v][k] = nxi.variance(k);
            }
            w = 0;
            t = 0;
            while (w < this.m_Dimension) {
                if (varN[v][w] <= 0.0) {
                    varN[v][w] = 0.0;
                }
                if (!Double.isNaN(this.m_MeanN[v][w])) {
                    for (u = 0; u < nxi.numInstances(); ++u) {
                        if (nxi.instance(u).isMissing(t)) continue;
                        double[] dArray = this.m_SumN[v];
                        int n = w;
                        dArray[n] = dArray[n] + nxi.instance(u).weight();
                    }
                    int n = w;
                    nMM[n] = nMM[n] + this.m_MeanN[v][w];
                    int n7 = w;
                    nVM[n7] = nVM[n7] + this.m_MeanN[v][w] * this.m_MeanN[v][w];
                    if (this.m_SumN[v][w] > 1.0 && varN[v][w] > ZERO) {
                        int n8 = w;
                        this.m_SgmSqN[n8] = this.m_SgmSqN[n8] + varN[v][w] * (this.m_SumN[v][w] - 1.0) / this.m_SumN[v][w];
                        int n9 = w;
                        effNumExN[n9] = effNumExN[n9] + 1.0;
                        int n10 = w;
                        nInvN[n10] = nInvN[n10] + 1.0 / this.m_SumN[v][w];
                    } else {
                        int n11 = w;
                        numOneInsExsN[n11] = numOneInsExsN[n11] + 1.0;
                    }
                }
                ++w;
                ++t;
            }
        }
        for (int u3 = 0; u3 < this.m_Dimension; ++u3) {
            if (this.m_SgmSqP[u3] != 0.0) {
                int n = u3;
                this.m_SgmSqP[n] = this.m_SgmSqP[n] / (effNumExP[u3] - pInvN[u3]);
            } else {
                this.m_SgmSqP[u3] = 0.0;
            }
            if (this.m_SgmSqN[u3] != 0.0) {
                int n = u3;
                this.m_SgmSqN[n] = this.m_SgmSqN[n] / (effNumExN[u3] - nInvN[u3]);
            } else {
                this.m_SgmSqN[u3] = 0.0;
            }
            int n = u3;
            effNumExP[n] = effNumExP[n] + numOneInsExsP[u3];
            int n12 = u3;
            effNumExN[n12] = effNumExN[n12] + numOneInsExsN[u3];
            int n13 = u3;
            pMM[n13] = pMM[n13] / effNumExP[u3];
            int n14 = u3;
            nMM[n14] = nMM[n14] / effNumExN[u3];
            pVM[u3] = pVM[u3] / (effNumExP[u3] - 1.0) - pMM[u3] * pMM[u3] * effNumExP[u3] / (effNumExP[u3] - 1.0);
            nVM[u3] = nVM[u3] / (effNumExN[u3] - 1.0) - nMM[u3] * nMM[u3] * effNumExN[u3] / (effNumExN[u3] - 1.0);
        }
        double[][] bounds = new double[2][2];
        double[] pThisParam = new double[2];
        double[] nThisParam = new double[2];
        Random whichEx = new Random(this.m_Seed);
        for (int x = 0; x < this.m_Dimension; ++x) {
            double w2;
            double m;
            int z;
            double thisMin;
            int y;
            pThisParam[0] = pVM[x];
            if (pThisParam[0] <= ZERO) {
                pThisParam[0] = 1.0;
            }
            pThisParam[1] = pMM[x];
            nThisParam[0] = nVM[x];
            if (nThisParam[0] <= ZERO) {
                nThisParam[0] = 1.0;
            }
            nThisParam[1] = nMM[x];
            bounds[0][0] = ZERO;
            bounds[0][1] = Double.NaN;
            bounds[1][0] = Double.NaN;
            bounds[1][1] = Double.NaN;
            double pminVal = Double.MAX_VALUE;
            double nminVal = Double.MAX_VALUE;
            TLDSimple_Optm pOp = null;
            TLDSimple_Optm nOp = null;
            boolean isRunValid = true;
            double[] sumP = new double[pnum];
            double[] meanP = new double[pnum];
            double[] sumN = new double[nnum];
            double[] meanN = new double[nnum];
            for (int p = 0; p < pnum; ++p) {
                sumP[p] = this.m_SumP[p][x];
                meanP[p] = this.m_MeanP[p][x];
            }
            for (int q = 0; q < nnum; ++q) {
                sumN[q] = this.m_SumN[q][x];
                meanN[q] = this.m_MeanN[q][x];
            }
            for (y = 0; y < this.m_Run; ++y) {
                pOp = new TLDSimple_Optm();
                pOp.setNum(sumP);
                pOp.setSgmSq(this.m_SgmSqP[x]);
                if (this.getDebug()) {
                    System.out.println("m_SgmSqP[" + x + "]= " + this.m_SgmSqP[x]);
                }
                pOp.setXBar(meanP);
                pThisParam = pOp.findArgmin(pThisParam, bounds);
                while (pThisParam == null) {
                    pThisParam = pOp.getVarbValues();
                    if (this.getDebug()) {
                        System.out.println("!!! 200 iterations finished, not enough!");
                    }
                    pThisParam = pOp.findArgmin(pThisParam, bounds);
                }
                thisMin = pOp.getMinFunction();
                if (!Double.isNaN(thisMin) && thisMin < pminVal) {
                    pminVal = thisMin;
                    for (z = 0; z < 2; ++z) {
                        this.m_ParamsP[2 * x + z] = pThisParam[z];
                    }
                }
                if (Double.isNaN(thisMin)) {
                    pThisParam = new double[2];
                    isRunValid = false;
                }
                if (!isRunValid) {
                    --y;
                    isRunValid = true;
                }
                int pone = whichEx.nextInt(pnum);
                while (Double.isNaN(this.m_MeanP[pone][x])) {
                    pone = whichEx.nextInt(pnum);
                }
                m = this.m_MeanP[pone][x];
                pThisParam[0] = w2 = (m - pThisParam[1]) * (m - pThisParam[1]);
                pThisParam[1] = m;
            }
            for (y = 0; y < this.m_Run; ++y) {
                nOp = new TLDSimple_Optm();
                nOp.setNum(sumN);
                nOp.setSgmSq(this.m_SgmSqN[x]);
                if (this.getDebug()) {
                    System.out.println(this.m_SgmSqN[x]);
                }
                nOp.setXBar(meanN);
                nThisParam = nOp.findArgmin(nThisParam, bounds);
                while (nThisParam == null) {
                    nThisParam = nOp.getVarbValues();
                    if (this.getDebug()) {
                        System.out.println("!!! 200 iterations finished, not enough!");
                    }
                    nThisParam = nOp.findArgmin(nThisParam, bounds);
                }
                thisMin = nOp.getMinFunction();
                if (!Double.isNaN(thisMin) && thisMin < nminVal) {
                    nminVal = thisMin;
                    for (z = 0; z < 2; ++z) {
                        this.m_ParamsN[2 * x + z] = nThisParam[z];
                    }
                }
                if (Double.isNaN(thisMin)) {
                    nThisParam = new double[2];
                    isRunValid = false;
                }
                if (!isRunValid) {
                    --y;
                    isRunValid = true;
                }
                int none = whichEx.nextInt(nnum);
                while (Double.isNaN(this.m_MeanN[none][x])) {
                    none = whichEx.nextInt(nnum);
                }
                m = this.m_MeanN[none][x];
                nThisParam[0] = w2 = (m - nThisParam[1]) * (m - nThisParam[1]);
                nThisParam[1] = m;
            }
        }
        this.m_LkRatio = new double[this.m_Dimension];
        if (this.m_UseEmpiricalCutOff) {
            double[] pLogOdds = new double[pnum];
            double[] nLogOdds = new double[nnum];
            for (int p = 0; p < pnum; ++p) {
                pLogOdds[p] = this.likelihoodRatio(this.m_SumP[p], this.m_MeanP[p]);
            }
            for (int q = 0; q < nnum; ++q) {
                nLogOdds[q] = this.likelihoodRatio(this.m_SumN[q], this.m_MeanN[q]);
            }
            this.findCutOff(pLogOdds, nLogOdds);
        } else {
            this.m_Cutoff = -Math.log((double)pnum / (double)nnum);
        }
        if (this.getDebug()) {
            System.err.println("\n\n???Cut-off=" + this.m_Cutoff);
        }
    }

    public double classifyInstance(Instance ex) throws Exception {
        Instances exi = ex.relationalValue(1);
        double[] n = new double[this.m_Dimension];
        double[] xBar = new double[this.m_Dimension];
        for (int i = 0; i < exi.numAttributes(); ++i) {
            xBar[i] = exi.meanOrMode(i);
        }
        int w = 0;
        int t = 0;
        while (w < this.m_Dimension) {
            for (int u = 0; u < exi.numInstances(); ++u) {
                if (exi.instance(u).isMissing(t)) continue;
                int n2 = w;
                n[n2] = n[n2] + exi.instance(u).weight();
            }
            ++w;
            ++t;
        }
        double logOdds = this.likelihoodRatio(n, xBar);
        return logOdds > this.m_Cutoff ? 1.0 : 0.0;
    }

    public double[] distributionForInstance(Instance ex) throws Exception {
        double[] distribution = new double[2];
        Instances exi = ex.relationalValue(1);
        double[] n = new double[this.m_Dimension];
        double[] xBar = new double[this.m_Dimension];
        for (int i = 0; i < exi.numAttributes(); ++i) {
            xBar[i] = exi.meanOrMode(i);
        }
        int w = 0;
        int t = 0;
        while (w < this.m_Dimension) {
            for (int u = 0; u < exi.numInstances(); ++u) {
                if (exi.instance(u).isMissing(t)) continue;
                int n2 = w;
                n[n2] = n[n2] + exi.instance(u).weight();
            }
            ++w;
            ++t;
        }
        double logOdds = this.likelihoodRatio(n, xBar);
        distribution[0] = 1.0 / (1.0 + Math.exp(logOdds));
        distribution[1] = 1.0 - distribution[0];
        return distribution;
    }

    private double likelihoodRatio(double[] n, double[] xBar) {
        double LLP = 0.0;
        double LLN = 0.0;
        for (int x = 0; x < this.m_Dimension; ++x) {
            if (Double.isNaN(xBar[x])) continue;
            double w = this.m_ParamsP[2 * x];
            double m = this.m_ParamsP[2 * x + 1];
            double llp = Math.log(w * n[x] + this.m_SgmSqP[x]) + n[x] * (m - xBar[x]) * (m - xBar[x]) / (w * n[x] + this.m_SgmSqP[x]);
            LLP -= llp;
            w = this.m_ParamsN[2 * x];
            m = this.m_ParamsN[2 * x + 1];
            double lln = Math.log(w * n[x] + this.m_SgmSqN[x]) + n[x] * (m - xBar[x]) * (m - xBar[x]) / (w * n[x] + this.m_SgmSqN[x]);
            LLN -= lln;
            int n2 = x;
            this.m_LkRatio[n2] = this.m_LkRatio[n2] + (llp - lln);
        }
        return LLP - LLN / (double)this.m_Dimension;
    }

    private void findCutOff(double[] pos, double[] neg) {
        int[] pOrder = Utils.sort((double[])pos);
        int[] nOrder = Utils.sort((double[])neg);
        int pNum = pos.length;
        int nNum = neg.length;
        int p = 0;
        int n = 0;
        double fstAccu = 0.0;
        double sndAccu = pNum;
        double maxAccu = 0.0;
        double minDistTo0 = Double.MAX_VALUE;
        while (n < nNum && pos[pOrder[0]] >= neg[nOrder[n]]) {
            ++n;
            fstAccu += 1.0;
        }
        if (n >= nNum) {
            this.m_Cutoff = (neg[nOrder[nNum - 1]] + pos[pOrder[0]]) / 2.0;
            return;
        }
        while (p < pNum && n < nNum) {
            double split;
            if (pos[pOrder[p]] >= neg[nOrder[n]]) {
                fstAccu += 1.0;
                split = neg[nOrder[n]];
                ++n;
            } else {
                sndAccu -= 1.0;
                split = pos[pOrder[p]];
                ++p;
            }
            if (!(fstAccu + sndAccu > maxAccu) && (fstAccu + sndAccu != maxAccu || !(Math.abs(split) < minDistTo0))) continue;
            maxAccu = fstAccu + sndAccu;
            this.m_Cutoff = split;
            minDistTo0 = Math.abs(split);
        }
    }

    public Enumeration<Option> listOptions() {
        Vector<Object> result = new Vector<Object>();
        result.addElement(new Option("\tSet whether or not use empirical\n\tlog-odds cut-off instead of 0", "C", 0, "-C"));
        result.addElement(new Option("\tSet the number of multiple runs \n\tneeded for searching the MLE.", "R", 1, "-R <numOfRuns>"));
        result.addAll(Collections.list(super.listOptions()));
        return result.elements();
    }

    public void setOptions(String[] options) throws Exception {
        this.setUsingCutOff(Utils.getFlag((char)'C', (String[])options));
        String runString = Utils.getOption((char)'R', (String[])options);
        if (runString.length() != 0) {
            this.setNumRuns(Integer.parseInt(runString));
        } else {
            this.setNumRuns(1);
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions((String[])options);
    }

    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        if (this.getUsingCutOff()) {
            result.add("-C");
        }
        result.add("-R");
        result.add("" + this.getNumRuns());
        Collections.addAll(result, super.getOptions());
        return result.toArray(new String[result.size()]);
    }

    public String numRunsTipText() {
        return "The number of runs to perform.";
    }

    public void setNumRuns(int numRuns) {
        this.m_Run = numRuns;
    }

    public int getNumRuns() {
        return this.m_Run;
    }

    public String usingCutOffTipText() {
        return "Whether to use an empirical cutoff.";
    }

    public void setUsingCutOff(boolean cutOff) {
        this.m_UseEmpiricalCutOff = cutOff;
    }

    public boolean getUsingCutOff() {
        return this.m_UseEmpiricalCutOff;
    }

    public String toString() {
        StringBuffer text = new StringBuffer("\n\nTLDSimple:\n");
        int x = 0;
        int y = 0;
        while (x < this.m_Dimension) {
            double sgm = this.m_SgmSqP[x];
            double w = this.m_ParamsP[2 * x];
            double m = this.m_ParamsP[2 * x + 1];
            text.append("\n" + this.m_Attribute.attribute(y).name() + "\nPositive: " + "sigma^2=" + sgm + ", w=" + w + ", m=" + m + "\n");
            sgm = this.m_SgmSqN[x];
            w = this.m_ParamsN[2 * x];
            m = this.m_ParamsN[2 * x + 1];
            text.append("Negative: sigma^2=" + sgm + ", w=" + w + ", m=" + m + "\n");
            ++x;
            ++y;
        }
        return text.toString();
    }

    public String getRevision() {
        return RevisionUtils.extract((String)"$Revision: 10369 $");
    }

    public static void main(String[] args) {
        TLDSimple.runClassifier((Classifier)new TLDSimple(), (String[])args);
    }
}

