Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def _clean_param_name(self, name: str) -> str:

@torch.no_grad()
def clip_grads_and_collect_metrics(
self, global_step: int, collect_param_metrics: bool = True
self,
global_step: int,
collect_param_metrics: bool = True,
process_group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, torch.Tensor]:
"""
Clips gradients for every group that has the field `max_grad_norm`.
Expand Down Expand Up @@ -69,6 +72,10 @@ def clip_grads_and_collect_metrics(
per_param_avg_metric_names: List[str] = []
per_param_norm_metric_names: List[str] = []

dst_rank = 0
if process_group is not None:
dst_rank = dist.get_global_rank(process_group, 0)

# Collect metrics locally.
for group in self.param_groups:
if is_distributed():
Expand Down Expand Up @@ -144,12 +151,12 @@ def is_grad_norm_metric(metric_name: str) -> bool:
# Reduce mins.
if per_param_min_metrics:
all_mins = torch.cat(per_param_min_metrics).to(device)
dist.reduce(all_mins, 0, op=dist.ReduceOp.MIN)
dist.reduce(all_mins, dst_rank, op=dist.ReduceOp.MIN, group=process_group)
per_param_min_metrics = all_mins.split(1)
# Reduce maxs.
if per_param_max_metrics:
all_maxs = torch.cat(per_param_max_metrics).to(device)
dist.reduce(all_maxs, 0, op=dist.ReduceOp.MAX)
dist.reduce(all_maxs, dst_rank, op=dist.ReduceOp.MAX, group=process_group)
per_param_max_metrics = all_maxs.split(1)
# Reduce sums or just norms.
all_norms = torch.cat(per_param_norm_metrics).to(device) ** 2.0
Expand All @@ -159,13 +166,13 @@ def is_grad_norm_metric(metric_name: str) -> bool:
all_sums_norms_numels = torch.cat(
[all_sums.unsqueeze(0), all_norms.unsqueeze(0), all_numels.unsqueeze(0)], dim=0
)
dist.all_reduce(all_sums_norms_numels, op=dist.ReduceOp.SUM)
dist.all_reduce(all_sums_norms_numels, op=dist.ReduceOp.SUM, group=process_group)
all_sums, all_norms, all_numels = all_sums_norms_numels.split(1)
# Get averages.
# NOTE: could get infs for non-rank0 processes but that's okay.
per_param_avg_metrics = (all_sums / all_numels).squeeze(0).split(1)
else:
dist.all_reduce(all_norms, op=dist.ReduceOp.SUM)
dist.all_reduce(all_norms, op=dist.ReduceOp.SUM, group=process_group)
grad_norm_metric_mask = torch.tensor(
[float(is_grad_norm_metric(n)) for n in per_param_norm_metric_names], device=all_norms.device
)
Expand Down Expand Up @@ -325,8 +332,10 @@ def _do_global_fixed_clipping(
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device, p.grad.dtype))
return num_grads_clipped

def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]:
del module
def get_post_step_metrics(
self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
) -> Dict[str, torch.Tensor]:
del module, process_group
return {}

def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]:
Expand Down Expand Up @@ -356,7 +365,9 @@ def __init__(
self._update_total_norm: Optional[torch.Tensor] = None
self._signed_update_total_norm: Optional[torch.Tensor] = None

def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]:
def get_post_step_metrics(
self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
) -> Dict[str, torch.Tensor]:
update_total_dot_prod = self._update_total_dot_prod
update_total_norm = self._update_total_norm
signed_update_total_norm = self._signed_update_total_norm
Expand All @@ -370,7 +381,11 @@ def get_post_step_metrics(self, module: nn.Module) -> Dict[str, torch.Tensor]:
# Reduce all together to avoid multiple communication calls.
all_together = torch.stack([update_total_dot_prod, update_total_norm, signed_update_total_norm])
# Only need the final result on rank0, since that's where we log from.
dist.reduce(all_together, 0)
dist.reduce(
all_together,
0 if process_group is None else dist.get_global_rank(process_group, 0),
group=process_group,
)
update_total_dot_prod, update_total_norm, signed_update_total_norm = all_together
update_total_norm = update_total_norm**0.5
signed_update_total_norm = signed_update_total_norm**0.5
Expand Down
10 changes: 8 additions & 2 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,11 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
# Clip gradient norms and collect param/gradient/optim metrics.
should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step()
optim_metrics = self.optim.clip_grads_and_collect_metrics(
self.global_step, collect_param_metrics=should_log_optim_metrics_this_step
self.global_step,
collect_param_metrics=should_log_optim_metrics_this_step,
# passing this process group here ensures metrics are reduced correctly when we're using
# HYBRID sharding.
process_group=self.fsdp_model.process_group,
)

# Adjust the learning rate.
Expand Down Expand Up @@ -742,7 +746,9 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->

# Maybe collect post-step optimizer-specific metrics.
if should_log_optim_metrics_this_step:
optim_metrics = self.optim.get_post_step_metrics(self.fsdp_model)
optim_metrics = self.optim.get_post_step_metrics(
self.fsdp_model, process_group=self.fsdp_model.process_group
)
for key, value in optim_metrics.items():
metrics[f"optim/{key}"] = value.item()

Expand Down
2 changes: 1 addition & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must be a positive integer")

num_nodes = get_world_size() // get_local_world_size()
if num_nodes % num_model_replicas != 0:
if num_nodes > 1 and num_nodes % num_model_replicas != 0:
raise OLMoConfigurationError("fsdp.hybrid_sharding_num_model_replicas must divide number of nodes")

device_mesh = init_device_mesh("cuda", (num_model_replicas, get_world_size() // num_model_replicas))
Expand Down