Skip to content

MLA with cudnn_flash_jax leads to double scaling #3138

@Angelogeb

Description

@Angelogeb

Bug report

attention_mla.MLA scales queries after projection,

query = jnp.concatenate([q_nope, q_pe], axis=-1) * self.softmax_scale
however cudnn_jax_flash_attention (implementation used when attention=cudnn_flash_jax) also hardcodes the scale
scale=1.0 / math.sqrt(head_dim),

This leads to incorrect attention results that do not match attention=dot_product and other implementations.

Logs/Output

No response

Environment Information

No response

Additional Context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions