Package moa.classifiers.meta
Class StreamingGradientBoostedTrees
- java.lang.Object
-
- moa.AbstractMOAObject
-
- moa.options.AbstractOptionHandler
-
- moa.classifiers.AbstractClassifier
-
- moa.classifiers.meta.StreamingGradientBoostedTrees
-
- All Implemented Interfaces:
Configurable
,Serializable
,CapabilitiesHandler
,Classifier
,MultiClassClassifier
,Regressor
,AWTRenderable
,Learner<Example<Instance>>
,MOAObject
,OptionHandler
public class StreamingGradientBoostedTrees extends AbstractClassifier implements MultiClassClassifier, Regressor
Gradient boosted trees for evolving data streamsStreaming Gradient Boosted Trees (SGBT), which is trained using weighted squared loss elicited in XGBoost. SGBT exploits trees with a replacement strategy to detect and recover from drifts, thus enabling the ensemble to adapt without sacrificing the predictive performance.
See details in:
Nuwan Gunasekara, Bernhard Pfahringer, Heitor Murilo Gomes, Albert Bifet. Gradient Boosted Trees for Evolving Data Streams. Machine Learning, Springer, 2024. DOI.Parameters:
- -l : Classifier to train on instances.
- -s : The number of boosting iterations.
- -m : Percentage (%) of attributes for each boosting iteration.
- -L : Learning rate.
- -H : Disable one-hot encoding for regressors that supports nominal attributes.
- -M : Multiple training iterations by Ceiling (Hessian * M).
- -S : Randomly skipp 1/S th of instances at training (S=1: No Skip, use all instances for training).
- -K : Use Squared Loss for Classification.
- Version:
- $Revision: 1 $
- Author:
- Nuwan Gunasekara (ng98 at students dot waikato dot ac dot nz)
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
StreamingGradientBoostedTrees.SGBT
-
Field Summary
Fields Modifier and Type Field Description ClassOption
baseLearnerOption
FlagOption
disableOneHotEncoding
protected double[]
lastPrediction
FloatOption
learningRateOption
IntOption
multipleIterationByCeilingOfHessianTimesM
protected int
numberClasses
IntOption
numberOfboostingIterations
IntOption
percentageOfAttributesForEachBoostingIteration
IntOption
randomlySkip1SthOfInstancesAtTraining
IntOption
randomSeedOption
protected boolean
reset
protected StreamingGradientBoostedTrees.SGBT[]
SGBTCommittee
FlagOption
useSquaredLossForClassification
-
Fields inherited from class moa.classifiers.AbstractClassifier
classifierRandom, modelContext, randomSeed, trainingWeightSeenByModel
-
Fields inherited from class moa.options.AbstractOptionHandler
config
-
-
Constructor Summary
Constructors Constructor Description StreamingGradientBoostedTrees()
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description boolean
correctlyClassifies(Instance inst)
Gets whether this classifier correctly classifies an instance.protected void
createSGBTs(int numSGBTs)
Capabilities
getCapabilities()
Gets the capabilities of the object.void
getModelDescription(StringBuilder out, int indent)
Returns a string representation of the model.protected Measurement[]
getModelMeasurementsImpl()
Gets the current measurements of this classifier.
The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in overridden methods.static double[]
getScoresWhenNullTree(int outputSize)
static Instance
getSubInstance(Instance instance, double weight, ArrayList<Integer> subSpaceFeaturesIndexes, boolean setNumericClassAttribute, double numericClassValue, boolean useOneHotEncoding)
double[]
getVotesForInstance(Instance inst)
Predicts the class memberships for a given instance.boolean
isRandomizable()
Gets whether this learner needs a random seed.int
measureByteSize()
Gets the memory size of this object.static Instance
newBinaryClassInstance(Instance instance)
void
resetLearningImpl()
Resets this classifier.void
trainOnInstanceImpl(Instance inst)
Trains this classifier incrementally using the given instance.
The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in overridden methods.-
Methods inherited from class moa.classifiers.AbstractClassifier
contextIsCompatible, copy, defineImmutableCapabilities, getAttributeNameString, getAWTRenderer, getClassLabelString, getClassNameString, getDescription, getModel, getModelContext, getModelMeasurements, getNominalValueString, getPredictionForInstance, getPredictionForInstance, getPurposeString, getSubClassifiers, getSublearners, getVotesForInstance, modelAttIndexToInstanceAttIndex, modelAttIndexToInstanceAttIndex, prepareForUseImpl, resetLearning, setModelContext, setRandomSeed, trainingHasStarted, trainingWeightSeenByModel, trainOnInstance, trainOnInstance
-
Methods inherited from class moa.options.AbstractOptionHandler
getCLICreationString, getOptions, getPreparedClassOption, prepareClassOptions, prepareForUse, prepareForUse
-
Methods inherited from class moa.AbstractMOAObject
copy, measureByteSize, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface moa.options.OptionHandler
getCLICreationString, getOptions, prepareForUse, prepareForUse
-
-
-
-
Field Detail
-
baseLearnerOption
public ClassOption baseLearnerOption
-
numberOfboostingIterations
public IntOption numberOfboostingIterations
-
percentageOfAttributesForEachBoostingIteration
public IntOption percentageOfAttributesForEachBoostingIteration
-
learningRateOption
public FloatOption learningRateOption
-
disableOneHotEncoding
public FlagOption disableOneHotEncoding
-
multipleIterationByCeilingOfHessianTimesM
public IntOption multipleIterationByCeilingOfHessianTimesM
-
randomlySkip1SthOfInstancesAtTraining
public IntOption randomlySkip1SthOfInstancesAtTraining
-
useSquaredLossForClassification
public FlagOption useSquaredLossForClassification
-
randomSeedOption
public IntOption randomSeedOption
-
SGBTCommittee
protected StreamingGradientBoostedTrees.SGBT[] SGBTCommittee
-
reset
protected boolean reset
-
numberClasses
protected int numberClasses
-
lastPrediction
protected double[] lastPrediction
-
-
Method Detail
-
resetLearningImpl
public void resetLearningImpl()
Description copied from class:AbstractClassifier
Resets this classifier. It must be similar to starting a new classifier from scratch.
The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in overridden methods. Note that this will produce compiler errors if not overridden.- Specified by:
resetLearningImpl
in classAbstractClassifier
-
getModelMeasurementsImpl
protected Measurement[] getModelMeasurementsImpl()
Description copied from class:AbstractClassifier
Gets the current measurements of this classifier.
The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in overridden methods. Note that this will produce compiler errors if not overridden.- Specified by:
getModelMeasurementsImpl
in classAbstractClassifier
- Returns:
- an array of measurements to be used in evaluation tasks
-
getModelDescription
public void getModelDescription(StringBuilder out, int indent)
Description copied from class:AbstractClassifier
Returns a string representation of the model.- Specified by:
getModelDescription
in classAbstractClassifier
- Parameters:
out
- the stringbuilder to add the descriptionindent
- the number of characters to indent
-
isRandomizable
public boolean isRandomizable()
Description copied from interface:Learner
Gets whether this learner needs a random seed. Examples of methods that needs a random seed are bagging and boosting.- Specified by:
isRandomizable
in interfaceLearner<Example<Instance>>
- Returns:
- true if the learner needs a random seed.
-
getCapabilities
public Capabilities getCapabilities()
Description copied from interface:CapabilitiesHandler
Gets the capabilities of the object. Should be overridden if the object's capabilities can change.- Specified by:
getCapabilities
in interfaceCapabilitiesHandler
- Returns:
- The capabilities of the object.
-
correctlyClassifies
public boolean correctlyClassifies(Instance inst)
Description copied from interface:Classifier
Gets whether this classifier correctly classifies an instance. Uses getVotesForInstance to obtain the prediction and the instance to obtain its true class.- Specified by:
correctlyClassifies
in interfaceClassifier
- Overrides:
correctlyClassifies
in classAbstractClassifier
- Parameters:
inst
- the instance to be classified- Returns:
- true if the instance is correctly classified
-
measureByteSize
public int measureByteSize()
Description copied from interface:MOAObject
Gets the memory size of this object.- Specified by:
measureByteSize
in interfaceMOAObject
- Overrides:
measureByteSize
in classAbstractMOAObject
- Returns:
- the memory size of this object
-
trainOnInstanceImpl
public void trainOnInstanceImpl(Instance inst)
Description copied from class:AbstractClassifier
Trains this classifier incrementally using the given instance.
The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in overridden methods. Note that this will produce compiler errors if not overridden.- Specified by:
trainOnInstanceImpl
in classAbstractClassifier
- Parameters:
inst
- the instance to be used for training
-
getVotesForInstance
public double[] getVotesForInstance(Instance inst)
Description copied from interface:Classifier
Predicts the class memberships for a given instance. If an instance is unclassified, the returned array elements must be all zero.- Specified by:
getVotesForInstance
in interfaceClassifier
- Specified by:
getVotesForInstance
in classAbstractClassifier
- Parameters:
inst
- the instance to be classified- Returns:
- an array containing the estimated membership probabilities of the test instance in each class
-
getSubInstance
public static Instance getSubInstance(Instance instance, double weight, ArrayList<Integer> subSpaceFeaturesIndexes, boolean setNumericClassAttribute, double numericClassValue, boolean useOneHotEncoding)
-
getScoresWhenNullTree
public static double[] getScoresWhenNullTree(int outputSize)
-
createSGBTs
protected void createSGBTs(int numSGBTs)
-
-