Package weka.classifiers.functions
Class SGD
java.lang.Object
weka.classifiers.AbstractClassifier
weka.classifiers.RandomizableClassifier
weka.classifiers.functions.SGD
- All Implemented Interfaces:
Serializable
,Cloneable
,Classifier
,UpdateableClassifier
,Aggregateable<SGD>
,BatchPredictor
,CapabilitiesHandler
,CapabilitiesIgnorer
,CommandlineRunnable
,OptionHandler
,Randomizable
,RevisionHandler
public class SGD
extends RandomizableClassifier
implements UpdateableClassifier, OptionHandler, Aggregateable<SGD>
Implements stochastic gradient descent for learning various linear models (binary class SVM, binary class logistic regression, squared loss, Huber loss and epsilon-insensitive loss linear regression). Globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes, so the coefficients in the output are based on the normalized data.
For numeric class attributes, the squared, Huber or epsilon-insensitve loss function must be used. Epsilon-insensitive and Huber loss may require a much higher learning rate. Valid options are:
For numeric class attributes, the squared, Huber or epsilon-insensitve loss function must be used. Epsilon-insensitive and Huber loss may require a much higher learning rate. Valid options are:
-F Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression), 2 = squared loss (regression), 3 = epsilon insensitive loss (regression), 4 = Huber loss (regression). (default = 0)
-L The learning rate. If normalization is turned off (as it is automatically for streaming data), then the default learning rate will need to be reduced (try 0.0001). (default = 0.01).
-R <double> The lambda regularization constant (default = 0.0001)
-E <integer> The number of epochs to perform (batch learning only, default = 500)
-C <double> The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
-N Don't normalize the data
-M Don't replace missing values
-S <num> Random number seed. (default 1)
-output-debug-info If set, classifier is run in debug mode and may output additional info to the console
-do-not-check-capabilities If set, classifier capabilities are not checked before classifier is built (use with caution).
- Version:
- $Revision: 15519 $
- Author:
- Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz), Mark Hall (mhall{[at]}pentaho{[dot]}com)
- See Also:
-
Field Summary
Modifier and TypeFieldDescriptionstatic final int
The epsilon insensitive loss functionstatic final int
the hinge loss function.static final int
The Huber loss functionstatic final int
the log loss function.static final int
the squared loss function.static final Tag[]
Loss functions to choose fromFields inherited from class weka.classifiers.AbstractClassifier
BATCH_SIZE_DEFAULT, NUM_DECIMAL_PLACES_DEFAULT
-
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptionAggregate an object with this onevoid
buildClassifier
(Instances data) Method for building the classifier.double[]
Computes the distribution for a given instanceReturns the tip text for this propertyReturns the tip text for this propertyReturns the tip text for this propertyReturns the tip text for this propertyvoid
Call to complete the aggregation process.Returns default capabilities of the classifier.boolean
Get whether normalization has been turned off.boolean
Get whether global replacement of missing values has been disabled.int
Get current number of epochsdouble
Get the epsilon threshold on the error for epsilon insensitive and Huber loss functionsdouble
Get the current value of lambdadouble
Get the learning rate.Get the current loss function.String[]
Gets the current settings of the classifier.Returns the revision string.double[]
Returns a string describing classifierReturns the tip text for this propertyReturns the tip text for this propertyReturns an enumeration describing the available options.Returns the tip text for this propertystatic void
Main method for testing this class.void
reset()
Reset the classifier.void
setDontNormalize
(boolean m) Turn normalization off/on.void
setDontReplaceMissing
(boolean m) Turn global replacement of missing values off/on.void
setEpochs
(int e) Set the number of epochs to usevoid
setEpsilon
(double e) Set the epsilon threshold on the error for epsilon insensitive and Huber loss functionsvoid
setLambda
(double lambda) Set the value of lambda to usevoid
setLearningRate
(double lr) Set the learning rate.void
setLossFunction
(SelectedTag function) Set the loss function to use.void
setOptions
(String[] options) Parses a given list of options.toString()
Prints out the classifier.void
updateClassifier
(Instance instance) Updates the classifier with the given instance.Methods inherited from class weka.classifiers.RandomizableClassifier
getSeed, seedTipText, setSeed
Methods inherited from class weka.classifiers.AbstractClassifier
batchSizeTipText, classifyInstance, debugTipText, distributionsForInstances, doNotCheckCapabilitiesTipText, forName, getBatchSize, getDebug, getDoNotCheckCapabilities, getNumDecimalPlaces, implementsMoreEfficientBatchPrediction, makeCopies, makeCopy, numDecimalPlacesTipText, postExecution, preExecution, run, runClassifier, setBatchSize, setDebug, setDoNotCheckCapabilities, setNumDecimalPlaces
-
Field Details
-
HINGE
public static final int HINGEthe hinge loss function.- See Also:
-
LOGLOSS
public static final int LOGLOSSthe log loss function.- See Also:
-
SQUAREDLOSS
public static final int SQUAREDLOSSthe squared loss function.- See Also:
-
EPSILON_INSENSITIVE
public static final int EPSILON_INSENSITIVEThe epsilon insensitive loss function- See Also:
-
HUBER
public static final int HUBERThe Huber loss function- See Also:
-
TAGS_SELECTION
Loss functions to choose from
-
-
Constructor Details
-
SGD
public SGD()
-
-
Method Details
-
getCapabilities
Returns default capabilities of the classifier.- Specified by:
getCapabilities
in interfaceCapabilitiesHandler
- Specified by:
getCapabilities
in interfaceClassifier
- Overrides:
getCapabilities
in classAbstractClassifier
- Returns:
- the capabilities of this classifier
- See Also:
-
epsilonTipText
Returns the tip text for this property- Returns:
- tip text for this property suitable for displaying in the explorer/experimenter gui
-
setEpsilon
public void setEpsilon(double e) Set the epsilon threshold on the error for epsilon insensitive and Huber loss functions- Parameters:
e
- the value of epsilon to use
-
getEpsilon
public double getEpsilon()Get the epsilon threshold on the error for epsilon insensitive and Huber loss functions- Returns:
- the value of epsilon to use
-
lambdaTipText
Returns the tip text for this property- Returns:
- tip text for this property suitable for displaying in the explorer/experimenter gui
-
setLambda
public void setLambda(double lambda) Set the value of lambda to use- Parameters:
lambda
- the value of lambda to use
-
getLambda
public double getLambda()Get the current value of lambda- Returns:
- the current value of lambda
-
setLearningRate
public void setLearningRate(double lr) Set the learning rate.- Parameters:
lr
- the learning rate to use.
-
getLearningRate
public double getLearningRate()Get the learning rate.- Returns:
- the learning rate
-
learningRateTipText
Returns the tip text for this property- Returns:
- tip text for this property suitable for displaying in the explorer/experimenter gui
-
epochsTipText
Returns the tip text for this property- Returns:
- tip text for this property suitable for displaying in the explorer/experimenter gui
-
setEpochs
public void setEpochs(int e) Set the number of epochs to use- Parameters:
e
- the number of epochs to use
-
getEpochs
public int getEpochs()Get current number of epochs- Returns:
- the current number of epochs
-
setDontNormalize
public void setDontNormalize(boolean m) Turn normalization off/on.- Parameters:
m
- true if normalization is to be disabled.
-
getDontNormalize
public boolean getDontNormalize()Get whether normalization has been turned off.- Returns:
- true if normalization has been disabled.
-
dontNormalizeTipText
Returns the tip text for this property- Returns:
- tip text for this property suitable for displaying in the explorer/experimenter gui
-
setDontReplaceMissing
public void setDontReplaceMissing(boolean m) Turn global replacement of missing values off/on. If turned off, then missing values are effectively ignored.- Parameters:
m
- true if global replacement of missing values is to be turned off.
-
getDontReplaceMissing
public boolean getDontReplaceMissing()Get whether global replacement of missing values has been disabled.- Returns:
- true if global replacement of missing values has been turned off
-
dontReplaceMissingTipText
Returns the tip text for this property- Returns:
- tip text for this property suitable for displaying in the explorer/experimenter gui
-
setLossFunction
Set the loss function to use.- Parameters:
function
- the loss function to use.
-
getLossFunction
Get the current loss function.- Returns:
- the current loss function.
-
lossFunctionTipText
Returns the tip text for this property- Returns:
- tip text for this property suitable for displaying in the explorer/experimenter gui
-
listOptions
Returns an enumeration describing the available options.- Specified by:
listOptions
in interfaceOptionHandler
- Overrides:
listOptions
in classRandomizableClassifier
- Returns:
- an enumeration of all the available options.
-
setOptions
Parses a given list of options. Valid options are:-F Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression), 2 = squared loss (regression), 3 = epsilon insensitive loss (regression), 4 = Huber loss (regression). (default = 0)
-L The learning rate. If normalization is turned off (as it is automatically for streaming data), then the default learning rate will need to be reduced (try 0.0001). (default = 0.01).
-R <double> The lambda regularization constant (default = 0.0001)
-E <integer> The number of epochs to perform (batch learning only, default = 500)
-C <double> The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
-N Don't normalize the data
-M Don't replace missing values
-S <num> Random number seed. (default 1)
-output-debug-info If set, classifier is run in debug mode and may output additional info to the console
-do-not-check-capabilities If set, classifier capabilities are not checked before classifier is built (use with caution).
- Specified by:
setOptions
in interfaceOptionHandler
- Overrides:
setOptions
in classRandomizableClassifier
- Parameters:
options
- the list of options as an array of strings- Throws:
Exception
- if an option is not supported
-
getOptions
Gets the current settings of the classifier.- Specified by:
getOptions
in interfaceOptionHandler
- Overrides:
getOptions
in classRandomizableClassifier
- Returns:
- an array of strings suitable for passing to setOptions
-
globalInfo
Returns a string describing classifier- Returns:
- a description suitable for displaying in the explorer/experimenter gui
-
reset
public void reset()Reset the classifier. -
buildClassifier
Method for building the classifier.- Specified by:
buildClassifier
in interfaceClassifier
- Parameters:
data
- the set of training instances.- Throws:
Exception
- if the classifier can't be built successfully.
-
updateClassifier
Updates the classifier with the given instance.- Specified by:
updateClassifier
in interfaceUpdateableClassifier
- Parameters:
instance
- the new training instance to include in the model- Throws:
Exception
- if the instance could not be incorporated in the model.
-
distributionForInstance
Computes the distribution for a given instance- Specified by:
distributionForInstance
in interfaceClassifier
- Overrides:
distributionForInstance
in classAbstractClassifier
- Parameters:
inst
- the instance for which distribution is computed- Returns:
- the distribution
- Throws:
Exception
- if the distribution can't be computed successfully
-
getWeights
public double[] getWeights() -
toString
Prints out the classifier. -
getRevision
Returns the revision string.- Specified by:
getRevision
in interfaceRevisionHandler
- Overrides:
getRevision
in classAbstractClassifier
- Returns:
- the revision
-
aggregate
Aggregate an object with this one- Specified by:
aggregate
in interfaceAggregateable<SGD>
- Parameters:
toAggregate
- the object to aggregate- Returns:
- the result of aggregation
- Throws:
Exception
- if the supplied object can't be aggregated for some reason
-
finalizeAggregation
Call to complete the aggregation process. Allows implementers to do any final processing based on how many objects were aggregated.- Specified by:
finalizeAggregation
in interfaceAggregateable<SGD>
- Throws:
Exception
- if the aggregation can't be finalized for some reason
-
main
Main method for testing this class.
-