weka.classifiers.functions
Class SGD

java.lang.Object
  extended by weka.classifiers.AbstractClassifier
      extended by weka.classifiers.RandomizableClassifier
          extended by weka.classifiers.functions.SGD
All Implemented Interfaces:
Serializable, Cloneable, Classifier, UpdateableClassifier, CapabilitiesHandler, OptionHandler, Randomizable, RevisionHandler

public class SGD
extends RandomizableClassifier
implements UpdateableClassifier, OptionHandler

Implements stochastic gradient descent for learning various linear models (binary class SVM, binary class logistic regression and linear regression). Globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes, so the coefficients in the output are based on the normalized data.
For numeric class attributes, the squared loss function (2) must be used.

Valid options are:

 -F
  Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression),
  2 = squared loss (regression).
  (default = 0)
 -L
  The learning rate. If normalization is
  turned off (as it is automatically for streaming data), then the
  default learning rate will need to be reduced (try 0.0001).
  (default = 0.01).
 -R <double>
  The lambda regularization constant (default = 0.0001)
 -E <integer>
  The number of epochs to perform (batch learning only, default = 500)
 -N
  Don't normalize the data
 -M
  Don't replace missing values

Version:
$Revision: 8034 $
Author:
Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz), Mark Hall (mhall{[at]}pentaho{[dot]}com)
See Also:
Serialized Form

Field Summary
static int HINGE
          the hinge loss function.
static int LOGLOSS
          the log loss function.
static int SQUAREDLOSS
          the squared loss funtion.
static Tag[] TAGS_SELECTION
          Loss functions to choose from
 
Constructor Summary
SGD()
           
 
Method Summary
 void buildClassifier(Instances data)
          Method for building the classifier.
 double[] distributionForInstance(Instance inst)
          Computes the distribution for a given instance
 String dontNormalizeTipText()
          Returns the tip text for this property
 String dontReplaceMissingTipText()
          Returns the tip text for this property
 String epochsTipText()
          Returns the tip text for this property
 Capabilities getCapabilities()
          Returns default capabilities of the classifier.
 boolean getDontNormalize()
          Get whether normalization has been turned off.
 boolean getDontReplaceMissing()
          Get whether global replacement of missing values has been disabled.
 int getEpochs()
          Get current number of epochs
 double getLambda()
          Get the current value of lambda
 double getLearningRate()
          Get the learning rate.
 SelectedTag getLossFunction()
          Get the current loss function.
 String[] getOptions()
          Gets the current settings of the classifier.
 String getRevision()
          Returns the revision string.
 String globalInfo()
          Returns a string describing classifier
 String lambdaTipText()
          Returns the tip text for this property
 String learningRateTipText()
          Returns the tip text for this property
 Enumeration<Option> listOptions()
          Returns an enumeration describing the available options.
 String lossFunctionTipText()
          Returns the tip text for this property
static void main(String[] args)
          Main method for testing this class.
 void reset()
          Reset the classifier.
 void setDontNormalize(boolean m)
          Turn normalization off/on.
 void setDontReplaceMissing(boolean m)
          Turn global replacement of missing values off/on.
 void setEpochs(int e)
          Set the number of epochs to use
 void setLambda(double lambda)
          Set the value of lambda to use
 void setLearningRate(double lr)
          Set the learning rate.
 void setLossFunction(SelectedTag function)
          Set the loss function to use.
 void setOptions(String[] options)
          Parses a given list of options.
 String toString()
          Prints out the classifier.
 void updateClassifier(Instance instance)
          Updates the classifier with the given instance.
 
Methods inherited from class weka.classifiers.RandomizableClassifier
getSeed, seedTipText, setSeed
 
Methods inherited from class weka.classifiers.AbstractClassifier
classifyInstance, debugTipText, forName, getDebug, makeCopies, makeCopy, runClassifier, setDebug
 
Methods inherited from class java.lang.Object
equals, getClass, hashCode, notify, notifyAll, wait, wait, wait
 

Field Detail

HINGE

public static final int HINGE
the hinge loss function.

See Also:
Constant Field Values

LOGLOSS

public static final int LOGLOSS
the log loss function.

See Also:
Constant Field Values

SQUAREDLOSS

public static final int SQUAREDLOSS
the squared loss funtion.

See Also:
Constant Field Values

TAGS_SELECTION

public static final Tag[] TAGS_SELECTION
Loss functions to choose from

Constructor Detail

SGD

public SGD()
Method Detail

getCapabilities

public Capabilities getCapabilities()
Returns default capabilities of the classifier.

Specified by:
getCapabilities in interface Classifier
Specified by:
getCapabilities in interface CapabilitiesHandler
Overrides:
getCapabilities in class AbstractClassifier
Returns:
the capabilities of this classifier
See Also:
Capabilities

lambdaTipText

public String lambdaTipText()
Returns the tip text for this property

Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

setLambda

public void setLambda(double lambda)
Set the value of lambda to use

Parameters:
lambda - the value of lambda to use

getLambda

public double getLambda()
Get the current value of lambda

Returns:
the current value of lambda

setLearningRate

public void setLearningRate(double lr)
Set the learning rate.

Parameters:
lr - the learning rate to use.

getLearningRate

public double getLearningRate()
Get the learning rate.

Returns:
the learning rate

learningRateTipText

public String learningRateTipText()
Returns the tip text for this property

Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

epochsTipText

public String epochsTipText()
Returns the tip text for this property

Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

setEpochs

public void setEpochs(int e)
Set the number of epochs to use

Parameters:
e - the number of epochs to use

getEpochs

public int getEpochs()
Get current number of epochs

Returns:
the current number of epochs

setDontNormalize

public void setDontNormalize(boolean m)
Turn normalization off/on.

Parameters:
m - true if normalization is to be disabled.

getDontNormalize

public boolean getDontNormalize()
Get whether normalization has been turned off.

Returns:
true if normalization has been disabled.

dontNormalizeTipText

public String dontNormalizeTipText()
Returns the tip text for this property

Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

setDontReplaceMissing

public void setDontReplaceMissing(boolean m)
Turn global replacement of missing values off/on. If turned off, then missing values are effectively ignored.

Parameters:
m - true if global replacement of missing values is to be turned off.

getDontReplaceMissing

public boolean getDontReplaceMissing()
Get whether global replacement of missing values has been disabled.

Returns:
true if global replacement of missing values has been turned off

dontReplaceMissingTipText

public String dontReplaceMissingTipText()
Returns the tip text for this property

Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

setLossFunction

public void setLossFunction(SelectedTag function)
Set the loss function to use.

Parameters:
function - the loss function to use.

getLossFunction

public SelectedTag getLossFunction()
Get the current loss function.

Returns:
the current loss function.

lossFunctionTipText

public String lossFunctionTipText()
Returns the tip text for this property

Returns:
tip text for this property suitable for displaying in the explorer/experimenter gui

listOptions

public Enumeration<Option> listOptions()
Returns an enumeration describing the available options.

Specified by:
listOptions in interface OptionHandler
Overrides:
listOptions in class RandomizableClassifier
Returns:
an enumeration of all the available options.

setOptions

public void setOptions(String[] options)
                throws Exception
Parses a given list of options.

Valid options are:

 -F
  Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression),
  2 = squared loss (regression).
  (default = 0)
 -L
  The learning rate. If normalization is
  turned off (as it is automatically for streaming data), then the
  default learning rate will need to be reduced (try 0.0001).
  (default = 0.01).
 -R <double>
  The lambda regularization constant (default = 0.0001)
 -E <integer>
  The number of epochs to perform (batch learning only, default = 500)
 -N
  Don't normalize the data
 -M
  Don't replace missing values

Specified by:
setOptions in interface OptionHandler
Overrides:
setOptions in class RandomizableClassifier
Parameters:
options - the list of options as an array of strings
Throws:
Exception - if an option is not supported

getOptions

public String[] getOptions()
Gets the current settings of the classifier.

Specified by:
getOptions in interface OptionHandler
Overrides:
getOptions in class RandomizableClassifier
Returns:
an array of strings suitable for passing to setOptions

globalInfo

public String globalInfo()
Returns a string describing classifier

Returns:
a description suitable for displaying in the explorer/experimenter gui

reset

public void reset()
Reset the classifier.


buildClassifier

public void buildClassifier(Instances data)
                     throws Exception
Method for building the classifier.

Specified by:
buildClassifier in interface Classifier
Parameters:
data - the set of training instances.
Throws:
Exception - if the classifier can't be built successfully.

updateClassifier

public void updateClassifier(Instance instance)
                      throws Exception
Updates the classifier with the given instance.

Specified by:
updateClassifier in interface UpdateableClassifier
Parameters:
instance - the new training instance to include in the model
Throws:
Exception - if the instance could not be incorporated in the model.

distributionForInstance

public double[] distributionForInstance(Instance inst)
                                 throws Exception
Computes the distribution for a given instance

Specified by:
distributionForInstance in interface Classifier
Overrides:
distributionForInstance in class AbstractClassifier
Parameters:
instance - the instance for which distribution is computed
Returns:
the distribution
Throws:
Exception - if the distribution can't be computed successfully

toString

public String toString()
Prints out the classifier.

Overrides:
toString in class Object
Returns:
a description of the classifier as a string

getRevision

public String getRevision()
Returns the revision string.

Specified by:
getRevision in interface RevisionHandler
Overrides:
getRevision in class AbstractClassifier
Returns:
the revision

main

public static void main(String[] args)
Main method for testing this class.



Copyright © 2012 University of Waikato, Hamilton, NZ. All Rights Reserved.