package weka.classifiers.trees;

import adams.core.TechnicalInformation;
import adams.core.TechnicalInformationHandler;
import adams.core.base.BaseKeyValuePair;
import adams.core.management.LDD;
import adams.core.management.OS;
import adams.core.option.AbstractOption;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.Map;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.XGBoostError;
import weka.classifiers.simple.AbstractSimpleClassifier;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;

/* loaded from: input_file:weka/classifiers/trees/XGBoost.class */
public class XGBoost extends AbstractSimpleClassifier implements TechnicalInformationHandler {
    private static final long serialVersionUID = 7228620850250174821L;
    public static final String[] MIN_GLIBC_VERSION = {"2", "23"};

    @XGBoostParameter("booster")
    protected BoosterType m_BoosterType;

    @XGBoostParameter("verbosity")
    protected Verbosity m_Verbosity;

    @XGBoostParameter("nthread")
    protected int m_NumberOfThreads;

    @XGBoostParameter("eta")
    protected float m_Eta;

    @XGBoostParameter("gamma")
    protected float m_Gamma;

    @XGBoostParameter("max_depth")
    protected int m_MaxDepth;

    @XGBoostParameter("min_child_weight")
    protected float m_MinChildWeight;

    @XGBoostParameter("max_delta_step")
    protected float m_MaxDeltaStep;

    @XGBoostParameter("subsample")
    protected float m_Subsample;

    @XGBoostParameter("colsample_bytree")
    protected float m_ColumnSampleByTree;

    @XGBoostParameter("colsample_bylevel")
    protected float m_ColumnSampleByLevel;

    @XGBoostParameter("colsample_bynode")
    protected float m_ColumnSampleByNode;

    @XGBoostParameter("tree_method")
    protected TreeMethod m_TreeMethod;

    @XGBoostParameter("scale_pos_weight")
    protected float m_ScalePositiveWeights;

    @XGBoostParameter("process_type")
    protected ProcessType m_ProcessType;

    @XGBoostParameter("grow_policy")
    protected GrowPolicy m_GrowPolicy;

    @XGBoostParameter("max_leaves")
    protected int m_MaxLeaves;

    @XGBoostParameter("max_bin")
    protected int m_MaxBin;

    @XGBoostParameter("predictor")
    protected Predictor m_Predictor;

    @XGBoostParameter("num_parallel_tree")
    protected int m_NumberOfParallelTrees;

    @XGBoostParameter("sample_type")
    protected SampleType m_SampleType;

    @XGBoostParameter("normalize_type")
    protected NormaliseType m_NormaliseType;

    @XGBoostParameter("rate_drop")
    protected float m_RateDrop;

    @XGBoostParameter("one_drop")
    protected boolean m_OneDrop;

    @XGBoostParameter("skip_drop")
    protected float m_SkipDrop;

    @XGBoostParameter("lambda")
    protected float m_Lambda;

    @XGBoostParameter("alpha")
    protected float m_Alpha;

    @XGBoostParameter("updater")
    protected Updater m_Updater;

    @XGBoostParameter("feature_selector")
    protected FeatureSelector m_FeatureSelector;

    @XGBoostParameter("top_k")
    protected int m_TopK;

    @XGBoostParameter("tweedie_variance_power")
    protected float m_TweedieVariancePower;

    @XGBoostParameter("objective")
    protected Objective m_Objective;

    @XGBoostParameter("base_score")
    protected float m_BaseScore;

    @XGBoostParameter("seed")
    protected int m_Seed;
    protected int m_NumberOfRounds;
    protected BaseKeyValuePair[] m_OtherParameters;
    protected Booster m_Booster;
    protected Instances m_Header;
    protected Map<String, Object> m_Params;

    /* loaded from: input_file:weka/classifiers/trees/XGBoost$BoosterType.class */
    public enum BoosterType {
        GBTREE,
        GBLINEAR,
        DART
    }

    /* loaded from: input_file:weka/classifiers/trees/XGBoost$FeatureSelector.class */
    public enum FeatureSelector {
        CYCLIC,
        SHUFFLE,
        RANDOM,
        GREEDY,
        THRIFTY
    }

    /* loaded from: input_file:weka/classifiers/trees/XGBoost$GrowPolicy.class */
    public enum GrowPolicy {
        DEPTHWISE,
        LOSSGUIDE
    }

    /* loaded from: input_file:weka/classifiers/trees/XGBoost$NormaliseType.class */
    public enum NormaliseType {
        TREE,
        FOREST
    }

    /* loaded from: input_file:weka/classifiers/trees/XGBoost$Objective.class */
    public enum Objective implements ParamValueProvider {
        LINEAR_REGRESSION("reg:linear"),
        LOGISTIC_REGRESSION("reg:logistic"),
        LOGISTIC_REGRESSION_FOR_BINARY_CLASSIFICATION("binary:logistic"),
        LOGIT_RAW_REGRESSION_FOR_BINARY_CLASSIFICATION("binary:logitraw"),
        HINGE_LOSS_FOR_BINARY_CLASSIFICATION("binary:hinge"),
        POISSON_REGRESSION_FOR_COUNT_DATA("count:poisson"),
        COX_REGRESSION("survival:cox"),
        SOFTMAX_MULTICLASS_CLASSIFICATION("multi:softmax"),
        SOFTPROB_MULTICLASS_CLASSIFICATION("multi:softprob"),
        LAMBDAMART_PAIRWISE_RANKING("rank:pairwise"),
        LAMBDAMART_MAXIMISE_NDCG("rank:ndcg"),
        LAMBDAMART_MAXIMISE_MAP("rank:map"),
        GAMMA_REGRESSION("reg:gamma"),
        TWEEDIE_REGRESSION("reg:tweedie");

        private final String m_ParamString;

        Objective(String str) {
            this.m_ParamString = str;
        }

        @Override // weka.classifiers.trees.XGBoost.ParamValueProvider
        public String paramValue() {
            return this.m_ParamString;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:weka/classifiers/trees/XGBoost$ParamValueProvider.class */
    public interface ParamValueProvider {
        Object paramValue();
    }

    /* loaded from: input_file:weka/classifiers/trees/XGBoost$Predictor.class */
    public enum Predictor implements ParamValueProvider {
        CPU,
        GPU,
        DEFAULT;

        @Override // weka.classifiers.trees.XGBoost.ParamValueProvider
        public String paramValue() {
            return name().toLowerCase() + "_predictor";
        }
    }

    /* loaded from: input_file:weka/classifiers/trees/XGBoost$ProcessType.class */
    public enum ProcessType {
        DEFAULT,
        UPDATE
    }

    /* loaded from: input_file:weka/classifiers/trees/XGBoost$SampleType.class */
    public enum SampleType {
        UNIFORM,
        WEIGHTED
    }

    /* loaded from: input_file:weka/classifiers/trees/XGBoost$TreeMethod.class */
    public enum TreeMethod {
        AUTO,
        EXACT,
        APPROX,
        HIST,
        GPU_EXACT,
        GPU_HIST
    }

    /* loaded from: input_file:weka/classifiers/trees/XGBoost$Updater.class */
    public enum Updater {
        SHOTGUN,
        COORD_DESCENT
    }

    /* loaded from: input_file:weka/classifiers/trees/XGBoost$Verbosity.class */
    public enum Verbosity implements ParamValueProvider {
        SILENT,
        WARNING,
        INFO,
        DEBUG;

        @Override // weka.classifiers.trees.XGBoost.ParamValueProvider
        public Integer paramValue() {
            return Integer.valueOf(ordinal());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Target({ElementType.FIELD})
    @Retention(RetentionPolicy.RUNTIME)
    /* loaded from: input_file:weka/classifiers/trees/XGBoost$XGBoostParameter.class */
    public @interface XGBoostParameter {
        String value();
    }

    public String globalInfo() {
        return "Classifier implementing XGBoost.";
    }

    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("booster", "booster", BoosterType.GBTREE);
        this.m_OptionManager.add("verbosity", "verbosity", Verbosity.WARNING);
        this.m_OptionManager.add("nthread", "numThreads", -1);
        this.m_OptionManager.add("eta", "eta", Float.valueOf(0.3f), Float.valueOf(0.0f), Float.valueOf(1.0f));
        this.m_OptionManager.add("gamma", "gamma", Float.valueOf(0.0f), Float.valueOf(0.0f), Float.valueOf(Float.POSITIVE_INFINITY));
        this.m_OptionManager.add("max_depth", "maxDepth", 6, 0, Integer.MAX_VALUE);
        this.m_OptionManager.add("min_child_weight", "minChildWeight", Float.valueOf(1.0f), Float.valueOf(0.0f), Float.valueOf(Float.POSITIVE_INFINITY));
        this.m_OptionManager.add("max_delta_step", "maximumDeltaStep", Float.valueOf(0.0f), Float.valueOf(0.0f), Float.valueOf(Float.MAX_VALUE));
        this.m_OptionManager.add("subsample", "subsampleRatio", Float.valueOf(1.0f), Float.valueOf(Float.MIN_VALUE), Float.valueOf(1.0f));
        this.m_OptionManager.add("colsample_bytree", "columnSampleByTree", Float.valueOf(1.0f), Float.valueOf(Float.MIN_VALUE), Float.valueOf(1.0f));
        this.m_OptionManager.add("colsample_bylevel", "columnSampleByLevel", Float.valueOf(1.0f), Float.valueOf(Float.MIN_VALUE), Float.valueOf(1.0f));
        this.m_OptionManager.add("colsample_bynode", "columnSampleByNode", Float.valueOf(1.0f), Float.valueOf(Float.MIN_VALUE), Float.valueOf(1.0f));
        this.m_OptionManager.add("tree_method", "treeMethod", TreeMethod.AUTO);
        this.m_OptionManager.add("scale_pos_weight", "scalePositiveWeights", Float.valueOf(1.0f), Float.valueOf(0.0f), Float.valueOf(Float.MAX_VALUE));
        this.m_OptionManager.add("process_type", "processType", ProcessType.DEFAULT);
        this.m_OptionManager.add("grow_policy", "growPolicy", GrowPolicy.DEPTHWISE);
        this.m_OptionManager.add("max_leaves", "maxLeaves", 0, 0, Integer.MAX_VALUE);
        this.m_OptionManager.add("max_bin", "maxBin", 256, 2, Integer.MAX_VALUE);
        this.m_OptionManager.add("predictor", "predictor", Predictor.DEFAULT);
        this.m_OptionManager.add("num_parallel_tree", "numberOfParallelTrees", 1, 1, Integer.MAX_VALUE);
        this.m_OptionManager.add("sample_type", "sampleType", SampleType.UNIFORM);
        this.m_OptionManager.add("normalize_type", "normaliseType", NormaliseType.TREE);
        this.m_OptionManager.add("rate_drop", "rateDrop", Float.valueOf(0.0f), Float.valueOf(0.0f), Float.valueOf(1.0f));
        this.m_OptionManager.add("one_drop", "oneDrop", false);
        this.m_OptionManager.add("skip_drop", "skipDrop", Float.valueOf(0.0f), Float.valueOf(0.0f), Float.valueOf(1.0f));
        this.m_OptionManager.add("lambda", "lambda", Float.valueOf(1.0f));
        this.m_OptionManager.add("alpha", "alpha", Float.valueOf(0.0f));
        this.m_OptionManager.add("updater", "updater", Updater.SHOTGUN);
        this.m_OptionManager.add("feature_selector", "featureSelector", FeatureSelector.CYCLIC);
        this.m_OptionManager.add("top_k", "topK", 0, 0, Integer.MAX_VALUE);
        this.m_OptionManager.add("tweedie_variance_power", "tweedieVariancePower", Float.valueOf(1.5f), Float.valueOf(1.0f), Float.valueOf(2.0f));
        this.m_OptionManager.add("objective", "objective", Objective.LINEAR_REGRESSION);
        this.m_OptionManager.add("base_score", "baseScore", Float.valueOf(0.5f));
        this.m_OptionManager.add("seed", "seed", 0);
        this.m_OptionManager.add("rounds", "numberOfRounds", 2, 1, Integer.MAX_VALUE);
        this.m_OptionManager.add("other_params", "otherParameters", new BaseKeyValuePair[0]);
    }

    public BoosterType getBooster() {
        return this.m_BoosterType;
    }

    public void setBooster(BoosterType boosterType) {
        this.m_BoosterType = boosterType;
        reset();
    }

    public String boosterTipText() {
        return "Which booster to use.";
    }

    public Verbosity getVerbosity() {
        return this.m_Verbosity;
    }

    public void setVerbosity(Verbosity verbosity) {
        this.m_Verbosity = verbosity;
        reset();
    }

    public String verbosityTipText() {
        return "Verbosity of printing messages.";
    }

    public int getNumThreads() {
        return this.m_NumberOfThreads;
    }

    public void setNumThreads(int i) {
        this.m_NumberOfThreads = i;
        reset();
    }

    public String numThreadsTipText() {
        return "The number of parallel threads used to run XGBoost.";
    }

    public float getEta() {
        return this.m_Eta;
    }

    public void setEta(float f) {
        this.m_Eta = f;
        reset();
    }

    public String etaTipText() {
        return "The step size shrinkage to use in updates to prevent overfitting.";
    }

    public float getGamma() {
        return this.m_Gamma;
    }

    public void setGamma(float f) {
        this.m_Gamma = f;
        reset();
    }

    public String gammaTipText() {
        return "The minimum loss reduction required to make a further partition on a leaf node of the tree.";
    }

    public int getMaxDepth() {
        return this.m_MaxDepth;
    }

    public void setMaxDepth(int i) {
        this.m_MaxDepth = i;
        reset();
    }

    public String maxDepthTipText() {
        return "The maximum depth of a tree.";
    }

    public float getMinChildWeight() {
        return this.m_MinChildWeight;
    }

    public void setMinChildWeight(float f) {
        this.m_MinChildWeight = f;
        reset();
    }

    public String minChildWeightTipText() {
        return "The minimum sum of instance weights (hessian) needed in a child.";
    }

    public float getMaximumDeltaStep() {
        return this.m_MaxDeltaStep;
    }

    public void setMaximumDeltaStep(float f) {
        this.m_MaxDeltaStep = f;
        reset();
    }

    public String maximumDeltaStepTipText() {
        return "The maximum delta step we allow each leaf output to be.";
    }

    public float getSubsampleRatio() {
        return this.m_Subsample;
    }

    public void setSubsampleRatio(float f) {
        this.m_Subsample = f;
        reset();
    }

    public String subsampleRatioTipText() {
        return "The sub-sample ratio of the training instances.";
    }

    public float getColumnSampleByTree() {
        return this.m_ColumnSampleByTree;
    }

    public void setColumnSampleByTree(float f) {
        this.m_ColumnSampleByTree = f;
        reset();
    }

    public String columnSampleByTreeTipText() {
        return "The sub-sample ratio of columns when constructing each tree.";
    }

    public float getColumnSampleByLevel() {
        return this.m_ColumnSampleByLevel;
    }

    public void setColumnSampleByLevel(float f) {
        this.m_ColumnSampleByLevel = f;
        reset();
    }

    public String columnSampleByLevelTipText() {
        return "The sub-sample ratio of columns for each level.";
    }

    public float getColumnSampleByNode() {
        return this.m_ColumnSampleByNode;
    }

    public void setColumnSampleByNode(float f) {
        this.m_ColumnSampleByNode = f;
        reset();
    }

    public String columnSampleByNodeTipText() {
        return "The sub-sample ratio of columns for each node (split).";
    }

    public TreeMethod getTreeMethod() {
        return this.m_TreeMethod;
    }

    public void setTreeMethod(TreeMethod treeMethod) {
        this.m_TreeMethod = treeMethod;
        reset();
    }

    public String treeMethodTipText() {
        return "The tree construction algorithm used in XGBoost.";
    }

    public float getScalePositiveWeights() {
        return this.m_ScalePositiveWeights;
    }

    public void setScalePositiveWeights(float f) {
        this.m_ScalePositiveWeights = f;
        reset();
    }

    public String scalePositiveWeightsTipText() {
        return "Scales the weights of positive examples by this factor.";
    }

    public ProcessType getProcessType() {
        return this.m_ProcessType;
    }

    public void setProcessType(ProcessType processType) {
        this.m_ProcessType = processType;
        reset();
    }

    public String processTypeTipText() {
        return "The type of boosting process to run.";
    }

    public GrowPolicy getGrowPolicy() {
        return this.m_GrowPolicy;
    }

    public void setGrowPolicy(GrowPolicy growPolicy) {
        this.m_GrowPolicy = growPolicy;
        reset();
    }

    public String growPolicyTipText() {
        return "The way new nodes are added to the tree.";
    }

    public int getMaxLeaves() {
        return this.m_MaxLeaves;
    }

    public void setMaxLeaves(int i) {
        this.m_MaxLeaves = i;
        reset();
    }

    public String maxLeavesTipText() {
        return "The maximum number of nodes to be added.";
    }

    public int getMaxBin() {
        return this.m_MaxBin;
    }

    public void setMaxBin(int i) {
        this.m_MaxBin = i;
        reset();
    }

    public String maxBinTipText() {
        return "The maximum number of discrete bins to bucket continuous features.";
    }

    public Predictor getPredictor() {
        return this.m_Predictor;
    }

    public void setPredictor(Predictor predictor) {
        this.m_Predictor = predictor;
        reset();
    }

    public String predictorTipText() {
        return "The type of predictor algorithm to use.";
    }

    public int getNumberOfParallelTrees() {
        return this.m_NumberOfParallelTrees;
    }

    public void setNumberOfParallelTrees(int i) {
        this.m_NumberOfParallelTrees = i;
        reset();
    }

    public String numberOfParallelTreesTipText() {
        return "The number of parallel trees constructed during each iteration.";
    }

    public SampleType getSampleType() {
        return this.m_SampleType;
    }

    public void setSampleType(SampleType sampleType) {
        this.m_SampleType = sampleType;
        reset();
    }

    public String sampleTypeTipText() {
        return "The type of sampling algorithm.";
    }

    public NormaliseType getNormaliseType() {
        return this.m_NormaliseType;
    }

    public void setNormaliseType(NormaliseType normaliseType) {
        this.m_NormaliseType = normaliseType;
        reset();
    }

    public String normaliseTypeTipText() {
        return "The type of normalisation algorithm.";
    }

    public float getRateDrop() {
        return this.m_RateDrop;
    }

    public void setRateDrop(float f) {
        this.m_RateDrop = f;
        reset();
    }

    public String rateDropTipText() {
        return "The dropout rate (a fraction of previous trees to drop during the dropout).";
    }

    public boolean getOneDrop() {
        return this.m_OneDrop;
    }

    public void setOneDrop(boolean z) {
        this.m_OneDrop = z;
        reset();
    }

    public String oneDropTipText() {
        return "Whether at least one tree is always dropped during the dropout.";
    }

    public float getSkipDrop() {
        return this.m_SkipDrop;
    }

    public void setSkipDrop(float f) {
        this.m_SkipDrop = f;
        reset();
    }

    public String skipDropTipText() {
        return "The probability of skipping the dropout procedure during a boosting iteration.";
    }

    public float getLambda() {
        return this.m_Lambda;
    }

    public void setLambda(float f) {
        this.m_Lambda = f;
        reset();
    }

    public String lambdaTipText() {
        return "The L2 regularisation term on weights.";
    }

    public float getAlpha() {
        return this.m_Alpha;
    }

    public void setAlpha(float f) {
        this.m_Alpha = f;
        reset();
    }

    public String alphaTipText() {
        return "The L1 regularisation term on weights.";
    }

    public Updater getUpdater() {
        return this.m_Updater;
    }

    public void setUpdater(Updater updater) {
        this.m_Updater = updater;
        reset();
    }

    public String updaterTipText() {
        return "The choice of algorithm to fit the linear model.";
    }

    public FeatureSelector getFeatureSelector() {
        return this.m_FeatureSelector;
    }

    public void setFeatureSelector(FeatureSelector featureSelector) {
        this.m_FeatureSelector = featureSelector;
        reset();
    }

    public String featureSelectorTipText() {
        return "The feature selection and ordering method.";
    }

    public int getTopK() {
        return this.m_TopK;
    }

    public void setTopK(int i) {
        this.m_TopK = i;
        reset();
    }

    public String topKTipText() {
        return "The number of top features to select when using the greedy or thrifty feature selector.";
    }

    public float getTweedieVariancePower() {
        return this.m_TweedieVariancePower;
    }

    public void setTweedieVariancePower(float f) {
        this.m_TweedieVariancePower = f;
        reset();
    }

    public String tweedieVariancePowerTipText() {
        return "The parameter that controls the variance of the Tweedie distribution.";
    }

    public Objective getObjective() {
        return this.m_Objective;
    }

    public void setObjective(Objective objective) {
        this.m_Objective = objective;
        reset();
    }

    public String objectiveTipText() {
        return "The learning objective.";
    }

    public float getBaseScore() {
        return this.m_BaseScore;
    }

    public void setBaseScore(float f) {
        this.m_BaseScore = f;
        reset();
    }

    public String baseScoreTipText() {
        return "The initial prediction score of all instances (global bias).";
    }

    public int getSeed() {
        return this.m_Seed;
    }

    public void setSeed(int i) {
        this.m_Seed = i;
        reset();
    }

    public String seedTipText() {
        return "The random number seed.";
    }

    public int getNumberOfRounds() {
        return this.m_NumberOfRounds;
    }

    public void setNumberOfRounds(int i) {
        this.m_NumberOfRounds = i;
        reset();
    }

    public String numberOfRoundsTipText() {
        return "The number of boosting rounds to perform.";
    }

    public BaseKeyValuePair[] getOtherParameters() {
        return this.m_OtherParameters;
    }

    public void setOtherParameters(BaseKeyValuePair[] baseKeyValuePairArr) {
        this.m_OtherParameters = baseKeyValuePairArr;
        reset();
    }

    public String otherParametersTipText() {
        return "Passes any additional parameters to XGBoost.";
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Chen, Tianqi and Guestrin, Carlos");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "XGBoost: A Scalable Tree Boosting System");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining");
        technicalInformation.setValue(TechnicalInformation.Field.SERIES, "KDD '16");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2016");
        technicalInformation.setValue(TechnicalInformation.Field.ISBN, "978-1-4503-4232-2");
        technicalInformation.setValue(TechnicalInformation.Field.LOCATION, "San Francisco, California, USA");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "785--794");
        technicalInformation.setValue(TechnicalInformation.Field.URL, "http://doi.acm.org/10.1145/2939672.2939785");
        technicalInformation.setValue(TechnicalInformation.Field.PUBLISHER, "ACM");
        technicalInformation.setValue(TechnicalInformation.Field.ADDRESS, "New York, NY, USA");
        technicalInformation.setValue(TechnicalInformation.Field.KEYWORDS, "large-scale machine learning");
        return technicalInformation;
    }

    protected int numberOfRequiredDMatrixColumns(Instances instances) {
        int i = 0;
        int classIndex = instances.classIndex();
        for (int i2 = 0; i2 < instances.numAttributes(); i2++) {
            Attribute attribute = instances.attribute(i2);
            if (classIndex != i2) {
                if (attribute.isNumeric() || attribute.isDate()) {
                    i++;
                } else if (attribute.isNominal()) {
                    i += attribute.numValues();
                }
            }
        }
        return i;
    }

    protected DMatrix instancesToDMatrix(Instance[] instanceArr) throws XGBoostError {
        int length = instanceArr.length;
        int numberOfRequiredDMatrixColumns = numberOfRequiredDMatrixColumns(this.m_Header);
        if (length == 0 || numberOfRequiredDMatrixColumns == 0) {
            return null;
        }
        float[] fArr = new float[length * numberOfRequiredDMatrixColumns];
        float[] fArr2 = new float[length];
        float[] fArr3 = new float[length];
        int classIndex = this.m_Header.classIndex();
        int i = 0;
        for (int i2 = 0; i2 < length; i2++) {
            Instance instance = instanceArr[i2];
            double[] doubleArray = instance.toDoubleArray();
            fArr3[i2] = (float) instance.weight();
            if (instance.classIsMissing()) {
                fArr2[i2] = 0.0f;
            } else {
                fArr2[i2] = (float) doubleArray[classIndex];
            }
            for (int i3 = 0; i3 < doubleArray.length; i3++) {
                Attribute attribute = this.m_Header.attribute(i3);
                if (i3 != classIndex) {
                    if (attribute.isDate() || attribute.isNumeric()) {
                        fArr[i] = (float) doubleArray[i3];
                        i++;
                    } else if (attribute.isNominal()) {
                        fArr[i + ((int) doubleArray[i3])] = 1.0f;
                        i += attribute.numValues();
                    }
                }
            }
        }
        DMatrix dMatrix = new DMatrix(fArr, length, numberOfRequiredDMatrixColumns, (float) Utils.missingValue());
        dMatrix.setLabel(fArr2);
        dMatrix.setWeight(fArr3);
        return dMatrix;
    }

    protected Map<String, Object> createParamsFromOptions() {
        AbstractOption findByFlag;
        HashMap hashMap = new HashMap();
        hashMap.putAll(BaseKeyValuePair.toMap(getOtherParameters()));
        for (Field field : getClass().getDeclaredFields()) {
            XGBoostParameter xGBoostParameter = (XGBoostParameter) field.getAnnotation(XGBoostParameter.class);
            if (xGBoostParameter != null && (findByFlag = this.m_OptionManager.findByFlag(xGBoostParameter.value())) != null && !this.m_OptionManager.isDefaultValueByFlag(xGBoostParameter.value())) {
                Object currentValue = findByFlag.getCurrentValue();
                if (currentValue instanceof ParamValueProvider) {
                    currentValue = ((ParamValueProvider) currentValue).paramValue();
                } else if (currentValue instanceof Enum) {
                    currentValue = ((Enum) currentValue).name().toLowerCase();
                }
                hashMap.put(xGBoostParameter.value(), currentValue);
            }
        }
        return hashMap;
    }

    @Override // weka.classifiers.simple.AbstractSimpleClassifier
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_CLASS);
        capabilities.setMinimumNumberInstances(1);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().test(instances);
        if (OS.isLinux() && LDD.compareTo(MIN_GLIBC_VERSION) < 0) {
            throw new Exception("XGBoost requires a minimum glibc version of " + adams.core.Utils.flatten(MIN_GLIBC_VERSION, ".") + " but found only " + adams.core.Utils.flatten(LDD.version(), ".") + "!");
        }
        this.m_Header = new Instances(instances, 0);
        DMatrix instancesToDMatrix = instancesToDMatrix((Instance[]) instances.toArray(new Instance[0]));
        if (instancesToDMatrix == null) {
            this.m_Booster = null;
            return;
        }
        this.m_Params = createParamsFromOptions();
        if (isLoggingEnabled()) {
            getLogger().info("XGBoost parameters: " + this.m_Params);
        }
        HashMap hashMap = new HashMap();
        if (getVerbosity() != Verbosity.SILENT) {
            hashMap.put("train", instancesToDMatrix);
        }
        this.m_Booster = ml.dmlc.xgboost4j.java.XGBoost.train(instancesToDMatrix, this.m_Params, this.m_NumberOfRounds, hashMap, (IObjective) null, (IEvaluation) null);
    }

    @Override // weka.classifiers.simple.AbstractSimpleClassifier
    public double classifyInstance(Instance instance) throws Exception {
        if (this.m_Booster == null || instance == null) {
            return Utils.missingValue();
        }
        DMatrix instancesToDMatrix = instancesToDMatrix(new Instance[]{instance});
        if (instancesToDMatrix == null) {
            return Utils.missingValue();
        }
        float[][] predict = this.m_Booster.predict(instancesToDMatrix);
        return instance.classAttribute().isNumeric() ? predict[0][0] : Math.round(predict[0][0]);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder("XGBoost\n=======\n\n");
        sb.append("Parameters: ");
        if (this.m_Params != null) {
            sb.append(this.m_Params.toString());
        } else {
            sb.append("No model built yet");
        }
        return sb.toString();
    }

    public static void main(String[] strArr) {
        runClassifier(new XGBoost(), strArr);
    }
}
