Package moa.classifiers.functions
Class SGDMultiClass
- java.lang.Object
-
- moa.AbstractMOAObject
-
- moa.options.AbstractOptionHandler
-
- moa.classifiers.AbstractClassifier
-
- moa.classifiers.functions.SGDMultiClass
-
- All Implemented Interfaces:
Configurable
,Serializable
,CapabilitiesHandler
,Classifier
,MultiClassClassifier
,Regressor
,AWTRenderable
,Learner<Example<Instance>>
,MOAObject
,OptionHandler
public class SGDMultiClass extends AbstractClassifier implements MultiClassClassifier, Regressor
Implements stochastic gradient descent for learning various linear models (binary class SVM, binary class logistic regression and linear regression).- See Also:
- Serialized Form
-
-
Field Summary
Fields Modifier and Type Field Description protected static int
HINGE
FloatOption
lambdaRegularizationOption
FloatOption
learningRateOption
protected static int
LOGLOSS
MultiChoiceOption
lossFunctionOption
protected double[]
m_bias
protected double
m_lambda
The regularization parameterprotected double
m_learningRate
The learning rateprotected int
m_loss
The current loss function to minimizeprotected double
m_numInstances
The number of training instancesprotected double
m_t
Holds the current iteration numberprotected DoubleVector[]
m_weights
Stores the weights (+ bias in the last element)protected static int
SQUAREDLOSS
-
Fields inherited from class moa.classifiers.AbstractClassifier
classifierRandom, modelContext, randomSeed, randomSeedOption, trainingWeightSeenByModel
-
Fields inherited from class moa.options.AbstractOptionHandler
config
-
-
Constructor Summary
Constructors Constructor Description SGDMultiClass()
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description protected double
dloss(double z)
protected static double
dotProd(Instance inst1, DoubleVector weights, int classIndex)
double
getLambda()
Get the current value of lambdadouble
getLearningRate()
Get the learning rate.int
getLossFunction()
Get the current loss function.void
getModelDescription(StringBuilder result, int indent)
Returns a string representation of the model.protected Measurement[]
getModelMeasurementsImpl()
Gets the current measurements of this classifier.
The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in overridden methods.String
getPurposeString()
Dictionary with option texts and objectsdouble[]
getVotesForInstance(Instance inst)
Calculates the class membership probabilities for the given test instance.boolean
isRandomizable()
Gets whether this learner needs a random seed.void
reset()
Reset the classifier.void
resetLearningImpl()
Resets this classifier.void
setLambda(double lambda)
Set the value of lambda to usevoid
setLearningRate(double lr)
Set the learning rate.void
setLossFunction(int function)
Set the loss function to use.String
toString()
Prints out the classifier.void
trainOnInstanceImpl(Instance instance)
Trains the classifier with the given instance.void
trainOnInstanceImpl(Instance instance, int classLabel)
-
Methods inherited from class moa.classifiers.AbstractClassifier
contextIsCompatible, copy, correctlyClassifies, defineImmutableCapabilities, getAttributeNameString, getAWTRenderer, getClassLabelString, getClassNameString, getDescription, getModel, getModelContext, getModelMeasurements, getNominalValueString, getPredictionForInstance, getPredictionForInstance, getSubClassifiers, getSublearners, getVotesForInstance, modelAttIndexToInstanceAttIndex, modelAttIndexToInstanceAttIndex, prepareForUseImpl, resetLearning, setModelContext, setRandomSeed, trainingHasStarted, trainingWeightSeenByModel, trainOnInstance, trainOnInstance
-
Methods inherited from class moa.options.AbstractOptionHandler
getCLICreationString, getOptions, getPreparedClassOption, prepareClassOptions, prepareForUse, prepareForUse
-
Methods inherited from class moa.AbstractMOAObject
copy, measureByteSize, measureByteSize
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface moa.capabilities.CapabilitiesHandler
getCapabilities
-
Methods inherited from interface moa.MOAObject
measureByteSize
-
Methods inherited from interface moa.options.OptionHandler
getCLICreationString, getOptions, prepareForUse, prepareForUse
-
-
-
-
Field Detail
-
m_lambda
protected double m_lambda
The regularization parameter
-
lambdaRegularizationOption
public FloatOption lambdaRegularizationOption
-
m_learningRate
protected double m_learningRate
The learning rate
-
learningRateOption
public FloatOption learningRateOption
-
m_weights
protected DoubleVector[] m_weights
Stores the weights (+ bias in the last element)
-
m_bias
protected double[] m_bias
-
m_t
protected double m_t
Holds the current iteration number
-
m_numInstances
protected double m_numInstances
The number of training instances
-
HINGE
protected static final int HINGE
- See Also:
- Constant Field Values
-
LOGLOSS
protected static final int LOGLOSS
- See Also:
- Constant Field Values
-
SQUAREDLOSS
protected static final int SQUAREDLOSS
- See Also:
- Constant Field Values
-
m_loss
protected int m_loss
The current loss function to minimize
-
lossFunctionOption
public MultiChoiceOption lossFunctionOption
-
-
Method Detail
-
getPurposeString
public String getPurposeString()
Description copied from class:AbstractOptionHandler
Dictionary with option texts and objects- Specified by:
getPurposeString
in interfaceOptionHandler
- Overrides:
getPurposeString
in classAbstractClassifier
- Returns:
- the string with the purpose of this object
-
setLambda
public void setLambda(double lambda)
Set the value of lambda to use- Parameters:
lambda
- the value of lambda to use
-
getLambda
public double getLambda()
Get the current value of lambda- Returns:
- the current value of lambda
-
setLossFunction
public void setLossFunction(int function)
Set the loss function to use.- Parameters:
function
- the loss function to use.
-
getLossFunction
public int getLossFunction()
Get the current loss function.- Returns:
- the current loss function.
-
setLearningRate
public void setLearningRate(double lr)
Set the learning rate.- Parameters:
lr
- the learning rate to use.
-
getLearningRate
public double getLearningRate()
Get the learning rate.- Returns:
- the learning rate
-
reset
public void reset()
Reset the classifier.
-
dloss
protected double dloss(double z)
-
dotProd
protected static double dotProd(Instance inst1, DoubleVector weights, int classIndex)
-
resetLearningImpl
public void resetLearningImpl()
Description copied from class:AbstractClassifier
Resets this classifier. It must be similar to starting a new classifier from scratch.
The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in overridden methods. Note that this will produce compiler errors if not overridden.- Specified by:
resetLearningImpl
in classAbstractClassifier
-
trainOnInstanceImpl
public void trainOnInstanceImpl(Instance instance)
Trains the classifier with the given instance.- Specified by:
trainOnInstanceImpl
in classAbstractClassifier
- Parameters:
instance
- the new training instance to include in the model
-
trainOnInstanceImpl
public void trainOnInstanceImpl(Instance instance, int classLabel)
-
getVotesForInstance
public double[] getVotesForInstance(Instance inst)
Calculates the class membership probabilities for the given test instance.- Specified by:
getVotesForInstance
in interfaceClassifier
- Specified by:
getVotesForInstance
in classAbstractClassifier
- Parameters:
inst
- the instance to be classified- Returns:
- predicted class probability distribution
-
getModelDescription
public void getModelDescription(StringBuilder result, int indent)
Description copied from class:AbstractClassifier
Returns a string representation of the model.- Specified by:
getModelDescription
in classAbstractClassifier
- Parameters:
result
- the stringbuilder to add the descriptionindent
- the number of characters to indent
-
toString
public String toString()
Prints out the classifier.- Overrides:
toString
in classAbstractMOAObject
- Returns:
- a description of the classifier as a string
-
getModelMeasurementsImpl
protected Measurement[] getModelMeasurementsImpl()
Description copied from class:AbstractClassifier
Gets the current measurements of this classifier.
The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in overridden methods. Note that this will produce compiler errors if not overridden.- Specified by:
getModelMeasurementsImpl
in classAbstractClassifier
- Returns:
- an array of measurements to be used in evaluation tasks
-
isRandomizable
public boolean isRandomizable()
Description copied from interface:Learner
Gets whether this learner needs a random seed. Examples of methods that needs a random seed are bagging and boosting.- Specified by:
isRandomizable
in interfaceLearner<Example<Instance>>
- Returns:
- true if the learner needs a random seed.
-
-