@@ -377,6 +377,7 @@ def init_forward_metadata_replay_cuda_graph(
377377 ]
378378 page_indices //= self .page_size
379379 metadata .page_table [:, :max_seq_pages ].copy_ (page_indices )
380+ metadata .max_seq_len_q = self .speculative_num_draft_tokens
380381 elif forward_mode .is_draft_extend ():
381382 metadata = self .draft_extend_metadata [bs ]
382383 metadata .cache_seqlens_int32 .copy_ (seq_lens )
@@ -614,24 +615,42 @@ def forward_extend(
614615 bmm1_scale = q_scale * k_scale * layer .scaling
615616 bmm2_scale = 1.0
616617
617- o = flashinfer .prefill .trtllm_batch_context_with_kv_cache (
618- query = q ,
619- kv_cache = kv_cache ,
620- workspace_buffer = self .workspace_buffer ,
621- block_tables = self .forward_metadata .page_table ,
622- seq_lens = self .forward_metadata .cache_seqlens_int32 ,
623- max_q_len = self .forward_metadata .max_seq_len_q ,
624- max_kv_len = self .max_context_len ,
625- bmm1_scale = bmm1_scale ,
626- bmm2_scale = bmm2_scale ,
627- batch_size = forward_batch .batch_size ,
628- cum_seq_lens_q = self .forward_metadata .cu_seqlens_q ,
629- cum_seq_lens_kv = self .forward_metadata .cu_seqlens_k ,
630- window_left = layer .sliding_window_size ,
631- # TODO: add attention_sink operation or nvfp4 scale factor if needed
632- sinks = attention_sink ,
633- out_dtype = self .q_data_type , # model_runner.dtype
634- )
618+ if forward_batch .forward_mode .is_target_verify ():
619+ o = flashinfer .decode .trtllm_batch_decode_with_kv_cache (
620+ query = q ,
621+ kv_cache = kv_cache ,
622+ workspace_buffer = self .workspace_buffer ,
623+ block_tables = self .forward_metadata .page_table ,
624+ seq_lens = self .forward_metadata .cache_seqlens_int32 ,
625+ max_seq_len = self .max_context_len ,
626+ bmm1_scale = bmm1_scale ,
627+ bmm2_scale = bmm2_scale ,
628+ window_left = layer .sliding_window_size ,
629+ # TODO: add attention_sink operation or nvfp4 scale factor if needed
630+ sinks = attention_sink ,
631+ out_dtype = self .q_data_type , # model_runner.dtype
632+ q_len_per_req = self .forward_metadata .max_seq_len_q ,
633+ )
634+ else :
635+
636+ o = flashinfer .prefill .trtllm_batch_context_with_kv_cache (
637+ query = q ,
638+ kv_cache = kv_cache ,
639+ workspace_buffer = self .workspace_buffer ,
640+ block_tables = self .forward_metadata .page_table ,
641+ seq_lens = self .forward_metadata .cache_seqlens_int32 ,
642+ max_q_len = self .forward_metadata .max_seq_len_q ,
643+ max_kv_len = self .max_context_len ,
644+ bmm1_scale = bmm1_scale ,
645+ bmm2_scale = bmm2_scale ,
646+ batch_size = forward_batch .batch_size ,
647+ cum_seq_lens_q = self .forward_metadata .cu_seqlens_q ,
648+ cum_seq_lens_kv = self .forward_metadata .cu_seqlens_k ,
649+ window_left = layer .sliding_window_size ,
650+ # TODO: add attention_sink operation or nvfp4 scale factor if needed
651+ sinks = attention_sink ,
652+ out_dtype = self .q_data_type , # model_runner.dtype
653+ )
635654
636655 return o .view (- 1 , layer .tp_q_head_num * layer .head_dim )
637656
0 commit comments