public class GradientDescent extends java.lang.Object implements Optimizer
Modifier and Type | Method and Description |
---|---|
protected static void |
initializeLogIfNecessary(boolean isInterpreter) |
protected static boolean |
isTraceEnabled() |
protected static org.slf4j.Logger |
log() |
protected static void |
logDebug(scala.Function0<java.lang.String> msg) |
protected static void |
logDebug(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logError(scala.Function0<java.lang.String> msg) |
protected static void |
logError(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static java.lang.String |
logName() |
protected static void |
logTrace(scala.Function0<java.lang.String> msg) |
protected static void |
logTrace(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
Vector |
optimize(RDD<scala.Tuple2<java.lang.Object,Vector>> data,
Vector initialWeights)
:: DeveloperApi ::
Runs gradient descent on the given training data.
|
static scala.Tuple2<Vector,double[]> |
runMiniBatchSGD(RDD<scala.Tuple2<java.lang.Object,Vector>> data,
Gradient gradient,
Updater updater,
double stepSize,
int numIterations,
double regParam,
double miniBatchFraction,
Vector initialWeights)
Alias of
runMiniBatchSGD with convergenceTol set to default value of 0.001. |
static scala.Tuple2<Vector,double[]> |
runMiniBatchSGD(RDD<scala.Tuple2<java.lang.Object,Vector>> data,
Gradient gradient,
Updater updater,
double stepSize,
int numIterations,
double regParam,
double miniBatchFraction,
Vector initialWeights,
double convergenceTol)
Run stochastic gradient descent (SGD) in parallel using mini batches.
|
GradientDescent |
setConvergenceTol(double tolerance)
Set the convergence tolerance.
|
GradientDescent |
setGradient(Gradient gradient)
Set the gradient function (of the loss function of one single data example)
to be used for SGD.
|
GradientDescent |
setMiniBatchFraction(double fraction)
:: Experimental ::
Set fraction of data to be used for each SGD iteration.
|
GradientDescent |
setNumIterations(int iters)
Set the number of iterations for SGD.
|
GradientDescent |
setRegParam(double regParam)
Set the regularization parameter.
|
GradientDescent |
setStepSize(double step)
Set the initial step size of SGD for the first step.
|
GradientDescent |
setUpdater(Updater updater)
Set the updater function to actually perform a gradient step in a given direction.
|
public static scala.Tuple2<Vector,double[]> runMiniBatchSGD(RDD<scala.Tuple2<java.lang.Object,Vector>> data, Gradient gradient, Updater updater, double stepSize, int numIterations, double regParam, double miniBatchFraction, Vector initialWeights, double convergenceTol)
data
- Input data for SGD. RDD of the set of data examples, each of
the form (label, [feature values]).gradient
- Gradient object (used to compute the gradient of the loss function of
one single data example)updater
- Updater function to actually perform a gradient step in a given direction.stepSize
- initial step size for the first stepnumIterations
- number of iterations that SGD should be run.regParam
- regularization parameterminiBatchFraction
- fraction of the input data set that should be used for
one iteration of SGD. Default value 1.0.convergenceTol
- Minibatch iteration will end before numIterations if the relative
difference between the current weight and the previous weight is less
than this value. In measuring convergence, L2 norm is calculated.
Default value 0.001. Must be between 0.0 and 1.0 inclusively.initialWeights
- (undocumented)public static scala.Tuple2<Vector,double[]> runMiniBatchSGD(RDD<scala.Tuple2<java.lang.Object,Vector>> data, Gradient gradient, Updater updater, double stepSize, int numIterations, double regParam, double miniBatchFraction, Vector initialWeights)
runMiniBatchSGD
with convergenceTol set to default value of 0.001.data
- (undocumented)gradient
- (undocumented)updater
- (undocumented)stepSize
- (undocumented)numIterations
- (undocumented)regParam
- (undocumented)miniBatchFraction
- (undocumented)initialWeights
- (undocumented)protected static java.lang.String logName()
protected static org.slf4j.Logger log()
protected static void logInfo(scala.Function0<java.lang.String> msg)
protected static void logDebug(scala.Function0<java.lang.String> msg)
protected static void logTrace(scala.Function0<java.lang.String> msg)
protected static void logWarning(scala.Function0<java.lang.String> msg)
protected static void logError(scala.Function0<java.lang.String> msg)
protected static void logInfo(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logDebug(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logTrace(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logWarning(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logError(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static boolean isTraceEnabled()
protected static void initializeLogIfNecessary(boolean isInterpreter)
public GradientDescent setStepSize(double step)
step
- (undocumented)public GradientDescent setMiniBatchFraction(double fraction)
fraction
- (undocumented)public GradientDescent setNumIterations(int iters)
iters
- (undocumented)public GradientDescent setRegParam(double regParam)
regParam
- (undocumented)public GradientDescent setConvergenceTol(double tolerance)
- If the norm of the new solution vector is >1, the diff of solution vectors is compared to relative tolerance which means normalizing by the norm of the new solution vector. - If the norm of the new solution vector is <=1, the diff of solution vectors is compared to absolute tolerance which is not normalizing.
Must be between 0.0 and 1.0 inclusively.
tolerance
- (undocumented)public GradientDescent setGradient(Gradient gradient)
gradient
- (undocumented)public GradientDescent setUpdater(Updater updater)
updater
- (undocumented)