/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.core.attributeclassobservers;

import java.io.Serializable;
import moa.classifiers.core.AttributeSplitSuggestion;
import moa.classifiers.core.attributeclassobservers.NumericAttributeClassObserver;
import moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest;
import moa.classifiers.core.splitcriteria.SplitCriterion;
import moa.core.DoubleVector;
import moa.core.ObjectRepository;
import moa.options.AbstractOptionHandler;
import moa.tasks.TaskMonitor;

public class BinaryTreeNumericAttributeClassObserver
extends AbstractOptionHandler
implements NumericAttributeClassObserver {
    private static final long serialVersionUID = 1L;
    protected Node root = null;

    public void observeAttributeClass(double attVal, int classVal, double weight) {
        if (!Double.isNaN(attVal)) {
            if (this.root == null) {
                this.root = new Node(attVal, classVal, weight);
            } else {
                this.root.insertValue(attVal, classVal, weight);
            }
        }
    }

    public double probabilityOfAttributeValueGivenClass(double attVal, int classVal) {
        return 0.0;
    }

    public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(SplitCriterion criterion, double[] preSplitDist, int attIndex, boolean binaryOnly) {
        return this.searchForBestSplitOption(this.root, null, null, null, null, false, criterion, preSplitDist, attIndex);
    }

    protected AttributeSplitSuggestion searchForBestSplitOption(Node currentNode, AttributeSplitSuggestion currentBestOption, double[] actualParentLeft, double[] parentLeft, double[] parentRight, boolean leftChild, SplitCriterion criterion, double[] preSplitDist, int attIndex) {
        if (currentNode == null) {
            return currentBestOption;
        }
        DoubleVector leftDist = new DoubleVector();
        DoubleVector rightDist = new DoubleVector();
        if (parentLeft == null) {
            leftDist.addValues(currentNode.classCountsLeft);
            rightDist.addValues(currentNode.classCountsRight);
        } else {
            leftDist.addValues(parentLeft);
            rightDist.addValues(parentRight);
            if (leftChild) {
                DoubleVector exactParentDist = new DoubleVector();
                exactParentDist.addValues(actualParentLeft);
                exactParentDist.subtractValues(currentNode.classCountsLeft);
                exactParentDist.subtractValues(currentNode.classCountsRight);
                leftDist.subtractValues(currentNode.classCountsRight);
                rightDist.addValues(currentNode.classCountsRight);
                rightDist.addValues(exactParentDist);
                leftDist.subtractValues(exactParentDist);
            } else {
                leftDist.addValues(currentNode.classCountsLeft);
                rightDist.subtractValues(currentNode.classCountsLeft);
            }
        }
        double[][] postSplitDists = new double[][]{leftDist.getArrayRef(), rightDist.getArrayRef()};
        double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists);
        if (currentBestOption == null || merit > currentBestOption.merit) {
            currentBestOption = new AttributeSplitSuggestion(new NumericAttributeBinaryTest(attIndex, currentNode.cut_point, true), postSplitDists, merit);
        }
        currentBestOption = this.searchForBestSplitOption(currentNode.left, currentBestOption, currentNode.classCountsLeft.getArrayRef(), postSplitDists[0], postSplitDists[1], true, criterion, preSplitDist, attIndex);
        currentBestOption = this.searchForBestSplitOption(currentNode.right, currentBestOption, currentNode.classCountsLeft.getArrayRef(), postSplitDists[0], postSplitDists[1], false, criterion, preSplitDist, attIndex);
        return currentBestOption;
    }

    public void getDescription(StringBuilder sb, int indent) {
    }

    protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
    }

    protected class Node
    implements Serializable {
        private static final long serialVersionUID = 1L;
        public double cut_point;
        public DoubleVector classCountsLeft = new DoubleVector();
        public DoubleVector classCountsRight = new DoubleVector();
        public Node left;
        public Node right;

        public Node(double val, int label, double weight) {
            this.cut_point = val;
            this.classCountsLeft.addToValue(label, weight);
        }

        public void insertValue(double val, int label, double weight) {
            if (val == this.cut_point) {
                this.classCountsLeft.addToValue(label, weight);
            } else if (val <= this.cut_point) {
                this.classCountsLeft.addToValue(label, weight);
                if (this.left == null) {
                    this.left = new Node(val, label, weight);
                } else {
                    this.left.insertValue(val, label, weight);
                }
            } else {
                this.classCountsRight.addToValue(label, weight);
                if (this.right == null) {
                    this.right = new Node(val, label, weight);
                } else {
                    this.right.insertValue(val, label, weight);
                }
            }
        }
    }
}

