Skip to content

Max calculation inside of forward creates CUDA sync #1531

@javak87

Description

@javak87

In the attention mechanism’s forward pass, we compute max(x_q_lens) and max(x_kv_lens) in order to call flash_attn_varlen_func. This max operation triggers a CUDA synchronization, which prevents effective use of torch.compile. Since x_q_lens and x_kv_lens are inputs (not intermediates), it’s recommended to move this CUDA sync as far up the call stack as possible (i.e., handle it at a higher level) to minimize its impact.

Metadata

Metadata

Assignees

No one assigned

    Labels

    modelRelated to model training or definition (not generic infra)performanceWork related to performance improvements

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions