Skip to content
2 changes: 1 addition & 1 deletion benchmark/mmmu/bench_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def eval_mmmu(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = add_common_sglang_args_and_parse(parser)
EvalArgs.add_cli_args(parser)
args = add_common_sglang_args_and_parse(parser)
args = parser.parse_args()

eval_mmmu(args)
118 changes: 111 additions & 7 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ class FlashAttentionMetadata:
# Page table, the index of KV Cache Tables/Blocks
page_table: torch.Tensor = None

# Encoder metadata
# Cumulative sequence lengths for encoder key
encoder_cu_seqlens_k: torch.Tensor = None
# Maximum sequence length for encoder key
encoder_max_seq_len_k: int = 0
# Sequence lengths for the forward batch
encoder_lens_int32: torch.Tensor = None
# Page table for the encoder
encoder_page_table: torch.Tensor = None

@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor = None # cu_seqlens_q for local attention
Expand Down Expand Up @@ -435,6 +445,30 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
)
metadata.local_attn_metadata = local_metadata

# Encoder metadata for cross attention
if forward_batch.encoder_lens is not None:
assert (
forward_batch.encoder_lens.numel() == 1
), "Only encoder size 1 is supported for now"

metadata.encoder_lens_int32 = forward_batch.encoder_lens.to(torch.int32)
metadata.encoder_cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
(1, 0),
)
metadata.encoder_max_seq_len_k = metadata.encoder_lens_int32.max().item()
metadata.encoder_page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.encoder_max_seq_len_k
]

# Currently only support forward_batch.encoder_lens.numel() == 1
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices,
metadata.encoder_max_seq_len_k : (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]

# Convert the page table to a strided format which is needed by FA3 API
if self.page_size > 1:
self.strided_indices = torch.arange(
Expand Down Expand Up @@ -486,6 +520,7 @@ def forward_extend(
if layer.sliding_window_size is not None
else (-1, -1)
)
causal = not layer.is_cross_attention

# Check if we should use local attention
use_local_attn = (
Expand Down Expand Up @@ -521,6 +556,12 @@ def forward_extend(
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
if layer.is_cross_attention:
page_table = metadata.encoder_page_table
cache_seqlens = metadata.encoder_lens_int32
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)

o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
Expand All @@ -531,7 +572,7 @@ def forward_extend(
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
causal=causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
Expand Down Expand Up @@ -614,6 +655,7 @@ def forward_decode(
if layer.sliding_window_size is not None
else (-1, -1)
)
causal = not layer.is_cross_attention

if not self.use_mla:
# Do multi-head attention
Expand All @@ -627,17 +669,27 @@ def forward_decode(
)

q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
if layer.is_cross_attention:
page_table = metadata.encoder_page_table
cache_seqlens = metadata.encoder_lens_int32
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
else:
page_table = metadata.page_table
cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k

o = flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=metadata.page_table,
cache_seqlens=metadata.cache_seqlens_int32,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=True,
causal=causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
Expand Down Expand Up @@ -733,6 +785,21 @@ def init_cuda_graph_state(self, max_bs: int):
),
}

self.encoder_metadata = {
"encoder_page_table": torch.zeros(
max_bs,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
"encoder_lens_int32": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"encoder_cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
}

def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
Expand Down Expand Up @@ -818,6 +885,19 @@ def init_forward_metadata_capture_cuda_graph(

self.target_verify_metadata[bs] = metadata

if encoder_lens is not None:
encoder_bs = encoder_lens.numel()
metadata.encoder_lens_int32 = self.encoder_metadata["encoder_lens_int32"][
:encoder_bs
]
metadata.encoder_cu_seqlens_k = self.encoder_metadata[
"encoder_cu_seqlens_k"
][: (encoder_bs + 1)]

metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
req_pool_indices, :
]

self.forward_metadata = metadata

def init_forward_metadata_replay_cuda_graph(
Expand Down Expand Up @@ -903,6 +983,30 @@ def init_forward_metadata_replay_cuda_graph(
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)

if encoder_lens is not None:
# Only support encoder size 1 for now
metadata.encoder_max_seq_len_k = encoder_lens[0]
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
metadata.encoder_cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
(1, 0),
)
)

metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
)

# Update the regular page table
page_table = self.req_to_token[
req_pool_indices,
metadata.encoder_max_seq_len_k : (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)

self.forward_metadata = metadata

def get_cuda_graph_seq_len_fill_value(self):
Expand Down Expand Up @@ -956,7 +1060,7 @@ def init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
encoder_lens=forward_batch.encoder_lens,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
Expand All @@ -973,7 +1077,7 @@ def init_forward_metadata_replay_cuda_graph(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
encoder_lens=None,
encoder_lens=forward_batch.encoder_lens,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def init_attention_backend(self):
"Please use `--attention-backend flashinfer`."
)
logger.warning(
"FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
"FlashAttention v3 Backend is in Beta. FP8 is not supported."
)
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
Expand Down
Loading