Skip to content

TslibDataModule.setup() re-entrancy guard checks wrong attribute name, causing full dataset recreation on every call #2218

@lomesh2312

Description

@lomesh2312

Describe the bug

While reviewing TslibDataModule, I noticed that setup() contains a guard on line 723 that is meant to avoid re-creating the training and validation datasets if they have already been built. The intent is clear — this kind of guard is standard practice to prevent redundant, expensive work when setup() is called more than once (e.g., by a PyTorch Lightning Trainer internally).

However, the guard is checking the wrong attribute names and never actually triggers.

Relevant code (_tslib_data_module.py, line 722–725):

if stage == "fit" or stage is None:
    if not hasattr(self, "_train_dataset") or not hasattr(self, "_val_dataset"):
        self._train_windows = self._create_windows(self._train_indices)
        ...

The attributes are stored on the instance as self.train_dataset and self.val_dataset (no leading underscore), but the guard checks for self._train_dataset and self._val_dataset (with a leading underscore). Because those private names are never assigned anywhere, hasattr always returns False, so the entire dataset setup — including _create_windows(), large Dataset construction, and random index shuffling — runs unconditionally on every setup() call.

To Reproduce

dm.setup("fit")
dm.setup("fit")  # Should skip, but rebuilds everything from scratch

# Confirm the guard never fires:
print(hasattr(dm, "_train_dataset"))  # Always prints False
print(hasattr(dm, "train_dataset"))   # Prints True after setup

Expected behavior

Calling setup("fit") a second time should detect the already-existing datasets and return early, both for performance and for determinism. The current behavior also introduces a reproducibility issue: since torch.randperm() is called every time, repeated setup() invocations silently produce different train/val/test splits, which can cause confusing results during training.

Proposed fix

The fix is a single-line change — correct the attribute names in the hasattr guard:

# Before (line 723, buggy):
if not hasattr(self, "_train_dataset") or not hasattr(self, "_val_dataset"):

# After (fixed):
if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"):

Location

pytorch_forecasting/data/_tslib_data_module.py, line 723

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    Status

    Needs triage & validation

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions