diff --git a/python/pyproject.toml b/python/pyproject.toml index 04b994ea6da4..78ee0041ab8b 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -53,7 +53,7 @@ dependencies = [ "scipy", "sentencepiece", "setproctitle", - "sgl-kernel==0.3.14.post1", + "sgl-kernel==0.3.15", "soundfile==0.13.1", "tiktoken", "timm==1.0.16", diff --git a/python/pyproject_other.toml b/python/pyproject_other.toml index e4de2303982b..4d20b593b942 100755 --- a/python/pyproject_other.toml +++ b/python/pyproject_other.toml @@ -65,7 +65,7 @@ tracing = [ srt = [ "sglang[runtime_common]", - "sgl-kernel==0.3.14.post1", + "sgl-kernel==0.3.15", "torch==2.8.0", "torchaudio==2.8.0", "torchvision", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index cbaa0d04a313..7f5d74302ae9 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs): if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"): assert_pkg_version( "sgl-kernel", - "0.3.14", + "0.3.15", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", ) diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index 0ec435d6fb6f..77d8e2eb6a77 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -129,9 +129,6 @@ def create_flashattention_v3_backend(runner): @register_attention_backend("fa4") def create_flashattention_v4_backend(runner): - assert ( - runner.use_mla_backend - ), "FlashAttention v4 Support is at an early stage, only MLA model supported now" from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend return FlashAttentionBackend(runner, fa_impl_ver=4) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 1deb9033cce9..279a6dbd5050 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -754,7 +754,6 @@ def forward_extend( # Use Flash Attention for prefill if not self.use_mla: - assert self.fa_impl_ver in [3], "Only FA3 support here" # Do multi-head attention key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e92fe4250f60..b3d2d1e67c61 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1746,16 +1746,10 @@ def init_attention_backend(self): def _get_attention_backend(self): """Init attention kernel backend.""" - self.decode_attention_backend_str = ( - self.server_args.decode_attention_backend - if self.server_args.decode_attention_backend - else self.server_args.attention_backend - ) - self.prefill_attention_backend_str = ( - self.server_args.prefill_attention_backend - if self.server_args.prefill_attention_backend - else self.server_args.attention_backend + self.prefill_attention_backend_str, self.decode_attention_backend_str = ( + self.server_args.get_attention_backends() ) + if self.decode_attention_backend_str != self.prefill_attention_backend_str: from sglang.srt.layers.attention.hybrid_attn_backend import ( HybridAttnBackend, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8c955aabc049..39e8bf5a714f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -456,6 +456,19 @@ class ServerArgs: enable_pdmux: bool = False sm_group_num: int = 3 + def get_attention_backends(server_args): + prefill_attention_backend_str = ( + server_args.prefill_attention_backend + if server_args.prefill_attention_backend + else server_args.attention_backend + ) + decode_attention_backend_str = ( + server_args.decode_attention_backend + if server_args.decode_attention_backend + else server_args.attention_backend + ) + return prefill_attention_backend_str, decode_attention_backend_str + def __post_init__(self): """ Orchestrates the handling of various server arguments, ensuring proper configuration and validation. @@ -740,20 +753,28 @@ def _handle_model_specific_adjustments(self): hf_config = self.get_hf_config() model_arch = hf_config.architectures[0] if model_arch in ["GptOssForCausalLM"]: - if self.attention_backend is None: + if ( + self.attention_backend is None + and self.prefill_attention_backend is None + and self.decode_attention_backend is None + ): if is_cuda() and is_sm100_supported(): self.attention_backend = "trtllm_mha" elif is_cuda() and is_sm90_supported(): self.attention_backend = "fa3" else: self.attention_backend = "triton" - supported_backends = ["triton", "trtllm_mha", "fa3"] - logger.info( - f"Use {self.attention_backend} as attention backend for GptOssForCausalLM" - ) + + supported_backends = ["triton", "trtllm_mha", "fa3", "fa4"] + prefill_attn_backend, decode_attn_backend = self.get_attention_backends() assert ( - self.attention_backend in supported_backends - ), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'" + prefill_attn_backend in supported_backends + and decode_attn_backend in supported_backends + ), ( + f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got the following backends\n" + f"- Prefill: {prefill_attn_backend}\n" + f"- Decode: {decode_attn_backend}\n" + ) if is_sm100_supported(): if not self.enable_dp_attention: