Skip to content

Calling SpargeAttn Twice Corrupts CUDA Context (unspecified launch failure) #97

@sajalmaheshwari624

Description

@sajalmaheshwari624

Hi SpargeAttn team — first, thank you for releasing this library and making state-of-the-art sparse attention publicly available! While integrating SpargeAttn into a vision model, I encountered a reproducible issue that appears to originate inside the CUDA kernels. I wanted to report it in case it helps improve the project, or find out is something is wrong with my installation/environment.

Calling spas_sage2_attn_meansim_topk_cuda twice in the same Python process consistently corrupts the CUDA runtime.

First call → always succeeds
Second call → always fails, even with identical tensors

Restarting Python restores normal behavior
This strongly suggests a kernel-level issue (illegal memory access, device-side assert, or resource leak).

import torch
from spas_sage_attn import spas_sage2_attn_meansim_topk_cuda

device = "cuda"

B, H, L, D = 2, 8, 128, 64 # valid per documentation

q = torch.randn(B, H, L, D, device=device, dtype=torch.float16)
k = torch.randn(B, H, L, D, device=device, dtype=torch.float16)
v = torch.randn(B, H, L, D, device=device, dtype=torch.float16)

First call — always works

o1 = spas_sage2_attn_meansim_topk_cuda(q, k, v, topk=0.5, is_causal=False)
print("First call OK")

Second call — always crashes

o2 = spas_sage2_attn_meansim_topk_cuda(q, k, v, topk=0.5, is_causal=False)
print("Second call OK")

Error output :

Traceback (most recent call last):
File "", line 1, in
File "/root/data/miniforge3/envs/spargeattn_env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
return fn(*args, **kwargs)
File "SpargeAttn/spas_sage_attn/core.py", line 114, in spas_sage2_attn_meansim_topk_cuda
km = k.mean(dim=-2, keepdim=True)

torch.AcceleratorError: CUDA error: unspecified launch failure
Search for cudaErrorLaunchFailure' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with TORCH_USE_CUDA_DSA` to enable device-side assertions.

This appears a library bug to me (please enlighten if not the case!) :
-Shapes satisfy all requirements (seq_len ≥ 128, head_dim ∈ {64,128})
-Using fp16 exactly as recommended
-No padding, slicing, or shape transforms
-Correct installation (build succeeded, import succeeds)
-First call always works, meaning inputs and installation are valid
-Only second call fails, even if tensors are identical
-Restarting Python = works again

When convenient, could you please clarify:

  1. Is this a known issue with the current kernels?

  2. Is there a recommended workaround (e.g., stream sync, manual context reset)?

  3. Are there plans to support repeated calls within a single process?

The env details I have are here :

import torch

print("PyTorch:", torch.version)
PyTorch: 2.9.1+cu128
print("CUDA runtime in torch:", torch.version.cuda)
CUDA runtime in torch: 12.8
print("Compiled with CUDA:", torch.cuda.is_available())
Compiled with CUDA: True
print("GPU name:", torch.cuda.get_device_name(0))
GPU name: NVIDIA A100-SXM4-40GB
print("Compute capability:", torch.cuda.get_device_capability(0))
Compute capability: (8, 0)

Repo git head for repro :
git rev-parse HEAD->e6b8d1c76167edc4147f9574096dc700a9d57f38

Thanks a lot and thank you for your awesome work!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions