|
||||||||||
| PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
| SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD | |||||||||
java.lang.Objectweka.classifiers.AbstractClassifier
weka.classifiers.RandomizableClassifier
weka.classifiers.functions.SGD
public class SGD
Implements stochastic gradient descent for learning various linear models (binary class SVM, binary class logistic regression and linear regression). Globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes, so the coefficients in the output are based on the normalized data.
For numeric class attributes, the squared loss function (2) must be used.
-F Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression), 2 = squared loss (regression). (default = 0)
-L The learning rate. If normalization is turned off (as it is automatically for streaming data), then the default learning rate will need to be reduced (try 0.0001). (default = 0.01).
-R <double> The lambda regularization constant (default = 0.0001)
-E <integer> The number of epochs to perform (batch learning only, default = 500)
-N Don't normalize the data
-M Don't replace missing values
| Field Summary | |
|---|---|
static int |
HINGE
the hinge loss function. |
static int |
LOGLOSS
the log loss function. |
static int |
SQUAREDLOSS
the squared loss funtion. |
static Tag[] |
TAGS_SELECTION
Loss functions to choose from |
| Constructor Summary | |
|---|---|
SGD()
|
|
| Method Summary | |
|---|---|
void |
buildClassifier(Instances data)
Method for building the classifier. |
double[] |
distributionForInstance(Instance inst)
Computes the distribution for a given instance |
String |
dontNormalizeTipText()
Returns the tip text for this property |
String |
dontReplaceMissingTipText()
Returns the tip text for this property |
String |
epochsTipText()
Returns the tip text for this property |
Capabilities |
getCapabilities()
Returns default capabilities of the classifier. |
boolean |
getDontNormalize()
Get whether normalization has been turned off. |
boolean |
getDontReplaceMissing()
Get whether global replacement of missing values has been disabled. |
int |
getEpochs()
Get current number of epochs |
double |
getLambda()
Get the current value of lambda |
double |
getLearningRate()
Get the learning rate. |
SelectedTag |
getLossFunction()
Get the current loss function. |
String[] |
getOptions()
Gets the current settings of the classifier. |
String |
getRevision()
Returns the revision string. |
String |
globalInfo()
Returns a string describing classifier |
String |
lambdaTipText()
Returns the tip text for this property |
String |
learningRateTipText()
Returns the tip text for this property |
Enumeration<Option> |
listOptions()
Returns an enumeration describing the available options. |
String |
lossFunctionTipText()
Returns the tip text for this property |
static void |
main(String[] args)
Main method for testing this class. |
void |
reset()
Reset the classifier. |
void |
setDontNormalize(boolean m)
Turn normalization off/on. |
void |
setDontReplaceMissing(boolean m)
Turn global replacement of missing values off/on. |
void |
setEpochs(int e)
Set the number of epochs to use |
void |
setLambda(double lambda)
Set the value of lambda to use |
void |
setLearningRate(double lr)
Set the learning rate. |
void |
setLossFunction(SelectedTag function)
Set the loss function to use. |
void |
setOptions(String[] options)
Parses a given list of options. |
String |
toString()
Prints out the classifier. |
void |
updateClassifier(Instance instance)
Updates the classifier with the given instance. |
| Methods inherited from class weka.classifiers.RandomizableClassifier |
|---|
getSeed, seedTipText, setSeed |
| Methods inherited from class weka.classifiers.AbstractClassifier |
|---|
classifyInstance, debugTipText, forName, getDebug, makeCopies, makeCopy, runClassifier, setDebug |
| Methods inherited from class java.lang.Object |
|---|
equals, getClass, hashCode, notify, notifyAll, wait, wait, wait |
| Field Detail |
|---|
public static final int HINGE
public static final int LOGLOSS
public static final int SQUAREDLOSS
public static final Tag[] TAGS_SELECTION
| Constructor Detail |
|---|
public SGD()
| Method Detail |
|---|
public Capabilities getCapabilities()
getCapabilities in interface ClassifiergetCapabilities in interface CapabilitiesHandlergetCapabilities in class AbstractClassifierCapabilitiespublic String lambdaTipText()
public void setLambda(double lambda)
lambda - the value of lambda to usepublic double getLambda()
public void setLearningRate(double lr)
lr - the learning rate to use.public double getLearningRate()
public String learningRateTipText()
public String epochsTipText()
public void setEpochs(int e)
e - the number of epochs to usepublic int getEpochs()
public void setDontNormalize(boolean m)
m - true if normalization is to be disabled.public boolean getDontNormalize()
public String dontNormalizeTipText()
public void setDontReplaceMissing(boolean m)
m - true if global replacement of missing values is to be
turned off.public boolean getDontReplaceMissing()
public String dontReplaceMissingTipText()
public void setLossFunction(SelectedTag function)
function - the loss function to use.public SelectedTag getLossFunction()
public String lossFunctionTipText()
public Enumeration<Option> listOptions()
listOptions in interface OptionHandlerlistOptions in class RandomizableClassifier
public void setOptions(String[] options)
throws Exception
-F Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression), 2 = squared loss (regression). (default = 0)
-L The learning rate. If normalization is turned off (as it is automatically for streaming data), then the default learning rate will need to be reduced (try 0.0001). (default = 0.01).
-R <double> The lambda regularization constant (default = 0.0001)
-E <integer> The number of epochs to perform (batch learning only, default = 500)
-N Don't normalize the data
-M Don't replace missing values
setOptions in interface OptionHandlersetOptions in class RandomizableClassifieroptions - the list of options as an array of strings
Exception - if an option is not supportedpublic String[] getOptions()
getOptions in interface OptionHandlergetOptions in class RandomizableClassifierpublic String globalInfo()
public void reset()
public void buildClassifier(Instances data)
throws Exception
buildClassifier in interface Classifierdata - the set of training instances.
Exception - if the classifier can't be built successfully.
public void updateClassifier(Instance instance)
throws Exception
updateClassifier in interface UpdateableClassifierinstance - the new training instance to include in the model
Exception - if the instance could not be incorporated in
the model.
public double[] distributionForInstance(Instance inst)
throws Exception
distributionForInstance in interface ClassifierdistributionForInstance in class AbstractClassifierinstance - the instance for which distribution is computed
Exception - if the distribution can't be computed successfullypublic String toString()
toString in class Objectpublic String getRevision()
getRevision in interface RevisionHandlergetRevision in class AbstractClassifierpublic static void main(String[] args)
|
||||||||||
| PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
| SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD | |||||||||