/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions.supportVector;

import java.util.Arrays;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.functions.supportVector.CachedKernel;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.instance.RemoveWithValues;

public class FisherKernel
extends CachedKernel {
    static final long serialVersionUID = 5247117544316387852L;
    protected double[] m_kernelPrecalc;
    protected double sigma = 7.0;
    protected boolean useStandardSigma = false;
    private Instances dataToFilter;
    private boolean first = true;

    public FisherKernel() {
    }

    public FisherKernel(Instances data, int cacheSize) throws Exception {
        this.setCacheSize(cacheSize);
        this.buildKernel(data);
    }

    private void calculateSigma() throws Exception {
        Instances preFiltered = new Instances(this.m_data);
        RemoveWithValues filter = new RemoveWithValues();
        String[] options = new String[]{"-L", "1", "-C", preFiltered.classIndex() + 1 + "", "-V"};
        filter.setInputFormat(preFiltered);
        filter.setOptions(options);
        Instances filtered = Filter.useFilter((Instances)preFiltered, (Filter)filter);
        Instances class1 = new Instances(filtered);
        class1.setRelationName(preFiltered.relationName() + "_class1");
        filter = new RemoveWithValues();
        options = new String[]{"-L", "1", "-C", preFiltered.classIndex() + 1 + ""};
        filter.setInputFormat(preFiltered);
        filter.setOptions(options);
        filtered = Filter.useFilter((Instances)preFiltered, (Filter)filter);
        Instances class2 = new Instances(filtered);
        class1.setRelationName(preFiltered.relationName() + "_class2");
        double[] saveMinDist = new double[class1.numInstances()];
        for (int i = 0; i < class1.numInstances(); ++i) {
            double[] findMinDist = new double[class2.numInstances()];
            Instance inst1 = class1.instance(i);
            for (int j = 0; j < class2.numInstances(); ++j) {
                Instance inst2 = class2.instance(j);
                findMinDist[j] = Math.sqrt(-2.0 * this.dotProd(inst1, inst2) + this.dotProd(inst1, inst1) + this.dotProd(inst2, inst2));
            }
            saveMinDist[i] = findMinDist[Utils.minIndex((double[])findMinDist)];
        }
        Arrays.sort(saveMinDist);
        this.sigma = saveMinDist.length % 2 == 1 ? saveMinDist[saveMinDist.length / 2] : 0.5 * (saveMinDist[saveMinDist.length / 2] + saveMinDist[saveMinDist.length / 2 - 1]);
        System.out.println("sigma: " + this.sigma);
    }

    public String globalInfo() {
        return "The Fisher kernel. K(x, y) = e^-((1/(2*sigma^2)) * <x-y, x-y>^2)";
    }

    public Enumeration listOptions() {
        Vector result = new Vector();
        Enumeration en = super.listOptions();
        while (en.hasMoreElements()) {
            result.addElement(en.nextElement());
        }
        if (this.useStandardSigma) {
            result.add("-G");
        }
        return result.elements();
    }

    public void setOptions(String[] options) throws Exception {
        this.useStandardSigma = Utils.getFlag((char)'G', (String[])options);
        super.setOptions(options);
    }

    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        String[] options = super.getOptions();
        for (int i = 0; i < options.length; ++i) {
            result.add(options[i]);
        }
        if (this.useStandardSigma) {
            result.add("-G");
        }
        return result.toArray(new String[result.size()]);
    }

    protected double evaluate(int id1, int id2, Instance inst1) throws Exception {
        if (id1 == id2) {
            return 1.0;
        }
        double precalc1 = id1 == -1 ? this.dotProd(inst1, inst1) : this.m_kernelPrecalc[id1];
        Instance inst2 = this.m_data.instance(id2);
        double result = Math.exp(1.0 / (2.0 * this.sigma * this.sigma) * (2.0 * this.dotProd(inst1, inst2) - precalc1 - this.m_kernelPrecalc[id2]));
        return result;
    }

    protected void initVars(Instances data) {
        super.initVars(data);
        this.m_kernelPrecalc = new double[data.numInstances()];
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enableAllClasses();
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    public void buildKernel(Instances data) throws Exception {
        if (!this.getChecksTurnedOff()) {
            this.getCapabilities().testWithFail(data);
        }
        this.initVars(data);
        if (this.first && !this.useStandardSigma) {
            this.calculateSigma();
            this.first = false;
        }
        for (int i = 0; i < data.numInstances(); ++i) {
            this.m_kernelPrecalc[i] = this.dotProd(data.instance(i), data.instance(i));
        }
    }

    public String toString() {
        return "Fisher kernel: K(x,y) = e^-((1/(2*sigma*sigma)) * <x-y,x-y>^2) ";
    }

    public String getRevision() {
        return RevisionUtils.extract((String)"$Revision: 6 $");
    }

    public boolean isUseStandardSigma() {
        return this.useStandardSigma;
    }

    public void setUseStandardSigma(boolean useStandardSigma) {
        this.useStandardSigma = useStandardSigma;
    }
}

