3232 ParallelLMHead ,
3333 VocabParallelEmbedding ,
3434)
35+ from sglang .srt .model_executor .cuda_graph_runner import get_is_capture_mode
3536from sglang .srt .model_executor .forward_batch_info import ForwardBatch , PPProxyTensors
3637from sglang .srt .model_loader .weight_utils import (
3738 default_weight_loader ,
@@ -52,6 +53,7 @@ def __init__(
5253 quant_config : Optional [QuantizationConfig ] = None ,
5354 prefix : str = "" ,
5455 layer_idx : int = 0 ,
56+ alt_stream : Optional [torch .cuda .Stream ] = None ,
5557 ):
5658 super ().__init__ ()
5759 hidden_size = config .hidden_size
@@ -63,6 +65,7 @@ def __init__(
6365 self .routed_scaling_factor = config .routed_scaling_factor
6466 self .num_shared_experts = config .num_shared_experts
6567 self .layer_idx = layer_idx
68+ self .alt_stream = alt_stream
6669
6770 if config .hidden_act != "silu" :
6871 raise ValueError (
@@ -120,11 +123,34 @@ def __init__(
120123 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
121124 num_tokens , hidden_size = hidden_states .shape
122125 hidden_states = hidden_states .view (- 1 , hidden_size )
123- if self .num_shared_experts is not None :
124- shared_output = self .shared_experts (hidden_states )
125- router_logits , _ = self .gate (hidden_states )
126- topk_output = self .topk (hidden_states , router_logits )
127- final_hidden_states = self .experts (hidden_states , topk_output )
126+
127+ shared_output = None
128+ DUAL_STREAM_TOKEN_THRESHOLD = 1024
129+
130+ if (
131+ self .alt_stream is not None
132+ and self .num_shared_experts is not None
133+ and hidden_states .shape [0 ] > 0
134+ and hidden_states .shape [0 ] <= DUAL_STREAM_TOKEN_THRESHOLD
135+ and get_is_capture_mode ()
136+ ):
137+ current_stream = torch .cuda .current_stream ()
138+ self .alt_stream .wait_stream (current_stream )
139+
140+ shared_output = self .shared_experts (hidden_states .clone ())
141+
142+ with torch .cuda .stream (self .alt_stream ):
143+ router_logits , _ = self .gate (hidden_states )
144+ topk_output = self .topk (hidden_states , router_logits )
145+ final_hidden_states = self .experts (hidden_states , topk_output )
146+
147+ current_stream .wait_stream (self .alt_stream )
148+ else :
149+ if self .num_shared_experts is not None and hidden_states .shape [0 ] > 0 :
150+ shared_output = self .shared_experts (hidden_states )
151+ router_logits , _ = self .gate (hidden_states )
152+ topk_output = self .topk (hidden_states , router_logits )
153+ final_hidden_states = self .experts (hidden_states , topk_output )
128154
129155 if shared_output is not None :
130156 final_hidden_states = final_hidden_states + shared_output
@@ -334,9 +360,11 @@ def __init__(
334360 layer_idx : int ,
335361 quant_config : Optional [QuantizationConfig ] = None ,
336362 prefix : str = "" ,
363+ alt_stream : Optional [torch .cuda .Stream ] = None ,
337364 ) -> None :
338365 super ().__init__ ()
339366 self .hidden_size = config .hidden_size
367+ self .alt_stream = alt_stream
340368
341369 self .is_moe = config .is_moe
342370
@@ -375,6 +403,7 @@ def __init__(
375403 quant_config = quant_config ,
376404 layer_idx = layer_idx ,
377405 prefix = f"{ prefix } .mlp" ,
406+ alt_stream = self .alt_stream ,
378407 )
379408 self .mlp = self .block_sparse_moe
380409 else :
@@ -442,13 +471,16 @@ def __init__(
442471 else :
443472 self .embed_tokens = PPMissingLayer ()
444473
474+ self .alt_stream = torch .cuda .Stream ()
475+
445476 self .layers , self .start_layer , self .end_layer = make_layers (
446477 config .num_hidden_layers ,
447478 lambda idx , prefix : KimiDecoderLayer (
448479 layer_idx = idx ,
449480 config = config ,
450481 quant_config = quant_config ,
451482 prefix = prefix ,
483+ alt_stream = self .alt_stream ,
452484 ),
453485 pp_rank = self .pp_group .rank_in_group ,
454486 pp_size = self .pp_group .world_size ,
0 commit comments