public class LogisticRegressionModel extends ProbabilisticClassificationModel<Vector,LogisticRegressionModel> implements MLWritable
LogisticRegression
.Modifier and Type | Method and Description |
---|---|
protected static <T> T |
$(Param<T> param) |
protected static void |
checkThresholdConsistency() |
void |
checkThresholdConsistency()
If
threshold and thresholds are both set, ensures they are consistent. |
static Params |
clear(Param<?> param) |
Vector |
coefficients() |
LogisticRegressionModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
protected static <T extends Params> |
copyValues(T to,
ParamMap extra) |
protected static <T extends Params> |
copyValues$default$2() |
protected static <T extends Params> |
defaultCopy(ParamMap extra) |
static DoubleParam |
elasticNetParam() |
LogisticRegressionSummary |
evaluate(Dataset<?> dataset)
Evaluates the model on a test dataset.
|
static java.lang.String |
explainParam(Param<?> param) |
static java.lang.String |
explainParams() |
static ParamMap |
extractParamMap() |
static ParamMap |
extractParamMap(ParamMap extra) |
static Param<java.lang.String> |
featuresCol() |
Param<java.lang.String> |
featuresCol()
Param for features column name.
|
protected static DataType |
featuresDataType() |
static BooleanParam |
fitIntercept() |
static <T> scala.Option<T> |
get(Param<T> param) |
static <T> scala.Option<T> |
getDefault(Param<T> param) |
static double |
getElasticNetParam() |
static java.lang.String |
getFeaturesCol() |
java.lang.String |
getFeaturesCol() |
static boolean |
getFitIntercept() |
static java.lang.String |
getLabelCol() |
java.lang.String |
getLabelCol() |
static int |
getMaxIter() |
static <T> T |
getOrDefault(Param<T> param) |
static Param<java.lang.Object> |
getParam(java.lang.String paramName) |
static java.lang.String |
getPredictionCol() |
java.lang.String |
getPredictionCol() |
static java.lang.String |
getProbabilityCol() |
static java.lang.String |
getRawPredictionCol() |
java.lang.String |
getRawPredictionCol() |
static double |
getRegParam() |
static boolean |
getStandardization() |
double |
getThreshold()
Get threshold for binary classification.
|
double[] |
getThresholds()
Get thresholds for binary or multiclass classification.
|
static double |
getTol() |
static java.lang.String |
getWeightCol() |
static <T> boolean |
hasDefault(Param<T> param) |
static boolean |
hasParam(java.lang.String paramName) |
static boolean |
hasParent() |
boolean |
hasSummary()
Indicates whether a training summary exists for this model instance.
|
protected static void |
initializeLogIfNecessary(boolean isInterpreter) |
double |
intercept() |
static boolean |
isDefined(Param<?> param) |
static boolean |
isSet(Param<?> param) |
protected static boolean |
isTraceEnabled() |
static Param<java.lang.String> |
labelCol() |
Param<java.lang.String> |
labelCol()
Param for label column name.
|
static LogisticRegressionModel |
load(java.lang.String path) |
protected static org.slf4j.Logger |
log() |
protected static void |
logDebug(scala.Function0<java.lang.String> msg) |
protected static void |
logDebug(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logError(scala.Function0<java.lang.String> msg) |
protected static void |
logError(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static java.lang.String |
logName() |
protected static void |
logTrace(scala.Function0<java.lang.String> msg) |
protected static void |
logTrace(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
static IntParam |
maxIter() |
int |
numClasses() |
int |
numFeatures() |
static Param<?>[] |
params() |
static void |
parent_$eq(Estimator<M> x$1) |
static Estimator<M> |
parent() |
protected double |
predict(Vector features)
Predict label for the given feature vector.
|
static Param<java.lang.String> |
predictionCol() |
Param<java.lang.String> |
predictionCol()
Param for prediction column name.
|
protected static Vector |
predictProbability(FeaturesType features) |
protected Vector |
predictRaw(Vector features) |
protected double |
probability2prediction(Vector probability)
Given a vector of class conditional probabilities, select the predicted label.
|
static Param<java.lang.String> |
probabilityCol() |
protected double |
raw2prediction(Vector rawPrediction)
Given a vector of raw predictions, select the predicted label.
|
protected static Vector |
raw2probability(Vector rawPrediction) |
protected Vector |
raw2probabilityInPlace(Vector rawPrediction)
Estimate the probability of each class given the raw prediction,
doing the computation in-place.
|
static Param<java.lang.String> |
rawPredictionCol() |
Param<java.lang.String> |
rawPredictionCol()
Param for raw prediction (a.k.a.
|
static MLReader<LogisticRegressionModel> |
read() |
static DoubleParam |
regParam() |
static void |
save(java.lang.String path) |
static <T> Params |
set(Param<T> param,
T value) |
protected static Params |
set(ParamPair<?> paramPair) |
protected static Params |
set(java.lang.String param,
java.lang.Object value) |
protected static <T> Params |
setDefault(Param<T> param,
T value) |
protected static Params |
setDefault(scala.collection.Seq<ParamPair<?>> paramPairs) |
static M |
setFeaturesCol(java.lang.String value) |
static M |
setParent(Estimator<M> parent) |
static M |
setPredictionCol(java.lang.String value) |
static M |
setProbabilityCol(java.lang.String value) |
static M |
setRawPredictionCol(java.lang.String value) |
LogisticRegressionModel |
setThreshold(double value)
Set threshold in binary classification, in range [0, 1].
|
LogisticRegressionModel |
setThresholds(double[] value)
Set thresholds in multiclass (or binary) classification to adjust the probability of
predicting each class.
|
static BooleanParam |
standardization() |
LogisticRegressionTrainingSummary |
summary()
Gets summary of model on training set.
|
static DoubleParam |
threshold() |
static DoubleArrayParam |
thresholds() |
static DoubleParam |
tol() |
static java.lang.String |
toString() |
static Dataset<Row> |
transform(Dataset<?> dataset) |
static Dataset<Row> |
transform(Dataset<?> dataset,
ParamMap paramMap) |
static Dataset<Row> |
transform(Dataset<?> dataset,
ParamPair<?> firstParamPair,
ParamPair<?>... otherParamPairs) |
static Dataset<Row> |
transform(Dataset<?> dataset,
ParamPair<?> firstParamPair,
scala.collection.Seq<ParamPair<?>> otherParamPairs) |
protected static Dataset<Row> |
transformImpl(Dataset<?> dataset) |
static StructType |
transformSchema(StructType schema) |
protected static StructType |
transformSchema(StructType schema,
boolean logging) |
java.lang.String |
uid()
An immutable unique ID for the object and its derivatives.
|
protected static StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
static void |
validateParams() |
void |
validateParams()
Validates parameter values stored internally.
|
static Param<java.lang.String> |
weightCol() |
MLWriter |
write()
Returns a
MLWriter instance for this ML instance. |
normalizeToProbabilitiesInPlace, predictProbability, raw2probability, setProbabilityCol, transform
setRawPredictionCol
featuresDataType, setFeaturesCol, setPredictionCol, transformImpl, transformSchema
transform, transform, transform
transformSchema
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
toString
save
public static MLReader<LogisticRegressionModel> read()
public static LogisticRegressionModel load(java.lang.String path)
public static java.lang.String toString()
public static Param<?>[] params()
public static java.lang.String explainParam(Param<?> param)
public static java.lang.String explainParams()
public static final boolean isSet(Param<?> param)
public static final boolean isDefined(Param<?> param)
public static boolean hasParam(java.lang.String paramName)
public static Param<java.lang.Object> getParam(java.lang.String paramName)
protected static final Params set(java.lang.String param, java.lang.Object value)
public static final <T> scala.Option<T> get(Param<T> param)
public static final <T> T getOrDefault(Param<T> param)
protected static final <T> T $(Param<T> param)
public static final <T> scala.Option<T> getDefault(Param<T> param)
public static final <T> boolean hasDefault(Param<T> param)
public static final ParamMap extractParamMap()
protected static java.lang.String logName()
protected static org.slf4j.Logger log()
protected static void logInfo(scala.Function0<java.lang.String> msg)
protected static void logDebug(scala.Function0<java.lang.String> msg)
protected static void logTrace(scala.Function0<java.lang.String> msg)
protected static void logWarning(scala.Function0<java.lang.String> msg)
protected static void logError(scala.Function0<java.lang.String> msg)
protected static void logInfo(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logDebug(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logTrace(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logWarning(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logError(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static boolean isTraceEnabled()
protected static void initializeLogIfNecessary(boolean isInterpreter)
protected static StructType transformSchema(StructType schema, boolean logging)
public static Dataset<Row> transform(Dataset<?> dataset, ParamPair<?> firstParamPair, scala.collection.Seq<ParamPair<?>> otherParamPairs)
public static Dataset<Row> transform(Dataset<?> dataset, ParamPair<?> firstParamPair, ParamPair<?>... otherParamPairs)
public static Estimator<M> parent()
public static void parent_$eq(Estimator<M> x$1)
public static M setParent(Estimator<M> parent)
public static boolean hasParent()
public static final Param<java.lang.String> labelCol()
public static final java.lang.String getLabelCol()
public static final Param<java.lang.String> featuresCol()
public static final java.lang.String getFeaturesCol()
public static final Param<java.lang.String> predictionCol()
public static final java.lang.String getPredictionCol()
public static M setFeaturesCol(java.lang.String value)
public static M setPredictionCol(java.lang.String value)
protected static DataType featuresDataType()
public static StructType transformSchema(StructType schema)
public static final Param<java.lang.String> rawPredictionCol()
public static final java.lang.String getRawPredictionCol()
public static M setRawPredictionCol(java.lang.String value)
public static final Param<java.lang.String> probabilityCol()
public static final java.lang.String getProbabilityCol()
public static final DoubleArrayParam thresholds()
protected static StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public static M setProbabilityCol(java.lang.String value)
protected static Vector predictProbability(FeaturesType features)
public static final DoubleParam regParam()
public static final double getRegParam()
public static final DoubleParam elasticNetParam()
public static final double getElasticNetParam()
public static final IntParam maxIter()
public static final int getMaxIter()
public static final BooleanParam fitIntercept()
public static final boolean getFitIntercept()
public static final DoubleParam tol()
public static final double getTol()
public static final BooleanParam standardization()
public static final boolean getStandardization()
public static final Param<java.lang.String> weightCol()
public static final java.lang.String getWeightCol()
public static final DoubleParam threshold()
protected static void checkThresholdConsistency()
public static void validateParams()
public static void save(java.lang.String path) throws java.io.IOException
java.io.IOException
public java.lang.String uid()
Identifiable
uid
in interface Identifiable
uid
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
public Vector coefficients()
public double intercept()
public LogisticRegressionModel setThreshold(double value)
If the estimated probability of class label 1 is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 more often; a low threshold encourages the model to predict 1 more often.
Note: Calling this with threshold p is equivalent to calling setThresholds(Array(1-p, p))
.
When setThreshold()
is called, any user-set value for thresholds
will be cleared.
If both threshold
and thresholds
are set in a ParamMap, then they must be
equivalent.
Default is 0.5.
value
- (undocumented)public double getThreshold()
If threshold
is set, returns that value.
Otherwise, if thresholds
is set with length 2 (i.e., binary classification),
this returns the equivalent threshold:
1 / (1 + thresholds(0) / thresholds(1))
.
Otherwise, returns {@link threshold} default value.
@group getParam
@throws IllegalArgumentException if {@link thresholds} is set to an array of length other than 2.public LogisticRegressionModel setThresholds(double[] value)
Note: When setThresholds()
is called, any user-set value for threshold
will be cleared.
If both threshold
and thresholds
are set in a ParamMap, then they must be
equivalent.
setThresholds
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
value
- (undocumented)public double[] getThresholds()
If thresholds
is set, return its value.
Otherwise, if threshold
is set, return the equivalent thresholds for binary
classification: (1-threshold, threshold).
If neither are set, throw an exception.
getThresholds
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
public int numFeatures()
numFeatures
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
public int numClasses()
numClasses
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
public LogisticRegressionTrainingSummary summary()
trainingSummary == None
.public boolean hasSummary()
public LogisticRegressionSummary evaluate(Dataset<?> dataset)
dataset
- Test dataset to evaluate model on.protected double predict(Vector features)
thresholds
.predict
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
features
- (undocumented)protected Vector raw2probabilityInPlace(Vector rawPrediction)
ProbabilisticClassificationModel
This internal method is used to implement transform()
and output probabilityCol
.
raw2probabilityInPlace
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
rawPrediction
- (undocumented)protected Vector predictRaw(Vector features)
predictRaw
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
public LogisticRegressionModel copy(ParamMap extra)
Params
copy
in interface Params
copy
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
extra
- (undocumented)defaultCopy()
protected double raw2prediction(Vector rawPrediction)
ClassificationModel
raw2prediction
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
rawPrediction
- (undocumented)protected double probability2prediction(Vector probability)
ProbabilisticClassificationModel
probability2prediction
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
probability
- (undocumented)public MLWriter write()
MLWriter
instance for this ML instance.
For LogisticRegressionModel
, this does NOT currently save the training summary
.
An option to save summary
may be added in the future.
This also does not save the parent
currently.
write
in interface MLWritable
public void checkThresholdConsistency()
threshold
and thresholds
are both set, ensures they are consistent.java.lang.IllegalArgumentException
- if threshold
and thresholds
are not equivalentpublic void validateParams()
Params
This only needs to check for interactions between parameters.
Parameter value checks which do not depend on other parameters are handled by
Param.validate()
. This method does not handle input/output column parameters;
those are checked during schema validation.
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public Param<java.lang.String> rawPredictionCol()
public java.lang.String getRawPredictionCol()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema
- input schemafitting
- whether this is in fittingfeaturesDataType
- SQL DataType for FeaturesType.
E.g., VectorUDT
for vector features.public Param<java.lang.String> labelCol()
public java.lang.String getLabelCol()
public Param<java.lang.String> featuresCol()
public java.lang.String getFeaturesCol()
public Param<java.lang.String> predictionCol()
public java.lang.String getPredictionCol()