Bug report
attention_mla.MLA scales queries after projection,
|
query = jnp.concatenate([q_nope, q_pe], axis=-1) * self.softmax_scale |
however
cudnn_jax_flash_attention (implementation used when
attention=cudnn_flash_jax) also hardcodes the scale
|
scale=1.0 / math.sqrt(head_dim), |
This leads to incorrect attention results that do not match attention=dot_product and other implementations.
Logs/Output
No response
Environment Information
No response
Additional Context
No response