Skip to content

Feat: Integrate FBGEMM into EPMoE#7123

Open
yuan-luo wants to merge 7 commits intosgl-project:mainfrom
yuan-luo:integrate_fbgemm
Open

Feat: Integrate FBGEMM into EPMoE#7123
yuan-luo wants to merge 7 commits intosgl-project:mainfrom
yuan-luo:integrate_fbgemm

Conversation

@yuan-luo
Copy link
Collaborator

@yuan-luo yuan-luo commented Jun 12, 2025

Motivation

Background

There are three categories of MoE modes in SGLang. Each mode goes to different GEMM kernel.
1. TP only mode
    TP only uses fused kernel. The implementation is in fused_moe().
    select_experts()
    fused_experts()-->
             inplace_fused_experts()
             outplace_fused_experts()
                     fused_experts_impl()-->fused_moe_kernel()
    We have plan to integrate FBGEMM into TP only mode.
    
2. EPMoE mode
    EPMoE use GroupGemmRunner, which has 2 GEMM implementations for the moment, 
    this PR introduce the third type GEMM:
        a. FlashInfer
        b. group_gemm_triton (SGLang Triton)
        c. FBGEMM 
    Currently this PR involves normal part, not considering fp8. 
    We will cover this part in the following PRs.

3. DeepEPMoE mode
     DeepEPMoE involves several types of GEMMs:
     if deepep_mode==normal:
         if enable deepgemm:
             use forward_deepgemm_contiguous
         else:
             use forward_nomal
     elif deepep_mode == low_latency:
         use forward_deepgemm_mask

Thanks to @BBuf in #6924, introducing Meta's FBGEMM in SGLang.
Now, this PR is to integrate FBGEMM BF16 into EPMoE. In the following PRs we will introduce FBGEMM FP8 and warp specialization features into SGLang EPMoE.

Modifications

This PR handle several issues:

  1. GEMM tensor reshape: SGLang Triton GEMM's shape and FBGEMM's shape are different.
    SGLang tensor shapes:
    a.shape(activation): [M , K]
    b.shape(weight): [G , N , K]
    c.shape(output of each layer): [M , N]
    FBGEMM tensor shapes:
    a.shape(activation): [M , K]
    b.shape(weight): [G/tp_size, N, K]
    c.shape(output of each layer): [M , N]
  2. The efforts to handle different signatures in Triton GEMM and FBGEMM.
    Triton GEMM leverages seg_indptr and weight_indices to calculate the tile-id to the expert-id and the token index range.
    FBGEMM has no mapping of weight_indices and seg_indptr. The method to decide expert-id is m_sizes. FBGEMM flattens b tensor into a [G * N, K] matrix. It requires the expert group with token be aligned with the weight tensor flattened. Moreover, in TP>1, it makes things worse.
  3. CUDA graph capture issue handling due to dynamic tensor shape in TP cases.
    Due to A tensor's baseline point need to be adjusted in TP>1, the A tensor needs to be sliced in order to handle partition-wise GEMM calculation. CUDA graph capture is therefore broken. We introduced an innovative way, through adjusting valid tile-id in newly-introduced Triton kernel, then in reduce-sum phase only copy-back valid portions into C tensor, making sure A tensor's shape is not changed along with TP partitioning, which resolves the CUDA graph capture error.

Checklist

  • Format your code according to the Code Formatting with Pre-Commit.
  • Add unit tests as outlined in the Running Unit Tests.
  • Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
  • Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
  • For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
  • Please feel free to join our Slack channel at https://slack.sglang.ai to discuss your PR.
  • - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ] - [ ]

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request is a work-in-progress effort to integrate the FBGEMM library into the Expert Parallel Mixture of Experts (EPMoE) layer. The goal is to leverage FBGEMM for potentially faster grouped matrix multiplication operations, including support for FP8 quantized weights. This adds an alternative backend for the core GEMM operation within the MoE layer.

Highlights

  • FBGEMM Integration: Integrated FBGEMM as an alternative grouped GEMM implementation within the EPMoE layer, specifically in the GroupedGemmRunner.
  • FP8 Support: Added support for FP8 rowwise quantization when using the FBGEMM grouped GEMM path.
  • Configuration: Introduced a use_fbgemm boolean flag in GroupedGemmRunner to enable or disable the FBGEMM implementation.
  • Typo Fixes: Corrected a consistent typo (preproess to preprocess) in the function name run_moe_ep_preprocess across multiple files.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR integrates FBGEMM into the EPMoE layer. The changes include adding new import paths, updating the GroupedGemmRunner to use FBGEMM, and enabling this path in the EPMoE layer. Typos related to preprocess and wrapper have also been corrected.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should assert block_shape is None now.

@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Jun 12, 2025

The e2e test no pass. Investigating.

$python3 -m sglang.launch_server --model /home/admin/Qwen3-30B-A3B --enable-ep-moe --tp-size 2 --port 30000
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[2].
[2025-06-12 21:28:59] server_args=ServerArgs(model_path='/home/admin/Qwen3-30B-A3B', tokenizer_path='/home/admin/Qwen3-30B-A3B', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='/home/admin/Qwen3-30B-A3B', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, impl='auto', host='127.0.0.1', port=30000, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=2, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=922262220, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=2, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_ep_moe=True, enable_deepep_moe=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
Failed to import ceil_div from deep_gemm.
Failed to import ceil_div from deep_gemm.
[2025-06-12 21:29:06 TP0] Attention backend not set. Use fa3 backend by default.
[2025-06-12 21:29:06 TP0] Init torch distributed begin.
[2025-06-12 21:29:07 TP0] sglang is using nccl==2.21.5
[2025-06-12 21:29:08 TP0] Init torch distributed ends. mem usage=0.81 GB
TMA benchmarks will be running with experimental grid constant TMA descriptor.
TMA benchmarks will be running with experimental grid constant TMA descriptor.
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
[2025-06-12 21:29:09 TP0] Load weight begin. avail mem=94.09 GB
Loading safetensors checkpoint shards:   0% Completed | 0/16 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:   6% Completed | 1/16 [00:00<00:07,  1.96it/s]
Loading safetensors checkpoint shards:  12% Completed | 2/16 [00:01<00:07,  1.90it/s]
......
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
Capturing batches (avail_mem=9.84 GB):   4%|██████▊                                                                                                                                                     | 1/23 [00:12<04:39, 12.71s/it]<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
Capturing batches (avail_mem=8.02 GB):  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎      | 22/23 [00:48<00:03,  3.25s/it]/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
Capturing batches (avail_mem=8.02 GB): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:55<00:00,  2.40s/it]
[2025-06-12 21:30:13 TP0] Registering 2231 cuda graph addresses
[2025-06-12 21:30:13 TP1] Registering 2231 cuda graph addresses
[2025-06-12 21:30:13 TP0] Capture cuda graph end. Time elapsed: 55.56 s. mem usage=2.43 GB. avail mem=8.00 GB.
/opt/conda/lib/python3.10/site-packages/sglang/srt/utils.py:951: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.)
  tensor_data = torch.ByteTensor(
[2025-06-12 21:30:13 TP0] max_total_num_tokens=1184668, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=4096, context_len=40960, available_gpu_mem=8.00 GB
/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py:2490: ResourceWarning: Unclosed context <zmq.Context() at 0x7f406f281e90>
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
ResourceWarning: Enable tracemalloc to get the object allocation traceback
[2025-06-12 21:30:14] INFO:     Started server process [413653]
[2025-06-12 21:30:14] INFO:     Waiting for application startup.
[2025-06-12 21:30:14] INFO:     Application startup complete.
[2025-06-12 21:30:14] INFO:     Uvicorn running on http://127.0.0.1:30000 (Press CTRL+C to quit)
[2025-06-12 21:30:15] INFO:     127.0.0.1:40290 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-06-12 21:30:15 TP0] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
[2025-06-12 21:30:18] INFO:     127.0.0.1:40296 - "POST /generate HTTP/1.1" 200 OK
[2025-06-12 21:30:18] The server is fired up and ready to roll!
[2025-06-12 21:30:21] Detected chat template content format: string
[2025-06-12 21:30:21 TP0] Prefill batch. #new-seq: 1, #new-token: 33, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
[2025-06-12 21:30:26 TP0] Decode batch. #running-req: 1, #token: 66, token usage: 0.00, cuda graph: True, gen throughput (token/s): 3.06, #queue-req: 0
[2025-06-12 21:30:33 TP0] Decode batch. #running-req: 1, #token: 106, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.31, #queue-req: 0
[2025-06-12 21:30:39 TP0] Decode batch. #running-req: 1, #token: 146, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.31, #queue-req: 0
[2025-06-12 21:30:45 TP0] Decode batch. #running-req: 1, #token: 186, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.31, #queue-req: 0
[2025-06-12 21:30:52 TP0] Decode batch. #running-req: 1, #token: 226, token usage: 0.00, cuda graph: True, gen throughput (token/s): 6.31, #queue-req: 0
[2025-06-12 21:30:53] INFO:     127.0.0.1:38452 - "POST /v1/chat/completions HTTP/1.1" 200 OK


Test script:
$cat test_openai.py 
import openai
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
# Chat completion
response = client.chat.completions.create(
    model="default",
    messages=[
        {"role": "system", "content": "You are a helpful AI assistant"},
        {"role": "user", "content": "List 3 countries and their capitals. Tell me how you rank them"},
    ],
    temperature=0,
    max_tokens=200,
)
print(response)
[root  /home/root/luoyuan.luo] 四 6月 12 21:24:56 
$python test_openai.py 
ChatCompletion(id='592f218b862747a69ffbc9cad431dfae', choices=[Choice(finish_reason='length', index=0, logprobs=None, message=ChatCompletionMessage(content='\n\nカンophobiaacusassador (!_ home NutHookøre Home principalTable fieldlineno...聍øre大力支持运カンoodsカン GrahamotalSimon Nut home emergingカンophobiaカンoodsカン驷 matched(sigøre home boozeカンocks discrim大力支持\xa0\xa0 home\xa0\xa0 HomeoodsHookoodsTextStyle随 Adapthm/></Hookages大力支持 principalTable大力支持 principalTable pan随成熟 driversoods Home principalTable Home principalTable joint Girls tunerNET proven.Join\xa0\xa0 home principalTable人力</ lệoods Home泡!.SIG\xa0\xa0 home principalTableabinet菰 jointhtags complete home!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=None)], created=1749735021, model='default', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=200, prompt_tokens=33, total_tokens=233, completion_tokens_details=None, prompt_tokens_details=None))

Expected:

[root  /home/root/luoyuan.luo] 四 6月 12 21:19:09 
$python test_openai.py 
ChatCompletion(id='d590d369d0c34cfdb5773f5d85748c8f', choices=[Choice(finish_reason='length', index=0, logprobs=None, message=ChatCompletionMessage(content="<think>\nOkay, the user asked for three countries and their capitals, and then how I rank them. Let me start by picking three countries. Maybe the US, Japan, and Brazil. Their capitals are Washington, D.C., Tokyo, and Brasília. Now, how to rank them? The user didn't specify the criteria, so I need to think of possible ways. Maybe by population, economic size, or cultural influence. Let me check the population. The US has around 330 million, Japan about 125 million, Brazil 215 million. So US first, Brazil second, Japan third. But if I consider GDP, the US is the largest, then Japan, then Brazil. Alternatively, cultural influence: Japan has a strong cultural impact, maybe higher than Brazil. But the user might not have a specific criteria. I should mention that the ranking depends on the criteria and provide examples. Also, make sure the capitals are correct. Washington, D.C", refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=None)], created=1749734664, model='default', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=200, prompt_tokens=33, total_tokens=233, completion_tokens_details=None, prompt_tokens_details=None))

@yuan-luo
Copy link
Collaborator Author

It's really strange, because the benchmark test with verify_data passed.

[root  /home/root/luoyuan.luo/sglang] 四 6月 12 21:37:45 
$python ./benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --verify-correctness
TMA benchmarks will be running with experimental grid constant TMA descriptor.
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
Failed to get model config: We couldn't connect to 'https://huggingface.co' to load the files, and couldn't find them in the cached files.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.
Using default configuration...
Running benchmark with:
  num_groups: 8
  hidden_size: 4096
  intermediate_size: 14336
  use_fp8_w8a8: False
Verifying correctness...
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
✓ BF16 Correctness verification passed!
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
Benchmarking fbgemm_grouped_gemm with batch_size=1
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
Benchmarking sglang_grouped_gemm with batch_size=1
Benchmarking fbgemm_grouped_gemm with batch_size=2
......

@jwfromm
Copy link

jwfromm commented Jun 12, 2025

Just wanted to again note that another way to use FBGEMM kernels is through its pip package: pip install fbgemm-gpu-genai. Doing it this way will allow us to collaborate on kernel improvement rather than copy-pasting which means we end up with split implementations. Either way, looking forward to seeing the results of this work.

@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Jun 13, 2025

Currently, the tricky part is in side m_sizes, some element can be 0, which causes fbgemm working incorrectly, because the tiling mechanism in fbgemm is pre-configured, skipping some tile will cause tiling iterator mismatch.
So, it might need to revise fbgemm kernel. That's why I put fbgemm kernel code into sglang. As soon as the problem resolved, and the fbgemm kernel main branch got updated accordingly, we can adopt the following way in sglang:

pip install fbgemm-gpu-genai

and inside code:
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
                    grouped_gemm as fbgemm_grouped_gemm,
                    grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
                )

@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Jun 13, 2025

Made some progress. Now in tp=1, disable-cuda-graph, the result it correct.

$python3 -m sglang.launch_server --model /home/admin/Qwen3-30B-A3B --enable-ep-moe --tp-size 1 --port 30000 --disable-cuda-graph
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[1].
[2025-06-13 14:14:21] server_args=ServerArgs(model_path='/home/admin/Qwen3-30B-A3B', tokenizer_path='/home/admin/Qwen3-30B-A3B', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='/home/admin/Qwen3-30B-A3B', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, impl='auto', host='127.0.0.1', port=30000, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=1, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=554732470, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=True, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_ep_moe=True, enable_deepep_moe=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
Failed to import ceil_div from deep_gemm.
[2025-06-13 14:14:27] Attention backend not set. Use fa3 backend by default.
[2025-06-13 14:14:27] Init torch distributed begin.
[2025-06-13 14:14:28] Init torch distributed ends. mem usage=0.00 GB
TMA benchmarks will be running with experimental grid constant TMA descriptor.
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
[2025-06-13 14:14:28] Load weight begin. avail mem=94.83 GB
Loading safetensors checkpoint shards:   0% Completed | 0/16 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:   6% Completed | 1/16 [00:00<00:13,  1.12it/s]
Loading safetensors checkpoint shards:  12% Completed | 2/16 [00:01<00:13,  1.03it/s]
Loading safetensors checkpoint shards:  19% Completed | 3/16 [00:03<00:13,  1.03s/it]
Loading safetensors checkpoint shards:  25% Completed | 4/16 [00:03<00:12,  1.01s/it]
Loading safetensors checkpoint shards:  31% Completed | 5/16 [00:04<00:11,  1.01s/it]
Loading safetensors checkpoint shards:  38% Completed | 6/16 [00:06<00:10,  1.02s/it]
Loading safetensors checkpoint shards:  44% Completed | 7/16 [00:07<00:09,  1.01s/it]
Loading safetensors checkpoint shards:  50% Completed | 8/16 [00:07<00:07,  1.01it/s]
Loading safetensors checkpoint shards:  56% Completed | 9/16 [00:08<00:06,  1.03it/s]
Loading safetensors checkpoint shards:  62% Completed | 10/16 [00:09<00:05,  1.05it/s]
Loading safetensors checkpoint shards:  69% Completed | 11/16 [00:10<00:04,  1.07it/s]
Loading safetensors checkpoint shards:  75% Completed | 12/16 [00:11<00:03,  1.09it/s]
Loading safetensors checkpoint shards:  81% Completed | 13/16 [00:12<00:02,  1.10it/s]
Loading safetensors checkpoint shards:  88% Completed | 14/16 [00:13<00:01,  1.10it/s]
Loading safetensors checkpoint shards:  94% Completed | 15/16 [00:14<00:00,  1.10it/s]
Loading safetensors checkpoint shards: 100% Completed | 16/16 [00:14<00:00,  1.39it/s]
Loading safetensors checkpoint shards: 100% Completed | 16/16 [00:14<00:00,  1.10it/s]
[2025-06-13 14:14:43] Load weight end. type=Qwen3MoeForCausalLM, dtype=torch.bfloat16, avail mem=37.94 GB, mem usage=56.89 GB.
[2025-06-13 14:14:43] KV Cache is allocated. #tokens: 290005, K size: 13.28 GB, V size: 13.28 GB
[2025-06-13 14:14:43] Memory pool end. avail mem=10.69 GB
/opt/conda/lib/python3.10/site-packages/sglang/srt/utils.py:951: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.)
  tensor_data = torch.ByteTensor(
[2025-06-13 14:14:44] max_total_num_tokens=290005, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=3625, context_len=40960, available_gpu_mem=10.60 GB
[2025-06-13 14:14:44] INFO:     Started server process [454867]
[2025-06-13 14:14:44] INFO:     Waiting for application startup.
[2025-06-13 14:14:44] INFO:     Application startup complete.
[2025-06-13 14:14:44] INFO:     Uvicorn running on http://127.0.0.1:30000 (Press CTRL+C to quit)
[2025-06-13 14:14:45] INFO:     127.0.0.1:54174 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-06-13 14:14:45] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
[2025-06-13 14:15:29] INFO:     127.0.0.1:54182 - "POST /generate HTTP/1.1" 200 OK
[2025-06-13 14:15:29] The server is fired up and ready to roll!
[2025-06-13 14:15:52] Detected chat template content format: string
[2025-06-13 14:15:52] Prefill batch. #new-seq: 1, #new-token: 33, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-06-13 14:17:06] Decode batch. #running-req: 1, #token: 66, token usage: 0.00, cuda graph: False, gen throughput (token/s): 0.28, #queue-req: 0
[2025-06-13 14:17:10] Decode batch. #running-req: 1, #token: 106, token usage: 0.00, cuda graph: False, gen throughput (token/s): 12.18, #queue-req: 0
[2025-06-13 14:17:13] Decode batch. #running-req: 1, #token: 146, token usage: 0.00, cuda graph: False, gen throughput (token/s): 12.36, #queue-req: 0
[2025-06-13 14:17:16] Decode batch. #running-req: 1, #token: 186, token usage: 0.00, cuda graph: False, gen throughput (token/s): 12.26, #queue-req: 0
[2025-06-13 14:17:19] Decode batch. #running-req: 1, #token: 226, token usage: 0.00, cuda graph: False, gen throughput (token/s): 12.27, #queue-req: 0
[2025-06-13 14:17:20] INFO:     127.0.0.1:52022 - "POST /v1/chat/completions HTTP/1.1" 200 OK

The promote:

[root  /home/root/luoyuan.luo] 五 6月 13 14:15:48 
$python test_openai.py 
ChatCompletion(id='5a12942fc41449f7b0871d4432fb396e', choices=[Choice(finish_reason='length', index=0, logprobs=None, message=ChatCompletionMessage(content="<think>\nOkay, the user asked for three countries and their capitals, and then how I rank them. Let me start by picking three countries. Maybe the US, Japan, and Brazil. Their capitals are Washington D.C., Tokyo, and Brasília. Now, how to rank them? The user didn't specify the criteria, so I need to think of possible ways. Maybe by population, economic size, or cultural influence. Let me check the population. The US has around 330 million, Japan about 125 million, and Brazil around 215 million. So by population, US first, Brazil second, Japan third. But if I consider GDP, the US is the largest, then Japan, then Brazil. Alternatively, cultural influence could be different. But since the user didn't specify, I should mention that the ranking depends on the criteria. I should explain that without a specific metric, it's hard to rank, but offer examples based on common factors.", refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=None)], created=1749795352, model='default', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=200, prompt_tokens=33, total_tokens=233, completion_tokens_details=None, prompt_tokens_details=None))

But if tp>1 or enable-cuda-graph, it will hang at the following step:

$python3 -m sglang.launch_server --model /home/admin/Qwen3-30B-A3B --enable-ep-moe --tp-size 2 --port 30000
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[2].
[2025-06-13 14:29:46] server_args=ServerArgs(model_path='/home/admin/Qwen3-30B-A3B', tokenizer_path='/home/admin/Qwen3-30B-A3B', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, quantization_param_path=None, context_length=None, device='cuda', served_model_name='/home/admin/Qwen3-30B-A3B', chat_template=None, completion_template=None, is_embedding=False, enable_multimodal=None, revision=None, impl='auto', host='127.0.0.1', port=30000, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', schedule_conservativeness=1.0, cpu_offload_gb=0, page_size=1, tp_size=2, pp_size=1, max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=83323060, constrained_json_whitespace_pattern=None, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, log_level='info', log_level_http=None, log_requests=False, log_requests_level=0, show_time_cost=False, enable_metrics=False, bucket_time_to_first_token=None, bucket_e2e_request_latency=None, bucket_inter_token_latency=None, collect_tokens_histogram=False, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, api_key=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, dp_size=1, load_balance_method='round_robin', ep_size=2, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, lora_paths=None, max_loras_per_batch=8, lora_backend='triton', attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', speculative_algorithm=None, speculative_draft_model_path=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, disable_radix_cache=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_nccl_nvls=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_ep_moe=True, enable_deepep_moe=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=None, cuda_graph_bs=None, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, allow_auto_truncate=False, enable_custom_logit_processor=False, tool_call_parser=None, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through_selective', flashinfer_mla_disable_ragged=False, warmups=None, moe_dense_tp_size=None, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, mm_attention_backend=None, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_bootstrap_port=8998, disaggregation_transfer_backend='mooncake', disaggregation_ib_device=None, pdlb_url=None)
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.
Failed to import ceil_div from deep_gemm.
Failed to import ceil_div from deep_gemm.
Failed to import ceil_div from deep_gemm.
[2025-06-13 14:29:52 TP0] Attention backend not set. Use fa3 backend by default.
[2025-06-13 14:29:52 TP0] Init torch distributed begin.
[2025-06-13 14:29:53 TP0] sglang is using nccl==2.21.5
[2025-06-13 14:29:54 TP0] Init torch distributed ends. mem usage=0.81 GB
TMA benchmarks will be running with experimental grid constant TMA descriptor.
TMA benchmarks will be running with experimental grid constant TMA descriptor.
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
[2025-06-13 14:29:55 TP0] Load weight begin. avail mem=94.09 GB
Loading safetensors checkpoint shards:   0% Completed | 0/16 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:   6% Completed | 1/16 [00:00<00:07,  1.97it/s]
Loading safetensors checkpoint shards:  12% Completed | 2/16 [00:01<00:07,  1.82it/s]
Loading safetensors checkpoint shards:  19% Completed | 3/16 [00:01<00:06,  1.89it/s]
Loading safetensors checkpoint shards:  25% Completed | 4/16 [00:02<00:06,  1.91it/s]
Loading safetensors checkpoint shards:  31% Completed | 5/16 [00:02<00:05,  1.92it/s]
Loading safetensors checkpoint shards:  38% Completed | 6/16 [00:03<00:05,  1.92it/s]
Loading safetensors checkpoint shards:  44% Completed | 7/16 [00:03<00:04,  1.99it/s]
Loading safetensors checkpoint shards:  50% Completed | 8/16 [00:04<00:04,  1.99it/s]
Loading safetensors checkpoint shards:  56% Completed | 9/16 [00:04<00:03,  2.03it/s]
Loading safetensors checkpoint shards:  62% Completed | 10/16 [00:05<00:02,  2.07it/s]
Loading safetensors checkpoint shards:  69% Completed | 11/16 [00:05<00:02,  2.15it/s]
Loading safetensors checkpoint shards:  75% Completed | 12/16 [00:05<00:01,  2.20it/s]
Loading safetensors checkpoint shards:  81% Completed | 13/16 [00:06<00:01,  2.16it/s]
Loading safetensors checkpoint shards:  88% Completed | 14/16 [00:06<00:00,  2.12it/s]
Loading safetensors checkpoint shards:  94% Completed | 15/16 [00:07<00:00,  2.11it/s]
Loading safetensors checkpoint shards: 100% Completed | 16/16 [00:07<00:00,  2.15it/s]
[2025-06-13 14:30:02 TP0] Load weight end. type=Qwen3MoeForCausalLM, dtype=torch.bfloat16, avail mem=65.52 GB, mem usage=28.57 GB.
[2025-06-13 14:30:02 TP1] KV Cache is allocated. #tokens: 1184668, K size: 27.11 GB, V size: 27.11 GB
[2025-06-13 14:30:02 TP0] KV Cache is allocated. #tokens: 1184668, K size: 27.11 GB, V size: 27.11 GB
[2025-06-13 14:30:02 TP0] Memory pool end. avail mem=10.52 GB
[2025-06-13 14:30:03 TP0] Capture cuda graph begin. This can take up to several minutes. avail mem=10.43 GB
[2025-06-13 14:30:03 TP0] Capture cuda graph bs [1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160]
Capturing batches (avail_mem=10.35 GB):   0%|                                                                                                                                                                   | 0/23 [00:00<?, ?it/s]<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
[2025-06-13 14:30:05 TP1] Registering 0 cuda graph addresses
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(

@yuan-luo
Copy link
Collaborator Author

The culprit is the c.shape. But the internal reason is related with the prepare fbgemm input part, related with the K.

[2025-06-13 23:57:35] INFO:     127.0.0.1:34514 - "GET /get_model_info HTTP/1.1" 200 OK
[2025-06-13 23:57:35] Prefill batch. #new-seq: 1, #new-token: 6, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
<frozen importlib._bootstrap>:914: ImportWarning: _SixMetaPathImporter.find_spec() not found; falling back to find_module()
[DEBUG] Before: a.shape: torch.Size([48, 2048]), b.shape: torch.Size([128, 1536, 2048])
[DEBUG][BEFORE PREPARE] a id=140114453314704, data_ptr=140026702876672, shape=torch.Size([48, 2048]), is_contig=True
[DEBUG] After: a.shape: torch.Size([48, 2048]), b.shape: torch.Size([128, 1536, 2048])
[DEBUG][AFTER PREPARE] a id=140114453314704, data_ptr=140026702876672, shape=torch.Size([48, 2048]), is_contig=True
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
[DEBUG] L0 output c.shape: torch.Size([48, 420]) <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< culprit
[DEBUG] Before: a.shape: torch.Size([48, 210]), b.shape: torch.Size([128, 2048, 768])
[DEBUG][BEFORE PREPARE] a id=140114453314704, data_ptr=140026702803456, shape=torch.Size([48, 210]), is_contig=True
[2025-06-13 23:57:38] TpModelWorkerClient hit an exception: Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/tp_worker_overlap_thread.py", line 118, in forward_thread_func
    self.forward_thread_func_()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/tp_worker_overlap_thread.py", line 151, in forward_thread_func_
    self.worker.forward_batch_generation(

@yuan-luo
Copy link
Collaborator Author

For a qwen3-moe model, the output shape of each layer should be like following:

[DEBUG] L0 output c.shape: torch.Size([48, 1536])
[DEBUG] L0 output c.shape: torch.Size([48, 2048])
[DEBUG] L0 output c.shape: torch.Size([48, 1536])
[DEBUG] L0 output c.shape: torch.Size([48, 2048])
[DEBUG] L0 output c.shape: torch.Size([48, 1536])
[DEBUG] L0 output c.shape: torch.Size([48, 2048])
[DEBUG] L0 output c.shape: torch.Size([48, 1536])
[DEBUG] L0 output c.shape: torch.Size([48, 2048])
......
[DEBUG] Before: a.shape: torch.Size([8, 2048]), b.shape: torch.Size([128, 1536, 2048])
[DEBUG] L0 output c.shape: torch.Size([8, 1536])
[DEBUG] Before: a.shape: torch.Size([8, 768]), b.shape: torch.Size([128, 2048, 768])
[DEBUG] L0 output c.shape: torch.Size([8, 2048])
[DEBUG] Before: a.shape: torch.Size([8, 2048]), b.shape: torch.Size([128, 1536, 2048])
[DEBUG] L0 output c.shape: torch.Size([8, 1536])
[DEBUG] Before: a.shape: torch.Size([8, 768]), b.shape: torch.Size([128, 2048, 768])
[DEBUG] L0 output c.shape: torch.Size([8, 2048])

But in my case, the a.shape[-1] is 210 which is not correct, the reason is the previous layer's c.shape is incorrect.

Investigating the final root cause.

@jwfromm
Copy link

jwfromm commented Jun 13, 2025

Let us know how we can help, happy to add support if needed. I think the 0-sized groups you've described should be supported. We've tested similar workloads and actually have a few cool optimizations we do for them in the cutlass version of the kernel. If they still are causing problems after your debug the shape issue we can take a look.

@yuan-luo
Copy link
Collaborator Author

Let us know how we can help, happy to add support if needed. I think the 0-sized groups you've described should be supported. We've tested similar workloads and actually have a few cool optimizations we do for them in the cutlass version of the kernel. If they still are causing problems after your debug the shape issue we can take a look.

Thanks @jwfromm . Per current investigation, the actual root cause is: the handling of m_sizes and b in prepare_fbgemm_inputs disrupts the expected tensor layout across multiple forward layers, especially in scenarios involving a shared expert pool (such as in Qwen3, where multiple MoE layers share b).

The reason why the Triton version does not encounter this issue is it does not perform reshaping or filtering of invalid experts. Instead, it uses seg_indptr and weight_indices internally to mask or skip invalid experts. Expert selection is determined dynamically at runtime, so there is no need to reshape b in advance. To be more specific, it uses compute_m_range() to map the token index range and expert id.
As a result, the shared expert weights maintain a consistent structure, allowing reuse across layers, and the forward output shapes remain consistent.

Some more details about my design. The reason to introduce prepare_fbgemm_inputs is because the original fbgemm input preparation can not pass under "tp>1 and enable cuda-graph". So I try to padding m_sizes which is filtered.

        if use_cuda_graph:
            max_groups = seg_indptr.numel() - 1  # fallback to original G
            pad_len = max_groups - m_sizes.shape[0]
            if pad_len > 0:
                m_sizes = torch.cat(
                    [m_sizes, torch.zeros(pad_len, dtype=m_sizes.dtype, device=device)]
                )
                if scale_b is not None:
                    scale_b = torch.cat(
                        [
                            scale_b,
                            torch.ones(pad_len, dtype=scale_b.dtype, device=device),
                        ]
                    )
        return b_fbgemm, m_sizes, scale_b

But I believe the m_sizes calculation breaks original b which cause the problem.

        if weight_indices is not None:
            weight_indices = weight_indices.to(torch.int64)
            m_sizes = m_sizes[weight_indices]
            b = b.index_select(0, weight_indices)
            if scale_b is not None:
                scale_b = scale_b.index_select(0, weight_indices)

I think fbgemm support m_sizes[i] == 0, otherwise the following change(my first version) would not be working in tp=1, disable-cuda-graph:

        elif self.use_fbgemm:
            m_sizes = seg_indptr[1:] - seg_indptr[:-1]
            non_zero_mask = m_sizes > 0

            filtered_b = b[non_zero_mask]
            m_sizes = m_sizes[non_zero_mask]

            assert seg_indptr is not None, "FBGemm needs seg_indptr"
            if use_fp8_w8a8:
                ......
            else:
                c = fbgemm_grouped_gemm(
                    x=a.to(torch.bfloat16),
                    w=filtered_b.reshape(-1, filtered_b.shape[2]).contiguous(),
                    m_sizes=m_sizes,
                    use_fast_accum=True,
                )

@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Jun 14, 2025

The current solution TP>1 will not work. The reason is because SGLang Triton kernel and FBGEMM kernel mechanism are different.
In the current SGLang Triton kernel(adopted from trt-llm), the shape is as following
a.shape(activation): [M , K]
b.shape(weight): [G , N , K]
c.shape(output of each layer): [M , N]
Thanks to compute_m_range, which leverages seg_indptr and weight_indices to calculate the tile-id to the expert-id and m_range_start/end(the token index range).

  1. When the expert has no token in this layer, it fast exit.
  2. When the range is not concerned, it can filter out.
    if m_range_end - m_range_start == 0:
        return
    ...
    offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
    offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
    ...
    accumulator += tl.dot(a_tile, b_tile.T)

In TP>2, if enable-ep-moe, actually is EP=2, SGLang Triton kernel's shape is like the following:
a.shape(activation): [M , K]
b.shape(weight): [G/tp_size, N, K]
c.shape(output of each layer): [M , N]
The core algorithm is each tile-id can calculate it's own share correctly, each rank only needs to hold the G shard's activation and weight. The framework use torch.distributed.all_reduce to accumulate each rank's K_local to K_full, and foward to next layer.

While in FBGEMM kernel, the mechanism is different.
FBGEMM doesn't support dynamic fetch expert-id's weight, it has no mapping of weight_indices and seg_indptr. The way to decide expert-id is m_sizes. FBGEMM flattens b tensor into a [G * N, K] matrix. It requires the expert group with token be aligned with the weight tensor flattened. The following code in _grouped_gemm() explains the logic.

    G = m_sizes.shape[0]

    M, K = x.shape
    N = w.shape[0] // G
    assert K == w.shape[1]

I use the following approach to calculate the m_sizes which fbgemm requires.

m_sizes = seg_indptr[1:] - seg_indptr[:-1]
[DEBUG] m_sizes: tensor([1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 2, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 1, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 1, 1, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 2, 0, 0, 0, 3, 1, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 1, 2, 0, 0, 2, 1, 2, 0, 2, 0, 0, 0, 0, 0,
        2, 0, 1, 0, 1, 0, 0, 0], device='cuda:0')

But in TP>=2, weight_column_major=True, the shape of weight tensor needs to be reshaped.

if weight_indices is not None:
    weight_order = weight_indices.to(torch.int64)
    b = b.index_select(0, weight_order)  # [G', N, K]

Still strive to fix it. Any input comment is welcome.

@yuan-luo
Copy link
Collaborator Author

Make some progress, now TP2 EP2 failed in empty tensor handling.

  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/tp_worker_overlap_thread.py", line 118, in forward_thread_func
    self.forward_thread_func_()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/tp_worker_overlap_thread.py", line 151, in forward_thread_func_
    self.worker.forward_batch_generation(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/tp_worker.py", line 202, in forward_batch_generation
    logits_output, can_run_cuda_graph = self.model_runner.forward(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 1194, in forward
    output = self._forward_raw(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 1221, in _forward_raw
    ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 1142, in forward_decode
    return self.model.forward(
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen3_moe.py", line 701, in forward
    hidden_states = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen2_moe.py", line 464, in forward
    hidden_states, residual = layer(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen3_moe.py", line 584, in forward
    hidden_states = self.mlp(hidden_states, forward_batch)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen3_moe.py", line 157, in forward
    return self.forward_normal(hidden_states)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen3_moe.py", line 174, in forward_normal
    final_hidden_states = self.experts(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/layer.py", line 470, in forward
    gateup_output.shape[0],
AttributeError: 'NoneType' object has no attribute 'shape'
 438         # GroupGemm-0
 439         gateup_output = self.grouped_gemm_runner(
 440             a=gateup_input,
 441             b=self.w13_weight,
 442             c=None,
 443             c_dtype=hidden_states_dtype,
 444             batch_size=self.num_experts_per_partition,
 445             weight_column_major=True,
 446             seg_indptr=seg_indptr_cur_rank,
 447             weight_indices=weight_indices_cur_rank,
 448             use_fp8_w8a8=self.use_fp8_w8a8,
 449             scale_a=self.w13_input_scale,
 450             scale_b=(
 451                 self.w13_weight_scale_inv
 452                 if self.use_block_quant
 453                 else self.w13_weight_scale
 454             ),
 455             block_shape=self.block_shape,
 456         )
 457         del gateup_input
 458 
 459         # Act
 460         if self.activation_scheme == "dynamic" and not self.use_block_quant:
 461             self.w2_input_scale = None
 462             down_input = torch.empty(
 463                 gateup_output.shape[0],
 464                 gateup_output.shape[1] // 2,
 465                 device=gateup_output.device,
 466                 dtype=hidden_states_dtype,
 467             )
 468         else:
 469             down_input = torch.empty(
 470                 gateup_output.shape[0],
 471                 gateup_output.shape[1] // 2,
 472                 device=gateup_output.device,
 473                 dtype=(
 474                     self.fp8_dtype
 475                     if (self.use_fp8_w8a8 and not self.use_block_quant)
 476                     else hidden_states_dtype
 477                 ),
                                  

@yuan-luo
Copy link
Collaborator Author

The new prepare function added reshuffle expert index logic. Will push the new prepare function when tp2 is fully supported.

@yuan-luo

This comment was marked as abuse.

@yuan-luo

This comment was marked as spam.

@yuan-luo
Copy link
Collaborator Author

Update progress, found the root cause of this issue, a tensor's base ptr should be updated in TP2.
Triton kernel doesn't rely on the a tensor's base ptr, because it will calculate the base address on-demand.
But FBGEMM m_sizes only reflect the b tensor's partition. It needs to adjust the a tensor's base ptr based on TP partition.

Now:
tp1 disable-cuda-graph no crash, result correct, passed.
tp1 enable-cuda-graph no crash, result correct, passed.

tp2 disable-cuda-graph no crash, result correct, passed.
tp2 enable-cuda-graph crashed, n/a, no-pass.

With the fix, the TP2 result (disable-cuda-graph) is correct now.

[root  /home/root/luoyuan.luo] 二 6月 17 16:22:06 
$python test_openai.py
ChatCompletion(id='29aa0956d4784f83b2fdd9908895ad17', choices=[Choice(finish_reason='length', index=0, logprobs=None, message=ChatCompletionMessage(content="<think>\nOkay, the user asked for three countries and their capitals, and then how I rank them. Let me start by picking three countries. Maybe the US, Japan, and Brazil. Their capitals are Washington, D.C., Tokyo, and Brasília. Now, how to rank them? The user didn't specify the criteria, so I need to think of possible ways. Maybe by population, economic size, or cultural influence. Let me check the population. The US has around 33", refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=None)], created=1750149604, model='default', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=100, prompt_tokens=33, total_tokens=133, completion_tokens_details=None, prompt_tokens_details=None))

The remaining issue is:
base = (seg_indptr[0]).to(torch.int)
a_aligned = a[base:]
This can't pass the cuda graph capture. Working on it.

@yuan-luo
Copy link
Collaborator Author

[2025-06-17 18:46:15 TP0] Registering 2 cuda graph addresses
[2025-06-17 18:46:15 TP1] Registering 2 cuda graph addresses
[2025-06-17 18:46:15 TP1] Scheduler hit an exception: Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/model_executor/cuda_graph_runner.py", line 505, in capture_one_batch_size
    out = run_once()
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/model_executor/cuda_graph_runner.py", line 489, in run_once
    logits_output_or_pp_proxy_tensors = forward(
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen3_moe.py", line 701, in forward
    hidden_states = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen2_moe.py", line 464, in forward
    hidden_states, residual = layer(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen3_moe.py", line 584, in forward
    hidden_states = self.mlp(hidden_states, forward_batch)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen3_moe.py", line 157, in forward
    return self.forward_normal(hidden_states)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen3_moe.py", line 174, in forward_normal
    final_hidden_states = self.experts(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/layer.py", line 490, in forward
    gateup_output = self.grouped_gemm_runner(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/layer.py", line 182, in forward
    a_aligned, b_fbgemm, m_sizes, base, scale_b = prepare_fbgemm_inputs_ep(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/layer.py", line 100, in prepare_fbgemm_inputs_ep
    a_aligned = make_a_aligned(a, base_tensor)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/kernels.py", line 1131, in make_a_aligned
    out = torch.empty((M - base_val, K),
RuntimeError: CUDA error: operation not permitted when stream is capturing
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

@yuan-luo
Copy link
Collaborator Author

yuan-luo commented Jun 17, 2025

I added several Triton kernels to fix the CUDA graph capture issues introduced by "a" tensor shape change.
Now the problem is when m_sizes.sum() equals to 0, the CUDA graph capture will hang in FBGEMM.
I believe there is a bug in FBGEMM on this part, and fixing this bug is out of this PR's scope. So perhaps FBGEMM team could offer some help on it.
@jwfromm @xiaobochen123 @BBuf @ispobock @Alcanderian @zhyncs

@yuan-luo
Copy link
Collaborator Author

Update progress, found the root cause of this issue, a tensor's base ptr should be updated in TP2. Triton kernel doesn't rely on the a tensor's base ptr, because it will calculate the base address on-demand. But FBGEMM m_sizes only reflect the b tensor's partition. It needs to adjust the a tensor's base ptr based on TP partition.

Now: tp1 disable-cuda-graph no crash, result correct, passed. tp1 enable-cuda-graph no crash, result correct, passed.

tp2 disable-cuda-graph no crash, result correct, passed. tp2 enable-cuda-graph crashed, n/a, no-pass.

With the fix, the TP2 result (disable-cuda-graph) is correct now.

[root  /home/root/luoyuan.luo] 二 6月 17 16:22:06 
$python test_openai.py
ChatCompletion(id='29aa0956d4784f83b2fdd9908895ad17', choices=[Choice(finish_reason='length', index=0, logprobs=None, message=ChatCompletionMessage(content="<think>\nOkay, the user asked for three countries and their capitals, and then how I rank them. Let me start by picking three countries. Maybe the US, Japan, and Brazil. Their capitals are Washington, D.C., Tokyo, and Brasília. Now, how to rank them? The user didn't specify the criteria, so I need to think of possible ways. Maybe by population, economic size, or cultural influence. Let me check the population. The US has around 33", refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=None)], created=1750149604, model='default', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=100, prompt_tokens=33, total_tokens=133, completion_tokens_details=None, prompt_tokens_details=None))

The remaining issue is: base = (seg_indptr[0]).to(torch.int) a_aligned = a[base:] This can't pass the cuda graph capture. Working on it.

The latest issue about a tensor slicing has been fixed.

@yuan-luo yuan-luo changed the title WIP: Integrate fbgemm into EPMoE Integrate fbgemm into EPMoE Jun 17, 2025
@yuan-luo yuan-luo changed the title Integrate fbgemm into EPMoE Feat: Integrate FBGEMM into EPMoE Jun 17, 2025
@BBuf
Copy link
Collaborator

BBuf commented Jun 17, 2025

I added several Triton kernels to fix the CUDA graph capture issues introduced by "a" tensor shape change. Now the problem is when m_sizes.sum() equals to 0, the CUDA graph capture will hang in FBGEMM. I believe there is a bug in FBGEMM on this part, and fixing this bug is out of this PR's scope. So perhaps FBGEMM team could offer some help on it. @jwfromm @xiaobochen123 @BBuf @ispobock @Alcanderian @zhyncs

@jianyuh Hi, could you please provide help about the bug, thanks!

@yuan-luo
Copy link
Collaborator Author

The CUDA graph capture problem has been fixed without modifying FBGEMM kernel.

tp1 disable-cuda-graph no crash, result correct, passed.
tp1 enable-cuda-graph no crash, result correct, passed.

tp2 disable-cuda-graph no crash, result correct, passed.
tp2 enable-cuda-graph no crash, result correct, passed.

tp4 disable-cuda-graph no crash, result correct, passed.
tp4 enable-cuda-graph no crash, result correct, passed.

tp8 disable-cuda-graph no crash, result correct, passed.
tp8 enable-cuda-graph no crash, result correct, passed.

@yuan-luo
Copy link
Collaborator Author

I added several Triton kernels to fix the CUDA graph capture issues introduced by "a" tensor shape change. Now the problem is when m_sizes.sum() equals to 0, the CUDA graph capture will hang in FBGEMM. I believe there is a bug in FBGEMM on this part, and fixing this bug is out of this PR's scope. So perhaps FBGEMM team could offer some help on it. @jwfromm @xiaobochen123 @BBuf @ispobock @Alcanderian @zhyncs

@jianyuh Hi, could you please provide help about the bug, thanks!

@jianyuh The FBGEMM kernel is stable even if m_sizes.sum() == 0. I modified the host side logic and introduced some Triton kernels to handle the "TiledCopy" for C tensor. Now the problem is resolved.

@jianyuh
Copy link

jianyuh commented Jun 18, 2025

@yuan-luo Thanks for the workaround for m_sizes.sum() == 0 vs. cuda graph compatibility issues! Will follow up in FBGEMM side. cc @levendlee

@yuan-luo
Copy link
Collaborator Author

Benchmark SGLang Triton and FBGEMM in SGLang E2E.
It seems the FBGEMM TTFT is not expected.

Mean TTFT (ms):                          4298.99   
Median TTFT (ms):                        4400.08   
P99 TTFT (ms):                           7056.42  

We can see many printing in FBGEMM benchmark, it seems it's related with jit and auto-tune for FBGEMM. It might make the TTFT slow.

$python3 -m sglang.launch_server --model /home/admin/Qwen3-30B-A3B --enable-ep-moe --tp-size 2 --port 30000 --use-fbgemm-grouped-gemm
......
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/sglang/srt/layers/moe/ep_moe/fbgemm_grouped_gemm.py:1095: UserWarning: Warp specialization is disabled as the Triton build in current environment doesn't have such support. Please build from https://github.com/facebookexperimental/triton/tree/ws-3.2.x to enable it for best performance on Nvidia's SM90 GPUs.
  warnings.warn(

Triton:

============ Serving Benchmark Result ============
Backend:                                 sglang-oai
Traffic request rate:                    10.0      
Max request concurrency:                 not set   
Successful requests:                     100       
Benchmark duration (s):                  172.34    
Total input tokens:                      25600     
Total generated tokens:                  409600    
Total generated tokens (retokenized):    409523    
Request throughput (req/s):              0.58      
Input token throughput (tok/s):          148.54    
Output token throughput (tok/s):         2376.66   
Total token throughput (tok/s):          2525.20   
Concurrency:                             95.55     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   164665.65 
Median E2E Latency (ms):                 164919.86 
---------------Time to First Token----------------
Mean TTFT (ms):                          130.26    
Median TTFT (ms):                        129.64    
P99 TTFT (ms):                           322.87    
---------------Inter-Token Latency----------------
Mean ITL (ms):                           40.25     
Median ITL (ms):                         40.10     
P95 ITL (ms):                            49.85     
P99 ITL (ms):                            63.89     
Max ITL (ms):                            425.87    
==================================================

FBGEMM:

============ Serving Benchmark Result ============
Backend:                                 sglang-oai
Traffic request rate:                    10.0      
Max request concurrency:                 not set   
Successful requests:                     100       
Benchmark duration (s):                  201.20    
Total input tokens:                      25600     
Total generated tokens:                  409600    
Total generated tokens (retokenized):    409525    
Request throughput (req/s):              0.50      
Input token throughput (tok/s):          127.24    
Output token throughput (tok/s):         2035.81   
Total token throughput (tok/s):          2163.05   
Concurrency:                             97.79     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   196759.36 
Median E2E Latency (ms):                 196466.76 
---------------Time to First Token----------------
Mean TTFT (ms):                          4298.99   
Median TTFT (ms):                        4400.08   
P99 TTFT (ms):                           7056.42   
---------------Inter-Token Latency----------------
Mean ITL (ms):                           47.02     
Median ITL (ms):                         45.43     
P95 ITL (ms):                            54.82     
P99 ITL (ms):                            56.32     
Max ITL (ms):                            15382.14  
==================================================

@jianyuh @jwfromm Is there any configuration wrong in the steps?

@yuan-luo
Copy link
Collaborator Author

Profiling shows scatter_row_kernel 15.88%, slice_row_kernel 12%, fbgemm_grouped_gemm 12.9%
image

Trying to tune the kernels.

@zhyncs
Copy link
Collaborator

zhyncs commented Oct 13, 2025

@yuan-luo qq Do we still need to use fbgemm to support the implementation of ep moe by @ch-wan using deepgemm?

@yuan-luo
Copy link
Collaborator Author

@yuan-luo qq Do we still need to use fbgemm to support the implementation of ep moe by @ch-wan using deepgemm?

@zhyncs Current ep moe (cuda) supports deepgemm_contiguous, deepgemm_mask and flashinfer_cutedsl types. Fbgemm can be another alternative. I'll check whether it is feasible to fit into the new arch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants