Skip to content

Comments

use correct PG when collecting metrics with HYBRID shard#551

Merged
epwalsh merged 3 commits intomainfrom
epwalsh/hybrid-shard
Apr 19, 2024
Merged

use correct PG when collecting metrics with HYBRID shard#551
epwalsh merged 3 commits intomainfrom
epwalsh/hybrid-shard

Conversation

@epwalsh
Copy link
Member

@epwalsh epwalsh commented Apr 18, 2024

Followup to #540. Fixes how we collect per-param optim metrics when using hybrid sharding. The process group we're using is the same process group that FSDP uses during hybrid sharding when reducing the grad norms, for example, so it should be the right one. See https://github.com/pytorch/pytorch/blob/cb17721899d4d6a55d66d4f7188e36c20a078231/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1149.

@epwalsh epwalsh requested a review from 2015aroras April 18, 2024 17:37
olmo/optim.py Outdated
if per_param_min_metrics:
all_mins = torch.cat(per_param_min_metrics).to(device)
dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN)
dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN, group=process_group)
Copy link
Collaborator

@2015aroras 2015aroras Apr 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the original approach I went with, but I had a few concerns:

  1. Rank 0 refers to the global rank, so only the 'first' group will perform the reduce. The code says it will warn for the other groups, and I vaguely remember my runs crashing instead. If it doesn't crash then I guess it's not a real problem. You can get around this by getting each group rank 0 using process_group.get_global_rank(group, 0).
  2. Using multiple process groups on the same stream without some sort of synchronization can lead to deadlocks (https://pytorch.org/docs/stable/distributed.html#torch.distributed.new_group). I don't fully understand it myself. The way torch seems to get around this is that it puts the ops on different streams AND (if dynamo compiling is off) has the streams wait on each other. For us, it may be that we have to either pass the process group to all dist.* calls or that we have to synchronize when we make distributed calls over different process groups.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@2015aroras thanks, 6911bfb should fix point (1). I'll see what happens now about (2).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be working fine on a single node test at least.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@2015aroras I'm going to merge and give this a try with the 70B. If I run into other issues I'll rethink this strategy.

@epwalsh epwalsh merged commit 7be71cd into main Apr 19, 2024
@epwalsh epwalsh deleted the epwalsh/hybrid-shard branch April 19, 2024 15:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants