/*
 * Decompiled with CFR 0.152.
 */
package moa.learners.featureanalysis;

import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.capabilities.CapabilitiesHandler;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.MultiClassClassifier;
import moa.classifiers.core.splitcriteria.InfoGainSplitCriterion;
import moa.classifiers.trees.HoeffdingTree;
import moa.core.Measurement;
import moa.core.Utils;
import moa.learners.featureanalysis.FeatureImportanceClassifier;
import moa.options.ClassOption;

public class FeatureImportanceHoeffdingTree
extends AbstractClassifier
implements MultiClassClassifier,
CapabilitiesHandler,
FeatureImportanceClassifier {
    public ClassOption treeLearnerOption = new ClassOption("treeLearner", 'l', "Decision Tree learner.", HoeffdingTree.class, "HoeffdingTree");
    public MultiChoiceOption featureImportanceOption = new MultiChoiceOption("featureImportance", 'o', "Which method to use for feature importance estimations.", new String[]{"MDI", "COVER"}, new String[]{"MDI", "COVER"}, 0);
    protected HoeffdingTree treeLearner = null;
    protected double[] featureImportances;
    protected int nodeCountAtLastFeatureImportanceInquiry = 0;
    protected int featureImportancesInquiries = 0;
    protected static final int FEATURE_IMPORTANCE_MDI = 0;
    protected static final int FEATURE_IMPORTANCE_COVER = 1;

    @Override
    public double[] getFeatureImportances(boolean normalize) {
        if (this.treeLearner.getTreeRoot() != null && this.treeLearner.getNodeCount() > this.nodeCountAtLastFeatureImportanceInquiry) {
            ++this.featureImportancesInquiries;
            this.featureImportances = new double[this.featureImportances.length];
            this.nodeCountAtLastFeatureImportanceInquiry = this.treeLearner.getNodeCount();
            switch (this.featureImportanceOption.getChosenIndex()) {
                case 0: {
                    this.calcMeanDecreaseImpurity(this.treeLearner.getTreeRoot());
                    break;
                }
                case 1: {
                    this.calcMeanCover(this.treeLearner.getTreeRoot());
                }
            }
            if (normalize) {
                double sumFeatureScores = Utils.sum(this.featureImportances);
                int i = 0;
                while (i < this.featureImportances.length) {
                    int n = i++;
                    this.featureImportances[n] = this.featureImportances[n] / sumFeatureScores;
                }
            }
        }
        return this.featureImportances;
    }

    @Override
    public int[] getTopKFeatures(int k, boolean normalize) {
        int i;
        if (this.getFeatureImportances(normalize) == null) {
            return null;
        }
        if (k > this.getFeatureImportances(normalize).length) {
            k = this.getFeatureImportances(normalize).length;
        }
        int[] topK = new int[k];
        double[] currentFeatureScores = new double[this.getFeatureImportances(normalize).length];
        for (i = 0; i < currentFeatureScores.length; ++i) {
            currentFeatureScores[i] = this.getFeatureImportances(normalize)[i];
        }
        for (i = 0; i < k; ++i) {
            int currentTop;
            topK[i] = currentTop = Utils.maxIndex(currentFeatureScores);
            currentFeatureScores[currentTop] = -1.0;
        }
        return topK;
    }

    private void calcMeanCover(HoeffdingTree.Node node) {
        if (node instanceof HoeffdingTree.SplitNode) {
            HoeffdingTree.SplitNode splitNode = (HoeffdingTree.SplitNode)node;
            int attributeIndex = splitNode.getSplitTest().getAttsTestDependsOn()[0];
            if (this.featureImportances.length <= attributeIndex) {
                System.out.println("Error with attributeIndex");
                assert (this.featureImportances.length <= attributeIndex);
            }
            int n = attributeIndex;
            this.featureImportances[n] = this.featureImportances[n] + this.calcNodeCover(splitNode);
            for (HoeffdingTree.Node childNode : splitNode.getChildren()) {
                if (childNode == null) continue;
                this.calcMeanCover(childNode);
            }
        }
    }

    public double calcNodeCover(HoeffdingTree.SplitNode splitNode) {
        double[] thisNodeClassDistributionAtLeaves = splitNode.getObservedClassDistributionAtLeavesReachableThroughThisNode();
        return Utils.sum(thisNodeClassDistributionAtLeaves);
    }

    private void calcMeanDecreaseImpurity(HoeffdingTree.Node node) {
        if (node instanceof HoeffdingTree.SplitNode) {
            HoeffdingTree.SplitNode splitNode = (HoeffdingTree.SplitNode)node;
            int attributeIndex = splitNode.getSplitTest().getAttsTestDependsOn()[0];
            if (this.featureImportances.length <= attributeIndex) {
                System.out.println("Error with attributeIndex");
                assert (this.featureImportances.length <= attributeIndex);
            }
            int n = attributeIndex;
            this.featureImportances[n] = this.featureImportances[n] + this.calcNodeDecreaseImpurity(splitNode);
            for (HoeffdingTree.Node childNode : splitNode.getChildren()) {
                if (childNode == null) continue;
                this.calcMeanDecreaseImpurity(childNode);
            }
        }
    }

    public double calcNodeDecreaseImpurity(HoeffdingTree.SplitNode splitNode) {
        double[] thisNodeClassDistributionAtLeaves = splitNode.getObservedClassDistributionAtLeavesReachableThroughThisNode();
        double thisNodeEntropy = InfoGainSplitCriterion.computeEntropy(thisNodeClassDistributionAtLeaves);
        double sumChildrenImpurityDecrease = 0.0;
        double thisNodeWeight = Utils.sum(thisNodeClassDistributionAtLeaves);
        for (HoeffdingTree.Node childNode : splitNode.getChildren()) {
            if (childNode == null) continue;
            int childNumInstances = (int)Utils.sum(childNode.getObservedClassDistributionAtLeavesReachableThroughThisNode());
            double childEntropy = InfoGainSplitCriterion.computeEntropy(childNode.getObservedClassDistributionAtLeavesReachableThroughThisNode());
            sumChildrenImpurityDecrease += (double)childNumInstances / thisNodeWeight * childEntropy;
        }
        double DI = thisNodeEntropy - sumChildrenImpurityDecrease;
        return DI;
    }

    @Override
    public double[] getVotesForInstance(Instance instance) {
        return this.treeLearner.getVotesForInstance(instance);
    }

    @Override
    public void resetLearningImpl() {
        this.featureImportances = null;
        this.nodeCountAtLastFeatureImportanceInquiry = 0;
        this.featureImportancesInquiries = 0;
        this.treeLearner = (HoeffdingTree)this.getPreparedClassOption(this.treeLearnerOption);
        this.treeLearner.resetLearning();
    }

    @Override
    public void trainOnInstanceImpl(Instance instance) {
        if (this.featureImportances == null) {
            this.featureImportances = new double[instance.numAttributes() - 1];
        }
        this.treeLearner.trainOnInstance(instance);
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return this.treeLearner.getModelMeasurements();
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
        this.treeLearner.getModelDescription(out, indent);
    }

    @Override
    public boolean isRandomizable() {
        if (this.treeLearner == null) {
            return false;
        }
        return this.treeLearner.isRandomizable();
    }
}

