-
Notifications
You must be signed in to change notification settings - Fork 848
Description
Describe the bug
The v2 BaseModel stores logging_metrics as a plain Python list instead of wrapping it in nn.ModuleList. This means PyTorch doesn't register the metrics as submodules, so they don't move to GPU when the model does. On GPU training, this causes a device mismatch crash when log_metrics() runs.
v1 does this correctly in _base_model.py in nn.ModuleList. But v2 at _base_model_v2.py just assigns the raw list:
# v2 (broken)
self.logging_metrics = logging_metrics if logging_metrics is not None else []
# v1 (correct)
self.logging_metrics = nn.ModuleList([...])The metrics themselves (SMAPE, MAE, MAPE, etc.) all inherit from MultiHorizonMetric, which registers internal state buffers via add_state() . These buffers stay on CPU when the model moves to GPU, because the metrics aren't registered submodules.
During a training/validation step, log_metrics() calls metric(y_hat, y), which triggers update() -> _update_losses_and_lengths(), hitting:
self.losses = self.losses + losses # CPU buffer + GPU tensor -> RuntimeErrorThis affects all 5 existing v2 models (DLinear, TiDE, TimeXer, SAMformer, TFT) since they all inherit from BaseModel.
To Reproduce
import torch
from pytorch_forecasting.metrics import MAE, SMAPE
from pytorch_forecasting.models.dlinear import DLinear
# Create a model with logging_metrics on GPU
model = DLinear(
loss=MAE(),
logging_metrics=[SMAPE(), MAE()],
metadata={...}, # appropriate metadata
)
model = model.to("cuda")
# Check: logging_metrics did NOT move to GPU
for m in model.logging_metrics:
for name, buf in m._buffers.items():
print(f"{name}: {buf.device}") # prints "cpu" -- should be "cuda"
# This will crash during training when log_metrics() is called
# because y_hat (cuda) and metric state buffers (cpu) are on different devicesThe existing tests don't catch this because:
- The v2 integration suite (test_all_estimators_v2.py) only tests
fit()andpredict()at the pkg level, never callingtraining_step()directly with logging_metrics active - Unit tests that pass ogging_metrics (test_dlinear_v2.py, test_timexer_v2.py only run
forward()undertorch.no_grad(), nevertraining_step() - CI runs on CPU, so the device mismatch never triggers
Expected behavior
logging_metrics should be stored as nn.ModuleList so that metrics are properly registered as submodules, moved to the correct device with the model, and included in state_dict() for checkpointing. This is how v1 handles it.
Suggested fix
# before
self.logging_metrics = logging_metrics if logging_metrics is not None else []
# after
self.logging_metrics = nn.ModuleList(logging_metrics) if logging_metrics else nn.ModuleList()And add a test that actually exercises log_metrics() during a training step with ogging_metricspassed, to prevent regression.
Additional context
Came across this while studying the v2 base layer for foundation model integration work #1959.
FMs will train on GPU and need metric logging during fine-tuning, so this is directly in the path.
Related to #1700 (logging redesign for v2) and #1754 (v1 device propagation for loss metrics), though neither covers this specific issue.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status