Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion deepmd/loggers/training.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import datetime
import logging
import math

log = logging.getLogger(__name__)


def format_training_message(
Expand All @@ -19,7 +23,23 @@ def format_training_message_per_task(
task_name: str,
rmse: dict[str, float],
learning_rate: float | None,
check_total_rmse_nan: bool = True,
) -> str:
"""Format training messages for a specific task.

Parameters
----------
batch : int
The batch index
task_name : str
The task name
rmse : dict[str, float]
The root-mean-squared errors.
learning_rate : float | None
The learning rate
check_total_rmse_nan : bool
Whether to throw an error if the total RMSE is NaN
"""
if task_name:
task_name += ": "
if learning_rate is None:
Expand All @@ -28,8 +48,16 @@ def format_training_message_per_task(
lr = f", lr = {learning_rate:8.2e}"
# sort rmse
rmse = dict(sorted(rmse.items()))
return (
msg = (
f"batch {batch:7d}: {task_name}"
f"{', '.join([f'{kk} = {vv:8.2e}' for kk, vv in rmse.items()])}"
f"{lr}"
)
if check_total_rmse_nan and math.isnan(rmse.get("rmse", 0.0)):
Comment thread
njzjz marked this conversation as resolved.
log.error(msg)
err_msg = (
f"NaN detected at batch {batch:7d}: {task_name}. "
"Something went wrong, and it is meaningless to continue."
)
raise RuntimeError(err_msg)
Comment thread
njzjz marked this conversation as resolved.
return msg