Skip to content

[ENH] Add NNLossWrapper to support standard PyTorch nn.Module losses in pytorch-forecasting #2214

@vinitjain2005

Description

@vinitjain2005

Problem

Currently, base_metrics.py provides TorchMetricWrapper to wrap
torchmetrics Lightning metrics. However, there is no equivalent for
standard PyTorch nn.Module losses such as:

  • torch.nn.HuberLoss
  • torch.nn.MSELoss
  • torch.nn.SmoothL1Loss
  • torch.nn.L1Loss

If a user tries to pass these directly to a pytorch-forecasting model, it fails because they:

  1. Don't implement update() / compute() required by LightningMetric
  2. Don't handle the (target, weight) tuple that pytorch-forecasting
    passes internally (see MultiHorizonMetric.update() in base_metrics.py)
  3. Don't implement to_prediction() which models call at inference time
  4. Can't be composed with MultiLoss or CompositeMetric

Proposed Solution

Add NNLossWrapper in pytorch_forecasting/metrics/base_metrics.py as a sibling to the existing
TorchMetricWrapper, following the same
pattern:

class NNLossWrapper(Metric):
    """
    Wrap a standard PyTorch nn.Module loss for use with pytorch-forecasting.

    Example
    -------
    >>> loss = NNLossWrapper(nn.HuberLoss(delta=1.5))
    >>> combined = NNLossWrapper(nn.HuberLoss()) + MAE()
    """

    def __init__(self, loss_fn: torch.nn.Module, **kwargs):
        if not isinstance(loss_fn, torch.nn.Module):
            raise TypeError(f"loss_fn must be an nn.Module, got {type(loss_fn)}")
        super().__init__(**kwargs)
        self.loss_fn = loss_fn

    def update(self, y_pred: torch.Tensor, y_actual) -> None:
        if isinstance(y_actual, (list, tuple)) and not isinstance(
            y_actual, rnn.PackedSequence
        ):
            target, _ = y_actual
        else:
            target = y_actual
        y_pred = self.to_prediction(y_pred)
        self._loss = self.loss_fn(y_pred, target)

    def compute(self) -> torch.Tensor:
        return self._loss

    def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
        if y_pred.ndim == 3:
            y_pred = y_pred[..., 0]
        return y_pred

    def __repr__(self):
        return f"NNLossWrapper({repr(self.loss_fn)})"

Files to Change

  • pytorch_forecasting/metrics/base_metrics.py — add NNLossWrapper
  • pytorch_forecasting/metrics/__init__.py — export NNLossWrapper

Related

  • Roadmap 2026 Roadmap 2026 #1993 — explicitly lists "Add adapters for nn losses"
  • Existing TorchMetricWrapper in base_metrics.py — the parallel
    class this follows as a pattern

Github: @vinitjain2005

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions