Skip to content

Commit 8fe3e37

Browse files
authored
Support piecewise cuda graph for dsv3 fp4 (#15531)
1 parent 6014365 commit 8fe3e37

File tree

7 files changed

+148
-16
lines changed

7 files changed

+148
-16
lines changed

python/sglang/srt/layers/attention/trtllm_mla_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import triton
1313
import triton.language as tl
1414

15+
from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph
1516
from sglang.srt.layers.attention.flashinfer_mla_backend import (
1617
FlashInferMLAAttnBackend,
1718
FlashInferMLAMultiStepDraftBackend,
@@ -582,10 +583,11 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
582583
):
583584
# For extend batch with prefix length > 0, fallback to ragged kernel implemented in flashinfer MLA backend
584585
# when chunked prefix cache is disabled.
586+
# Also fallback to flashinfer MLA backend when in piecewise cuda graph, since it only supports MLA forward mode.
585587
has_prefix = any(forward_batch.extend_prefix_lens_cpu)
586588
fallback_to_flashinfer_impl = (
587589
self.disable_chunked_prefix_cache and has_prefix
588-
)
590+
) or is_in_piecewise_cuda_graph()
589591
if fallback_to_flashinfer_impl:
590592
super().init_forward_metadata(forward_batch)
591593

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,13 @@
4141
StandardDispatcher,
4242
StandardDispatchOutput,
4343
)
44-
from sglang.srt.layers.moe.topk import StandardTopKOutput, TopKOutput, TopKOutputChecker
44+
from sglang.srt.layers.moe.topk import (
45+
BypassedTopKOutput,
46+
StandardTopKOutput,
47+
TopKConfig,
48+
TopKOutput,
49+
TopKOutputChecker,
50+
)
4551
from sglang.srt.layers.moe.utils import RoutingMethodType
4652
from sglang.srt.layers.quantization.base_config import (
4753
FusedMoEMethodBase,
@@ -1210,16 +1216,21 @@ def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor):
12101216
return hs_fp4, hs_sf
12111217

12121218
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
1219+
assert TopKOutputChecker.format_is_bypassed(
1220+
topk_output
1221+
), "Only bypassed topk output is supported for flashinfer fp4 moe"
1222+
12131223
if is_in_piecewise_cuda_graph():
1214-
assert TopKOutputChecker.format_is_standard(
1215-
topk_output
1216-
), "Only standard topk output is supported for piecewise cuda graph"
1217-
return torch.ops.sglang.moe_forward_piecewise_cuda_graph_impl(
1218-
hidden_states,
1219-
topk_output.topk_weights,
1220-
topk_output.topk_ids,
1221-
topk_output.router_logits,
1222-
self.layer_id,
1224+
return (
1225+
torch.ops.sglang.flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl(
1226+
hidden_states,
1227+
topk_output.router_logits,
1228+
topk_output.topk_config.top_k,
1229+
topk_output.topk_config.topk_group,
1230+
topk_output.topk_config.num_expert_group,
1231+
topk_output.topk_config.correction_bias,
1232+
self.layer_id,
1233+
)
12231234
)
12241235
else:
12251236
return self.forward_impl(hidden_states, topk_output)
@@ -1343,9 +1354,52 @@ def moe_forward_piecewise_cuda_graph_impl_fake(
13431354
return torch.empty_like(hidden_states)
13441355

13451356

1357+
def flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl(
1358+
hidden_states: torch.Tensor,
1359+
router_logits: torch.Tensor,
1360+
top_k: int,
1361+
topk_group: Optional[int],
1362+
num_expert_group: Optional[int],
1363+
correction_bias: Optional[torch.Tensor],
1364+
layer_id: int,
1365+
) -> torch.Tensor:
1366+
topk_output = BypassedTopKOutput(
1367+
hidden_states=hidden_states,
1368+
router_logits=router_logits,
1369+
topk_config=TopKConfig(
1370+
top_k=top_k,
1371+
topk_group=topk_group,
1372+
num_expert_group=num_expert_group,
1373+
correction_bias=correction_bias,
1374+
),
1375+
)
1376+
forward_context = get_forward_context()
1377+
moe_layer = forward_context.moe_layers[layer_id]
1378+
return moe_layer.forward_impl(hidden_states, topk_output)
1379+
1380+
1381+
def flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl_fake(
1382+
hidden_states: torch.Tensor,
1383+
router_logits: torch.Tensor,
1384+
top_k: int,
1385+
topk_group: Optional[int],
1386+
num_expert_group: Optional[int],
1387+
correction_bias: Optional[torch.Tensor],
1388+
layer_id: int,
1389+
) -> torch.Tensor:
1390+
return torch.empty_like(hidden_states)
1391+
1392+
13461393
direct_register_custom_op(
13471394
op_name="moe_forward_piecewise_cuda_graph_impl",
13481395
op_func=moe_forward_piecewise_cuda_graph_impl,
13491396
mutates_args=[],
13501397
fake_impl=moe_forward_piecewise_cuda_graph_impl_fake,
13511398
)
1399+
1400+
direct_register_custom_op(
1401+
op_name="flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl",
1402+
op_func=flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl,
1403+
mutates_args=[],
1404+
fake_impl=flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl_fake,
1405+
)

python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
)
1818
from sglang.srt.layers.quantization.modelopt_quant import (
1919
FLASHINFER_FP4_GEMM_BACKEND,
20-
_sglang_fp4_gemm,
2120
enable_flashinfer_fp4_gemm,
2221
fp4_quantize,
2322
)
@@ -154,7 +153,7 @@ def apply_weights(
154153
w = layer.weight_packed.T
155154
w_blockscale = layer.weight_scale.T
156155

157-
out = _sglang_fp4_gemm(
156+
out = torch.ops.sglang.fp4_gemm(
158157
x_fp4,
159158
w,
160159
x_blockscale,

python/sglang/srt/layers/quantization/modelopt_quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,7 @@ def apply(
12291229
backend = (
12301230
FLASHINFER_FP4_GEMM_BACKEND if FLASHINFER_FP4_GEMM_BACKEND else "cutlass"
12311231
)
1232-
out = _sglang_fp4_gemm(
1232+
out = torch.ops.sglang.fp4_gemm(
12331233
x_fp4,
12341234
w,
12351235
x_scale_interleaved,

python/sglang/srt/models/deepseek_v2.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import concurrent.futures
2121
import logging
2222
import os
23+
from contextlib import nullcontext
2324
from enum import IntEnum, auto
2425
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
2526

@@ -400,6 +401,9 @@ def handle_attention_fa4(attn, forward_batch):
400401

401402

402403
def handle_attention_trtllm_mla(attn, forward_batch):
404+
if is_in_piecewise_cuda_graph():
405+
return AttnForwardMethod.MLA
406+
403407
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
404408
if forward_batch.forward_mode.is_extend_without_speculative() and (
405409
not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
@@ -3188,7 +3192,13 @@ def forward(
31883192
normal_end_layer = normal_start_layer = 0
31893193
aux_hidden_states = []
31903194
for i in range(normal_start_layer, normal_end_layer):
3191-
with get_global_expert_distribution_recorder().with_current_layer(i):
3195+
# NOTE: torch dynamo does not support graph break in context manager
3196+
ctx = (
3197+
nullcontext()
3198+
if get_global_server_args().enable_piecewise_cuda_graph
3199+
else get_global_expert_distribution_recorder().with_current_layer(i)
3200+
)
3201+
with ctx:
31923202
if i in self.layers_to_capture:
31933203
aux_hidden_states.append(hidden_states + residual)
31943204
layer = self.layers[i]

test/srt/run_suite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@
163163
TestFile("test_disaggregation_dp_attention.py", 155),
164164
],
165165
"per-commit-4-gpu-b200-stage-b": [
166-
TestFile("test_deepseek_v3_fp4_4gpu.py", 1800), # Stage B test
166+
TestFile("test_deepseek_v3_fp4_4gpu.py", 2000), # Stage B test
167167
],
168168
"per-commit-4-gpu-b200": [
169169
TestFile("test_flash_attention_4.py", 90),

test/srt/test_deepseek_v3_fp4_4gpu.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,5 +176,72 @@ def test_bs_1_speed(self):
176176
self.assertGreater(speed, 150)
177177

178178

179+
class TestDeepseekV3FP4PiecewiseCudaGraph(CustomTestCase):
180+
@classmethod
181+
def setUpClass(cls):
182+
cls.model = FULL_DEEPSEEK_V3_FP4_MODEL_PATH
183+
cls.base_url = DEFAULT_URL_FOR_TEST
184+
other_args = [
185+
"--tp",
186+
"4",
187+
"--attention-backend",
188+
"trtllm_mla",
189+
"--moe-runner-backend",
190+
"flashinfer_trtllm",
191+
"--quantization",
192+
"modelopt_fp4",
193+
"--enable-piecewise-cuda-graph",
194+
"--kv-cache-dtype",
195+
"fp8_e4m3",
196+
"--model-loader-extra-config",
197+
'{"enable_multithread_load": true,"num_threads": 64}',
198+
]
199+
cls.process = popen_launch_server(
200+
cls.model,
201+
cls.base_url,
202+
timeout=SERVER_LAUNCH_TIMEOUT,
203+
other_args=other_args,
204+
)
205+
206+
@classmethod
207+
def tearDownClass(cls):
208+
kill_process_tree(cls.process.pid)
209+
210+
def test_a_gsm8k(
211+
self,
212+
):
213+
args = SimpleNamespace(
214+
num_shots=8,
215+
data_path=None,
216+
num_questions=1319,
217+
parallel=1319,
218+
max_new_tokens=512,
219+
host="http://127.0.0.1",
220+
port=int(self.base_url.split(":")[-1]),
221+
)
222+
metrics = run_eval_few_shot_gsm8k(args)
223+
print(f"{metrics=}")
224+
225+
if is_in_ci():
226+
write_github_step_summary(
227+
f"### test_gsm8k (deepseek-v3-fp4)\n" f'{metrics["accuracy"]=:.3f}\n'
228+
)
229+
230+
self.assertGreater(metrics["accuracy"], 0.935)
231+
232+
def test_bs_1_speed(self):
233+
args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048)
234+
_, speed = send_one_prompt(args)
235+
236+
print(f"{speed=:.2f}")
237+
238+
if is_in_ci():
239+
write_github_step_summary(
240+
f"### test_bs_1_speed (deepseek-v3-fp4)\n" f"{speed=:.2f} token/s\n"
241+
)
242+
243+
self.assertGreater(speed, 120)
244+
245+
179246
if __name__ == "__main__":
180247
unittest.main()

0 commit comments

Comments
 (0)