Skip to content

fuse allreduce and residual_rmsnorm#8731

Merged
zhyncs merged 23 commits intomainfrom
cache_fuse_allreduce_residual_rmsnorm_judge
Aug 11, 2025
Merged

fuse allreduce and residual_rmsnorm#8731
zhyncs merged 23 commits intomainfrom
cache_fuse_allreduce_residual_rmsnorm_judge

Conversation

@BBuf
Copy link
Collaborator

@BBuf BBuf commented Aug 3, 2025

Motivation

  • Cache fuse allreduce_residual_rmsnorm check result
  • A better name for fuse allreduce and residual_rmsnorm
  • Fix a typo in server_args.

Acc

python3 -m sglang.launch_server --model-path /dev/shm/DeepSeek-R1-0528-FP4 --trust-remote-code --quantization modelopt_fp4 --tp 8 --enable-flashinfer-allreduce-fusion --attention-backend cutlass_mla

➜  sglang git:(cache_fuse_allreduce_residual_rmsnorm_judge) ✗ python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --num-shots 8
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:51<00:00, 25.67it/s]
Accuracy: 0.958
Invalid: 0.000
Latency: 51.636 s
Output throughput: 2783.583 token/s

Performance

 python3 -m sglang.launch_server --model-path /dev/shm/DeepSeek-R1-0528-FP4 --trust-remote-code --quantization modelopt_fp4 --tp 8 --enable-flashinfer-allreduce-fusion --attention-backend cutlass_mla

curl http://127.0.0.1:30000/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 1000 --random-output-len 1000 --random-range-ratio 1 --num-prompts 5 --max-concurrency 1 --warmup-requests 5 --output-file ds_v3_fp4_pr.jsonl
curl http://127.0.0.1:30000/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 1000 --random-output-len 1000 --random-range-ratio 1 --num-prompts 20 --max-concurrency 4 --warmup-requests 5 --output-file ds_v3_fp4_pr.jsonl
curl http://127.0.0.1:30000/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 1000 --random-output-len 1000 --random-range-ratio 1 --num-prompts 80 --max-concurrency 16 --warmup-requests 5 --output-file ds_v3_fp4_pr.jsonl
curl http://127.0.0.1:30000/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 1000 --random-output-len 1000 --random-range-ratio 1 --num-prompts 160 --max-concurrency 32 --warmup-requests 5 --output-file ds_v3_fp4_pr.jsonl
图片

Modifications

Accuracy Test

Benchmark & Profiling

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @BBuf, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on optimizing the fusion of allreduce and residual RMSNorm operations within the SGLang framework. My primary goal is to reduce computational overhead by caching the results of checks that determine if this fusion is possible. This change also includes renaming a parameter for better clarity and fixing a minor typo in the server arguments.

Highlights

  • Improved Fusion Logic: I've refactored the logic for determining whether to fuse allreduce and residual RMSNorm operations. The previous can_fuse_mlp_allreduce parameter has been renamed to should_fuse_allreduce_residual_rmsnorm for clearer intent.
  • Performance Optimization via Caching: To reduce redundant computations, I've implemented a caching mechanism for the _should_fuse_mlp_allreduce_with_next_layer check. This cache stores results based on layer ID and batch size, significantly reducing overhead for repeated checks.
  • Typo Correction: I've fixed a minor typo in the help message for the --enable-flashinfer-allreduce-fusion command-line argument, changing "Add_RMSNorm" to "Residual RMSNorm" for accuracy.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a caching mechanism to optimize the check for fusing allreduce with residual RMSNorm, which should improve performance. It also improves code clarity by renaming can_fuse_mlp_allreduce to a more descriptive should_fuse_allreduce_residual_rmsnorm across several files. While the caching logic is a good idea, I've found a couple of issues with its implementation: redundant caching and a potential memory leak due to misplaced eviction logic. I've provided suggestions to centralize the caching logic and fix these issues. The other changes, including the variable renames and a typo fix in server arguments, look good.

@ispobock
Copy link
Collaborator

ispobock commented Aug 3, 2025

cc: @merrymercy

@BBuf BBuf requested a review from Edwardf0t1 as a code owner August 8, 2025 14:07
@BBuf
Copy link
Collaborator Author

BBuf commented Aug 11, 2025

For gpt-oss-29b, 9.5us->6us for every allreduce+add_rms_norm pattern.

 CUDA_VISIBLE_DEVICES=0,1 python3 -m sglang.launch_server --model-path openai/gpt-oss-20b --tp-size 2 --port 30001

 SIMPLE_EVALS_DIR_OUTPUT=/home/yineng/bbuf/output OPENAI_BASE_URL=http://localhost:30001/v1 OPENAI_API_KEY=dummy python -m simple-evals.simple_evals --model o4-mini-with-chat-completion-and-4k-gen --eval mmlu --examples 1000

python3 -m sglang.bench_serving --model openai/gpt-oss-20b --dataset-name random --backend sglang-oai --random-range-ratio 1 --random-input-len 1200 --random-output-len 20 --max-concurrency 1 --num-prompts 5 --profile --port 30001

All results: 
| model_name                                              |   ('metric', 'mmlu') |
|:--------------------------------------------------------|---------------------:|
| o4-mini-with-chat-completion-and-4k-gen_20250811_025644 |                0.836 |
图片
 CUDA_VISIBLE_DEVICES=0,1 python3 -m sglang.launch_server --model-path openai/gpt-oss-20b --tp-size 2 --port 30001 --enable-flashinfer-allreduce-fusion

 SIMPLE_EVALS_DIR_OUTPUT=/home/yineng/bbuf/output OPENAI_BASE_URL=http://localhost:30001/v1 OPENAI_API_KEY=dummy python -m simple-evals.simple_evals --model o4-mini-with-chat-completion-and-4k-gen --eval mmlu --examples 1000

All results: 
| model_name                                              |   ('metric', 'mmlu') |
|:--------------------------------------------------------|---------------------:|
| o4-mini-with-chat-completion-and-4k-gen_20250811_041142 |                0.839 |

 python3 -m sglang.bench_serving --model openai/gpt-oss-20b --dataset-name random --backend sglang-oai --random-range-ratio 1 --random-input-len 1200 --random-output-len 20 --max-concurrency 1 --num-prompts 5 --profile --port 30001

图片

@BBuf
Copy link
Collaborator Author

BBuf commented Aug 11, 2025

curl http://127.0.0.1:30001/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 512 --random-output-len 1024 --random-range-ratio 1 --num-prompts 20 --max-concurrency 1 --output-file 1.jsonl --port 30001 --warmup-requests 10
curl http://127.0.0.1:30001/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 512 --random-output-len 1024 --random-range-ratio 1 --num-prompts 200 --max-concurrency 32 --output-file 1.jsonl --port 30001 --warmup-requests 10
curl http://127.0.0.1:30001/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 512 --random-output-len 1024 --random-range-ratio 1 --num-prompts 300 --max-concurrency 64 --output-file 1.jsonl --port 30001 --warmup-requests 10
curl http://127.0.0.1:30001/flush_cache
python3 -m sglang.bench_serving --backend sglang-oai  --dataset-name random --random-input-len 512 --random-output-len 1024 --random-range-ratio 1 --num-prompts 400 --max-concurrency 128 --output-file 1.jsonl --port 30001 --warmup-requests 10



CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m sglang.launch_server --model-path openai/gpt-oss-120b --tp-size 4 --port 30001 


+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|    |   max_concurrency |   input_throughput |   output_throughput |   mean_ttft_ms |   median_ttft_ms |   p99_ttft_ms |   mean_tpot_ms |   median_tpot_ms |   p99_tpot_ms |   per_user_throughput |
+====+===================+====================+=====================+================+==================+===============+================+==================+===============+=======================+
|  0 |             1.000 |            116.732 |             233.463 |         67.612 |           63.121 |       120.138 |          4.220 |            4.252 |         4.798 |               233.463 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  1 |            32.000 |           2147.463 |            4294.927 |        252.945 |          198.376 |       548.624 |          6.655 |            6.725 |         6.871 |               134.216 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  2 |            64.000 |           4085.339 |            8170.677 |        315.211 |          293.034 |       423.917 |          7.014 |            6.985 |         7.294 |               127.667 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  3 |           128.000 |           5996.815 |           11993.630 |        540.879 |          542.956 |       869.169 |          8.485 |            8.565 |         8.902 |                93.700 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+


CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m sglang.launch_server --model-path openai/gpt-oss-120b --tp-size 4 --port 30001 --enable-flashinfer-allreduce-fusion

+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|    |   max_concurrency |   input_throughput |   output_throughput |   mean_ttft_ms |   median_ttft_ms |   p99_ttft_ms |   mean_tpot_ms |   median_tpot_ms |   p99_tpot_ms |   per_user_throughput |
+====+===================+====================+=====================+================+==================+===============+================+==================+===============+=======================+
|  0 |             1.000 |            116.441 |             232.883 |         73.033 |           67.702 |       119.717 |          4.225 |            4.318 |         4.462 |               232.883 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  1 |            32.000 |           2245.470 |            4490.940 |        241.326 |          189.548 |       440.724 |          6.347 |            6.397 |         6.573 |               140.342 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  2 |            64.000 |           4204.064 |            8408.128 |        323.082 |          322.495 |       390.187 |          6.805 |            6.620 |         7.455 |               131.377 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+
|  3 |           128.000 |           6497.915 |           12995.830 |        507.289 |          463.732 |       960.363 |          7.776 |            7.812 |         8.231 |               101.530 |
+----+-------------------+--------------------+---------------------+----------------+------------------+---------------+----------------+------------------+---------------+-----------------------+

@BBuf
Copy link
Collaborator Author

BBuf commented Aug 11, 2025

图片

@BBuf BBuf changed the title [reduce overhead] Cache fuse allreduce_residual_rmsnorm check result [gpt-oss fuse allreduce_residual_rmsnorm] fuse allreduce and residual_rmsnorm in gpt-oss model to improve performance Aug 11, 2025
@zhyncs zhyncs changed the title [gpt-oss fuse allreduce_residual_rmsnorm] fuse allreduce and residual_rmsnorm in gpt-oss model to improve performance fuse allreduce and residual_rmsnorm Aug 11, 2025
@zhyncs
Copy link
Collaborator

zhyncs commented Aug 11, 2025

@zhyncs zhyncs mentioned this pull request Aug 11, 2025
4 tasks
)
return output

def _build_fuse_allreduce_lookup_table(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is duplicate

@zhyncs zhyncs merged commit 44e8648 into main Aug 11, 2025
35 of 62 checks passed
@zhyncs zhyncs deleted the cache_fuse_allreduce_residual_rmsnorm_judge branch August 11, 2025 20:50
narutolhy pushed a commit to narutolhy/sglang that referenced this pull request Aug 17, 2025
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants