Skip to content

Commit 9482a10

Browse files
Use trtllm mha decode kernel for target_verify in speculative decoding (sgl-project#13976)
Co-authored-by: Kangyan-Zhou <[email protected]>
1 parent d5a76fa commit 9482a10

File tree

1 file changed

+37
-18
lines changed

1 file changed

+37
-18
lines changed

python/sglang/srt/layers/attention/trtllm_mha_backend.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def init_forward_metadata_replay_cuda_graph(
377377
]
378378
page_indices //= self.page_size
379379
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
380+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
380381
elif forward_mode.is_draft_extend():
381382
metadata = self.draft_extend_metadata[bs]
382383
metadata.cache_seqlens_int32.copy_(seq_lens)
@@ -614,24 +615,42 @@ def forward_extend(
614615
bmm1_scale = q_scale * k_scale * layer.scaling
615616
bmm2_scale = 1.0
616617

617-
o = flashinfer.prefill.trtllm_batch_context_with_kv_cache(
618-
query=q,
619-
kv_cache=kv_cache,
620-
workspace_buffer=self.workspace_buffer,
621-
block_tables=self.forward_metadata.page_table,
622-
seq_lens=self.forward_metadata.cache_seqlens_int32,
623-
max_q_len=self.forward_metadata.max_seq_len_q,
624-
max_kv_len=self.max_context_len,
625-
bmm1_scale=bmm1_scale,
626-
bmm2_scale=bmm2_scale,
627-
batch_size=forward_batch.batch_size,
628-
cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
629-
cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
630-
window_left=layer.sliding_window_size,
631-
# TODO: add attention_sink operation or nvfp4 scale factor if needed
632-
sinks=attention_sink,
633-
out_dtype=self.q_data_type, # model_runner.dtype
634-
)
618+
if forward_batch.forward_mode.is_target_verify():
619+
o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
620+
query=q,
621+
kv_cache=kv_cache,
622+
workspace_buffer=self.workspace_buffer,
623+
block_tables=self.forward_metadata.page_table,
624+
seq_lens=self.forward_metadata.cache_seqlens_int32,
625+
max_seq_len=self.max_context_len,
626+
bmm1_scale=bmm1_scale,
627+
bmm2_scale=bmm2_scale,
628+
window_left=layer.sliding_window_size,
629+
# TODO: add attention_sink operation or nvfp4 scale factor if needed
630+
sinks=attention_sink,
631+
out_dtype=self.q_data_type, # model_runner.dtype
632+
q_len_per_req=self.forward_metadata.max_seq_len_q,
633+
)
634+
else:
635+
636+
o = flashinfer.prefill.trtllm_batch_context_with_kv_cache(
637+
query=q,
638+
kv_cache=kv_cache,
639+
workspace_buffer=self.workspace_buffer,
640+
block_tables=self.forward_metadata.page_table,
641+
seq_lens=self.forward_metadata.cache_seqlens_int32,
642+
max_q_len=self.forward_metadata.max_seq_len_q,
643+
max_kv_len=self.max_context_len,
644+
bmm1_scale=bmm1_scale,
645+
bmm2_scale=bmm2_scale,
646+
batch_size=forward_batch.batch_size,
647+
cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
648+
cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
649+
window_left=layer.sliding_window_size,
650+
# TODO: add attention_sink operation or nvfp4 scale factor if needed
651+
sinks=attention_sink,
652+
out_dtype=self.q_data_type, # model_runner.dtype
653+
)
635654

636655
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
637656

0 commit comments

Comments
 (0)