Skip to content

Commit 6012b4d

Browse files
njzjzpre-commit-ci[bot]Copilot
authored
feat: add NaN detection during training (#5135)
Fix #4985. This implementation is much simpler than #4986. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved training-metric validation to detect NaN total RMSE, logging a clear error and halting runs to avoid silent failures. * **Documentation** * Added documentation for the new option that controls NaN checking so users can enable or disable the validation as needed. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Signed-off-by: Jinzhe Zeng <njzjz@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 5f73113 commit 6012b4d

1 file changed

Lines changed: 29 additions & 1 deletion

File tree

deepmd/loggers/training.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import datetime
3+
import logging
4+
import math
5+
6+
log = logging.getLogger(__name__)
37

48

59
def format_training_message(
@@ -19,7 +23,23 @@ def format_training_message_per_task(
1923
task_name: str,
2024
rmse: dict[str, float],
2125
learning_rate: float | None,
26+
check_total_rmse_nan: bool = True,
2227
) -> str:
28+
"""Format training messages for a specific task.
29+
30+
Parameters
31+
----------
32+
batch : int
33+
The batch index
34+
task_name : str
35+
The task name
36+
rmse : dict[str, float]
37+
The root-mean-squared errors.
38+
learning_rate : float | None
39+
The learning rate
40+
check_total_rmse_nan : bool
41+
Whether to throw an error if the total RMSE is NaN
42+
"""
2343
if task_name:
2444
task_name += ": "
2545
if learning_rate is None:
@@ -28,8 +48,16 @@ def format_training_message_per_task(
2848
lr = f", lr = {learning_rate:8.2e}"
2949
# sort rmse
3050
rmse = dict(sorted(rmse.items()))
31-
return (
51+
msg = (
3252
f"batch {batch:7d}: {task_name}"
3353
f"{', '.join([f'{kk} = {vv:8.2e}' for kk, vv in rmse.items()])}"
3454
f"{lr}"
3555
)
56+
if check_total_rmse_nan and math.isnan(rmse.get("rmse", 0.0)):
57+
log.error(msg)
58+
err_msg = (
59+
f"NaN detected at batch {batch:7d}: {task_name}. "
60+
"Something went wrong, and it is meaningless to continue."
61+
)
62+
raise RuntimeError(err_msg)
63+
return msg

0 commit comments

Comments
 (0)