Skip to content

Commit 104fc36

Browse files
authored
Perf: use fused Adam optimizer (#4463)
This PR sets the Adam optimizer to use the `fused=True` parameter. For the profiling result shown below, this modification brings an 2.75x improvement on optimizer update (22ms vs. 8ms) and ~3% improvement for total speed up (922ms vs. 892ms). The benchmark case is training a DPA-2 Q3 release model. Please note that the absolute time may differs between steps. <details><summary>Before</summary> <p> ![image](https://github.com/user-attachments/assets/d6b05a1d-6e6c-478d-921f-c497718bc551) </p> </details> <details><summary>After</summary> <p> ![image](https://github.com/user-attachments/assets/b216b919-094c-441f-96a7-146e1e3db483) </p> </details> [Ref](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html): > The foreach and fused implementations are typically faster than the for-loop, single-tensor implementation, with **fused being theoretically fastest** with both vertical and horizontal fusion. As such, if the user has not specified either flag (i.e., when foreach = fused = None), we will attempt defaulting to the foreach implementation when the tensors are all on CUDA. Why not fused? Since the fused implementation is relatively new, we want to give it sufficient bake-in time. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved optimizer performance during training by modifying the initialization of the Adam optimizer. - **Documentation** - Updated method signature for clarity in the `Trainer` class. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent e8167ce commit 104fc36

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

deepmd/pt/train/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def warm_up_linear(step, warmup_steps):
579579
# author: iProzd
580580
if self.opt_type == "Adam":
581581
self.optimizer = torch.optim.Adam(
582-
self.wrapper.parameters(), lr=self.lr_exp.start_lr
582+
self.wrapper.parameters(), lr=self.lr_exp.start_lr, fused=True
583583
)
584584
if optimizer_state_dict is not None and self.restart_training:
585585
self.optimizer.load_state_dict(optimizer_state_dict)

0 commit comments

Comments
 (0)