/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.dt;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.tree.dt.AbstractCompressedNode;
import hex.tree.dt.AbstractSplittingRule;
import hex.tree.dt.CategoricalFeatureLimits;
import hex.tree.dt.CategoricalSplittingRule;
import hex.tree.dt.CompressedDT;
import hex.tree.dt.CompressedLeaf;
import hex.tree.dt.CompressedNode;
import hex.tree.dt.DTModel;
import hex.tree.dt.DataFeaturesLimits;
import hex.tree.dt.NumericFeatureLimits;
import hex.tree.dt.NumericSplittingRule;
import hex.tree.dt.binning.BinningStrategy;
import hex.tree.dt.binning.Histogram;
import hex.tree.dt.binning.SplitStatistics;
import hex.tree.dt.mrtasks.GetClassCountsMRTask;
import hex.tree.dt.mrtasks.ScoreDTTask;
import java.util.Arrays;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.Queue;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.log4j.Logger;
import water.DKV;
import water.Key;
import water.Keyed;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.util.Log;
import water.util.MathUtils;
import water.util.RandomUtils;

public class DT
extends ModelBuilder<DTModel, DTModel.DTParameters, DTModel.DTOutput> {
    private int _min_rows;
    int _nodesCount;
    int _leavesCount;
    private AbstractCompressedNode[] _tree;
    private DTModel _model;
    transient Random _rand;
    public static final double EPSILON = 1.0E-6;
    public static final double MIN_IMPROVEMENT = 1.0E-6;
    private static final Logger LOG = Logger.getLogger(DT.class);

    public DT(DTModel.DTParameters parameters) {
        super((Model.Parameters)parameters);
        this._min_rows = parameters._min_rows;
        this._nodesCount = 0;
        this._leavesCount = 0;
        this._tree = null;
        this.init(true);
    }

    public DT(boolean startup_once) {
        super((Model.Parameters)new DTModel.DTParameters(), startup_once);
    }

    private AbstractSplittingRule findBestSplit(Histogram histogram) {
        int featuresNumber = histogram.featuresCount();
        AbstractSplittingRule currentMinCriterionSplittingRule = null;
        int bestFeatureIndex = -1;
        for (int featureIndex = 0; featureIndex < featuresNumber; ++featureIndex) {
            AbstractSplittingRule minCriterionSplittingRuleForFeature;
            if (histogram.isConstant(featureIndex) || (minCriterionSplittingRuleForFeature = this.findBestSplitForFeature(histogram, featureIndex)) == null || currentMinCriterionSplittingRule != null && !(minCriterionSplittingRuleForFeature._criterionValue < currentMinCriterionSplittingRule._criterionValue)) continue;
            currentMinCriterionSplittingRule = minCriterionSplittingRuleForFeature;
            bestFeatureIndex = featureIndex;
        }
        if (bestFeatureIndex == -1) {
            return null;
        }
        return currentMinCriterionSplittingRule;
    }

    private AbstractSplittingRule findBestSplitForFeature(Histogram histogram, int featureIndex) {
        return (this._train.vec(featureIndex).isNumeric() ? histogram.calculateSplitStatisticsForNumericFeature(featureIndex) : histogram.calculateSplitStatisticsForCategoricalFeature(featureIndex)).stream().filter(binStatistics -> binStatistics._leftCount >= this._min_rows && binStatistics._rightCount >= this._min_rows).peek(binStatistics -> Log.debug((Object[])new Object[]{"split: " + (Object)((Object)binStatistics._splittingRule) + ", counts: " + binStatistics._leftCount + " " + binStatistics._rightCount})).peek(binStatistics -> binStatistics.setCriterionValue(DT.calculateCriterionOfSplit(binStatistics)).setFeatureIndex(featureIndex)).map(binStatistics -> binStatistics._splittingRule).min(Comparator.comparing(AbstractSplittingRule::getCriterionValue)).orElse(null);
    }

    private static double calculateCriterionOfSplit(SplitStatistics binStatistics) {
        return binStatistics.binaryEntropy();
    }

    private int selectDecisionValue(int[] countsByClass) {
        if (this._nclass == 1) {
            return countsByClass[0];
        }
        int currentMaxClass = 0;
        int currentMax = countsByClass[currentMaxClass];
        for (int c = 1; c < this._nclass; ++c) {
            if (countsByClass[c] <= currentMax) continue;
            currentMaxClass = c;
            currentMax = countsByClass[c];
        }
        return currentMaxClass;
    }

    private double[] calculateProbability(int[] countsByClass) {
        int samplesCount = Arrays.stream(countsByClass).sum();
        return Arrays.stream(countsByClass).asDoubleStream().map(n -> n / (double)samplesCount).toArray();
    }

    public void makeLeafFromNode(int[] countsByClass, int nodeIndex) {
        this._tree[nodeIndex] = new CompressedLeaf(this.selectDecisionValue(countsByClass), this.calculateProbability(countsByClass)[0]);
        ++this._leavesCount;
    }

    public void buildNextNode(Queue<DataFeaturesLimits> limitsQueue, int nodeIndex) {
        DataFeaturesLimits limitsRight;
        DataFeaturesLimits limitsLeft;
        int nodeDepth;
        DataFeaturesLimits actualLimits = limitsQueue.poll();
        if (actualLimits == null) {
            limitsQueue.add(null);
            limitsQueue.add(null);
            return;
        }
        int[] countsByClass = this.countClasses(actualLimits);
        if (nodeIndex == 0) {
            Log.info((Object[])new Object[]{"Classes counts in dataset: 0 - " + countsByClass[0] + ", 1 - " + countsByClass[1]});
        }
        if ((nodeDepth = (int)Math.floor(MathUtils.log2((int)(nodeIndex + 1)))) >= ((DTModel.DTParameters)this._parms)._max_depth || countsByClass[0] <= this._min_rows || countsByClass[1] <= this._min_rows) {
            limitsQueue.add(null);
            limitsQueue.add(null);
            this.makeLeafFromNode(countsByClass, nodeIndex);
            return;
        }
        Histogram histogram = new Histogram(this._train, actualLimits, BinningStrategy.EQUAL_WIDTH);
        AbstractSplittingRule bestSplittingRule = this.findBestSplit(histogram);
        double criterionForTheParentNode = SplitStatistics.entropyBinarySplit(1.0 * (double)countsByClass[0] / (double)(countsByClass[0] + countsByClass[1]));
        if (bestSplittingRule == null || Math.abs(criterionForTheParentNode - bestSplittingRule._criterionValue) < 1.0E-6) {
            limitsQueue.add(null);
            limitsQueue.add(null);
            this.makeLeafFromNode(countsByClass, nodeIndex);
            return;
        }
        this._tree[nodeIndex] = new CompressedNode(bestSplittingRule);
        int splitFeatureIndex = bestSplittingRule.getFeatureIndex();
        if (this._train.vec(splitFeatureIndex).isNumeric()) {
            double threshold = ((NumericSplittingRule)bestSplittingRule).getThreshold();
            limitsLeft = actualLimits.updateMax(splitFeatureIndex, threshold);
            limitsRight = actualLimits.updateMin(splitFeatureIndex, threshold);
        } else {
            boolean[] mask = ((CategoricalSplittingRule)bestSplittingRule).getMask();
            limitsLeft = actualLimits.updateMask(splitFeatureIndex, mask);
            limitsRight = actualLimits.updateMaskExcluded(splitFeatureIndex, mask);
        }
        limitsQueue.add(limitsLeft);
        limitsQueue.add(limitsRight);
    }

    public static DataFeaturesLimits getInitialFeaturesLimits(Frame data) {
        return new DataFeaturesLimits(IntStream.range(0, data.numCols() - 1).mapToObj(arg_0 -> ((Frame)data).vec(arg_0)).map(v -> v.isNumeric() ? new NumericFeatureLimits(v.min() - 1.0E-6, v.max()) : new CategoricalFeatureLimits(v.cardinality())).collect(Collectors.toList()));
    }

    protected ModelBuilder.Driver trainModelImpl() {
        return new DTDriver();
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Binomial};
    }

    public boolean isSupervised() {
        return true;
    }

    protected final void makeModelMetrics() {
        ModelMetrics modelMetrics;
        ModelMetrics.MetricBuilder metricsBuilder = ((ScoreDTTask)new ScoreDTTask(this._model).doAll(this._train)).getMetricsBuilder();
        ((DTModel.DTOutput)this._model._output)._training_metrics = modelMetrics = metricsBuilder.makeModelMetrics((Model)this._model, ((DTModel.DTParameters)this._parms).train(), null, null);
        if (((DTModel.DTParameters)this._parms)._valid != null) {
            Frame v = new Frame(this.valid());
            metricsBuilder = ((ScoreDTTask)new ScoreDTTask(this._model).doAll(v)).getMetricsBuilder();
            ((DTModel.DTOutput)this._model._output)._validation_metrics = metricsBuilder.makeModelMetrics((Model)this._model, v, null, null);
        }
    }

    private int[] countClasses(DataFeaturesLimits featuresLimits) {
        GetClassCountsMRTask task = new GetClassCountsMRTask(featuresLimits == null ? DT.getInitialFeaturesLimits(this._train).toDoubles() : featuresLimits.toDoubles(), this._nclass);
        task.doAll(this._train);
        return task._countsByClass;
    }

    static /* synthetic */ AbstractCompressedNode[] access$602(DT x0, AbstractCompressedNode[] x1) {
        x0._tree = x1;
        return x1;
    }

    private class DTDriver
    extends ModelBuilder.Driver {
        private DTDriver() {
            super((ModelBuilder)DT.this);
        }

        private void dtChecks() {
            if (((DTModel.DTParameters)DT.this._parms)._max_depth < 1) {
                DT.this.error("_parms._max_depth", "Max depth has to be at least 1");
            }
            if (DT.this._train.hasNAs()) {
                DT.this.error("_train", "NaNs are not supported yet");
            }
            if (DT.this._train.hasInfs()) {
                DT.this.error("_train", "Infs are not supported");
            }
            if (!DT.this._response.isCategorical()) {
                DT.this.error("_response", "Only categorical response is supported");
            }
            if (!DT.this._response.isBinary()) {
                DT.this.error("_response", "Only binary response is supported");
            }
        }

        public void computeImpl() {
            DT.this._model = null;
            try {
                DT.this.init(true);
                this.dtChecks();
                if (DT.this.error_count() > 0) {
                    throw H2OModelBuilderIllegalArgumentException.makeFromBuilder((ModelBuilder)DT.this);
                }
                DT.this._rand = RandomUtils.getRNG((long[])new long[]{((DTModel.DTParameters)DT.this._parms)._seed});
                DT.this._model = new DTModel((Key<DTModel>)DT.this.dest(), (DTModel.DTParameters)DT.this._parms, new DTModel.DTOutput(DT.this));
                DT.this._model.delete_and_lock(DT.this._job);
                this.buildDT();
                LOG.info((Object)DT.this._model.toString());
            }
            finally {
                if (DT.this._model != null) {
                    DT.this._model.unlock(DT.this._job);
                }
            }
        }

        private void buildDT() {
            this.buildDTIteratively();
            Log.debug((Object[])new Object[]{"depth: " + ((DTModel.DTParameters)DT.this._parms)._max_depth + ", nodes count: " + DT.this._nodesCount});
            CompressedDT compressedDT = new CompressedDT(DT.this._tree, DT.this._leavesCount);
            ((DTModel.DTOutput)((DT)DT.this)._model._output)._treeKey = compressedDT._key;
            DKV.put((Keyed)compressedDT);
            DT.this._job.update(1L);
            DT.this._model.update(DT.this._job);
        }

        private void buildDTIteratively() {
            int treeLength = (int)Math.pow(2.0, ((DTModel.DTParameters)DT.this._parms)._max_depth + 1) - 1;
            DT.access$602(DT.this, new AbstractCompressedNode[treeLength]);
            LinkedList<DataFeaturesLimits> limitsQueue = new LinkedList<DataFeaturesLimits>();
            limitsQueue.add(DT.getInitialFeaturesLimits(DT.this._train));
            for (int nodeIndex = 0; nodeIndex < treeLength; ++nodeIndex) {
                DT.this.buildNextNode(limitsQueue, nodeIndex);
            }
        }
    }
}

