Skip to content

Conversation

@hubertlu-tw
Copy link
Collaborator

Co-author: @b8zhong, @kkHuang-amd, Alan Kao.

Motivation

Continue the work from #11484

Modifications

Accuracy Tests

$ SGLANG_USE_AITER=1 NCCL_MIN_NCHANNELS=112 SGLANG_INT4_WEIGHT=0 SGLANG_MOE_PADDING=1 SGLANG_USE_ROCM700A=1 SGLANG_SET_CPU_AFFINITY=1 SGLANG_ROCM_FUSED_DECODE_MLA=1         SGLANG_AITER_AR=1         python3 -m sglang.launch_server     --model-path deepseek-ai/DeepSeek-R1-MXFP4-Preview/     --tensor-parallel-size 8     --trust-remote-code     --chunked-prefill-size 131072     --host 0.0.0.0     --port 8000        --mem-fraction-static 0.95     --speculative-algorithm EAGLE     --speculative-num-steps 3     --speculative-eagle-topk 1     --speculative-num-draft-tokens 4

$ python3 benchmark/gsm8k/bench_sglang.py --parallel 1400 --num-questions 1400

Accuracy: 0.942
Invalid: 0.000

Benchmarking and Profiling

Checklist

TP=8 results from torchrun --nproc_per_node=8 benchmark/kernels/all_reduce/benchmark_aiter.py:

    Size    SGLang(ms)    Aiter(ms)
-----------------------------------
     32K         0.038        0.045
     64K         0.046        0.058
    128K         0.042        0.042
    256K         0.053        0.053
    512K         0.046        0.044
      1M         0.058        0.053
      2M         0.056        0.050
      4M         0.082        0.069
      8M         0.099        0.081
     16M         0.167        0.130
     32M         0.270        0.201
     64M         0.508        0.365

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hubertlu-tw, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances SGLang's distributed capabilities by integrating AITER's custom all-reduce implementation for AMD GPUs. The primary goal is to leverage AITER's optimized kernels to improve the performance of all-reduce operations on ROCm-enabled systems. It provides a configurable way to enable or disable this new functionality and includes a dedicated benchmark to assess its impact.

Highlights

  • AITER Custom All-reduce Integration: Introduces support for AITER's custom all-reduce kernels, specifically for ROCm (AMD GPUs), allowing for potentially optimized distributed operations.
  • New Environment Variable: Adds SGLANG_AITER_AR (defaulting to true) to control whether the AITER custom all-reduce kernels are used. Setting it to 0 will disable them.
  • Dynamic All-reduce Dispatch: Implements a dispatch mechanism to conditionally select between SGLang's native custom all-reduce and AITER's implementation based on the environment and hardware (ROCm).
  • Benchmarking Script: A new benchmark script (benchmark_aiter.py) has been added to compare the performance of SGLang's custom all-reduce against AITER's across various message sizes.
  • Test Suite Enhancements: The custom all-reduce test suite has been updated to include larger message sizes and address potential Global server args is not set yet! errors by setting dummy server arguments.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@hubertlu-tw hubertlu-tw changed the title [AMD] Add AITER Custom All-reduce [AMD] Add AITER Custom All-Reduce Nov 12, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for AITER's custom all-reduce on AMD platforms, enabled via the SGLANG_AITER_AR environment variable. The changes are well-structured, including a dispatch mechanism to select the appropriate all-reduce implementation and a comprehensive benchmark script to compare performance. The tests have also been updated accordingly. My main feedback concerns a reduction in type safety in parallel_state.py where Optional[Any] is used. I've suggested an improvement to enhance maintainability.

)

self.ca_comm: Optional[CustomAllreduce] = None
self.ca_comm: Optional[Any] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Changing the type hint for ca_comm from Optional[CustomAllreduce] to Optional[Any] reduces type safety and maintainability. While this works because both sglang.CustomAllreduce and aiter.CustomAllreduce are expected to have a compatible interface, it would be better to define a Protocol or an abstract base class that both communicators implement. This would make the code more robust and easier to understand.

For example, you could define a protocol:

from typing import Protocol, Union, Optional
from torch.distributed import ProcessGroup
import torch

class AllReduceCommunicator(Protocol):
    def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], **kwargs):
        ...
    def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
        ...
    def should_custom_ar(self, inp: torch.Tensor) -> bool:
        ...
    def close(self) -> None:
        ...
    # Add other common methods like capture if needed

# Then use it as:
self.ca_comm: Optional[AllReduceCommunicator] = None

This would enforce that any class assigned to ca_comm has the required methods, improving static analysis and preventing potential runtime errors.

@HaiShaw HaiShaw added the run-ci label Nov 12, 2025
@HaiShaw HaiShaw merged commit e4b2937 into sgl-project:main Nov 13, 2025
76 of 79 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants