Skip to content

Custom All Reduce for Piecewise Cuda Graph#15356

Merged
ispobock merged 16 commits intosgl-project:mainfrom
Oasis-Git:car
Dec 25, 2025
Merged

Custom All Reduce for Piecewise Cuda Graph#15356
ispobock merged 16 commits intosgl-project:mainfrom
Oasis-Git:car

Conversation

@Oasis-Git
Copy link
Collaborator

@Oasis-Git Oasis-Git commented Dec 18, 2025

Motivation

Enable Custom All Reduce in Piecewise Cuda Graph, equal contribution to @ByronHsu

Modifications

  1. Split the compile period and warmup-capture period to make sure to replay during pcg init.
  2. Enable graph_capture during capture

Accuracy Tests

# Server
python3 -m sglang.launch_server --model-path Qwen/Qwen3-8B --enable-piecewise-cuda-graph \
     --tp 2 \
     --piecewise-cuda-graph-max-tokens 2048

# Client
python3 sglang/benchmark/gsm8k/bench_sglang.py --num-questions 1319 --parallel 1319 --port 30000
bash client.sh 
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:10<00:00, 127.84it/s]
Accuracy: 0.908
Invalid: 0.000
Latency: 10.396 s
Output throughput: 15629.471 token/s

For VL Model:

# Server
python -m sglang.launch_server --model Qwen/Qwen2.5-VL-7B-Instruct --tp 4 \
    --enable-piecewise-cuda-graph \
    --disable-radix-cache
    
# Client
python3 -m sglang.bench_serving \
  --backend sglang-oai-chat \
  --dataset-name image \
  --num-prompts 256 \
  --apply-chat-template \
  --random-input-len 128 \
  --random-output-len 32 \
  --image-resolution 560x560 \
  --image-format jpeg \
  --image-count 1 \
  --image-content random \
  --random-range-ratio 0.1 \
  --port 30000 \
  --max-concurrency 32
  
Created 256 random jpeg images with average 316335 bytes per request
Starting warmup with 1 sequences...
Warmup completed with 1 sequences. Starting main benchmark run...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:14<00:00, 18.11it/s]

============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf       
Max request concurrency:                 32        
Successful requests:                     256       
Benchmark duration (s):                  14.14     
Total input tokens:                      126306    
Total input text tokens:                 23394     
Total input vision tokens:               102912    
Total generated tokens:                  4541      
Total generated tokens (retokenized):    4513      
Request throughput (req/s):              18.10     
Input token throughput (tok/s):          8932.69   
Output token throughput (tok/s):         321.15    
Peak output token throughput (tok/s):    522.00    
Peak concurrent requests:                59        
Total token throughput (tok/s):          9253.84   
Concurrency:                             31.68     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   1750.05   
Median E2E Latency (ms):                 1660.67   
---------------Time to First Token----------------
Mean TTFT (ms):                          870.50    
Median TTFT (ms):                        618.02    
P99 TTFT (ms):                           2365.29   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          50.13     
Median TPOT (ms):                        49.86     
P99 TPOT (ms):                           135.72    
---------------Inter-Token Latency----------------
Mean ITL (ms):                           52.74     
Median ITL (ms):                         4.36      
P95 ITL (ms):                            435.08    
P99 ITL (ms):                            623.28    
Max ITL (ms):                            1451.89   
==================================================

Checklist

Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@ByronHsu
Copy link
Collaborator

ref: #14193

class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
_MAX_CAR_SIZE = 8192 * 1024
_MAX_CAR_SIZE = 8192 * 1024 * 4
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here should be a fix. Need to calculate the max_size for all_reduce

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment here how it's calculated?

Copy link
Collaborator Author

@Oasis-Git Oasis-Git left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix it here

# if not entry.use_cudagraph or skip_cuda_graphs:
# return entry.runnable(*args)
if is_in_torch_compile():
return entry.runnable(*args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question: what is this for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically we should prevent replay happens in capture since the custom all reduce buffer has not been allocated. Since now the capture function does warmup and capture only, we should avoid any compile operations to be counted into piecewise cudagraph backend. Thus in compile stage we skip all the later process and directly return the eager run.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without this protection the warmup could happen in the warmup_and_torch_compile function. For example, if we capture from 4 to 4096, after the first run with 4096 in warmup_and_torch_compile, other shape from 4 to 3840 will be warmup without this protection since no recompile exists for dense model.

@ByronHsu
Copy link
Collaborator

could you test qwen vl as well? thanks

@hebiao064 hebiao064 self-assigned this Dec 18, 2025
@Oasis-Git
Copy link
Collaborator Author

could you test qwen vl as well? thanks

The result for vl model is updated. Please check!

Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
_MAX_CAR_SIZE = 8192 * 1024
_MAX_CAR_SIZE = 8192 * 1024 * 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment here how it's calculated?

Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
@Oasis-Git
Copy link
Collaborator Author

/tag-run-ci-label

@ispobock
Copy link
Collaborator

ispobock commented Dec 19, 2025

/tag-and-rerun-ci

@ByronHsu
Copy link
Collaborator

btw, we can remove use_original_ca_comm and disable_ca_comm from piecewise_cuda_graph_runner.py?

@Oasis-Git
Copy link
Collaborator Author

btw, we can remove use_original_ca_comm and disable_ca_comm from piecewise_cuda_graph_runner.py?

True

Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
@ByronHsu
Copy link
Collaborator

can you resolve the conflict?

Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
@Oasis-Git
Copy link
Collaborator Author

can you resolve the conflict?

solved

Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
@Oasis-Git
Copy link
Collaborator Author

/rerun-failed-ci

@Oasis-Git
Copy link
Collaborator Author

/rerun-failed-ci

@ispobock
Copy link
Collaborator

@ispobock ispobock merged commit 5c243ba into sgl-project:main Dec 25, 2025
427 of 448 checks passed
@Oasis-Git Oasis-Git deleted the car branch December 25, 2025 23:14
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
stream = get_pcg_capture_stream()
assert stream is not None, "PCG capture stream is not set"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we can't use “stream is None”

YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants