/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
import weka.classifiers.Classifier;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.meta.Corr;
import weka.core.Instance;
import weka.core.Instances;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
class Node
implements Serializable {
    protected double splitValue = 0.0;
    protected int splitIndex = -1;
    protected double[] m_coeffs;
    protected int[] m_subset;
    protected Node less;
    protected Node more;
    protected double min;
    protected double max;

    public void turnIntoLeaf(List<Instance> data, double ridge) throws Exception {
        this.turnIntoLeafCORR(data, ridge);
    }

    public void turnIntoLeafLR(List<Instance> data, double ridge) throws Exception {
        this.splitIndex = -1;
        this.computeMinMax(data);
        LinearRegression l = new LinearRegression();
        l.setOptions(new String[]{"-C", "-S", "1", "-R", "" + ridge});
        l.turnChecksOff();
        Instances trainData = new Instances(data.get(0).dataset(), data.size());
        for (Instance instance : data) {
            trainData.add(instance);
        }
        l.buildClassifier(trainData);
        double[] coeffs = l.coefficients();
        this.m_coeffs = new double[coeffs.length - 1];
        int offset = 0;
        for (int i = 0; i < this.m_coeffs.length; ++i) {
            if (i == trainData.classIndex()) continue;
            this.m_coeffs[offset++] = coeffs[i];
        }
        this.m_coeffs[offset] = coeffs[this.m_coeffs.length];
        this.m_subset = null;
    }

    public void turnIntoLeafCORR(List<Instance> data, double ridge) throws Exception {
        this.splitIndex = -1;
        this.computeMinMax(data);
        LinearRegression l = new LinearRegression();
        l.setOptions(new String[]{"-C", "-S", "1", "-R", "" + ridge});
        l.turnChecksOff();
        Corr corr = new Corr();
        corr.setClassifier((Classifier)l);
        Instances trainData = new Instances(data.get(0).dataset(), data.size());
        for (Instance instance : data) {
            trainData.add(instance);
        }
        corr.buildClassifier(trainData);
        this.m_subset = corr.getSubset();
        this.m_coeffs = corr.getCoeffs();
    }

    public Node(List<Instance> data, Random r, int level, double ridge, int trials, Comparator<Instance>[] comparators) throws Exception {
        if (level <= 0 || data.size() < 10) {
            this.turnIntoLeaf(data, ridge);
            return;
        }
        this.findRandomSplitMedian(data, r, trials, comparators);
        if (this.splitIndex == -1) {
            this.turnIntoLeaf(data, ridge);
            return;
        }
        ArrayList<Instance> subset = new ArrayList<Instance>(data.size());
        for (Instance instance : data) {
            if (!(instance.value(this.splitIndex) < this.splitValue)) continue;
            subset.add(instance);
        }
        if (subset.size() == 0 || subset.size() == data.size()) {
            this.turnIntoLeaf(data, ridge);
            return;
        }
        this.less = new Node(subset, r, level - 1, ridge, trials, comparators);
        subset.clear();
        for (Instance instance : data) {
            if (!(instance.value(this.splitIndex) >= this.splitValue)) continue;
            subset.add(instance);
        }
        this.more = new Node(subset, r, level - 1, ridge, trials, comparators);
    }

    public void computeMinMax(List<Instance> data) throws Exception {
        this.max = this.min = data.get(0).classValue();
        for (Instance instance : data) {
            double v = instance.classValue();
            if (v > this.max) {
                this.max = v;
            }
            if (!(v < this.min)) continue;
            this.min = v;
        }
    }

    public double classifyInstance(Instance instance) throws Exception {
        if (this.splitIndex == -1) {
            double v = this.leafPrediction(instance);
            if (Double.isNaN(v)) {
                return 0.5 * (this.min + this.max);
            }
            if (v > this.max) {
                return this.max;
            }
            if (v < this.min) {
                return this.min;
            }
            return v;
        }
        if (instance.value(this.splitIndex) < this.splitValue) {
            return this.less.classifyInstance(instance);
        }
        return this.more.classifyInstance(instance);
    }

    public double leafPrediction(Instance instance) throws Exception {
        if (this.m_subset != null) {
            double sum = this.m_coeffs[this.m_coeffs.length - 1];
            for (int i = 0; i < this.m_subset.length; ++i) {
                sum += this.m_coeffs[i] * instance.value(this.m_subset[i]);
            }
            return sum;
        }
        int offset = 0;
        double sum = 0.0;
        for (int i = 0; i < instance.numAttributes(); ++i) {
            if (i == instance.classIndex()) continue;
            sum += this.m_coeffs[offset++] * instance.value(i);
        }
        return sum += this.m_coeffs[offset];
    }

    public void prefix(int indent, StringBuffer sb) {
        for (int i = 0; i < indent; ++i) {
            sb.append("| ");
        }
    }

    public void toString(int indent, StringBuffer sb, List<String> models, Instances header) {
        this.prefix(indent, sb);
        if (this.splitIndex == -1) {
            sb.append("LM" + models.size() + "\n");
        } else {
            sb.append(header.attribute(this.splitIndex).name() + " < " + this.splitValue + "\n");
            this.less.toString(indent + 1, sb, models, header);
            this.prefix(indent, sb);
            sb.append(header.attribute(this.splitIndex).name() + " > " + this.splitValue + "\n");
            this.more.toString(indent + 1, sb, models, header);
        }
    }

    public void toString(int indent, StringBuffer sb, Instances header) {
        this.prefix(indent, sb);
        if (this.splitIndex == -1) {
            sb.append("target = " + this.splitValue + "\n");
        } else {
            sb.append(header.attribute(this.splitIndex).name() + " < " + this.splitValue + "\n");
            this.less.toString(indent + 1, sb, header);
            this.prefix(indent, sb);
            sb.append(header.attribute(this.splitIndex).name() + " > " + this.splitValue + "\n");
            this.more.toString(indent + 1, sb, header);
        }
    }

    public void findRandomSplit(List<Instance> data, Random r, int numTrials) {
        int classIndex = data.get(0).classIndex();
        for (int pairs = 0; pairs < 10; ++pairs) {
            int n = data.size();
            int index1 = r.nextInt(n);
            int index2 = r.nextInt(n - 1);
            if (index2 >= index1) {
                ++index2;
            }
            Instance instance1 = data.get(index1);
            Instance instance2 = data.get(index2);
            n = instance1.numValues();
            int bestSplitIndex = -1;
            double bestSplitValue = 0.0;
            double minSSE = Double.MAX_VALUE;
            if (n > 0) {
                for (int trial = 0; trial < numTrials; ++trial) {
                    double v2;
                    double v1;
                    int index = r.nextInt(n);
                    int attrIndex = instance1.index(index);
                    if (attrIndex == classIndex || (v1 = instance1.valueSparse(index)) == (v2 = instance2.value(attrIndex))) continue;
                    double fraction = r.nextDouble();
                    this.splitIndex = attrIndex;
                    this.splitValue = fraction * v1 + (1.0 - fraction) * v2;
                    double sse = this.splitSSE(data);
                    if (!(sse < minSSE)) continue;
                    minSSE = sse;
                    bestSplitIndex = this.splitIndex;
                    bestSplitValue = this.splitValue;
                }
            }
            if (bestSplitIndex <= -1) continue;
            this.splitIndex = bestSplitIndex;
            this.splitValue = bestSplitValue;
            return;
        }
    }

    public double splitSSE(List<Instance> data) {
        double wSum1 = 0.0;
        double wSum2 = 0.0;
        double mean1 = 0.0;
        double mean2 = 0.0;
        double sse1 = 0.0;
        double sse2 = 0.0;
        for (Instance instance : data) {
            double oldMean;
            double value = instance.classValue();
            double w = 1.0;
            if (instance.value(this.splitIndex) < this.splitValue) {
                if (wSum1 > 0.0) {
                    oldMean = mean1;
                    sse1 += (value - oldMean) * (value - (mean1 += (value - oldMean) / (wSum1 += w)));
                    continue;
                }
                mean1 = value;
                wSum1 = w;
                continue;
            }
            if (wSum2 > 0.0) {
                oldMean = mean2;
                sse2 += (value - oldMean) * (value - (mean2 += (value - oldMean) / (wSum2 += w)));
                continue;
            }
            mean2 = value;
            wSum2 = w;
        }
        return sse1 + sse2;
    }

    public void findRandomSplitMedian(List<Instance> train, Random r, int numTrials, Comparator<Instance>[] comparators) {
        this.splitIndex = -1;
        int classIndex = train.get(0).classIndex();
        int numAttributes = train.get(0).numAttributes();
        int bestSplitIndex = -1;
        double bestSplitValue = 0.0;
        double minSSE = Double.MAX_VALUE;
        for (int trial = 0; trial < numTrials; ++trial) {
            int attrIndex = r.nextInt(numAttributes);
            if (attrIndex == classIndex) continue;
            this.splitIndex = attrIndex;
            Collections.sort(train, comparators[this.splitIndex]);
            double sse = this.sse(train);
            if (!(sse < minSSE)) continue;
            minSSE = sse;
            bestSplitIndex = this.splitIndex;
            bestSplitValue = this.splitValue;
        }
        if (bestSplitIndex > -1) {
            this.splitIndex = bestSplitIndex;
            this.splitValue = bestSplitValue;
        }
    }

    public double sse(List<Instance> data) {
        int middle = data.size() / 2;
        this.splitValue = 0.5 * (data.get(middle - 1).value(this.splitIndex) + data.get(middle).value(this.splitIndex));
        double sse1 = this.sse(data, 0, middle);
        double sse2 = this.sse(data, middle, data.size());
        return sse1 + sse2;
    }

    public double sse(List<Instance> data, int from, int to) {
        double mean = this.mean(data, from, to);
        double sse = 0.0;
        for (int i = from; i < to; ++i) {
            double v = data.get(i).classValue();
            double delta = mean - v;
            sse += delta * delta;
        }
        return sse;
    }

    public double mean(List<Instance> data, int from, int to) {
        double sum = 0.0;
        for (int i = from; i < to; ++i) {
            double v = data.get(i).classValue();
            sum += v;
        }
        return sum / (double)(to - from);
    }

    public void findRandomSplitALL(List<Instance> train, Random r, int numTrials, Comparator<Instance>[] comparators) {
        this.splitIndex = -1;
        int classIndex = train.get(0).classIndex();
        int numAttributes = train.get(0).numAttributes();
        int bestSplitIndex = -1;
        double bestSplitValue = 0.0;
        double minSSE = Double.MAX_VALUE;
        for (int trial = 0; trial < numTrials; ++trial) {
            int attrIndex = r.nextInt(numAttributes);
            if (attrIndex == classIndex) continue;
            this.splitIndex = attrIndex;
            Collections.sort(train, comparators[this.splitIndex]);
            double sse = this.sseALL(train);
            if (!(sse < minSSE)) continue;
            minSSE = sse;
            bestSplitIndex = this.splitIndex;
            bestSplitValue = this.splitValue;
        }
        if (bestSplitIndex > -1) {
            this.splitIndex = bestSplitIndex;
            this.splitValue = bestSplitValue;
        }
    }

    public double sseALL(List<Instance> data) {
        double wSum1 = 0.0;
        double wSum2 = 0.0;
        double mean1 = 0.0;
        double mean2 = 0.0;
        double sse1 = 0.0;
        double sse2 = 0.0;
        for (Instance instance : data) {
            double value = instance.classValue();
            double w = 1.0;
            if (wSum2 > 0.0) {
                double oldMean = mean2;
                sse2 += (value - oldMean) * (value - (mean2 += (value - oldMean) / (wSum2 += w)));
                continue;
            }
            mean2 = value;
            wSum2 = w;
        }
        double minSSE = sse1 + sse2;
        double bestSplitValue = Double.MAX_VALUE;
        for (int i = 0; i < data.size() - 1; ++i) {
            double oldMean;
            Instance instance = data.get(i);
            double value = instance.classValue();
            double w = 1.0;
            if (wSum1 > 0.0) {
                oldMean = mean1;
                sse1 += (value - oldMean) * (value - (mean1 += (value - oldMean) / (wSum1 += w)));
            } else {
                mean1 = value;
                wSum1 = w;
            }
            if (wSum2 > w) {
                oldMean = mean2;
                mean2 = (mean2 * wSum2 - value) / (wSum2 - w);
                sse2 -= (value - mean2) / (value - oldMean);
                wSum2 -= w;
            } else {
                wSum2 = 0.0;
                mean2 = 0.0;
                sse2 = 0.0;
            }
            if (!(sse1 + sse2 < minSSE)) continue;
            minSSE = sse1 + sse2;
            bestSplitValue = 0.5 * (instance.value(this.splitIndex) + data.get(i + 1).value(this.splitIndex));
        }
        this.splitValue = bestSplitValue;
        return minSSE;
    }
}

