public class LogisticRegressionModel extends ProbabilisticClassificationModel<Vector,LogisticRegressionModel> implements MLWritable
LogisticRegression.| Modifier and Type | Method and Description |
|---|---|
void |
checkThresholdConsistency()
If
threshold and thresholds are both set, ensures they are consistent. |
Vector |
coefficients() |
LogisticRegressionModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
Param<String> |
featuresCol()
Param for features column name.
|
String |
getFeaturesCol() |
String |
getLabelCol() |
String |
getPredictionCol() |
String |
getRawPredictionCol() |
double |
getThreshold()
Get threshold for binary classification.
|
double[] |
getThresholds()
Get thresholds for binary or multiclass classification.
|
boolean |
hasSummary()
Indicates whether a training summary exists for this model instance.
|
double |
intercept() |
Param<String> |
labelCol()
Param for label column name.
|
static LogisticRegressionModel |
load(String path) |
int |
numClasses()
Number of classes (values which the label can take).
|
int |
numFeatures()
Returns the number of features the model was trained on.
|
Param<String> |
predictionCol()
Param for prediction column name.
|
Param<String> |
rawPredictionCol()
Param for raw prediction (a.k.a.
|
static MLReader<LogisticRegressionModel> |
read() |
LogisticRegressionModel |
setThreshold(double value)
Set threshold in binary classification, in range [0, 1].
|
LogisticRegressionModel |
setThresholds(double[] value)
Set thresholds in multiclass (or binary) classification to adjust the probability of
predicting each class.
|
LogisticRegressionTrainingSummary |
summary()
Gets summary of model on training set.
|
String |
uid()
An immutable unique ID for the object and its derivatives.
|
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
void |
validateParams() |
Vector |
weights() |
MLWriter |
write()
Returns a
MLWriter instance for this ML instance. |
normalizeToProbabilitiesInPlace, setProbabilityCol, transformsetRawPredictionColsetFeaturesCol, setPredictionCol, transformSchematransform, transform, transformequals, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitclear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwntoStringsaveinitializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarningpublic static MLReader<LogisticRegressionModel> read()
public static LogisticRegressionModel load(String path)
public String uid()
Identifiableuid in interface Identifiablepublic Vector coefficients()
public double intercept()
public Vector weights()
public LogisticRegressionModel setThreshold(double value)
If the estimated probability of class label 1 is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 more often; a low threshold encourages the model to predict 1 more often.
Note: Calling this with threshold p is equivalent to calling setThresholds(Array(1-p, p)).
When setThreshold() is called, any user-set value for thresholds will be cleared.
If both threshold and thresholds are set in a ParamMap, then they must be
equivalent.
Default is 0.5.
value - (undocumented)public double getThreshold()
If threshold is set, returns that value.
Otherwise, if thresholds is set with length 2 (i.e., binary classification),
this returns the equivalent threshold:
1 / (1 + thresholds(0) / thresholds(1)).
Otherwise, returns {@link threshold} default value.
@group getParam
@throws IllegalArgumentException if {@link thresholds} is set to an array of length other than 2.public LogisticRegressionModel setThresholds(double[] value)
Note: When setThresholds() is called, any user-set value for threshold will be cleared.
If both threshold and thresholds are set in a ParamMap, then they must be
equivalent.
setThresholds in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>value - (undocumented)public double[] getThresholds()
If thresholds is set, return its value.
Otherwise, if threshold is set, return the equivalent thresholds for binary
classification: (1-threshold, threshold).
If neither are set, throw an exception.
public int numFeatures()
PredictionModelnumFeatures in class PredictionModel<Vector,LogisticRegressionModel>public int numClasses()
ClassificationModelnumClasses in class ClassificationModel<Vector,LogisticRegressionModel>public LogisticRegressionTrainingSummary summary()
trainingSummary == None.public boolean hasSummary()
public LogisticRegressionModel copy(ParamMap extra)
Paramscopy in interface Paramscopy in class Model<LogisticRegressionModel>extra - (undocumented)defaultCopy()public MLWriter write()
MLWriter instance for this ML instance.
For LogisticRegressionModel, this does NOT currently save the training summary.
An option to save summary may be added in the future.
This also does not save the parent currently.
write in interface MLWritablepublic void checkThresholdConsistency()
threshold and thresholds are both set, ensures they are consistent.IllegalArgumentException - if threshold and thresholds are not equivalentpublic void validateParams()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public Param<String> rawPredictionCol()
public String getRawPredictionCol()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema - input schemafitting - whether this is in fittingfeaturesDataType - SQL DataType for FeaturesType.
E.g., VectorUDT for vector features.public Param<String> labelCol()
public String getLabelCol()
public Param<String> featuresCol()
public String getFeaturesCol()
public Param<String> predictionCol()
public String getPredictionCol()