Skip to content

[BUG] v2 BaseModel stores logging_metrics as plain list instead of nn.ModuleList (breaks GPU training) #2197

@StrikerEureka34

Description

@StrikerEureka34

Describe the bug

The v2 BaseModel stores logging_metrics as a plain Python list instead of wrapping it in nn.ModuleList. This means PyTorch doesn't register the metrics as submodules, so they don't move to GPU when the model does. On GPU training, this causes a device mismatch crash when log_metrics() runs.

v1 does this correctly in _base_model.py in nn.ModuleList. But v2 at _base_model_v2.py just assigns the raw list:

# v2 (broken)
self.logging_metrics = logging_metrics if logging_metrics is not None else []

# v1 (correct)
self.logging_metrics = nn.ModuleList([...])

The metrics themselves (SMAPE, MAE, MAPE, etc.) all inherit from MultiHorizonMetric, which registers internal state buffers via add_state() . These buffers stay on CPU when the model moves to GPU, because the metrics aren't registered submodules.

During a training/validation step, log_metrics() calls metric(y_hat, y), which triggers update() -> _update_losses_and_lengths(), hitting:

self.losses = self.losses + losses  # CPU buffer + GPU tensor -> RuntimeError

This affects all 5 existing v2 models (DLinear, TiDE, TimeXer, SAMformer, TFT) since they all inherit from BaseModel.


To Reproduce

import torch
from pytorch_forecasting.metrics import MAE, SMAPE

from pytorch_forecasting.models.dlinear import DLinear

# Create a model with logging_metrics on GPU
model = DLinear(
    loss=MAE(),
    logging_metrics=[SMAPE(), MAE()],
    metadata={...},  # appropriate metadata
)
model = model.to("cuda")

# Check: logging_metrics did NOT move to GPU
for m in model.logging_metrics:
    for name, buf in m._buffers.items():
        print(f"{name}: {buf.device}")  # prints "cpu" -- should be "cuda"

# This will crash during training when log_metrics() is called
# because y_hat (cuda) and metric state buffers (cpu) are on different devices

The existing tests don't catch this because:

  • The v2 integration suite (test_all_estimators_v2.py) only tests fit() and predict() at the pkg level, never calling training_step() directly with logging_metrics active
  • Unit tests that pass ogging_metrics (test_dlinear_v2.py, test_timexer_v2.py only run forward() under torch.no_grad(), never training_step()
  • CI runs on CPU, so the device mismatch never triggers

Expected behavior

logging_metrics should be stored as nn.ModuleList so that metrics are properly registered as submodules, moved to the correct device with the model, and included in state_dict() for checkpointing. This is how v1 handles it.


Suggested fix

In _base_model_v2.py:

# before
self.logging_metrics = logging_metrics if logging_metrics is not None else []

# after
self.logging_metrics = nn.ModuleList(logging_metrics) if logging_metrics else nn.ModuleList()

And add a test that actually exercises log_metrics() during a training step with ogging_metricspassed, to prevent regression.

Additional context

Came across this while studying the v2 base layer for foundation model integration work #1959.
FMs will train on GPU and need metric logging during fine-tuning, so this is directly in the path.
Related to #1700 (logging redesign for v2) and #1754 (v1 device propagation for loss metrics), though neither covers this specific issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    Status

    Reproduced/confirmed

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions