-
Notifications
You must be signed in to change notification settings - Fork 848
Description
Describe the bug
While reviewing TslibDataModule, I noticed that setup() contains a guard on line 723 that is meant to avoid re-creating the training and validation datasets if they have already been built. The intent is clear — this kind of guard is standard practice to prevent redundant, expensive work when setup() is called more than once (e.g., by a PyTorch Lightning Trainer internally).
However, the guard is checking the wrong attribute names and never actually triggers.
Relevant code (_tslib_data_module.py, line 722–725):
if stage == "fit" or stage is None:
if not hasattr(self, "_train_dataset") or not hasattr(self, "_val_dataset"):
self._train_windows = self._create_windows(self._train_indices)
...The attributes are stored on the instance as self.train_dataset and self.val_dataset (no leading underscore), but the guard checks for self._train_dataset and self._val_dataset (with a leading underscore). Because those private names are never assigned anywhere, hasattr always returns False, so the entire dataset setup — including _create_windows(), large Dataset construction, and random index shuffling — runs unconditionally on every setup() call.
To Reproduce
dm.setup("fit")
dm.setup("fit") # Should skip, but rebuilds everything from scratch
# Confirm the guard never fires:
print(hasattr(dm, "_train_dataset")) # Always prints False
print(hasattr(dm, "train_dataset")) # Prints True after setupExpected behavior
Calling setup("fit") a second time should detect the already-existing datasets and return early, both for performance and for determinism. The current behavior also introduces a reproducibility issue: since torch.randperm() is called every time, repeated setup() invocations silently produce different train/val/test splits, which can cause confusing results during training.
Proposed fix
The fix is a single-line change — correct the attribute names in the hasattr guard:
# Before (line 723, buggy):
if not hasattr(self, "_train_dataset") or not hasattr(self, "_val_dataset"):
# After (fixed):
if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"):Location
pytorch_forecasting/data/_tslib_data_module.py, line 723
Metadata
Metadata
Assignees
Labels
Type
Projects
Status