Skip to content

Commit f57ef6f

Browse files
authored
[python][sklearn] respect parameters for predictions in init() and set_params() methods (#4822)
* in predict(), respect params set via `set_params()` after fit() * continue * add test * fix return name * hotfix * simplify
1 parent b31d5a4 commit f57ef6f

File tree

3 files changed

+119
-36
lines changed

3 files changed

+119
-36
lines changed

python-package/lightgbm/basic.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,16 @@ def get(cls, *args):
423423
ret |= cls.aliases.get(i, {i})
424424
return ret
425425

426+
@classmethod
427+
def get_by_alias(cls, *args):
428+
ret = set(args)
429+
for arg in args:
430+
for aliases in cls.aliases.values():
431+
if arg in aliases:
432+
ret |= aliases
433+
break
434+
return ret
435+
426436

427437
def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_value: Any) -> Dict[str, Any]:
428438
"""Get a single parameter value, accounting for aliases.

python-package/lightgbm/sklearn.py

Lines changed: 72 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""Scikit-learn wrapper interface for LightGBM."""
33
import copy
44
from inspect import signature
5-
from typing import Callable, Dict, List, Optional, Tuple, Union
5+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
66

77
import numpy as np
88

@@ -582,38 +582,49 @@ def set_params(self, **params):
582582
self._other_params[key] = value
583583
return self
584584

585-
def fit(self, X, y,
586-
sample_weight=None, init_score=None, group=None,
587-
eval_set=None, eval_names=None, eval_sample_weight=None,
588-
eval_class_weight=None, eval_init_score=None, eval_group=None,
589-
eval_metric=None, early_stopping_rounds=None,
590-
feature_name='auto', categorical_feature='auto',
591-
callbacks=None, init_model=None):
592-
"""Docstring is set after definition, using a template."""
585+
def _process_params(self, stage: str) -> Dict[str, Any]:
586+
"""Process the parameters of this estimator based on its type, parameter aliases, etc.
587+
588+
Parameters
589+
----------
590+
stage : str
591+
Name of the stage (can be ``fit`` or ``predict``) this method is called from.
592+
593+
Returns
594+
-------
595+
processed_params : dict
596+
Processed parameter names mapped to their values.
597+
"""
598+
assert stage in {"fit", "predict"}
593599
params = self.get_params()
594600

595601
params.pop('objective', None)
596602
for alias in _ConfigAliases.get('objective'):
597603
if alias in params:
598-
self._objective = params.pop(alias)
604+
obj = params.pop(alias)
599605
_log_warning(f"Found '{alias}' in params. Will use it instead of 'objective' argument")
600-
if self._objective is None:
601-
if isinstance(self, LGBMRegressor):
602-
self._objective = "regression"
603-
elif isinstance(self, LGBMClassifier):
604-
if self._n_classes > 2:
605-
self._objective = "multiclass"
606+
if stage == "fit":
607+
self._objective = obj
608+
if stage == "fit":
609+
if self._objective is None:
610+
if isinstance(self, LGBMRegressor):
611+
self._objective = "regression"
612+
elif isinstance(self, LGBMClassifier):
613+
if self._n_classes > 2:
614+
self._objective = "multiclass"
615+
else:
616+
self._objective = "binary"
617+
elif isinstance(self, LGBMRanker):
618+
self._objective = "lambdarank"
606619
else:
607-
self._objective = "binary"
608-
elif isinstance(self, LGBMRanker):
609-
self._objective = "lambdarank"
610-
else:
611-
raise ValueError("Unknown LGBMModel type.")
620+
raise ValueError("Unknown LGBMModel type.")
612621
if callable(self._objective):
613-
self._fobj = _ObjectiveFunctionWrapper(self._objective)
622+
if stage == "fit":
623+
self._fobj = _ObjectiveFunctionWrapper(self._objective)
614624
params['objective'] = 'None' # objective = nullptr for unknown objective
615625
else:
616-
self._fobj = None
626+
if stage == "fit":
627+
self._fobj = None
617628
params['objective'] = self._objective
618629

619630
params.pop('importance_type', None)
@@ -634,16 +645,6 @@ def fit(self, X, y,
634645
eval_at = params.pop(alias)
635646
params['eval_at'] = eval_at
636647

637-
# Do not modify original args in fit function
638-
# Refer to https://github.com/microsoft/LightGBM/pull/2619
639-
eval_metric_list = copy.deepcopy(eval_metric)
640-
if not isinstance(eval_metric_list, list):
641-
eval_metric_list = [eval_metric_list]
642-
643-
# Separate built-in from callable evaluation metrics
644-
eval_metrics_callable = [_EvalFunctionWrapper(f) for f in eval_metric_list if callable(f)]
645-
eval_metrics_builtin = [m for m in eval_metric_list if isinstance(m, str)]
646-
647648
# register default metric for consistency with callable eval_metric case
648649
original_metric = self._objective if isinstance(self._objective, str) else None
649650
if original_metric is None:
@@ -658,6 +659,28 @@ def fit(self, X, y,
658659
# overwrite default metric by explicitly set metric
659660
params = _choose_param_value("metric", params, original_metric)
660661

662+
return params
663+
664+
def fit(self, X, y,
665+
sample_weight=None, init_score=None, group=None,
666+
eval_set=None, eval_names=None, eval_sample_weight=None,
667+
eval_class_weight=None, eval_init_score=None, eval_group=None,
668+
eval_metric=None, early_stopping_rounds=None,
669+
feature_name='auto', categorical_feature='auto',
670+
callbacks=None, init_model=None):
671+
"""Docstring is set after definition, using a template."""
672+
params = self._process_params(stage="fit")
673+
674+
# Do not modify original args in fit function
675+
# Refer to https://github.com/microsoft/LightGBM/pull/2619
676+
eval_metric_list = copy.deepcopy(eval_metric)
677+
if not isinstance(eval_metric_list, list):
678+
eval_metric_list = [eval_metric_list]
679+
680+
# Separate built-in from callable evaluation metrics
681+
eval_metrics_callable = [_EvalFunctionWrapper(f) for f in eval_metric_list if callable(f)]
682+
eval_metrics_builtin = [m for m in eval_metric_list if isinstance(m, str)]
683+
661684
# concatenate metric from params (or default if not provided in params) and eval_metric
662685
params['metric'] = [params['metric']] if isinstance(params['metric'], (str, type(None))) else params['metric']
663686
params['metric'] = [e for e in eval_metrics_builtin if e not in params['metric']] + params['metric']
@@ -799,8 +822,23 @@ def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
799822
raise ValueError("Number of features of the model must "
800823
f"match the input. Model n_features_ is {self._n_features} and "
801824
f"input n_features is {n_features}")
825+
# retrive original params that possibly can be used in both training and prediction
826+
# and then overwrite them (considering aliases) with params that were passed directly in prediction
827+
predict_params = self._process_params(stage="predict")
828+
for alias in _ConfigAliases.get_by_alias(
829+
"data",
830+
"X",
831+
"raw_score",
832+
"start_iteration",
833+
"num_iteration",
834+
"pred_leaf",
835+
"pred_contrib",
836+
*kwargs.keys()
837+
):
838+
predict_params.pop(alias, None)
839+
predict_params.update(kwargs)
802840
return self._Booster.predict(X, raw_score=raw_score, start_iteration=start_iteration, num_iteration=num_iteration,
803-
pred_leaf=pred_leaf, pred_contrib=pred_contrib, **kwargs)
841+
pred_leaf=pred_leaf, pred_contrib=pred_contrib, **predict_params)
804842

805843
predict.__doc__ = _lgbmmodel_doc_predict.format(
806844
description="Return the predicted value for each sample.",

tests/python_package_test/test_sklearn.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,8 +612,8 @@ def test_pandas_sparse():
612612
def test_predict():
613613
# With default params
614614
iris = load_iris(return_X_y=False)
615-
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target,
616-
test_size=0.2, random_state=42)
615+
X_train, X_test, y_train, _ = train_test_split(iris.data, iris.target,
616+
test_size=0.2, random_state=42)
617617

618618
gbm = lgb.train({'objective': 'multiclass',
619619
'num_class': 3,
@@ -689,6 +689,41 @@ def test_predict():
689689
np.testing.assert_allclose(res_engine, res_sklearn_params)
690690

691691

692+
def test_predict_with_params_from_init():
693+
X, y = load_iris(return_X_y=True)
694+
X_train, X_test, y_train, _ = train_test_split(X, y, test_size=0.2, random_state=42)
695+
696+
predict_params = {
697+
'pred_early_stop': True,
698+
'pred_early_stop_margin': 1.0
699+
}
700+
701+
y_preds_no_params = lgb.LGBMClassifier(verbose=-1).fit(X_train, y_train).predict(
702+
X_test, raw_score=True)
703+
704+
y_preds_params_in_predict = lgb.LGBMClassifier(verbose=-1).fit(X_train, y_train).predict(
705+
X_test, raw_score=True, **predict_params)
706+
with pytest.raises(AssertionError):
707+
np.testing.assert_allclose(y_preds_no_params, y_preds_params_in_predict)
708+
709+
y_preds_params_in_set_params_before_fit = lgb.LGBMClassifier(verbose=-1).set_params(
710+
**predict_params).fit(X_train, y_train).predict(X_test, raw_score=True)
711+
np.testing.assert_allclose(y_preds_params_in_predict, y_preds_params_in_set_params_before_fit)
712+
713+
y_preds_params_in_set_params_after_fit = lgb.LGBMClassifier(verbose=-1).fit(X_train, y_train).set_params(
714+
**predict_params).predict(X_test, raw_score=True)
715+
np.testing.assert_allclose(y_preds_params_in_predict, y_preds_params_in_set_params_after_fit)
716+
717+
y_preds_params_in_init = lgb.LGBMClassifier(verbose=-1, **predict_params).fit(X_train, y_train).predict(
718+
X_test, raw_score=True)
719+
np.testing.assert_allclose(y_preds_params_in_predict, y_preds_params_in_init)
720+
721+
# test that params passed in predict have higher priority
722+
y_preds_params_overwritten = lgb.LGBMClassifier(verbose=-1, **predict_params).fit(X_train, y_train).predict(
723+
X_test, raw_score=True, pred_early_stop=False)
724+
np.testing.assert_allclose(y_preds_no_params, y_preds_params_overwritten)
725+
726+
692727
def test_evaluate_train_set():
693728
X, y = load_boston(return_X_y=True)
694729
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

0 commit comments

Comments
 (0)