Skip to content

[BUG] [Python DSL] BlockScaledMmaOp restricts FP4 operations to sm_100a only, blocks sm_120/sm_121 #2800

@huangyucbr-hub

Description

@huangyucbr-hub

Which component has the problem?

CuTe DSL

Bug Report

Bug Report

Summary

CUTLASS 4.2+ added SM120 and SM121 kernel support for Blackwell GeForce (RTX 50-series) and DGX Spark (GB10) GPUs according to the https://docs.nvidia.com/cutlass/4.2.1/CHANGELOG.html, but the Python DSL BlockScaledMmaOp class restricts FP4 operations to sm_100a
only, preventing use on sm_120 and sm_121 hardware.

Environment

  • Hardware: NVIDIA DGX Spark GB10 (Compute Capability 12.1, sm_121)
  • Package: nvidia-cutlass-dsl version 4.3.0 (latest from PyPI)
  • Python: 3.13
  • CUDA: 13.0.1
  • PyTorch: 2.10.0.dev20251118+cu130

Bug Location

File: python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py
Class: BlockScaledMmaOp
Lines: 303-305

@DataClass(frozen=True)
class BlockScaledMmaOp(Tcgen05MmaOp):
# ... other fields ...

  admissible_archs = [
      Arch.sm_100a,  # ← Only sm_100a allowed
  ]

  def __post_init__(self) -> None:
      arch = CuTeDSL._get_dsl().get_arch_enum()
      if arch not in self.admissible_archs:
          raise OpError(
              self,
              f"expects arch to be one of {self.admissible_archs}, but got {arch}",
              suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
          )

Reproduction

Minimal Test Case

from cutlass.cute.nvgpu.tcgen05 import MmaMXF4NVF4Op
from cutlass import Float8E8M0FNU, Arch
from cutlass.cute.nvgpu.tcgen05 import CtaGroup, OperandSource

On sm_121 (GB10) or sm_120 (RTX 5090) hardware:

mma_op = MmaMXF4NVF4Op(
sf_dtype=Float8E8M0FNU,
instruction_shape=(16, 16, 64),
cta_group=CtaGroup.ONE,
a_src=OperandSource.TMEM
)

Error Output

OpError: expects arch to be one of [Arch.sm_100a], but got Arch.sm_121a

Traceback (most recent call last):
File "test_fp4.py", line 6, in
mma_op = MmaMXF4NVF4Op(...)
File ".../mma.py", line 311, in post_init
raise OpError(
self,
f"expects arch to be one of {self.admissible_archs}, but got {arch}",
suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture",
)

Expected Behavior

Based on CUTLASS 4.2 changelog:
"Support for Blackwell SM121 kernels for DGX Spark GPUs. Share the major codes with Blackwell SM120 kernels."

The Python DSL should allow FP4 operations on sm_120 (RTX 5090) and sm_121 (GB10) architectures, consistent with C++ API support.

Proposed Fix

admissible_archs = [
Arch.sm_100a, # B200/B100 datacenter Blackwell
Arch.sm_120a, # RTX 5090 GeForce Blackwell (compute 12.0)
Arch.sm_121a, # GB10 DGX Spark Blackwell (compute 12.1)
]

Validation

  • ✅ Hardware verified: GB10 has 5th-generation Tensor Cores with FP4 support (1 PFLOPS peak)
  • ✅ C++ API works: vLLM successfully uses CUTLASS FP4 on sm_120 via C++ API (Support CUTLASS NVFP4 (w4a4) for Blackwell Geforce GPUs (SM120) vllm-project/vllm#21309)
  • ✅ Patch tested: Applying the proposed fix eliminates the architecture error on GB10
  • ⚠️ Kernel availability: Pre-compiled sm_121 kernels may still be unavailable (separate issue)

Impact

Affected Users:

  • ❌ All RTX 5090 users (sm_120)
  • ❌ All DGX Spark GB10 users (sm_121)
  • ❌ Potentially RTX 5080/5070/5060 users (also sm_120)

Workaround:
Users can manually patch the installed package, but this:

  • Requires editing system packages (not ideal)
  • Gets overwritten on package upgrades
  • Isn't discoverable to most users

Additional Context

CUTLASS Changelog References

Related Work

  • vLLM PR #21309: Demonstrates CUTLASS NVFP4 working on sm_120 via C++ API
  • cuBLAS 12.9: Supports FP4 block-scaled operations on Blackwell

Hardware Specifications

  • sm_100: B200/B100 datacenter GPUs
  • sm_120: RTX 5090/5080/5070/5060 consumer GPUs
  • sm_121: GB10 superchip (DGX Spark, Project DIGITS)

All have 5th-generation Tensor Cores with hardware FP4 support.

Request

Please update BlockScaledMmaOp.admissible_archs to include Arch.sm_120a and Arch.sm_121a to match the C++ API's architecture support.

This will enable the Python DSL FP4 functionality on the full Blackwell GPU family, not just datacenter variants.

Thank you for maintaining this excellent library!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions