Skip to content

Support TP overlap#9058

Open
artetaout wants to merge 33 commits intosgl-project:mainfrom
bytedance-iaas:feat/overlap
Open

Support TP overlap#9058
artetaout wants to merge 33 commits intosgl-project:mainfrom
bytedance-iaas:feat/overlap

Conversation

@artetaout
Copy link
Contributor

@artetaout artetaout commented Aug 11, 2025

Motivation

We do GEMM+AllReduce overlap during TensorParallel via Triton-Distributed to speed it up !

https://github.com/ByteDance-Seed/Triton-distributed

Modifications

  • Add two GEMM+AR op in ModelRunner , one for ATTN and the other for MLP.
  • Add a Dockerfile to setup all deps.
  • Optimize for decode stage. (Update) Consider other methods.

Accuracy Tests

$ python3 bench_sglang.py
100%|██| 200/200 [00:46<00:00,  4.29it/s]
Accuracy: 0.950
Invalid: 0.000
Latency: 48.500 s

besides, every layer's hidden_states are checked to be all-closed with origin's

Benchmarking and Profiling

bench_one_batch

GPU model TP batch_size in/out latency/s throughput sglang version
H20 Qwen2.5-72B-Instruct 8 1 1000/1 0.164 6087.71 origin
H20 Qwen2.5-72B-Instruct 8 1 2000/1 0.33266 6012.11 origin
H20 Qwen2.5-72B-Instruct 8 1 4000/1 0.6356 6293.22 origin
H20 Qwen2.5-72B-Instruct 8 1 8000/1 1.27645 6267.38 origin
H20 Qwen2.5-72B-Instruct 8 1 1000/1 0.15954 6267.09 origin+overlap
H20 Qwen2.5-72B-Instruct 8 1 2000/1 0.31418 6366.73 origin+overlap
H20 Qwen2.5-72B-Instruct 8 1 4000/1 0.59648 6706.05 origin+overlap
H20 Qwen2.5-72B-Instruct 8 1 8000/1 1.19078 6718.26 origin+overlap

bench_serving

  • 8*H20 Qwen2.5-72B-Instruct
  • origin
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 20
Successful requests:                     100
Benchmark duration (s):                  57.89
Total input tokens:                      374403
Total generated tokens:                  130
Total generated tokens (retokenized):    130
Request throughput (req/s):              1.73
Input token throughput (tok/s):          6466.94
Output token throughput (tok/s):         2.25
Total token throughput (tok/s):          6469.19
Concurrency:                             19.76
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   11442.26
Median E2E Latency (ms):                 7730.02
---------------Time to First Token----------------
Mean TTFT (ms):                          6927.27
Median TTFT (ms):                        6302.95
P99 TTFT (ms):                           13944.67
  • ours overlap
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 20
Successful requests:                     100
Benchmark duration (s):                  53.80
Total input tokens:                      374403
Total generated tokens:                  130
Total generated tokens (retokenized):    130
Request throughput (req/s):              1.86
Input token throughput (tok/s):          6959.78
Output token throughput (tok/s):         2.42
Total token throughput (tok/s):          6962.20
Concurrency:                             19.77
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   10633.49
Median E2E Latency (ms):                 7194.52
---------------Time to First Token----------------
Mean TTFT (ms):                          6424.30
Median TTFT (ms):                        5838.61
P99 TTFT (ms):                           12906.79

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @artetaout, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've implemented a significant optimization to our Tensor Parallelism strategy by enabling the overlap of GEMM computations with AllReduce communication. This change introduces specialized operators for attention and MLP layers, designed to reduce communication overhead and enhance the overall efficiency of distributed model inference. The new functionality is integrated into the model's linear layers and can be activated via an environment variable, providing a flexible way to leverage this performance improvement.

Highlights

  • Tensor Parallelism Overlap: I've introduced support for overlapping GEMM (General Matrix Multiply) operations with AllReduce communication during Tensor Parallelism. This aims to improve performance by hiding communication latency.
  • New GEMM+AllReduce Operators: I've added two new GEMM+AllReduce operators, specifically designed for attention (ATTN) and Multi-Layer Perceptron (MLP) operations, leveraging the triton_dist library.
  • Distributed State Management: The system now initializes an NVSHMEM-enabled GLOO process group for Tensor Parallelism, which is crucial for the new overlap functionality.
  • Conditional Operator Execution: The Linear layer's forward pass has been modified to conditionally use these new overlapped operators based on the SGL_USE_TP_OVERLAP environment variable, allowing for easy activation of this feature.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@artetaout artetaout changed the title Support TP overlap [WIP] Support TP overlap Aug 11, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for overlapping Tensor Parallelism communication (AllReduce) with computation (GEMM) to improve performance. This is achieved by adding new GEMM+AllReduce operators for attention and MLP layers, controlled by an environment variable. The changes look promising, but I've found a critical typo that needs to be fixed, along with some leftover debug code and unused imports that should be cleaned up.

@artetaout artetaout requested a review from ByronHsu as a code owner August 11, 2025 07:39
@FlamingoPg FlamingoPg self-assigned this Aug 13, 2025
@hubertlu-tw
Copy link
Collaborator

@artetaout Nice work!
I am wondering if you have any plans to support this on AMD GPUs. Please feel free to reach out to me if anything I can help. Thanks!

@artetaout
Copy link
Contributor Author

artetaout commented Aug 14, 2025

@artetaout Nice work! I am wondering if you have any plans to support this on AMD GPUs. Please feel free to reach out to me if anything I can help. Thanks!

Thanks! We lack AMD GPU to develope and debug code. But Triton-distributed itself, supports AMD, so you can try to support that!

@jasonlizhengjian
Copy link

Hi! I'm interested in using this feature. Is there any update on the status of this PR or what's blocking it from being merged? @artetaout @FlamingoPg

@FlamingoPg
Copy link
Collaborator

Hi! I'm interested in using this feature. Is there any update on the status of this PR or what's blocking it from being merged? @artetaout @FlamingoPg

Great work, btw we may need to discuss how to install Triton-distributed for use in sgl. Should we compile it within sgl-kernel, or can we install it directly via pip?

@artetaout
Copy link
Contributor Author

Hi! I'm interested in using this feature. Is there any update on the status of this PR or what's blocking it from being merged? @artetaout @FlamingoPg

Great work, btw we may need to discuss how to install Triton-distributed for use in sgl. Should we compile it within sgl-kernel, or can we install it directly via pip?

It needs to uninstall triton and install Triton-distributed from start

@merrymercy merrymercy mentioned this pull request Oct 23, 2025
1 task
artetaout and others added 4 commits October 27, 2025 18:45
Signed-off-by: artetaout <lulala341@gmail.com>
Signed-off-by: artetaout <lulala341@gmail.com>
@hebais
Copy link

hebais commented Jan 6, 2026

Hello, I have a question about the shmem allocation in this implementation,
the shmem tensor seems are created per layer during the Op Init with max_M = max_token_embeddings, this is quite memory consuming and kind of waste.

I am wondering if there is a good way to co-use the allocated shmem buffer for the same type of layer (MLP and O_proj layer)

@HanHan009527 HanHan009527 restored the feat/overlap branch January 7, 2026 13:15
@HanHan009527 HanHan009527 reopened this Jan 7, 2026
@hebais
Copy link

hebais commented Jan 9, 2026

It seems that the PR Test errors also indicates that the Memory Allocation is not enough because of the Shared Memory Allocation takes too much memory

@artetaout artetaout changed the title [WIP] Support TP overlap Support TP overlap Jan 15, 2026
@artetaout
Copy link
Contributor Author

Hello, I have a question about the shmem allocation in this implementation, the shmem tensor seems are created per layer during the Op Init with max_M = max_token_embeddings, this is quite memory consuming and kind of waste.

I am wondering if there is a good way to co-use the allocated shmem buffer for the same type of layer (MLP and O_proj layer)

Agree, this max_M may differ in different scenes, we will add some args option for user

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants

Comments