Support hdimQK != hdimV backward#1604
Conversation
1d7a74b to
1923fa7
Compare
|
Thank you! Could you rebase on the latest in |
dd2960d to
572a0dc
Compare
|
@tridao |
|
The hdim dispatching should be based on |
|
I was initially thinking about the case where d = 192, d_v = 128 like deepseek, but |
|
@tridao |
e094679 to
69a84a3
Compare
|
Hi @shcho1118 @tridao , I noticed that this pr has implemented hdimQK != hdimV for backward for both sm80 and sm90 arch, but for forward (mha_forward function) flash-attention/hopper/flash_api.cpp Line 759 in a9a3170 TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); there's a if-statement that only hopper arch can run deepseek series models in foward function Do you have any ideas of how to bypass this if-statement and run deepseek with sm80 arch for forward function |
|
@ehuaa |
Thanks for your quick reply, i'll try it later. |
|
|
@Graham1025 |
* separate d & dv (interface) * separate d & dv (api) * separate d & dv (template) * separate d & dv (mainloop) * separate d & dv (epilogue) * update test * disable backward test when attention_chunk != 0 * extend backward d > dv to d != dv --------- Co-authored-by: monk.ey <monk.ey@kakaocorp.com>
* separate d & dv (interface) * separate d & dv (api) * separate d & dv (template) * separate d & dv (mainloop) * separate d & dv (epilogue) * update test * disable backward test when attention_chunk != 0 * extend backward d > dv to d != dv --------- Co-authored-by: monk.ey <monk.ey@kakaocorp.com>
Recently, there are cases where
hdimQK >= hdimVhdimQK != hdimV, such as DeepSeek (#1487).So we worked on a simple trick (rounding hdimQK and hdimV to the same value) to support that configuration in the backward kernel as well.
Of course, it would be best to tile K and V with different rounding values (which we're working on internally), but this implementation allows us to skip explicit padding & unpadding hdimV, resulting in a performance improvement of about 15%.
This feature was developed for internal use at @kakao.