/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.functions.LinearRegression;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.SelectedTag;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.supervised.attribute.PLSFilter;

public class RandomRegressionForest
extends RandomizableClassifier
implements WeightedInstancesHandler {
    private static final long serialVersionUID = -3779643299723247400L;
    protected Node[] m_Node;
    protected int m_NumIterations = 100;
    protected int m_PLS = 20;
    protected int m_Min = 100;
    protected Instances m_Data;
    protected PLSFilter m_PLSFilter = null;
    protected double m_Mean = 0.0;

    public String globalInfo() {
        return "RandomRegressionForest: subtract mean and pls, then grow completely random trees (leaf: min .. 2min).\nplus local regression models (-S 1 -C), min >> numPLScomps";
    }

    public Enumeration listOptions() {
        Vector<Object> result = new Vector<Object>();
        result.addElement(new Option("\tNumber of trees.\n\t(default 100)", "N", 1, "-N <num>"));
        result.addElement(new Option("\tLeaf threshold.\n\t(default 100)", "M", 1, "-M <num>"));
        result.addElement(new Option("\tNumber of PLS components.\n\t(default 20)", "C", 1, "-C <num>"));
        Enumeration enm = super.listOptions();
        while (enm.hasMoreElements()) {
            result.addElement(enm.nextElement());
        }
        return result.elements();
    }

    public void setOptions(String[] options) throws Exception {
        String tmpStr = Utils.getOption((char)'C', (String[])options);
        if (tmpStr.length() != 0) {
            this.setPLS(Integer.parseInt(tmpStr));
        } else {
            this.setPLS(20);
        }
        tmpStr = Utils.getOption((char)'M', (String[])options);
        if (tmpStr.length() != 0) {
            this.setMin(Integer.parseInt(tmpStr));
        } else {
            this.setMin(100);
        }
        tmpStr = Utils.getOption((char)'N', (String[])options);
        if (tmpStr.length() != 0) {
            this.setNumIterations(Integer.parseInt(tmpStr));
        } else {
            this.setNumIterations(100);
        }
        super.setOptions(options);
    }

    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        result.add("-C");
        result.add("" + this.getPLS());
        result.add("-M");
        result.add("" + this.getMin());
        result.add("-N");
        result.add("" + this.getNumIterations());
        String[] options = super.getOptions();
        for (int i = 0; i < options.length; ++i) {
            result.add(options[i]);
        }
        return result.toArray(new String[result.size()]);
    }

    public String numIterationsTipText() {
        return "The number of iterations/trees.";
    }

    public void setNumIterations(int value) {
        this.m_NumIterations = value;
    }

    public int getNumIterations() {
        return this.m_NumIterations;
    }

    public String minTipText() {
        return "The leaf threshold.";
    }

    public void setMin(int value) {
        this.m_Min = value;
    }

    public int getMin() {
        return this.m_Min;
    }

    public String PLSTipText() {
        return "The number of PLS components to generate.";
    }

    public void setPLS(int value) {
        this.m_PLS = value;
    }

    public int getPLS() {
        return this.m_PLS;
    }

    protected Instances centerClass(Instances data) {
        this.m_Mean = data.meanOrMode(data.classIndex());
        Instances newData = new Instances(data);
        for (int i = 0; i < newData.numInstances(); ++i) {
            Instance instance = newData.instance(i);
            instance.setClassValue(instance.classValue() - this.m_Mean);
        }
        return newData;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.DATE_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    public void buildClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        data = this.centerClass(data);
        this.m_PLSFilter = new PLSFilter();
        this.m_PLSFilter.setNumComponents(this.m_PLS);
        this.m_PLSFilter.setReplaceMissing(true);
        this.m_PLSFilter.setAlgorithm(new SelectedTag(1, PLSFilter.TAGS_PREPROCESSING));
        this.m_PLSFilter.setAlgorithm(new SelectedTag(1, PLSFilter.TAGS_ALGORITHM));
        this.m_PLSFilter.setInputFormat(data);
        this.m_Data = data = Filter.useFilter((Instances)data, (Filter)this.m_PLSFilter);
        this.m_Node = new Node[this.getNumIterations()];
        Random r = new Random(this.getSeed());
        for (int j = 0; j < this.m_Node.length; ++j) {
            this.m_Node[j] = new Node(data, r, this.m_Min);
        }
        this.m_Data = new Instances(this.m_Data, 0);
    }

    public double classifyInstance(Instance instance) throws Exception {
        this.m_PLSFilter.input(instance);
        this.m_PLSFilter.batchFinished();
        instance = this.m_PLSFilter.output();
        double sum = 0.0;
        for (Node node : this.m_Node) {
            sum += node.classifyInstance(instance);
        }
        return this.m_Mean + sum / (double)this.m_Node.length;
    }

    public String toString() {
        if (this.m_Node == null) {
            return "RandomRegressionForest: No model built yet.";
        }
        StringBuffer text = new StringBuffer();
        text.append("RandomRegressionForest: \n\n");
        ArrayList<String> models = new ArrayList<String>();
        for (Node node : this.m_Node) {
            node.toString(0, text, models);
            text.append("\n\n");
            text.append("-------------------------------------");
        }
        for (int i = 0; i < models.size(); ++i) {
            text.append("LM" + i + ":\n" + (String)models.get(i) + "\n");
        }
        return text.toString();
    }

    public static void main(String[] args) {
        RandomRegressionForest.runClassifier((Classifier)new RandomRegressionForest(), (String[])args);
    }

    public String getRevision() {
        return "1.0";
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public class Node
    implements Serializable {
        private static final long serialVersionUID = -3856383120244210709L;
        protected double m_SplitValue;
        protected int m_SplitIndex = -1;
        protected LinearRegression m_LinearReg;
        protected Node m_Less;
        protected Node m_More;

        public Node(Instances data, Random r, int min) throws Exception {
            Instance instance;
            int i;
            if (data.numInstances() < 2 * min) {
                this.turnIntoLeaf(data);
                return;
            }
            this.findRandomSplit(data, r, min);
            if (this.m_SplitIndex == -1) {
                this.turnIntoLeaf(data);
                return;
            }
            Instances subset = new Instances(data, min);
            for (i = 0; i < data.numInstances(); ++i) {
                instance = data.instance(i);
                if (!(instance.value(this.m_SplitIndex) < this.m_SplitValue)) continue;
                subset.add(instance);
            }
            this.m_Less = new Node(subset, r, min);
            subset = new Instances(data, min);
            for (i = 0; i < data.numInstances(); ++i) {
                instance = data.instance(i);
                if (!(instance.value(this.m_SplitIndex) >= this.m_SplitValue)) continue;
                subset.add(instance);
            }
            this.m_More = new Node(subset, r, min);
        }

        public void turnIntoLeaf(Instances data) throws Exception {
            this.m_LinearReg = new LinearRegression();
            this.m_LinearReg.setEliminateColinearAttributes(false);
            this.m_LinearReg.setAttributeSelectionMethod(new SelectedTag(1, LinearRegression.TAGS_SELECTION));
            this.m_LinearReg.turnChecksOff();
            this.m_LinearReg.setMinimal(true);
            this.m_LinearReg.buildClassifier(data);
        }

        public double classifyInstance(Instance instance) throws Exception {
            if (this.m_LinearReg != null) {
                return this.m_LinearReg.classifyInstance(instance);
            }
            if (instance.value(this.m_SplitIndex) < this.m_SplitValue) {
                return this.m_Less.classifyInstance(instance);
            }
            return this.m_More.classifyInstance(instance);
        }

        public void findRandomSplit(Instances data, Random r, int min) {
            int classIndex = data.classIndex();
            for (int pairs = 0; pairs < 10; ++pairs) {
                int n = data.numInstances();
                int index1 = r.nextInt(n);
                int index2 = r.nextInt(n - 1);
                if (index2 >= index1) {
                    ++index2;
                }
                Instance instance1 = data.instance(index1);
                Instance instance2 = data.instance(index2);
                n = instance1.numValues();
                if (n <= 0) continue;
                for (int retry = 0; retry < 10; ++retry) {
                    double v2;
                    double v1;
                    index1 = r.nextInt(n);
                    int attrIndex = instance1.index(index1);
                    if (attrIndex == classIndex || (v1 = instance1.valueSparse(index1)) == (v2 = instance2.value(attrIndex))) continue;
                    double fraction = r.nextDouble();
                    this.m_SplitIndex = attrIndex;
                    this.m_SplitValue = fraction * v1 + (1.0 - fraction) * v2;
                    if (this.subsetSizesOK(data, min)) {
                        return;
                    }
                    this.m_SplitIndex = -1;
                }
            }
        }

        public boolean subsetSizesOK(Instances data, int min) {
            int smaller = 0;
            int larger = 0;
            for (int i = 0; i < data.numInstances(); ++i) {
                if (data.instance(i).value(this.m_SplitIndex) < this.m_SplitValue) {
                    ++smaller;
                    continue;
                }
                ++larger;
            }
            return smaller >= min && larger >= min;
        }

        public void prefix(int indent, StringBuffer sb) {
            for (int i = 0; i < indent; ++i) {
                sb.append("| ");
            }
        }

        public void toString(int indent, StringBuffer sb, List<String> models) {
            this.prefix(indent, sb);
            if (this.m_SplitIndex == -1) {
                sb.append("LM" + models.size() + "\n");
                models.add(this.m_LinearReg.toString());
            } else {
                sb.append(RandomRegressionForest.this.m_Data.attribute(this.m_SplitIndex).name() + " < " + this.m_SplitValue + "\n");
                this.m_Less.toString(indent + 1, sb, models);
                this.prefix(indent, sb);
                sb.append(RandomRegressionForest.this.m_Data.attribute(this.m_SplitIndex).name() + " > " + this.m_SplitValue + "\n");
                this.m_More.toString(indent + 1, sb, models);
            }
        }
    }
}

