Skip to content

Conversation

@haoyangli-amd
Copy link
Contributor

@haoyangli-amd haoyangli-amd commented Oct 14, 2025

When the tensor parallelism (TP) degree is set to 4 or 8, frequent changes in the input shape can cause QuickReduce to hang (this issue has been observed with the gpt_oss model).
We have identified that the root cause is overlapping flag memory addresses between consecutive AllReduce operations.

For most models, the hidden size remains relatively stable, so this issue does not occur.

Our current solution is to allocate separate memory regions for the flags and data of the two AllReduce phases in each operation.
(Note: The data region must also be separated, as overlapping would lead to correctness issues.)

To reproduce error

1.install sglang
2.python3 this_script.py

import torch
import multiprocessing
import argparse
import torch.distributed as dist
from sglang.srt import _custom_ops as ops
def worker(rank, world_size, comm_handles, comm_handle_dict):
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    qr_max_size = None # MB
    _ptr = ops.init_custom_qr(rank, world_size, qr_max_size)
    ranks = []
    for i in range(world_size):
        ranks.append(i)
    dist.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:29500",
        rank=rank,
        world_size=world_size
    )
    cpu_group = torch.distributed.new_group(ranks, backend="nccl")

    handle = ops.qr_get_handle(_ptr)
    world_size = dist.get_world_size(group=cpu_group)
    handles = [None] * world_size
    dist.all_gather_object(handles, handle, group=cpu_group)
    ops.qr_open_handles(_ptr, handles)
    num = 1
    s1 = 1024
    while s1 > 0 :
        dtype = torch.float16
        if num % 60 == 0:
            s1 = s1 // 2
        if num % 2 == 0:
            s2=1024
            inp1 = torch.zeros((s1, s2),
                dtype=dtype,
                device=torch.cuda.current_device())
        else:
            s2=2048
            inp1 = torch.ones((s1, s2),
                dtype=dtype,
                device=torch.cuda.current_device())
        # 1=FP16, 2=FP8, 3=Q8, 4=Q6, 5=Q4
        print(f"num:{num}, rank:{rank}, shape:{inp1.shape}")
        result = torch.empty_like(inp1)
        # FP = 0 INT8 = 1 INT6 = 2 INT4 = 3 NONE = 4
        ops.qr_all_reduce(_ptr, inp1, result, 3, cast_bf2half=True)
        try:
            if inp1[0,0] == 0:
                assert torch.all(result == 0)
            else:
                assert torch.all(result == world_size)
        except AssertionError:
            torch.save(result, "result_failed.pth")
            print("Assertion failed! Saved result to result_failed.pth")
            raise 
        # dist.barrier(group=cpu_group)
        num+=1
        if s1 < 100:
            s1 = 8*1024
    print("done")
def run_multiprocessing(world_size):
    with multiprocessing.Manager() as manager:
        comm_handle_dict = manager.dict()
        comm_handles = manager.Barrier(world_size)

        processes = []
        for rank in range(world_size):
            p = multiprocessing.Process(
                target=worker,
                args=(rank, world_size, comm_handles, comm_handle_dict)
            )
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--world_size",
        type=int,
        default=4,      
        help="number of processes / GPUs to use"
    )
    args = parser.parse_args()

    multiprocessing.set_start_method("spawn")
    run_multiprocessing(world_size=args.world_size)

we can also use this command to check if the hang issue is resolved and checkout if the result is reasonable.

A more detailed explanation
Why does the frequently changing INP shape cause problems?
1.I have obtained some logs.
It appears the program isn't stuck at [256, 2048], but rather at its previous execution point, [256, 1024].
To summarize this phenomenon:
It seems that the program didn’t actually hang at the [256, 2048] stage, but rather at the previous one — [256, 1024].
I suspect that the n-th allreduce and the (n+1)-th allreduce overlap in time.
When the (n+1)-th allreduce executes its phase 1, it modifies the flag used by the n-th allreduce’s phase 2.
For [256,2048], its phase 1 address completely overlaps with the phase 1+phase 2 address of [256,1024].

num:6451, rank:3, shape:torch.Size([256, 1024])
flag_color1:6451
flag_color1:6451
num_blocks:16
flag_color1:6451
flag_color2:6452
flag_color2:6452
flag_color2:6452
num:6452, rank:0, shape:torch.Size([256, 2048])
num_blocks:32
flag_color1:6452
num:6452, rank:1, shape:torch.Size([256, 2048])
num:6452, rank:3, shape:torch.Size([256, 2048])
num_blocks:32
num_blocks:32
flag_color1:6452
flag_color1:6452
2, block:12, thread:0, flag_color:6451, flag_ptr:6452
2, block:14, thread:0, flag_color:6451, flag_ptr:6452

2.If we use dist.barrier(group=cpu_group) to guarantee all ranks would block at the same point.,
and even after running for an hour, the program will not hang.

3.Referring to vLLM’s communication reduction (CR) implementation, using isolated addresses to distinguish different phases of different allreduce batches is necessary to prevent interference between the n-th and (n+1)-th allreduce operations.

Why don't other models have this issue?
For typical models, the hidden size is fixed, so the input does not change frequently, and the phase 2 of the n-th allreduce and phase 1 of the (n+1)-th allreduce do not share addresses. However, for models like GPT-OSS with variable-length inputs, conflicts may occur. What we need to do is to completely isolate the addresses to avoid any conflicts.
image

Signed-off-by: Haoyang Li <[email protected]>
Signed-off-by: Haoyang Li <[email protected]>
@haoyangli-amd
Copy link
Contributor Author

hi,
@HaiShaw
could you please help to review this pr, thank you so much.

@HaiShaw HaiShaw added the run-ci label Oct 15, 2025
@haoyangli-amd
Copy link
Contributor Author

haoyangli-amd commented Oct 28, 2025

hi, @HaiShaw
Could you leave some suggestions? Is it possible to merge this PR?
Thank you so much

@HaiShaw HaiShaw merged commit ea10a9d into sgl-project:main Nov 11, 2025
75 of 81 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants