Skip to content

Add option to skip optim steps for 0 grad params#636

Merged
epwalsh merged 13 commits intomainfrom
epwalsh/selective-wd
Jul 9, 2024
Merged

Add option to skip optim steps for 0 grad params#636
epwalsh merged 13 commits intomainfrom
epwalsh/selective-wd

Conversation

@epwalsh
Copy link
Member

@epwalsh epwalsh commented Jun 28, 2024

#605 should be reviewed and merged first.

This PR adds the ability to skip optimizer updates for the parts of parameters that have 0 gradients, such as the embeddings for tokens not present in the current batch (assuming no weight tying).

epwalsh and others added 9 commits May 28, 2024 13:53
- Adds configuration field `optimizer.record_update_metrics`, which
  defaults to `False`, but when set to `True` will trigger AdamW to
  collect the step size norm and absolute max for each parameter.
- Changes the behavior of the Lion optimizer to only record the update cosine
  similarity when `optimizer.record_update_metrics` is `True` in order to be
  consistent with the API.
olmo/optim.py Outdated
# Perform step weight decay
mask: Optional[torch.Tensor] = None
if self._selective_updates:
mask = grad != 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: you could instead do mask = grad != 0 if self._selective_updates else 1, and assume the mask is always present in subsequent logic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call: 1024122

olmo/optim.py Outdated
super().__init__(params, defaults)
for group in self.param_groups:
group["initial_lr"] = group["lr"]
self._selective_updates = selective_updates
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Like in the other PR, this could be moved into the parent class

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done: e597e5f

super().__init__(*args, **kwargs)
self._record_step_size = record_update_metrics

# Need to set these here just like in our base `Optimizer` class since our `Optimizer.__init__`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we don't call Optimizer.__init__ too? Because multiple inheritance is complicated?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea this gets messy b/c our Optimizer.__init__() also calls PyTorch's Optimizer.__init__(), which would then get called twice here unless we didn't call torch.optim.AdamW.__init__() (via super().__init__()), but then we'd have to copy over all the other code that happens within torch.optim.AdamW.__init__().

@epwalsh epwalsh merged commit bc60b8a into main Jul 9, 2024
@epwalsh epwalsh deleted the epwalsh/selective-wd branch July 9, 2024 17:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants