public class LogisticRegressionModel extends ProbabilisticClassificationModel<Vector,LogisticRegressionModel> implements MLWritable
LogisticRegression.| Modifier and Type | Method and Description |
|---|---|
void |
checkThresholdConsistency()
If
threshold and thresholds are both set, ensures they are consistent. |
Vector |
coefficients() |
LogisticRegressionModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
Param<java.lang.String> |
featuresCol()
Param for features column name.
|
java.lang.String |
getFeaturesCol() |
java.lang.String |
getLabelCol() |
java.lang.String |
getPredictionCol() |
java.lang.String |
getRawPredictionCol() |
double |
getThreshold()
Get threshold for binary classification.
|
double[] |
getThresholds()
Get thresholds for binary or multiclass classification.
|
boolean |
hasSummary()
Indicates whether a training summary exists for this model instance.
|
double |
intercept() |
Param<java.lang.String> |
labelCol()
Param for label column name.
|
static LogisticRegressionModel |
load(java.lang.String path) |
int |
numClasses()
Number of classes (values which the label can take).
|
int |
numFeatures()
Returns the number of features the model was trained on.
|
protected double |
predict(Vector features)
Predict label for the given feature vector.
|
Param<java.lang.String> |
predictionCol()
Param for prediction column name.
|
protected Vector |
predictRaw(Vector features)
Raw prediction for each possible label.
|
protected double |
probability2prediction(Vector probability)
Given a vector of class conditional probabilities, select the predicted label.
|
protected double |
raw2prediction(Vector rawPrediction)
Given a vector of raw predictions, select the predicted label.
|
protected Vector |
raw2probabilityInPlace(Vector rawPrediction)
Estimate the probability of each class given the raw prediction,
doing the computation in-place.
|
Param<java.lang.String> |
rawPredictionCol()
Param for raw prediction (a.k.a.
|
static MLReader<LogisticRegressionModel> |
read() |
LogisticRegressionModel |
setThreshold(double value)
Set threshold in binary classification, in range [0, 1].
|
LogisticRegressionModel |
setThresholds(double[] value)
Set thresholds in multiclass (or binary) classification to adjust the probability of
predicting each class.
|
LogisticRegressionTrainingSummary |
summary()
Gets summary of model on training set.
|
java.lang.String |
uid()
An immutable unique ID for the object and its derivatives.
|
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
void |
validateParams() |
Vector |
weights() |
MLWriter |
write()
Returns a
MLWriter instance for this ML instance. |
normalizeToProbabilitiesInPlace, predictProbability, raw2probability, setProbabilityCol, transformsetRawPredictionColfeaturesDataType, setFeaturesCol, setPredictionCol, transformImpl, transformSchematransform, transform, transformtransformSchemaclone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitclear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwntoStringsaveinitializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarningpublic static MLReader<LogisticRegressionModel> read()
public static LogisticRegressionModel load(java.lang.String path)
public java.lang.String uid()
Identifiableuid in interface Identifiablepublic Vector coefficients()
public double intercept()
public Vector weights()
public LogisticRegressionModel setThreshold(double value)
If the estimated probability of class label 1 is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 more often; a low threshold encourages the model to predict 1 more often.
Note: Calling this with threshold p is equivalent to calling setThresholds(Array(1-p, p)).
When setThreshold() is called, any user-set value for thresholds will be cleared.
If both threshold and thresholds are set in a ParamMap, then they must be
equivalent.
Default is 0.5.
value - (undocumented)public double getThreshold()
If threshold is set, returns that value.
Otherwise, if thresholds is set with length 2 (i.e., binary classification),
this returns the equivalent threshold:
1 / (1 + thresholds(0) / thresholds(1)).
Otherwise, returns {@link threshold} default value.
@group getParam
@throws IllegalArgumentException if {@link thresholds} is set to an array of length other than 2.public LogisticRegressionModel setThresholds(double[] value)
Note: When setThresholds() is called, any user-set value for threshold will be cleared.
If both threshold and thresholds are set in a ParamMap, then they must be
equivalent.
setThresholds in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>value - (undocumented)public double[] getThresholds()
If thresholds is set, return its value.
Otherwise, if threshold is set, return the equivalent thresholds for binary
classification: (1-threshold, threshold).
If neither are set, throw an exception.
public int numFeatures()
PredictionModelnumFeatures in class PredictionModel<Vector,LogisticRegressionModel>public int numClasses()
ClassificationModelnumClasses in class ClassificationModel<Vector,LogisticRegressionModel>public LogisticRegressionTrainingSummary summary()
trainingSummary == None.public boolean hasSummary()
protected double predict(Vector features)
thresholds.predict in class ClassificationModel<Vector,LogisticRegressionModel>features - (undocumented)protected Vector raw2probabilityInPlace(Vector rawPrediction)
ProbabilisticClassificationModel
This internal method is used to implement transform() and output probabilityCol.
raw2probabilityInPlace in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>rawPrediction - (undocumented)protected Vector predictRaw(Vector features)
ClassificationModeltransform() and output rawPredictionCol.
predictRaw in class ClassificationModel<Vector,LogisticRegressionModel>features - (undocumented)public LogisticRegressionModel copy(ParamMap extra)
Paramscopy in interface Paramscopy in class Model<LogisticRegressionModel>extra - (undocumented)defaultCopy()protected double raw2prediction(Vector rawPrediction)
ClassificationModelraw2prediction in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>rawPrediction - (undocumented)protected double probability2prediction(Vector probability)
ProbabilisticClassificationModelprobability2prediction in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>probability - (undocumented)public MLWriter write()
MLWriter instance for this ML instance.
For LogisticRegressionModel, this does NOT currently save the training summary.
An option to save summary may be added in the future.
This also does not save the parent currently.
write in interface MLWritablepublic void checkThresholdConsistency()
threshold and thresholds are both set, ensures they are consistent.java.lang.IllegalArgumentException - if threshold and thresholds are not equivalentpublic void validateParams()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public Param<java.lang.String> rawPredictionCol()
public java.lang.String getRawPredictionCol()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema - input schemafitting - whether this is in fittingfeaturesDataType - SQL DataType for FeaturesType.
E.g., VectorUDT for vector features.public Param<java.lang.String> labelCol()
public java.lang.String getLabelCol()
public Param<java.lang.String> featuresCol()
public java.lang.String getFeaturesCol()
public Param<java.lang.String> predictionCol()
public java.lang.String getPredictionCol()