Skip to content

[CuTeDSL] Add BF16 grouped GEMM example for Hopper SM90#3060

Open
vruga wants to merge 1 commit intoNVIDIA:mainfrom
vruga:feat/hopper-bf16-grouped-gemm
Open

[CuTeDSL] Add BF16 grouped GEMM example for Hopper SM90#3060
vruga wants to merge 1 commit intoNVIDIA:mainfrom
vruga:feat/hopper-bf16-grouped-gemm

Conversation

@vruga
Copy link

@vruga vruga commented Feb 23, 2026

Closes #3040

Adds examples/python/CuTeDSL/hopper/grouped_gemm.py — a Python/CuTeDSL
grouped GEMM kernel for Hopper SM90 with BF16 (and Float16) support.

This is the CuTeDSL equivalent of examples/57_hopper_grouped_gemm,
which only exists in C++ and only supports FP8.

Design

  • WGMMA + TMA with warp specialization (DMA warp group + MMA warp groups)
  • Register accumulators (SM90 has no TMEM)
  • PipelineTmaAsync for the A/B mainloop
  • TensorMapManager for per-group TMA descriptor patching (SMEM and GMEM modes)
  • StaticPersistentGroupTileScheduler for persistent multi-group scheduling
  • Supports BF16/Float16 inputs, Float16/BFloat16/Float32 outputs

Tested on H100 (SM90) across BF16/Float16 dtypes, SMEM/GMEM tensormap modes, cluster shapes (1,1) and (2,1), up to 8 groups with unequal sizes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEA] CuTeDSL BF16 Hopper MoE GEMM/Grouped GEMM example

1 participant