Skip to content

Support hdimQK != hdimV backward#1604

Merged
tridao merged 8 commits intoDao-AILab:mainfrom
shcho1118:feat/d_dv_bwd_public
Apr 24, 2025
Merged

Support hdimQK != hdimV backward#1604
tridao merged 8 commits intoDao-AILab:mainfrom
shcho1118:feat/d_dv_bwd_public

Conversation

@shcho1118
Copy link
Contributor

@shcho1118 shcho1118 commented Apr 21, 2025

Recently, there are cases where hdimQK >= hdimV hdimQK != hdimV, such as DeepSeek (#1487).
So we worked on a simple trick (rounding hdimQK and hdimV to the same value) to support that configuration in the backward kernel as well.
Of course, it would be best to tile K and V with different rounding values (which we're working on internally), but this implementation allows us to skip explicit padding & unpadding hdimV, resulting in a performance improvement of about 15%.

This feature was developed for internal use at @kakao.

@shcho1118 shcho1118 marked this pull request as ready for review April 21, 2025 05:50
@shcho1118 shcho1118 force-pushed the feat/d_dv_bwd_public branch 2 times, most recently from 1d7a74b to 1923fa7 Compare April 22, 2025 02:29
@tridao
Copy link
Member

tridao commented Apr 22, 2025

Thank you! Could you rebase on the latest in main and then I'll merge?

@shcho1118 shcho1118 force-pushed the feat/d_dv_bwd_public branch from dd2960d to 572a0dc Compare April 23, 2025 04:30
@shcho1118
Copy link
Contributor Author

@tridao
We've done the rebase as you requested, and I think we're ready to merge.

@tridao
Copy link
Member

tridao commented Apr 23, 2025

The hdim dispatching should be based on max(params.d, params.d_v) instead of based on params.d?
E.g. if d = 128, d_v = 192, then it should dispatch to kHeadDim=192 instead of kHeadDim=128?

@shcho1118
Copy link
Contributor Author

I was initially thinking about the case where d = 192, d_v = 128 like deepseek, but max(params.d, params.d_v) would be more universal.

@shcho1118 shcho1118 changed the title Support hdimQK >= hdimV backward Support hdimQK != hdimV backward Apr 23, 2025
@shcho1118
Copy link
Contributor Author

@tridao
I took your advice and extended the support a bit more, now if max(params.d, params.dv) <= 256 it should work fine.

@shcho1118 shcho1118 force-pushed the feat/d_dv_bwd_public branch from e094679 to 69a84a3 Compare April 23, 2025 06:46
@tridao tridao merged commit 37c816a into Dao-AILab:main Apr 24, 2025
@ehuaa
Copy link

ehuaa commented Apr 27, 2025

Hi @shcho1118 @tridao , I noticed that this pr has implemented hdimQK != hdimV for backward for both sm80 and sm90 arch, but for forward (mha_forward function)

TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim");

TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim");
there's a if-statement that only hopper arch can run deepseek series models in foward function
Do you have any ideas of how to bypass this if-statement and run deepseek with sm80 arch for forward function

@shcho1118
Copy link
Contributor Author

shcho1118 commented Apr 28, 2025

@ehuaa
After I finished implementing sm80 backward, I discovered that if-statement, so sm80 backward kernel is not fully validated in this PR.
If you want to enable the same functionality in sm80 forward kernel, the fastest way to do it is to touch rounding hdim like this PR.

@ehuaa
Copy link

ehuaa commented Apr 28, 2025

@ehuaa After I finished implementing sm80 backward, I discovered that if-statement, so sm80 backward kernel is not fully validated in this PR. If you want to enable the same functionality in sm80 forward kernel, the fastest way to do it is to touch rounding hdim like this PR.

Thanks for your quick reply, i'll try it later.

@Graham1025
Copy link

Recently, there are cases where hdimQK >= hdimV hdimQK != hdimV, such as DeepSeek (#1487). So we worked on a simple trick (rounding hdimQK and hdimV to the same value) to support that configuration in the backward kernel as well. Of course, it would be best to tile K and V with different rounding values (which we're working on internally), but this implementation allows us to skip explicit padding & unpadding hdimV, resulting in a performance improvement of about 15%.

This feature was developed for internal use at @kakao.
Thanks for your great work! So does it use TMA oob feature for (rounding hdimQK and hdimV to the same value)?

@shcho1118
Copy link
Contributor Author

shcho1118 commented Jun 11, 2025

@Graham1025
Yes, for sm90, it was already using that feature to do rounding, so I just modified it slightly.

playerzer0x pushed a commit to Liqhtworks/flash-attention that referenced this pull request Jul 24, 2025
* separate d & dv (interface)

* separate d & dv (api)

* separate d & dv (template)

* separate d & dv (mainloop)

* separate d & dv (epilogue)

* update test

* disable backward test when attention_chunk != 0

* extend backward d > dv to d != dv

---------

Co-authored-by: monk.ey <monk.ey@kakaocorp.com>
elewarr pushed a commit to elewarr/flash-attention that referenced this pull request Feb 4, 2026
* separate d & dv (interface)

* separate d & dv (api)

* separate d & dv (template)

* separate d & dv (mainloop)

* separate d & dv (epilogue)

* update test

* disable backward test when attention_chunk != 0

* extend backward d > dv to d != dv

---------

Co-authored-by: monk.ey <monk.ey@kakaocorp.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants