/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.tree;

import org.apache.spark.Logging;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.GradientBoostedTrees;
import org.apache.spark.mllib.tree.configuration.Algo$;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impl.TimeTracker;
import org.apache.spark.mllib.tree.impurity.Variance$;
import org.apache.spark.mllib.tree.loss.Loss;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import org.slf4j.Logger;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.mutable.StringBuilder;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;

public final class GradientBoostedTrees$
implements Logging,
Serializable {
    public static final GradientBoostedTrees$ MODULE$;
    private transient Logger org$apache$spark$Logging$$log_;

    static {
        new GradientBoostedTrees$();
    }

    public Logger org$apache$spark$Logging$$log_() {
        return this.org$apache$spark$Logging$$log_;
    }

    public void org$apache$spark$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$Logging$$log_ = x$1;
    }

    public String logName() {
        return Logging.class.logName((Logging)this);
    }

    public Logger log() {
        return Logging.class.log((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.class.logInfo((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.class.logDebug((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.class.logTrace((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.class.logWarning((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.class.logError((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.class.logInfo((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.class.logDebug((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.class.logTrace((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.class.logWarning((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.class.logError((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.class.isTraceEnabled((Logging)this);
    }

    public GradientBoostedTreesModel train(RDD<LabeledPoint> input, BoostingStrategy boostingStrategy) {
        return new GradientBoostedTrees(boostingStrategy).run(input);
    }

    public GradientBoostedTreesModel train(JavaRDD<LabeledPoint> input, BoostingStrategy boostingStrategy) {
        return this.train((RDD<LabeledPoint>)input.rdd(), boostingStrategy);
    }

    public GradientBoostedTreesModel org$apache$spark$mllib$tree$GradientBoostedTrees$$boost(RDD<LabeledPoint> input, RDD<LabeledPoint> validationInput, BoostingStrategy boostingStrategy, boolean validate2) {
        boolean bl;
        TimeTracker timer = new TimeTracker();
        timer.start("total");
        timer.start("init");
        boostingStrategy.assertValid();
        int numIterations = boostingStrategy.numIterations();
        DecisionTreeModel[] baseLearners = new DecisionTreeModel[numIterations];
        double[] baseLearnerWeights = new double[numIterations];
        Loss loss2 = boostingStrategy.loss();
        double learningRate = boostingStrategy.learningRate();
        Strategy treeStrategy = boostingStrategy.treeStrategy().copy();
        double validationTol = boostingStrategy.validationTol();
        treeStrategy.algo_$eq(Algo$.MODULE$.Regression());
        treeStrategy.impurity_$eq(Variance$.MODULE$);
        treeStrategy.assertValid();
        StorageLevel storageLevel = input.getStorageLevel();
        StorageLevel storageLevel2 = StorageLevel$.MODULE$.NONE();
        if (!(storageLevel != null ? !storageLevel.equals(storageLevel2) : storageLevel2 != null)) {
            input.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK());
            bl = true;
        } else {
            bl = false;
        }
        boolean persistedInput = bl;
        PeriodicRDDCheckpointer<RDD> predErrorCheckpointer = new PeriodicRDDCheckpointer<RDD>(treeStrategy.getCheckpointInterval(), input.sparkContext());
        PeriodicRDDCheckpointer<RDD<Tuple2<Object, Object>>> validatePredErrorCheckpointer = new PeriodicRDDCheckpointer<RDD<Tuple2<Object, Object>>>(treeStrategy.getCheckpointInterval(), input.sparkContext());
        timer.stop("init");
        this.logDebug((Function0<String>)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final String apply() {
                return "##########";
            }
        });
        this.logDebug((Function0<String>)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final String apply() {
                return "Building tree 0";
            }
        });
        this.logDebug((Function0<String>)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final String apply() {
                return "##########";
            }
        });
        timer.start("building tree 0");
        DecisionTreeModel firstTreeModel = new DecisionTree(treeStrategy).run(input);
        double firstTreeWeight = 1.0;
        baseLearners[0] = firstTreeModel;
        baseLearnerWeights[0] = firstTreeWeight;
        ObjectRef predError = ObjectRef.create(GradientBoostedTreesModel$.MODULE$.computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss2));
        predErrorCheckpointer.update((RDD)predError.elem);
        this.logDebug((Function0<String>)new Serializable(predError){
            public static final long serialVersionUID = 0L;
            private final ObjectRef predError$1;

            public final String apply() {
                return new StringBuilder().append((Object)"error of gbt = ").append((Object)BoxesRunTime.boxToDouble((double)RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions((RDD)this.predError$1.elem, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), (Ordering)Ordering.Double$.MODULE$).values()).mean())).toString();
            }
            {
                this.predError$1 = predError$1;
            }
        });
        timer.stop("building tree 0");
        RDD<Tuple2<Object, Object>> validatePredError = GradientBoostedTreesModel$.MODULE$.computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss2);
        if (validate2) {
            validatePredErrorCheckpointer.update(validatePredError);
        }
        double bestValidateError = validate2 ? RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(validatePredError, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), (Ordering)Ordering.Double$.MODULE$).values()).mean() : 0.0;
        int bestM = 1;
        IntRef m = IntRef.create((int)1);
        boolean doneLearning = false;
        while (m.elem < numIterations && !doneLearning) {
            RDD data = ((RDD)predError.elem).zip(input, ClassTag$.MODULE$.apply(LabeledPoint.class)).map((Function1)new Serializable(loss2){
                public static final long serialVersionUID = 0L;
                private final Loss loss$1;

                public final LabeledPoint apply(Tuple2<Tuple2<Object, Object>, LabeledPoint> x0$1) {
                    Tuple2<Tuple2<Object, Object>, LabeledPoint> tuple2 = x0$1;
                    if (tuple2 != null) {
                        Tuple2 tuple22 = (Tuple2)tuple2._1();
                        LabeledPoint point = (LabeledPoint)tuple2._2();
                        if (tuple22 != null) {
                            double pred = tuple22._1$mcD$sp();
                            LabeledPoint labeledPoint = new LabeledPoint(-this.loss$1.gradient(pred, point.label()), point.features());
                            return labeledPoint;
                        }
                    }
                    throw new MatchError(tuple2);
                }
                {
                    this.loss$1 = loss$1;
                }
            }, ClassTag$.MODULE$.apply(LabeledPoint.class));
            timer.start(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"building tree ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)m.elem)})));
            this.logDebug((Function0<String>)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return "###################################################";
                }
            });
            this.logDebug((Function0<String>)new Serializable(m){
                public static final long serialVersionUID = 0L;
                private final IntRef m$1;

                public final String apply() {
                    return new StringBuilder().append((Object)"Gradient boosting tree iteration ").append((Object)BoxesRunTime.boxToInteger((int)this.m$1.elem)).toString();
                }
                {
                    this.m$1 = m$1;
                }
            });
            this.logDebug((Function0<String>)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return "###################################################";
                }
            });
            DecisionTreeModel model = new DecisionTree(treeStrategy).run((RDD<LabeledPoint>)data);
            timer.stop(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"building tree ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)m.elem)})));
            baseLearners[m.elem] = model;
            baseLearnerWeights[m.elem] = learningRate;
            predError.elem = GradientBoostedTreesModel$.MODULE$.updatePredictionError(input, (RDD<Tuple2<Object, Object>>)((RDD)predError.elem), baseLearnerWeights[m.elem], baseLearners[m.elem], loss2);
            predErrorCheckpointer.update((RDD)predError.elem);
            this.logDebug((Function0<String>)new Serializable(predError){
                public static final long serialVersionUID = 0L;
                private final ObjectRef predError$1;

                public final String apply() {
                    return new StringBuilder().append((Object)"error of gbt = ").append((Object)BoxesRunTime.boxToDouble((double)RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions((RDD)this.predError$1.elem, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), (Ordering)Ordering.Double$.MODULE$).values()).mean())).toString();
                }
                {
                    this.predError$1 = predError$1;
                }
            });
            if (validate2) {
                validatePredError = GradientBoostedTreesModel$.MODULE$.updatePredictionError(validationInput, validatePredError, baseLearnerWeights[m.elem], baseLearners[m.elem], loss2);
                validatePredErrorCheckpointer.update(validatePredError);
                double currentValidateError = RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(RDD$.MODULE$.rddToPairRDDFunctions(validatePredError, ClassTag$.MODULE$.Double(), ClassTag$.MODULE$.Double(), (Ordering)Ordering.Double$.MODULE$).values()).mean();
                if (bestValidateError - currentValidateError < validationTol * Math.max(currentValidateError, 0.01)) {
                    doneLearning = true;
                } else if (currentValidateError < bestValidateError) {
                    bestValidateError = currentValidateError;
                    bestM = m.elem + 1;
                }
            }
            ++m.elem;
        }
        timer.stop("total");
        this.logInfo((Function0<String>)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final String apply() {
                return "Internal timing for DecisionTree:";
            }
        });
        this.logInfo((Function0<String>)new Serializable(timer){
            public static final long serialVersionUID = 0L;
            private final TimeTracker timer$1;

            public final String apply() {
                return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{this.timer$1}));
            }
            {
                this.timer$1 = timer$1;
            }
        });
        predErrorCheckpointer.deleteAllCheckpoints();
        validatePredErrorCheckpointer.deleteAllCheckpoints();
        Object object = persistedInput ? input.unpersist(input.unpersist$default$1()) : BoxedUnit.UNIT;
        return validate2 ? new GradientBoostedTreesModel(boostingStrategy.treeStrategy().algo(), (DecisionTreeModel[])Predef$.MODULE$.refArrayOps((Object[])baseLearners).slice(0, bestM), (double[])Predef$.MODULE$.doubleArrayOps(baseLearnerWeights).slice(0, bestM)) : new GradientBoostedTreesModel(boostingStrategy.treeStrategy().algo(), baseLearners, baseLearnerWeights);
    }

    private Object readResolve() {
        return MODULE$;
    }

    private GradientBoostedTrees$() {
        MODULE$ = this;
        Logging.class.$init$((Logging)this);
    }
}

