-
Notifications
You must be signed in to change notification settings - Fork 848
Open
Description
Problem
Currently, base_metrics.py provides TorchMetricWrapper to wrap
torchmetrics Lightning metrics. However, there is no equivalent for
standard PyTorch nn.Module losses such as:
torch.nn.HuberLosstorch.nn.MSELosstorch.nn.SmoothL1Losstorch.nn.L1Loss
If a user tries to pass these directly to a pytorch-forecasting model, it fails because they:
- Don't implement
update()/compute()required byLightningMetric - Don't handle the
(target, weight)tuple that pytorch-forecasting
passes internally (seeMultiHorizonMetric.update()inbase_metrics.py) - Don't implement
to_prediction()which models call at inference time - Can't be composed with
MultiLossorCompositeMetric
Proposed Solution
Add NNLossWrapper in pytorch_forecasting/metrics/base_metrics.py as a sibling to the existing
TorchMetricWrapper, following the same
pattern:
class NNLossWrapper(Metric):
"""
Wrap a standard PyTorch nn.Module loss for use with pytorch-forecasting.
Example
-------
>>> loss = NNLossWrapper(nn.HuberLoss(delta=1.5))
>>> combined = NNLossWrapper(nn.HuberLoss()) + MAE()
"""
def __init__(self, loss_fn: torch.nn.Module, **kwargs):
if not isinstance(loss_fn, torch.nn.Module):
raise TypeError(f"loss_fn must be an nn.Module, got {type(loss_fn)}")
super().__init__(**kwargs)
self.loss_fn = loss_fn
def update(self, y_pred: torch.Tensor, y_actual) -> None:
if isinstance(y_actual, (list, tuple)) and not isinstance(
y_actual, rnn.PackedSequence
):
target, _ = y_actual
else:
target = y_actual
y_pred = self.to_prediction(y_pred)
self._loss = self.loss_fn(y_pred, target)
def compute(self) -> torch.Tensor:
return self._loss
def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
if y_pred.ndim == 3:
y_pred = y_pred[..., 0]
return y_pred
def __repr__(self):
return f"NNLossWrapper({repr(self.loss_fn)})"Files to Change
pytorch_forecasting/metrics/base_metrics.py— addNNLossWrapperpytorch_forecasting/metrics/__init__.py— exportNNLossWrapper
Related
- Roadmap 2026 Roadmap 2026 #1993 — explicitly lists "Add adapters for nn losses"
- Existing
TorchMetricWrapperinbase_metrics.py— the parallel
class this follows as a pattern
Github: @vinitjain2005
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels