Skip to content

Commit 0b3d9da

Browse files
authored
[python-package] mark EarlyStopException as part of public API (#6095)
1 parent 1a6e6ff commit 0b3d9da

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

python-package/lightgbm/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pathlib import Path
77

88
from .basic import Booster, Dataset, Sequence, register_logger
9-
from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter
9+
from .callback import EarlyStopException, early_stopping, log_evaluation, record_evaluation, reset_parameter
1010
from .engine import CVBooster, cv, train
1111

1212
try:
@@ -32,5 +32,5 @@
3232
'train', 'cv',
3333
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
3434
'DaskLGBMRegressor', 'DaskLGBMClassifier', 'DaskLGBMRanker',
35-
'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
35+
'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'EarlyStopException',
3636
'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph']

python-package/lightgbm/callback.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .engine import CVBooster
1313

1414
__all__ = [
15+
'EarlyStopException',
1516
'early_stopping',
1617
'log_evaluation',
1718
'record_evaluation',
@@ -30,7 +31,11 @@
3031

3132

3233
class EarlyStopException(Exception):
33-
"""Exception of early stopping."""
34+
"""Exception of early stopping.
35+
36+
Raise this from a callback passed in via keyword argument ``callbacks``
37+
in ``cv()`` or ``train()`` to trigger early stopping.
38+
"""
3439

3540
def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
3641
"""Create early stopping exception.
@@ -39,6 +44,7 @@ def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) ->
3944
----------
4045
best_iteration : int
4146
The best iteration stopped.
47+
0-based... pass ``best_iteration=2`` to indicate that the third iteration was the best one.
4248
best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
4349
Scores for each metric, on each validation set, as of the best iteration.
4450
"""

tests/python_package_test/test_engine.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,33 @@ def test_early_stopping_min_delta(first_only, single_metric, greater_is_better):
10921092
assert np.greater_equal(last_score, best_score - min_delta).any()
10931093

10941094

1095+
def test_early_stopping_can_be_triggered_via_custom_callback():
1096+
X, y = make_synthetic_regression()
1097+
1098+
def _early_stop_after_seventh_iteration(env):
1099+
if env.iteration == 6:
1100+
exc = lgb.EarlyStopException(
1101+
best_iteration=6,
1102+
best_score=[("some_validation_set", "some_metric", 0.708, True)]
1103+
)
1104+
raise exc
1105+
1106+
bst = lgb.train(
1107+
params={
1108+
"objective": "regression",
1109+
"verbose": -1,
1110+
"num_leaves": 2
1111+
},
1112+
train_set=lgb.Dataset(X, label=y),
1113+
num_boost_round=23,
1114+
callbacks=[_early_stop_after_seventh_iteration]
1115+
)
1116+
assert bst.num_trees() == 7
1117+
assert bst.best_score["some_validation_set"]["some_metric"] == 0.708
1118+
assert bst.best_iteration == 7
1119+
assert bst.current_iteration() == 7
1120+
1121+
10951122
def test_continue_train():
10961123
X, y = make_synthetic_regression()
10971124
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

0 commit comments

Comments
 (0)