-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Closed
Description
This issue tracks the progress of adding support for the NVIDIA DGX Spark system (GB10, sm_121a).
The benchmark results published in the LMSYS blog post, along with the Docker image tag for Spark (lmsysorg/sglang:spark), were produced using a custom SGLang snapshot from my personal development branch: main...yvbbrjdr:sglang:spark.
The branch currently includes several temporary workarounds that need to be properly addressed before it can be merged into main. These include:
- Outdated base: The branch is approximately two weeks old, and rebasing onto
mainmay not succeed cleanly or work properly. - PyTorch compatibility: The official PyTorch release does not yet support CUDA 13.0, so a nightly build was used in a custom Dockerfile.
- Triton issue: Running GPT-OSS models triggers PTXAS compilation error: '.tile::gather4 with destination state space as .shared::cluster' not supported on target 'sm_121a' triton-lang/triton#8335, which remains unresolved.
- FP8 kernel dispatch: FP8 CUTLASS kernels currently fail to dispatch on GB10 (
sm_121a). As a temporary workaround, they are disabled on this branch, causing PyTorch to fall back to legacy FP8 inference kernels with reduced performance. - Dependency status: All external dependencies (except for
sgl-kernel, which must be rebuilt forsm_121a) have been disabled due to unknown compatibility. It’s unclear which git tags or commits of these dependencies are compatible with the GB10 architecture.
Additionally, @johnnynunez has opened related PRs for CUDA 13 and FA4 support, which may help resolve some of the above issues: #11299, #11606 (Thank you!)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels