Skip to content

Conversation

@mori360
Copy link
Contributor

@mori360 mori360 commented Feb 12, 2025

Fix issue #809

The current self._data.skip(self._sample_idx) could not get the correct data for c_4 dataset.
Thus we switch to next() first before the fix is landed.

Test plan:
We reproduce the #809 by resuming from checkpoint at step 500, then compare the loss curve in 3 conditions:

  1. the origin curve running from step 0 to 750
  2. the resumed curve keeping .skip()
  3. the resumed curve switch to next(), with this PR change
Screenshot 2025-02-12 at 11 19 24 AM

Warning
for c_4 dataset, if we resume from a large enough step, we call next() for self._sample_idx times, resuming from checkpoint would be much slower than using .skip()

Next step:
add unit test:

  1. test the state_dict check between dcp.save/load and torch.save/load
  2. test the difference between next() and .skip()

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 12, 2025
@mori360 mori360 marked this pull request as ready for review February 12, 2025 20:34
@mori360 mori360 requested review from fegin and tianyu-l February 12, 2025 20:34
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])

return iter(self._data.skip(self._sample_idx))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to understand if skip causes error in both map-style and Iterable datasets, or only in the newly added IterableDataset case.
If it's the latter we should just revert #521, rather than universally use next for both, because it would make the healthy case slow too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest that we land the PR first. It is better to have a slower checkpoint resume than an incorrect silent accuracy failure. It's blocking several accuracy verifications. Or at least we should make the default C4 dataset work for now.

@tianyu-l tianyu-l linked an issue Feb 13, 2025 that may be closed by this pull request
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp to unblock, but we should follow up with more robust tests.

@mori360 mori360 merged commit 0b0931c into pytorch:main Feb 13, 2025
6 checks passed
@mariosasko mariosasko mentioned this pull request Apr 9, 2025
tianyu-l pushed a commit that referenced this pull request May 16, 2025
This PR makes resuming dataset iteration from a checkpoint fast again.

This performance regression comes from
#838. In that PR, `.skip` is
removed for both map-style and iterable-style datasets for correctness
reasons. However, `.skip` works as expected for map-style datasets, so
the change can be reverted for that case. On the other hand, for
iterable-style datasets, calling `.skip` after `split_dataset_by_node`
splits the number of elements to skip **across the ranks** (e.g. calling
`.skip(10)` after `split_dataset_by_node(<rank>, 2)` effectively skips 5
(`10 // 2 = 5`) elements on each rank), which isn'r what we want/expect,
so removing `.skip` was justified there. Still, we can make the whole
thing much faster using the [`state_dict`
API](https://huggingface.co/docs/datasets/v3.5.0/en/stream#save-a-dataset-checkpoint-and-resume-iteration)
for iterable-style datasets, which avoids re-iterating past shards/files
when resuming.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Loss metrics dramatically change after resuming from checkpoint

4 participants