Add option to skip optim steps for 0 grad params#636
Conversation
- Adds configuration field `optimizer.record_update_metrics`, which defaults to `False`, but when set to `True` will trigger AdamW to collect the step size norm and absolute max for each parameter. - Changes the behavior of the Lion optimizer to only record the update cosine similarity when `optimizer.record_update_metrics` is `True` in order to be consistent with the API.
olmo/optim.py
Outdated
| # Perform step weight decay | ||
| mask: Optional[torch.Tensor] = None | ||
| if self._selective_updates: | ||
| mask = grad != 0 |
There was a problem hiding this comment.
thought: you could instead do mask = grad != 0 if self._selective_updates else 1, and assume the mask is always present in subsequent logic.
olmo/optim.py
Outdated
| super().__init__(params, defaults) | ||
| for group in self.param_groups: | ||
| group["initial_lr"] = group["lr"] | ||
| self._selective_updates = selective_updates |
There was a problem hiding this comment.
nit: Like in the other PR, this could be moved into the parent class
| super().__init__(*args, **kwargs) | ||
| self._record_step_size = record_update_metrics | ||
|
|
||
| # Need to set these here just like in our base `Optimizer` class since our `Optimizer.__init__` |
There was a problem hiding this comment.
Any reason we don't call Optimizer.__init__ too? Because multiple inheritance is complicated?
There was a problem hiding this comment.
yea this gets messy b/c our Optimizer.__init__() also calls PyTorch's Optimizer.__init__(), which would then get called twice here unless we didn't call torch.optim.AdamW.__init__() (via super().__init__()), but then we'd have to copy over all the other code that happens within torch.optim.AdamW.__init__().
#605 should be reviewed and merged first.
This PR adds the ability to skip optimizer updates for the parts of parameters that have 0 gradients, such as the embeddings for tokens not present in the current batch (assuming no weight tying).