use correct PG when collecting metrics with HYBRID shard#551
Merged
Conversation
2015aroras
reviewed
Apr 18, 2024
olmo/optim.py
Outdated
| if per_param_min_metrics: | ||
| all_mins = torch.cat(per_param_min_metrics).to(device) | ||
| dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN) | ||
| dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN, group=process_group) |
Collaborator
There was a problem hiding this comment.
This was the original approach I went with, but I had a few concerns:
- Rank 0 refers to the global rank, so only the 'first' group will perform the reduce. The code says it will warn for the other groups, and I vaguely remember my runs crashing instead. If it doesn't crash then I guess it's not a real problem. You can get around this by getting each group rank 0 using
process_group.get_global_rank(group, 0). - Using multiple process groups on the same stream without some sort of synchronization can lead to deadlocks (https://pytorch.org/docs/stable/distributed.html#torch.distributed.new_group). I don't fully understand it myself. The way torch seems to get around this is that it puts the ops on different streams AND (if dynamo compiling is off) has the streams wait on each other. For us, it may be that we have to either pass the process group to all
dist.*calls or that we have to synchronize when we make distributed calls over different process groups.
Member
Author
There was a problem hiding this comment.
@2015aroras thanks, 6911bfb should fix point (1). I'll see what happens now about (2).
Member
Author
There was a problem hiding this comment.
Seems to be working fine on a single node test at least.
Member
Author
There was a problem hiding this comment.
@2015aroras I'm going to merge and give this a try with the 70B. If I run into other issues I'll rethink this strategy.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Followup to #540. Fixes how we collect per-param optim metrics when using hybrid sharding. The process group we're using is the same process group that
FSDPuses during hybrid sharding when reducing the grad norms, for example, so it should be the right one. See https://github.com/pytorch/pytorch/blob/cb17721899d4d6a55d66d4f7188e36c20a078231/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1149.