diff --git a/deepmd/loggers/training.py b/deepmd/loggers/training.py index c7fe94e24d..555ab32622 100644 --- a/deepmd/loggers/training.py +++ b/deepmd/loggers/training.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import datetime +import logging +import math + +log = logging.getLogger(__name__) def format_training_message( @@ -19,7 +23,23 @@ def format_training_message_per_task( task_name: str, rmse: dict[str, float], learning_rate: float | None, + check_total_rmse_nan: bool = True, ) -> str: + """Format training messages for a specific task. + + Parameters + ---------- + batch : int + The batch index + task_name : str + The task name + rmse : dict[str, float] + The root-mean-squared errors. + learning_rate : float | None + The learning rate + check_total_rmse_nan : bool + Whether to throw an error if the total RMSE is NaN + """ if task_name: task_name += ": " if learning_rate is None: @@ -28,8 +48,16 @@ def format_training_message_per_task( lr = f", lr = {learning_rate:8.2e}" # sort rmse rmse = dict(sorted(rmse.items())) - return ( + msg = ( f"batch {batch:7d}: {task_name}" f"{', '.join([f'{kk} = {vv:8.2e}' for kk, vv in rmse.items()])}" f"{lr}" ) + if check_total_rmse_nan and math.isnan(rmse.get("rmse", 0.0)): + log.error(msg) + err_msg = ( + f"NaN detected at batch {batch:7d}: {task_name}. " + "Something went wrong, and it is meaningless to continue." + ) + raise RuntimeError(err_msg) + return msg