Skip to content

use fast stream instead of torch.cuda.current_stream in llama 4 shared experts overlap#12811

Merged
Fridge003 merged 1 commit intosgl-project:mainfrom
bzhng-development:llama-4-register-fast-stream
Nov 8, 2025
Merged

use fast stream instead of torch.cuda.current_stream in llama 4 shared experts overlap#12811
Fridge003 merged 1 commit intosgl-project:mainfrom
bzhng-development:llama-4-register-fast-stream

Conversation

@b8zhong
Copy link
Collaborator

@b8zhong b8zhong commented Nov 7, 2025

Don't use torch.cuda.current_stream, use get_current_device_stream_fast(), just like #12524

python3 -m sglang.launch_server \
  --model-path=/opt/dlami/nvme/models/Llama-4-Scout-17B-16E-Instruct/ \
  --tp=8 \
  --trust-remote-code \
  --mem-fraction-static=0.7 \
  --context-length=131072 \
  --kv-cache-dtype=fp8_e4m3 \
  --attention-backend=fa3 \
  --model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}'
python3 -m sglang.bench_serving --backend sglang --num-prompts 64 --dataset-name random --random-input-len 1024 --random-output-len 1024 --random-range-ratio 1 --max-concurrency=8 --flush-cache
Metric Before After Δ Gain
Request Throughput (req/s) 0.89 0.89 0.00 0.0%
Input Token Throughput (tok/s) 911.59 911.67 +0.08 +0.01%
Output Token Throughput (tok/s) 911.59 911.67 +0.08 +0.01%
Total Token Throughput (tok/s) 1823.19 1823.35 +0.16 +0.01%
Mean E2E Latency (ms) 8983.28 8983.05 −0.23 −0.0%
Median E2E Latency (ms) 8958.14 8965.86 +7.72 +0.1%
Mean TTFT (ms) 206.48 198.57 −7.91 −3.8%
Median TTFT (ms) 198.58 199.17 +0.59 +0.3%
P99 TTFT (ms) 288.02 214.91 −73.11 −25.4%
Mean ITL (ms) 8.58 8.59 +0.01 +0.1%
P95 ITL (ms) 8.90 8.90 0.00 0.0%
P99 ITL (ms) 9.23 9.18 −0.05 −0.5%

Generally, it doesn't make such a great difference.

root@ip-10-40-0-228:/sgl-workspace/sglang# python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 500
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:31<00:00, 42.38it/s]
Accuracy: 0.920
Invalid: 0.000
Latency: 31.533 s
Output throughput: 4312.445 token/s

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @b8zhong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refines the CUDA stream management within the Llama 4 model's shared experts overlap implementation. By switching to a more optimized method for retrieving the current device stream, the change aims to enhance the efficiency of GPU operations and synchronization. Although the general performance metrics show minor differences, the improvement in tail latency for the first token suggests a positive impact on user experience for certain workloads.

Highlights

  • Stream Optimization: Replaced self.device_module.current_stream() with get_current_device_stream_fast() for more efficient CUDA stream handling in the Llama 4 model's shared experts overlap logic.
  • Performance Impact: While overall throughput metrics remained largely unchanged, the P99 Time To First Token (TTFT) saw a significant improvement of 25.4%, indicating better responsiveness for tail latencies.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@b8zhong b8zhong added the run-ci label Nov 7, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization by replacing torch.cuda.current_stream() with a faster, cached version get_current_device_stream_fast() within the Llama 4 shared experts overlap implementation. This change is consistent with similar optimizations made elsewhere in the codebase and aims to reduce the overhead of CUDA API calls. The implementation appears correct and the change is well-contained. My review did not identify any issues.

@b8zhong b8zhong requested a review from Fridge003 November 8, 2025 21:31
@Fridge003 Fridge003 merged commit 49653c8 into sgl-project:main Nov 8, 2025
170 of 183 checks passed
@b8zhong b8zhong deleted the llama-4-register-fast-stream branch December 7, 2025 07:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments