|
41 | 41 | StandardDispatcher, |
42 | 42 | StandardDispatchOutput, |
43 | 43 | ) |
44 | | -from sglang.srt.layers.moe.topk import StandardTopKOutput, TopKOutput, TopKOutputChecker |
| 44 | +from sglang.srt.layers.moe.topk import ( |
| 45 | + BypassedTopKOutput, |
| 46 | + StandardTopKOutput, |
| 47 | + TopKConfig, |
| 48 | + TopKOutput, |
| 49 | + TopKOutputChecker, |
| 50 | +) |
45 | 51 | from sglang.srt.layers.moe.utils import RoutingMethodType |
46 | 52 | from sglang.srt.layers.quantization.base_config import ( |
47 | 53 | FusedMoEMethodBase, |
@@ -1210,16 +1216,21 @@ def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor): |
1210 | 1216 | return hs_fp4, hs_sf |
1211 | 1217 |
|
1212 | 1218 | def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): |
| 1219 | + assert TopKOutputChecker.format_is_bypassed( |
| 1220 | + topk_output |
| 1221 | + ), "Only bypassed topk output is supported for flashinfer fp4 moe" |
| 1222 | + |
1213 | 1223 | if is_in_piecewise_cuda_graph(): |
1214 | | - assert TopKOutputChecker.format_is_standard( |
1215 | | - topk_output |
1216 | | - ), "Only standard topk output is supported for piecewise cuda graph" |
1217 | | - return torch.ops.sglang.moe_forward_piecewise_cuda_graph_impl( |
1218 | | - hidden_states, |
1219 | | - topk_output.topk_weights, |
1220 | | - topk_output.topk_ids, |
1221 | | - topk_output.router_logits, |
1222 | | - self.layer_id, |
| 1224 | + return ( |
| 1225 | + torch.ops.sglang.flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl( |
| 1226 | + hidden_states, |
| 1227 | + topk_output.router_logits, |
| 1228 | + topk_output.topk_config.top_k, |
| 1229 | + topk_output.topk_config.topk_group, |
| 1230 | + topk_output.topk_config.num_expert_group, |
| 1231 | + topk_output.topk_config.correction_bias, |
| 1232 | + self.layer_id, |
| 1233 | + ) |
1223 | 1234 | ) |
1224 | 1235 | else: |
1225 | 1236 | return self.forward_impl(hidden_states, topk_output) |
@@ -1343,9 +1354,52 @@ def moe_forward_piecewise_cuda_graph_impl_fake( |
1343 | 1354 | return torch.empty_like(hidden_states) |
1344 | 1355 |
|
1345 | 1356 |
|
| 1357 | +def flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl( |
| 1358 | + hidden_states: torch.Tensor, |
| 1359 | + router_logits: torch.Tensor, |
| 1360 | + top_k: int, |
| 1361 | + topk_group: Optional[int], |
| 1362 | + num_expert_group: Optional[int], |
| 1363 | + correction_bias: Optional[torch.Tensor], |
| 1364 | + layer_id: int, |
| 1365 | +) -> torch.Tensor: |
| 1366 | + topk_output = BypassedTopKOutput( |
| 1367 | + hidden_states=hidden_states, |
| 1368 | + router_logits=router_logits, |
| 1369 | + topk_config=TopKConfig( |
| 1370 | + top_k=top_k, |
| 1371 | + topk_group=topk_group, |
| 1372 | + num_expert_group=num_expert_group, |
| 1373 | + correction_bias=correction_bias, |
| 1374 | + ), |
| 1375 | + ) |
| 1376 | + forward_context = get_forward_context() |
| 1377 | + moe_layer = forward_context.moe_layers[layer_id] |
| 1378 | + return moe_layer.forward_impl(hidden_states, topk_output) |
| 1379 | + |
| 1380 | + |
| 1381 | +def flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl_fake( |
| 1382 | + hidden_states: torch.Tensor, |
| 1383 | + router_logits: torch.Tensor, |
| 1384 | + top_k: int, |
| 1385 | + topk_group: Optional[int], |
| 1386 | + num_expert_group: Optional[int], |
| 1387 | + correction_bias: Optional[torch.Tensor], |
| 1388 | + layer_id: int, |
| 1389 | +) -> torch.Tensor: |
| 1390 | + return torch.empty_like(hidden_states) |
| 1391 | + |
| 1392 | + |
1346 | 1393 | direct_register_custom_op( |
1347 | 1394 | op_name="moe_forward_piecewise_cuda_graph_impl", |
1348 | 1395 | op_func=moe_forward_piecewise_cuda_graph_impl, |
1349 | 1396 | mutates_args=[], |
1350 | 1397 | fake_impl=moe_forward_piecewise_cuda_graph_impl_fake, |
1351 | 1398 | ) |
| 1399 | + |
| 1400 | +direct_register_custom_op( |
| 1401 | + op_name="flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl", |
| 1402 | + op_func=flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl, |
| 1403 | + mutates_args=[], |
| 1404 | + fake_impl=flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl_fake, |
| 1405 | +) |
0 commit comments