/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees;

import adams.core.TechnicalInformation;
import adams.core.TechnicalInformationHandler;
import adams.core.base.BaseKeyValuePair;
import adams.core.management.LDD;
import adams.core.management.OS;
import adams.core.option.AbstractOption;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.Map;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
import weka.classifiers.simple.AbstractSimpleClassifier;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;

public class XGBoost
extends AbstractSimpleClassifier
implements TechnicalInformationHandler {
    private static final long serialVersionUID = 7228620850250174821L;
    public static final String[] MIN_GLIBC_VERSION = new String[]{"2", "23"};
    @XGBoostParameter(value="booster")
    protected BoosterType m_BoosterType;
    @XGBoostParameter(value="verbosity")
    protected Verbosity m_Verbosity;
    @XGBoostParameter(value="nthread")
    protected int m_NumberOfThreads;
    @XGBoostParameter(value="eta")
    protected float m_Eta;
    @XGBoostParameter(value="gamma")
    protected float m_Gamma;
    @XGBoostParameter(value="max_depth")
    protected int m_MaxDepth;
    @XGBoostParameter(value="min_child_weight")
    protected float m_MinChildWeight;
    @XGBoostParameter(value="max_delta_step")
    protected float m_MaxDeltaStep;
    @XGBoostParameter(value="subsample")
    protected float m_Subsample;
    @XGBoostParameter(value="colsample_bytree")
    protected float m_ColumnSampleByTree;
    @XGBoostParameter(value="colsample_bylevel")
    protected float m_ColumnSampleByLevel;
    @XGBoostParameter(value="colsample_bynode")
    protected float m_ColumnSampleByNode;
    @XGBoostParameter(value="tree_method")
    protected TreeMethod m_TreeMethod;
    @XGBoostParameter(value="scale_pos_weight")
    protected float m_ScalePositiveWeights;
    @XGBoostParameter(value="process_type")
    protected ProcessType m_ProcessType;
    @XGBoostParameter(value="grow_policy")
    protected GrowPolicy m_GrowPolicy;
    @XGBoostParameter(value="max_leaves")
    protected int m_MaxLeaves;
    @XGBoostParameter(value="max_bin")
    protected int m_MaxBin;
    @XGBoostParameter(value="predictor")
    protected Predictor m_Predictor;
    @XGBoostParameter(value="num_parallel_tree")
    protected int m_NumberOfParallelTrees;
    @XGBoostParameter(value="sample_type")
    protected SampleType m_SampleType;
    @XGBoostParameter(value="normalize_type")
    protected NormaliseType m_NormaliseType;
    @XGBoostParameter(value="rate_drop")
    protected float m_RateDrop;
    @XGBoostParameter(value="one_drop")
    protected boolean m_OneDrop;
    @XGBoostParameter(value="skip_drop")
    protected float m_SkipDrop;
    @XGBoostParameter(value="lambda")
    protected float m_Lambda;
    @XGBoostParameter(value="alpha")
    protected float m_Alpha;
    @XGBoostParameter(value="updater")
    protected Updater m_Updater;
    @XGBoostParameter(value="feature_selector")
    protected FeatureSelector m_FeatureSelector;
    @XGBoostParameter(value="top_k")
    protected int m_TopK;
    @XGBoostParameter(value="tweedie_variance_power")
    protected float m_TweedieVariancePower;
    @XGBoostParameter(value="objective")
    protected Objective m_Objective;
    @XGBoostParameter(value="base_score")
    protected float m_BaseScore;
    @XGBoostParameter(value="seed")
    protected int m_Seed;
    protected int m_NumberOfRounds;
    protected BaseKeyValuePair[] m_OtherParameters;
    protected Booster m_Booster;
    protected Instances m_Header;
    protected Map<String, Object> m_Params;

    public String globalInfo() {
        return "Classifier implementing XGBoost.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("booster", "booster", (Object)BoosterType.GBTREE);
        this.m_OptionManager.add("verbosity", "verbosity", (Object)Verbosity.WARNING);
        this.m_OptionManager.add("nthread", "numThreads", (Object)-1);
        this.m_OptionManager.add("eta", "eta", (Object)Float.valueOf(0.3f), (Number)Float.valueOf(0.0f), (Number)Float.valueOf(1.0f));
        this.m_OptionManager.add("gamma", "gamma", (Object)Float.valueOf(0.0f), (Number)Float.valueOf(0.0f), (Number)Float.valueOf(Float.POSITIVE_INFINITY));
        this.m_OptionManager.add("max_depth", "maxDepth", (Object)6, (Number)0, (Number)Integer.MAX_VALUE);
        this.m_OptionManager.add("min_child_weight", "minChildWeight", (Object)Float.valueOf(1.0f), (Number)Float.valueOf(0.0f), (Number)Float.valueOf(Float.POSITIVE_INFINITY));
        this.m_OptionManager.add("max_delta_step", "maximumDeltaStep", (Object)Float.valueOf(0.0f), (Number)Float.valueOf(0.0f), (Number)Float.valueOf(Float.MAX_VALUE));
        this.m_OptionManager.add("subsample", "subsampleRatio", (Object)Float.valueOf(1.0f), (Number)Float.valueOf(Float.MIN_VALUE), (Number)Float.valueOf(1.0f));
        this.m_OptionManager.add("colsample_bytree", "columnSampleByTree", (Object)Float.valueOf(1.0f), (Number)Float.valueOf(Float.MIN_VALUE), (Number)Float.valueOf(1.0f));
        this.m_OptionManager.add("colsample_bylevel", "columnSampleByLevel", (Object)Float.valueOf(1.0f), (Number)Float.valueOf(Float.MIN_VALUE), (Number)Float.valueOf(1.0f));
        this.m_OptionManager.add("colsample_bynode", "columnSampleByNode", (Object)Float.valueOf(1.0f), (Number)Float.valueOf(Float.MIN_VALUE), (Number)Float.valueOf(1.0f));
        this.m_OptionManager.add("tree_method", "treeMethod", (Object)TreeMethod.AUTO);
        this.m_OptionManager.add("scale_pos_weight", "scalePositiveWeights", (Object)Float.valueOf(1.0f), (Number)Float.valueOf(0.0f), (Number)Float.valueOf(Float.MAX_VALUE));
        this.m_OptionManager.add("process_type", "processType", (Object)ProcessType.DEFAULT);
        this.m_OptionManager.add("grow_policy", "growPolicy", (Object)GrowPolicy.DEPTHWISE);
        this.m_OptionManager.add("max_leaves", "maxLeaves", (Object)0, (Number)0, (Number)Integer.MAX_VALUE);
        this.m_OptionManager.add("max_bin", "maxBin", (Object)256, (Number)2, (Number)Integer.MAX_VALUE);
        this.m_OptionManager.add("predictor", "predictor", (Object)Predictor.DEFAULT);
        this.m_OptionManager.add("num_parallel_tree", "numberOfParallelTrees", (Object)1, (Number)1, (Number)Integer.MAX_VALUE);
        this.m_OptionManager.add("sample_type", "sampleType", (Object)SampleType.UNIFORM);
        this.m_OptionManager.add("normalize_type", "normaliseType", (Object)NormaliseType.TREE);
        this.m_OptionManager.add("rate_drop", "rateDrop", (Object)Float.valueOf(0.0f), (Number)Float.valueOf(0.0f), (Number)Float.valueOf(1.0f));
        this.m_OptionManager.add("one_drop", "oneDrop", (Object)false);
        this.m_OptionManager.add("skip_drop", "skipDrop", (Object)Float.valueOf(0.0f), (Number)Float.valueOf(0.0f), (Number)Float.valueOf(1.0f));
        this.m_OptionManager.add("lambda", "lambda", (Object)Float.valueOf(1.0f));
        this.m_OptionManager.add("alpha", "alpha", (Object)Float.valueOf(0.0f));
        this.m_OptionManager.add("updater", "updater", (Object)Updater.SHOTGUN);
        this.m_OptionManager.add("feature_selector", "featureSelector", (Object)FeatureSelector.CYCLIC);
        this.m_OptionManager.add("top_k", "topK", (Object)0, (Number)0, (Number)Integer.MAX_VALUE);
        this.m_OptionManager.add("tweedie_variance_power", "tweedieVariancePower", (Object)Float.valueOf(1.5f), (Number)Float.valueOf(1.0f), (Number)Float.valueOf(2.0f));
        this.m_OptionManager.add("objective", "objective", (Object)Objective.LINEAR_REGRESSION);
        this.m_OptionManager.add("base_score", "baseScore", (Object)Float.valueOf(0.5f));
        this.m_OptionManager.add("seed", "seed", (Object)0);
        this.m_OptionManager.add("rounds", "numberOfRounds", (Object)2, (Number)1, (Number)Integer.MAX_VALUE);
        this.m_OptionManager.add("other_params", "otherParameters", (Object)new BaseKeyValuePair[0]);
    }

    public BoosterType getBooster() {
        return this.m_BoosterType;
    }

    public void setBooster(BoosterType value) {
        this.m_BoosterType = value;
        this.reset();
    }

    public String boosterTipText() {
        return "Which booster to use.";
    }

    public Verbosity getVerbosity() {
        return this.m_Verbosity;
    }

    public void setVerbosity(Verbosity value) {
        this.m_Verbosity = value;
        this.reset();
    }

    public String verbosityTipText() {
        return "Verbosity of printing messages.";
    }

    public int getNumThreads() {
        return this.m_NumberOfThreads;
    }

    public void setNumThreads(int value) {
        this.m_NumberOfThreads = value;
        this.reset();
    }

    public String numThreadsTipText() {
        return "The number of parallel threads used to run XGBoost.";
    }

    public float getEta() {
        return this.m_Eta;
    }

    public void setEta(float value) {
        this.m_Eta = value;
        this.reset();
    }

    public String etaTipText() {
        return "The step size shrinkage to use in updates to prevent overfitting.";
    }

    public float getGamma() {
        return this.m_Gamma;
    }

    public void setGamma(float value) {
        this.m_Gamma = value;
        this.reset();
    }

    public String gammaTipText() {
        return "The minimum loss reduction required to make a further partition on a leaf node of the tree.";
    }

    public int getMaxDepth() {
        return this.m_MaxDepth;
    }

    public void setMaxDepth(int value) {
        this.m_MaxDepth = value;
        this.reset();
    }

    public String maxDepthTipText() {
        return "The maximum depth of a tree.";
    }

    public float getMinChildWeight() {
        return this.m_MinChildWeight;
    }

    public void setMinChildWeight(float value) {
        this.m_MinChildWeight = value;
        this.reset();
    }

    public String minChildWeightTipText() {
        return "The minimum sum of instance weights (hessian) needed in a child.";
    }

    public float getMaximumDeltaStep() {
        return this.m_MaxDeltaStep;
    }

    public void setMaximumDeltaStep(float value) {
        this.m_MaxDeltaStep = value;
        this.reset();
    }

    public String maximumDeltaStepTipText() {
        return "The maximum delta step we allow each leaf output to be.";
    }

    public float getSubsampleRatio() {
        return this.m_Subsample;
    }

    public void setSubsampleRatio(float value) {
        this.m_Subsample = value;
        this.reset();
    }

    public String subsampleRatioTipText() {
        return "The sub-sample ratio of the training instances.";
    }

    public float getColumnSampleByTree() {
        return this.m_ColumnSampleByTree;
    }

    public void setColumnSampleByTree(float value) {
        this.m_ColumnSampleByTree = value;
        this.reset();
    }

    public String columnSampleByTreeTipText() {
        return "The sub-sample ratio of columns when constructing each tree.";
    }

    public float getColumnSampleByLevel() {
        return this.m_ColumnSampleByLevel;
    }

    public void setColumnSampleByLevel(float value) {
        this.m_ColumnSampleByLevel = value;
        this.reset();
    }

    public String columnSampleByLevelTipText() {
        return "The sub-sample ratio of columns for each level.";
    }

    public float getColumnSampleByNode() {
        return this.m_ColumnSampleByNode;
    }

    public void setColumnSampleByNode(float value) {
        this.m_ColumnSampleByNode = value;
        this.reset();
    }

    public String columnSampleByNodeTipText() {
        return "The sub-sample ratio of columns for each node (split).";
    }

    public TreeMethod getTreeMethod() {
        return this.m_TreeMethod;
    }

    public void setTreeMethod(TreeMethod value) {
        this.m_TreeMethod = value;
        this.reset();
    }

    public String treeMethodTipText() {
        return "The tree construction algorithm used in XGBoost.";
    }

    public float getScalePositiveWeights() {
        return this.m_ScalePositiveWeights;
    }

    public void setScalePositiveWeights(float value) {
        this.m_ScalePositiveWeights = value;
        this.reset();
    }

    public String scalePositiveWeightsTipText() {
        return "Scales the weights of positive examples by this factor.";
    }

    public ProcessType getProcessType() {
        return this.m_ProcessType;
    }

    public void setProcessType(ProcessType value) {
        this.m_ProcessType = value;
        this.reset();
    }

    public String processTypeTipText() {
        return "The type of boosting process to run.";
    }

    public GrowPolicy getGrowPolicy() {
        return this.m_GrowPolicy;
    }

    public void setGrowPolicy(GrowPolicy value) {
        this.m_GrowPolicy = value;
        this.reset();
    }

    public String growPolicyTipText() {
        return "The way new nodes are added to the tree.";
    }

    public int getMaxLeaves() {
        return this.m_MaxLeaves;
    }

    public void setMaxLeaves(int value) {
        this.m_MaxLeaves = value;
        this.reset();
    }

    public String maxLeavesTipText() {
        return "The maximum number of nodes to be added.";
    }

    public int getMaxBin() {
        return this.m_MaxBin;
    }

    public void setMaxBin(int value) {
        this.m_MaxBin = value;
        this.reset();
    }

    public String maxBinTipText() {
        return "The maximum number of discrete bins to bucket continuous features.";
    }

    public Predictor getPredictor() {
        return this.m_Predictor;
    }

    public void setPredictor(Predictor value) {
        this.m_Predictor = value;
        this.reset();
    }

    public String predictorTipText() {
        return "The type of predictor algorithm to use.";
    }

    public int getNumberOfParallelTrees() {
        return this.m_NumberOfParallelTrees;
    }

    public void setNumberOfParallelTrees(int value) {
        this.m_NumberOfParallelTrees = value;
        this.reset();
    }

    public String numberOfParallelTreesTipText() {
        return "The number of parallel trees constructed during each iteration.";
    }

    public SampleType getSampleType() {
        return this.m_SampleType;
    }

    public void setSampleType(SampleType value) {
        this.m_SampleType = value;
        this.reset();
    }

    public String sampleTypeTipText() {
        return "The type of sampling algorithm.";
    }

    public NormaliseType getNormaliseType() {
        return this.m_NormaliseType;
    }

    public void setNormaliseType(NormaliseType value) {
        this.m_NormaliseType = value;
        this.reset();
    }

    public String normaliseTypeTipText() {
        return "The type of normalisation algorithm.";
    }

    public float getRateDrop() {
        return this.m_RateDrop;
    }

    public void setRateDrop(float value) {
        this.m_RateDrop = value;
        this.reset();
    }

    public String rateDropTipText() {
        return "The dropout rate (a fraction of previous trees to drop during the dropout).";
    }

    public boolean getOneDrop() {
        return this.m_OneDrop;
    }

    public void setOneDrop(boolean value) {
        this.m_OneDrop = value;
        this.reset();
    }

    public String oneDropTipText() {
        return "Whether at least one tree is always dropped during the dropout.";
    }

    public float getSkipDrop() {
        return this.m_SkipDrop;
    }

    public void setSkipDrop(float value) {
        this.m_SkipDrop = value;
        this.reset();
    }

    public String skipDropTipText() {
        return "The probability of skipping the dropout procedure during a boosting iteration.";
    }

    public float getLambda() {
        return this.m_Lambda;
    }

    public void setLambda(float value) {
        this.m_Lambda = value;
        this.reset();
    }

    public String lambdaTipText() {
        return "The L2 regularisation term on weights.";
    }

    public float getAlpha() {
        return this.m_Alpha;
    }

    public void setAlpha(float value) {
        this.m_Alpha = value;
        this.reset();
    }

    public String alphaTipText() {
        return "The L1 regularisation term on weights.";
    }

    public Updater getUpdater() {
        return this.m_Updater;
    }

    public void setUpdater(Updater value) {
        this.m_Updater = value;
        this.reset();
    }

    public String updaterTipText() {
        return "The choice of algorithm to fit the linear model.";
    }

    public FeatureSelector getFeatureSelector() {
        return this.m_FeatureSelector;
    }

    public void setFeatureSelector(FeatureSelector value) {
        this.m_FeatureSelector = value;
        this.reset();
    }

    public String featureSelectorTipText() {
        return "The feature selection and ordering method.";
    }

    public int getTopK() {
        return this.m_TopK;
    }

    public void setTopK(int value) {
        this.m_TopK = value;
        this.reset();
    }

    public String topKTipText() {
        return "The number of top features to select when using the greedy or thrifty feature selector.";
    }

    public float getTweedieVariancePower() {
        return this.m_TweedieVariancePower;
    }

    public void setTweedieVariancePower(float value) {
        this.m_TweedieVariancePower = value;
        this.reset();
    }

    public String tweedieVariancePowerTipText() {
        return "The parameter that controls the variance of the Tweedie distribution.";
    }

    public Objective getObjective() {
        return this.m_Objective;
    }

    public void setObjective(Objective value) {
        this.m_Objective = value;
        this.reset();
    }

    public String objectiveTipText() {
        return "The learning objective.";
    }

    public float getBaseScore() {
        return this.m_BaseScore;
    }

    public void setBaseScore(float value) {
        this.m_BaseScore = value;
        this.reset();
    }

    public String baseScoreTipText() {
        return "The initial prediction score of all instances (global bias).";
    }

    public int getSeed() {
        return this.m_Seed;
    }

    public void setSeed(int value) {
        this.m_Seed = value;
        this.reset();
    }

    public String seedTipText() {
        return "The random number seed.";
    }

    public int getNumberOfRounds() {
        return this.m_NumberOfRounds;
    }

    public void setNumberOfRounds(int value) {
        this.m_NumberOfRounds = value;
        this.reset();
    }

    public String numberOfRoundsTipText() {
        return "The number of boosting rounds to perform.";
    }

    public BaseKeyValuePair[] getOtherParameters() {
        return this.m_OtherParameters;
    }

    public void setOtherParameters(BaseKeyValuePair[] value) {
        this.m_OtherParameters = value;
        this.reset();
    }

    public String otherParametersTipText() {
        return "Passes any additional parameters to XGBoost.";
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Chen, Tianqi and Guestrin, Carlos");
        result.setValue(TechnicalInformation.Field.TITLE, "XGBoost: A Scalable Tree Boosting System");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining");
        result.setValue(TechnicalInformation.Field.SERIES, "KDD '16");
        result.setValue(TechnicalInformation.Field.YEAR, "2016");
        result.setValue(TechnicalInformation.Field.ISBN, "978-1-4503-4232-2");
        result.setValue(TechnicalInformation.Field.LOCATION, "San Francisco, California, USA");
        result.setValue(TechnicalInformation.Field.PAGES, "785--794");
        result.setValue(TechnicalInformation.Field.URL, "http://doi.acm.org/10.1145/2939672.2939785");
        result.setValue(TechnicalInformation.Field.PUBLISHER, "ACM");
        result.setValue(TechnicalInformation.Field.ADDRESS, "New York, NY, USA");
        result.setValue(TechnicalInformation.Field.KEYWORDS, "large-scale machine learning");
        return result;
    }

    protected int numberOfRequiredDMatrixColumns(Instances instances) {
        int nColumns = 0;
        int classIndex = instances.classIndex();
        for (int i = 0; i < instances.numAttributes(); ++i) {
            Attribute attribute = instances.attribute(i);
            if (classIndex == i) continue;
            if (attribute.isNumeric() || attribute.isDate()) {
                ++nColumns;
                continue;
            }
            if (!attribute.isNominal()) continue;
            nColumns += attribute.numValues();
        }
        return nColumns;
    }

    protected DMatrix instancesToDMatrix(Instance[] instances) throws XGBoostError {
        int nRows = instances.length;
        int nColumns = this.numberOfRequiredDMatrixColumns(this.m_Header);
        if (nRows == 0 || nColumns == 0) {
            return null;
        }
        float[] data = new float[nRows * nColumns];
        float[] labels = new float[nRows];
        float[] weights = new float[nRows];
        int classIndex = this.m_Header.classIndex();
        int insertionIndex = 0;
        for (int rowIndex = 0; rowIndex < nRows; ++rowIndex) {
            Instance instance = instances[rowIndex];
            double[] instanceData = instance.toDoubleArray();
            weights[rowIndex] = (float)instance.weight();
            labels[rowIndex] = instance.classIsMissing() ? 0.0f : (float)instanceData[classIndex];
            for (int i = 0; i < instanceData.length; ++i) {
                Attribute attribute = this.m_Header.attribute(i);
                if (i == classIndex) continue;
                if (attribute.isDate() || attribute.isNumeric()) {
                    data[insertionIndex] = (float)instanceData[i];
                    ++insertionIndex;
                    continue;
                }
                if (!attribute.isNominal()) continue;
                data[insertionIndex + (int)instanceData[i]] = 1.0f;
                insertionIndex += attribute.numValues();
            }
        }
        DMatrix dMatrix = new DMatrix(data, nRows, nColumns, (float)Utils.missingValue());
        dMatrix.setLabel(labels);
        dMatrix.setWeight(weights);
        return dMatrix;
    }

    protected Map<String, Object> createParamsFromOptions() {
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.putAll(BaseKeyValuePair.toMap((BaseKeyValuePair[])this.getOtherParameters()));
        for (Field field : ((Object)((Object)this)).getClass().getDeclaredFields()) {
            AbstractOption paramOption;
            XGBoostParameter param = field.getAnnotation(XGBoostParameter.class);
            if (param == null || (paramOption = this.m_OptionManager.findByFlag(param.value())) == null || this.m_OptionManager.isDefaultValueByFlag(param.value())) continue;
            Object optionValue = paramOption.getCurrentValue();
            if (optionValue instanceof ParamValueProvider) {
                optionValue = ((ParamValueProvider)optionValue).paramValue();
            } else if (optionValue instanceof Enum) {
                optionValue = ((Enum)optionValue).name().toLowerCase();
            }
            params.put(param.value(), optionValue);
        }
        return params;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_CLASS);
        capabilities.setMinimumNumberInstances(1);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        this.getCapabilities().test(instances);
        if (OS.isLinux() && LDD.compareTo((String[])MIN_GLIBC_VERSION) < 0) {
            throw new Exception("XGBoost requires a minimum glibc version of " + adams.core.Utils.flatten((Object[])MIN_GLIBC_VERSION, (String)".") + " but found only " + adams.core.Utils.flatten((Object[])LDD.version(), (String)".") + "!");
        }
        this.m_Header = new Instances(instances, 0);
        DMatrix train = this.instancesToDMatrix((Instance[])instances.toArray((Object[])new Instance[0]));
        if (train == null) {
            this.m_Booster = null;
            return;
        }
        this.m_Params = this.createParamsFromOptions();
        if (this.isLoggingEnabled()) {
            this.getLogger().info("XGBoost parameters: " + this.m_Params);
        }
        HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
        if (this.getVerbosity() != Verbosity.SILENT) {
            watches.put("train", train);
        }
        this.m_Booster = ml.dmlc.xgboost4j.java.XGBoost.train((DMatrix)train, this.m_Params, (int)this.m_NumberOfRounds, watches, null, null);
    }

    @Override
    public double classifyInstance(Instance instance) throws Exception {
        if (this.m_Booster == null || instance == null) {
            return Utils.missingValue();
        }
        DMatrix testData = this.instancesToDMatrix(new Instance[]{instance});
        if (testData == null) {
            return Utils.missingValue();
        }
        float[][] predictions = this.m_Booster.predict(testData);
        if (instance.classAttribute().isNumeric()) {
            return predictions[0][0];
        }
        return Math.round(predictions[0][0]);
    }

    public String toString() {
        StringBuilder result = new StringBuilder("XGBoost\n=======\n\n");
        result.append("Parameters: ");
        if (this.m_Params != null) {
            result.append(this.m_Params.toString());
        } else {
            result.append("No model built yet");
        }
        return result.toString();
    }

    public static void main(String[] args) {
        XGBoost.runClassifier(new XGBoost(), args);
    }

    @Retention(value=RetentionPolicy.RUNTIME)
    @Target(value={ElementType.FIELD})
    protected static @interface XGBoostParameter {
        public String value();
    }

    protected static interface ParamValueProvider {
        public Object paramValue();
    }

    public static enum FeatureSelector {
        CYCLIC,
        SHUFFLE,
        RANDOM,
        GREEDY,
        THRIFTY;

    }

    public static enum Updater {
        SHOTGUN,
        COORD_DESCENT;

    }

    public static enum NormaliseType {
        TREE,
        FOREST;

    }

    public static enum SampleType {
        UNIFORM,
        WEIGHTED;

    }

    public static enum Predictor implements ParamValueProvider
    {
        CPU,
        GPU,
        DEFAULT;


        @Override
        public String paramValue() {
            return this.name().toLowerCase() + "_predictor";
        }
    }

    public static enum GrowPolicy {
        DEPTHWISE,
        LOSSGUIDE;

    }

    public static enum ProcessType {
        DEFAULT,
        UPDATE;

    }

    public static enum TreeMethod {
        AUTO,
        EXACT,
        APPROX,
        HIST,
        GPU_EXACT,
        GPU_HIST;

    }

    public static enum Objective implements ParamValueProvider
    {
        LINEAR_REGRESSION("reg:linear"),
        LOGISTIC_REGRESSION("reg:logistic"),
        LOGISTIC_REGRESSION_FOR_BINARY_CLASSIFICATION("binary:logistic"),
        LOGIT_RAW_REGRESSION_FOR_BINARY_CLASSIFICATION("binary:logitraw"),
        HINGE_LOSS_FOR_BINARY_CLASSIFICATION("binary:hinge"),
        POISSON_REGRESSION_FOR_COUNT_DATA("count:poisson"),
        COX_REGRESSION("survival:cox"),
        SOFTMAX_MULTICLASS_CLASSIFICATION("multi:softmax"),
        SOFTPROB_MULTICLASS_CLASSIFICATION("multi:softprob"),
        LAMBDAMART_PAIRWISE_RANKING("rank:pairwise"),
        LAMBDAMART_MAXIMISE_NDCG("rank:ndcg"),
        LAMBDAMART_MAXIMISE_MAP("rank:map"),
        GAMMA_REGRESSION("reg:gamma"),
        TWEEDIE_REGRESSION("reg:tweedie");

        private final String m_ParamString;

        private Objective(String paramString) {
            this.m_ParamString = paramString;
        }

        @Override
        public String paramValue() {
            return this.m_ParamString;
        }
    }

    public static enum Verbosity implements ParamValueProvider
    {
        SILENT,
        WARNING,
        INFO,
        DEBUG;


        @Override
        public Integer paramValue() {
            return this.ordinal();
        }
    }

    public static enum BoosterType {
        GBTREE,
        GBLINEAR,
        DART;

    }
}

