Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ dependencies = [
"scipy",
"sentencepiece",
"setproctitle",
"sgl-kernel==0.3.14.post1",
"sgl-kernel==0.3.15",
"soundfile==0.13.1",
"tiktoken",
"timm==1.0.16",
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject_other.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ tracing = [

srt = [
"sglang[runtime_common]",
"sgl-kernel==0.3.14.post1",
"sgl-kernel==0.3.15",
"torch==2.8.0",
"torchaudio==2.8.0",
"torchvision",
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
assert_pkg_version(
"sgl-kernel",
"0.3.14",
"0.3.15",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
)

Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/layers/attention/attention_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@ def create_flashattention_v3_backend(runner):

@register_attention_backend("fa4")
def create_flashattention_v4_backend(runner):
assert (
runner.use_mla_backend
), "FlashAttention v4 Support is at an early stage, only MLA model supported now"
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend

return FlashAttentionBackend(runner, fa_impl_ver=4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,6 @@ def forward_extend(

# Use Flash Attention for prefill
if not self.use_mla:
assert self.fa_impl_ver in [3], "Only FA3 support here"
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
Expand Down
12 changes: 3 additions & 9 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1746,16 +1746,10 @@ def init_attention_backend(self):

def _get_attention_backend(self):
"""Init attention kernel backend."""
self.decode_attention_backend_str = (
self.server_args.decode_attention_backend
if self.server_args.decode_attention_backend
else self.server_args.attention_backend
)
self.prefill_attention_backend_str = (
self.server_args.prefill_attention_backend
if self.server_args.prefill_attention_backend
else self.server_args.attention_backend
self.prefill_attention_backend_str, self.decode_attention_backend_str = (
self.server_args.get_attention_backends()
)

if self.decode_attention_backend_str != self.prefill_attention_backend_str:
from sglang.srt.layers.attention.hybrid_attn_backend import (
HybridAttnBackend,
Expand Down
35 changes: 28 additions & 7 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,19 @@ class ServerArgs:
enable_pdmux: bool = False
sm_group_num: int = 3

def get_attention_backends(server_args):
prefill_attention_backend_str = (
server_args.prefill_attention_backend
if server_args.prefill_attention_backend
else server_args.attention_backend
)
decode_attention_backend_str = (
server_args.decode_attention_backend
if server_args.decode_attention_backend
else server_args.attention_backend
)
return prefill_attention_backend_str, decode_attention_backend_str

def __post_init__(self):
"""
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
Expand Down Expand Up @@ -740,20 +753,28 @@ def _handle_model_specific_adjustments(self):
hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0]
if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None:
if (
self.attention_backend is None
and self.prefill_attention_backend is None
and self.decode_attention_backend is None
):
if is_cuda() and is_sm100_supported():
self.attention_backend = "trtllm_mha"
elif is_cuda() and is_sm90_supported():
self.attention_backend = "fa3"
else:
self.attention_backend = "triton"
supported_backends = ["triton", "trtllm_mha", "fa3"]
logger.info(
f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
)

supported_backends = ["triton", "trtllm_mha", "fa3", "fa4"]
prefill_attn_backend, decode_attn_backend = self.get_attention_backends()
assert (
self.attention_backend in supported_backends
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
prefill_attn_backend in supported_backends
and decode_attn_backend in supported_backends
), (
f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got the following backends\n"
f"- Prefill: {prefill_attn_backend}\n"
f"- Decode: {decode_attn_backend}\n"
)

if is_sm100_supported():
if not self.enable_dp_attention:
Expand Down
Loading