-
Notifications
You must be signed in to change notification settings - Fork 49
Open
Labels
modelRelated to model training or definition (not generic infra)Related to model training or definition (not generic infra)performanceWork related to performance improvementsWork related to performance improvements
Description
In the attention mechanism’s forward pass, we compute max(x_q_lens) and max(x_kv_lens) in order to call flash_attn_varlen_func. This max operation triggers a CUDA synchronization, which prevents effective use of torch.compile. Since x_q_lens and x_kv_lens are inputs (not intermediates), it’s recommended to move this CUDA sync as far up the call stack as possible (i.e., handle it at a higher level) to minimize its impact.
Metadata
Metadata
Assignees
Labels
modelRelated to model training or definition (not generic infra)Related to model training or definition (not generic infra)performanceWork related to performance improvementsWork related to performance improvements
Type
Projects
Status
No status