-
Notifications
You must be signed in to change notification settings - Fork 87
Description
Hi SpargeAttn team — first, thank you for releasing this library and making state-of-the-art sparse attention publicly available! While integrating SpargeAttn into a vision model, I encountered a reproducible issue that appears to originate inside the CUDA kernels. I wanted to report it in case it helps improve the project, or find out is something is wrong with my installation/environment.
Calling spas_sage2_attn_meansim_topk_cuda twice in the same Python process consistently corrupts the CUDA runtime.
First call → always succeeds
Second call → always fails, even with identical tensors
Restarting Python restores normal behavior
This strongly suggests a kernel-level issue (illegal memory access, device-side assert, or resource leak).
import torch
from spas_sage_attn import spas_sage2_attn_meansim_topk_cuda
device = "cuda"
B, H, L, D = 2, 8, 128, 64 # valid per documentation
q = torch.randn(B, H, L, D, device=device, dtype=torch.float16)
k = torch.randn(B, H, L, D, device=device, dtype=torch.float16)
v = torch.randn(B, H, L, D, device=device, dtype=torch.float16)
First call — always works
o1 = spas_sage2_attn_meansim_topk_cuda(q, k, v, topk=0.5, is_causal=False)
print("First call OK")
Second call — always crashes
o2 = spas_sage2_attn_meansim_topk_cuda(q, k, v, topk=0.5, is_causal=False)
print("Second call OK")
Error output :
Traceback (most recent call last):
File "", line 1, in
File "/root/data/miniforge3/envs/spargeattn_env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
return fn(*args, **kwargs)
File "SpargeAttn/spas_sage_attn/core.py", line 114, in spas_sage2_attn_meansim_topk_cuda
km = k.mean(dim=-2, keepdim=True)
torch.AcceleratorError: CUDA error: unspecified launch failure
Search for cudaErrorLaunchFailure' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with TORCH_USE_CUDA_DSA` to enable device-side assertions.
This appears a library bug to me (please enlighten if not the case!) :
-Shapes satisfy all requirements (seq_len ≥ 128, head_dim ∈ {64,128})
-Using fp16 exactly as recommended
-No padding, slicing, or shape transforms
-Correct installation (build succeeded, import succeeds)
-First call always works, meaning inputs and installation are valid
-Only second call fails, even if tensors are identical
-Restarting Python = works again
When convenient, could you please clarify:
-
Is this a known issue with the current kernels?
-
Is there a recommended workaround (e.g., stream sync, manual context reset)?
-
Are there plans to support repeated calls within a single process?
The env details I have are here :
import torch
print("PyTorch:", torch.version)
PyTorch: 2.9.1+cu128
print("CUDA runtime in torch:", torch.version.cuda)
CUDA runtime in torch: 12.8
print("Compiled with CUDA:", torch.cuda.is_available())
Compiled with CUDA: True
print("GPU name:", torch.cuda.get_device_name(0))
GPU name: NVIDIA A100-SXM4-40GB
print("Compute capability:", torch.cuda.get_device_capability(0))
Compute capability: (8, 0)
Repo git head for repro :
git rev-parse HEAD->e6b8d1c76167edc4147f9574096dc700a9d57f38
Thanks a lot and thank you for your awesome work!