diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 463b924fa79b..2696a96e93d6 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -75,13 +75,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()"); m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); - m.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + m.def("silu_and_mul(Tensor! out, Tensor input, bool enable_pdl) -> ()"); m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); - m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, bool enable_pdl) -> ()"); m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); - m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); + m.def("gelu_and_mul(Tensor! out, Tensor input, bool enable_pdl) -> ()"); m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); m.def( diff --git a/sgl-kernel/csrc/common_extension_rocm.cc b/sgl-kernel/csrc/common_extension_rocm.cc index 94c13fdad985..7cc732192099 100644 --- a/sgl-kernel/csrc/common_extension_rocm.cc +++ b/sgl-kernel/csrc/common_extension_rocm.cc @@ -22,13 +22,13 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { /* * From csrc/elementwise */ - m.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + m.def("silu_and_mul(Tensor! out, Tensor input, bool enable_pdl) -> ()"); m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); - m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, bool enable_pdl) -> ()"); m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); - m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); + m.def("gelu_and_mul(Tensor! out, Tensor input, bool enable_pdl) -> ()"); m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); m.def("gelu_quick(Tensor! out, Tensor input) -> ()"); diff --git a/sgl-kernel/csrc/elementwise/activation.cu b/sgl-kernel/csrc/elementwise/activation.cu index 43617f87f318..b823c37cc559 100644 --- a/sgl-kernel/csrc/elementwise/activation.cu +++ b/sgl-kernel/csrc/elementwise/activation.cu @@ -82,7 +82,7 @@ __device__ __forceinline__ T gelu_tanh(const T& x) { return detail::from_f32(f32_val * cdf); } -void silu_and_mul(at::Tensor& out, at::Tensor& input) { +void silu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl) { int d = input.size(-1) / 2; int64_t num_tokens = input.numel() / input.size(-1); dim3 grid(num_tokens); @@ -97,52 +97,99 @@ void silu_and_mul(at::Tensor& out, at::Tensor& input) { sgl_hip::activation::act_and_mul_kernel <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); #else - flashinfer::activation::act_and_mul_kernel - <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); + cudaLaunchConfig_t config; + config.gridDim = grid; + config.blockDim = block; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + + auto kernel = flashinfer::activation::act_and_mul_kernel; + + cudaLaunchKernelEx( + &config, kernel, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); + + cudaError_t err = cudaGetLastError(); + TORCH_CHECK(err == cudaSuccess, "Failed to launch kernel: ", cudaGetErrorString(err)); #endif return true; }); } -void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) { +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl) { int d = input.size(-1) / 2; int64_t num_tokens = input.numel() / input.size(-1); - dim3 grid(num_tokens); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const c10::cuda::OptionalCUDAGuard device_guard(device_of(input)); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); - dim3 block(std::min(d / vec_size, 1024U)); #if USE_ROCM + dim3 grid(num_tokens); + dim3 block(std::min(d / vec_size, 1024U)); sgl_hip::activation::act_and_mul_kernel <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); #else - flashinfer::activation::act_and_mul_kernel - <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); + cudaLaunchConfig_t config; + config.gridDim = num_tokens; + config.blockDim = std::min(d / vec_size, 1024U); + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + + auto kernel = flashinfer::activation::act_and_mul_kernel; + + cudaLaunchKernelEx( + &config, kernel, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); + + cudaError_t err = cudaGetLastError(); + TORCH_CHECK(err == cudaSuccess, "Failed to launch kernel: ", cudaGetErrorString(err)); #endif + return true; }); } -void gelu_and_mul(at::Tensor& out, at::Tensor& input) { +void gelu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl) { int d = input.size(-1) / 2; int64_t num_tokens = input.numel() / input.size(-1); - dim3 grid(num_tokens); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const c10::cuda::OptionalCUDAGuard device_guard(device_of(input)); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); - dim3 block(std::min(d / vec_size, 1024U)); #if USE_ROCM + dim3 grid(num_tokens); + dim3 block(std::min(d / vec_size, 1024U)); sgl_hip::activation::act_and_mul_kernel <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); #else - flashinfer::activation::act_and_mul_kernel - <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); + cudaLaunchConfig_t config; + config.gridDim = num_tokens; + config.blockDim = std::min(d / vec_size, 1024U); + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + + auto kernel = flashinfer::activation::act_and_mul_kernel; + + cudaLaunchKernelEx( + &config, kernel, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); + + cudaError_t err = cudaGetLastError(); + TORCH_CHECK(err == cudaSuccess, "Failed to launch kernel: ", cudaGetErrorString(err)); #endif return true; diff --git a/sgl-kernel/csrc/elementwise/rope.cu b/sgl-kernel/csrc/elementwise/rope.cu index b6abd1cce122..c2712d49d549 100644 --- a/sgl-kernel/csrc/elementwise/rope.cu +++ b/sgl-kernel/csrc/elementwise/rope.cu @@ -23,7 +23,6 @@ #include "utils.h" using namespace flashinfer; - void apply_rope_pos_ids_cos_sin_cache( at::Tensor q, at::Tensor k, @@ -92,6 +91,7 @@ void apply_rope_pos_ids_cos_sin_cache( size_t k_rope_stride_n = k_rope.stride(0); size_t k_rope_stride_h = k_rope.stride(1); + const c10::cuda::OptionalCUDAGuard device_guard(q.device()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(q.scalar_type(), c_type, [&] { // TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache diff --git a/sgl-kernel/csrc/gemm/bmm_fp8.cu b/sgl-kernel/csrc/gemm/bmm_fp8.cu index cef85a7de8ee..642b80741d41 100644 --- a/sgl-kernel/csrc/gemm/bmm_fp8.cu +++ b/sgl-kernel/csrc/gemm/bmm_fp8.cu @@ -50,6 +50,7 @@ void bmm_fp8( auto n = B.size(2); auto lt_handle = reinterpret_cast(cublas_handle); + const c10::cuda::OptionalCUDAGuard device_guard(A.device()); auto stream = at::cuda::getCurrentCUDAStream(); auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt( diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 5e3cf24f9036..eefcd022eef4 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -131,10 +131,9 @@ void sgl_fused_add_rmsnorm( torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl); void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl); void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl); -void silu_and_mul(at::Tensor& out, at::Tensor& input); -void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input); -void gelu_and_mul(at::Tensor& out, at::Tensor& input); - +void silu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl); +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl); +void gelu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl); void apply_rope_pos_ids_cos_sin_cache( at::Tensor q, at::Tensor k, diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 68dc221d1a1e..f981b9caafe4 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -12,7 +12,7 @@ def rmsnorm( weight: torch.Tensor, eps: float = 1e-6, out: Optional[torch.Tensor] = None, - enable_pdl: Optional[bool] = None, + enable_pdl: Optional[bool] = is_arch_support_pdl(), ) -> torch.Tensor: r"""Root mean square normalization. @@ -31,7 +31,7 @@ def rmsnorm( enable_pdl: Optional[bool] Whether to enable `programmatic dependent launch `_ - If None, will be automatically enabled on Hopper architecture. + Enabled by default on Hopper or later architectures. Returns ------- @@ -51,7 +51,7 @@ def fused_add_rmsnorm( residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, - enable_pdl: Optional[bool] = None, + enable_pdl: Optional[bool] = is_arch_support_pdl(), ) -> None: r"""Fused add root mean square normalization. @@ -74,7 +74,7 @@ def fused_add_rmsnorm( enable_pdl: Optional[bool] Whether to enable `programmatic dependent launch `_ - If None, will be automatically enabled on Hopper architecture. + Enabled by default on Hopper architecture. """ if enable_pdl is None: enable_pdl = is_arch_support_pdl() @@ -88,7 +88,7 @@ def gemma_rmsnorm( weight: torch.Tensor, eps: float = 1e-6, out: Optional[torch.Tensor] = None, - enable_pdl: Optional[bool] = None, + enable_pdl: Optional[bool] = is_arch_support_pdl(), ) -> torch.Tensor: r"""Gemma-style root mean square normalization. @@ -107,7 +107,7 @@ def gemma_rmsnorm( enable_pdl: Optional[bool] Whether to enable `programmatic dependent launch `_ - If None, will be automatically enabled on Hopper architecture. + Enabled by default on Hopper architecture. Returns ------- @@ -127,7 +127,7 @@ def gemma_fused_add_rmsnorm( residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, - enable_pdl: Optional[bool] = None, + enable_pdl: Optional[bool] = is_arch_support_pdl(), ) -> None: r"""Gemma-style fused add root mean square normalization. @@ -150,7 +150,7 @@ def gemma_fused_add_rmsnorm( enable_pdl: Optional[bool] Whether to enable `programmatic dependent launch `_ - If None, will be automatically enabled on Hopper architecture. + Enabled by default on Hopper architecture. """ if enable_pdl is None: enable_pdl = is_arch_support_pdl() @@ -169,7 +169,11 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: ), f"{input.shape[-1]} != {2 * output.shape[-1]}" -def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: +def silu_and_mul( + input: torch.Tensor, + out: torch.Tensor = None, + enable_pdl: bool = is_arch_support_pdl(), +) -> torch.Tensor: if input.shape[-1] * input.dtype.itemsize % 16 != 0: raise ValueError("The pointers must be multiple of 16 bytes.") if out is not None: @@ -180,11 +184,15 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernel.silu_and_mul.default(out, input) + torch.ops.sgl_kernel.silu_and_mul.default(out, input, enable_pdl) return out -def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: +def gelu_tanh_and_mul( + input: torch.Tensor, + out: torch.Tensor = None, + enable_pdl: bool = is_arch_support_pdl(), +) -> torch.Tensor: if input.shape[-1] * input.dtype.itemsize % 16 != 0: raise ValueError("The pointers must be multiple of 16 bytes.") if out is not None: @@ -195,11 +203,15 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input) + torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input, enable_pdl) return out -def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: +def gelu_and_mul( + input: torch.Tensor, + out: torch.Tensor = None, + enable_pdl: bool = is_arch_support_pdl(), +) -> torch.Tensor: if input.shape[-1] * input.dtype.itemsize % 16 != 0: raise ValueError("The pointers must be multiple of 16 bytes.") if out is not None: @@ -210,7 +222,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: device=input.device, dtype=input.dtype, ) - torch.ops.sgl_kernel.gelu_and_mul.default(out, input) + torch.ops.sgl_kernel.gelu_and_mul.default(out, input, enable_pdl) return out diff --git a/sgl-kernel/python/sgl_kernel/utils.py b/sgl-kernel/python/sgl_kernel/utils.py index d03476eff05a..647f411acd4c 100644 --- a/sgl-kernel/python/sgl_kernel/utils.py +++ b/sgl-kernel/python/sgl_kernel/utils.py @@ -39,6 +39,8 @@ def _to_tensor_scalar_tuple(x): @functools.lru_cache(maxsize=1) def is_arch_support_pdl() -> bool: + if not torch.cuda.is_available(): + return False # Hopper arch's compute capability == 9.0 device = torch.cuda.current_device() major, minor = torch.cuda.get_device_capability(device)