Skip to content

Commit cc2e36c

Browse files
authored
overlap shared + routed expert computation in kimi linear (#12660)
1 parent d8f7816 commit cc2e36c

File tree

1 file changed

+37
-5
lines changed

1 file changed

+37
-5
lines changed

python/sglang/srt/models/kimi_linear.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
ParallelLMHead,
3333
VocabParallelEmbedding,
3434
)
35+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
3536
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
3637
from sglang.srt.model_loader.weight_utils import (
3738
default_weight_loader,
@@ -52,6 +53,7 @@ def __init__(
5253
quant_config: Optional[QuantizationConfig] = None,
5354
prefix: str = "",
5455
layer_idx: int = 0,
56+
alt_stream: Optional[torch.cuda.Stream] = None,
5557
):
5658
super().__init__()
5759
hidden_size = config.hidden_size
@@ -63,6 +65,7 @@ def __init__(
6365
self.routed_scaling_factor = config.routed_scaling_factor
6466
self.num_shared_experts = config.num_shared_experts
6567
self.layer_idx = layer_idx
68+
self.alt_stream = alt_stream
6669

6770
if config.hidden_act != "silu":
6871
raise ValueError(
@@ -120,11 +123,34 @@ def __init__(
120123
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
121124
num_tokens, hidden_size = hidden_states.shape
122125
hidden_states = hidden_states.view(-1, hidden_size)
123-
if self.num_shared_experts is not None:
124-
shared_output = self.shared_experts(hidden_states)
125-
router_logits, _ = self.gate(hidden_states)
126-
topk_output = self.topk(hidden_states, router_logits)
127-
final_hidden_states = self.experts(hidden_states, topk_output)
126+
127+
shared_output = None
128+
DUAL_STREAM_TOKEN_THRESHOLD = 1024
129+
130+
if (
131+
self.alt_stream is not None
132+
and self.num_shared_experts is not None
133+
and hidden_states.shape[0] > 0
134+
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
135+
and get_is_capture_mode()
136+
):
137+
current_stream = torch.cuda.current_stream()
138+
self.alt_stream.wait_stream(current_stream)
139+
140+
shared_output = self.shared_experts(hidden_states.clone())
141+
142+
with torch.cuda.stream(self.alt_stream):
143+
router_logits, _ = self.gate(hidden_states)
144+
topk_output = self.topk(hidden_states, router_logits)
145+
final_hidden_states = self.experts(hidden_states, topk_output)
146+
147+
current_stream.wait_stream(self.alt_stream)
148+
else:
149+
if self.num_shared_experts is not None and hidden_states.shape[0] > 0:
150+
shared_output = self.shared_experts(hidden_states)
151+
router_logits, _ = self.gate(hidden_states)
152+
topk_output = self.topk(hidden_states, router_logits)
153+
final_hidden_states = self.experts(hidden_states, topk_output)
128154

129155
if shared_output is not None:
130156
final_hidden_states = final_hidden_states + shared_output
@@ -334,9 +360,11 @@ def __init__(
334360
layer_idx: int,
335361
quant_config: Optional[QuantizationConfig] = None,
336362
prefix: str = "",
363+
alt_stream: Optional[torch.cuda.Stream] = None,
337364
) -> None:
338365
super().__init__()
339366
self.hidden_size = config.hidden_size
367+
self.alt_stream = alt_stream
340368

341369
self.is_moe = config.is_moe
342370

@@ -375,6 +403,7 @@ def __init__(
375403
quant_config=quant_config,
376404
layer_idx=layer_idx,
377405
prefix=f"{prefix}.mlp",
406+
alt_stream=self.alt_stream,
378407
)
379408
self.mlp = self.block_sparse_moe
380409
else:
@@ -442,13 +471,16 @@ def __init__(
442471
else:
443472
self.embed_tokens = PPMissingLayer()
444473

474+
self.alt_stream = torch.cuda.Stream()
475+
445476
self.layers, self.start_layer, self.end_layer = make_layers(
446477
config.num_hidden_layers,
447478
lambda idx, prefix: KimiDecoderLayer(
448479
layer_idx=idx,
449480
config=config,
450481
quant_config=quant_config,
451482
prefix=prefix,
483+
alt_stream=self.alt_stream,
452484
),
453485
pp_rank=self.pp_group.rank_in_group,
454486
pp_size=self.pp_group.world_size,

0 commit comments

Comments
 (0)