[PD] feat: refactor custom mem pool and add barex pd support#12332
[PD] feat: refactor custom mem pool and add barex pd support#12332ShangmingCai merged 22 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @stmatengss, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates support for Barex, a specialized point-to-point transfer library, into the system to facilitate GDR-style PD disaggregation. It refactors the custom memory pool configuration mechanism, transitioning the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for the BAREX custom memory pool, which is a valuable addition. The implementation correctly handles the new allocator type and maintains backward compatibility. However, I've identified significant code duplication in handling the SGLANG_MOONCAKE_CUSTOM_MEM_POOL environment variable and creating the memory pool. This logic is repeated across four different places. Additionally, the list of supported custom memory pool types is duplicated and, in some cases, hardcoded. I recommend refactoring the duplicated logic into a utility function and defining the list of supported types as a central constant to improve maintainability and consistency.
| logger = logging.getLogger(__name__) | ||
|
|
||
| # Global constants for custom memory pool types | ||
| SUPPORTED_CUSTOM_MEM_POOL_TYPES = ["NVLINK", "BAREX"] |
There was a problem hiding this comment.
| custom_mem_pool_type = os.getenv("SGLANG_MOONCAKE_CUSTOM_MEM_POOL") | ||
| if custom_mem_pool_type is not None: | ||
| # Handle boolean True as NVLINK | ||
| if custom_mem_pool_type.lower() == "true": | ||
| custom_mem_pool_type = "NVLINK" | ||
| self.enable_custom_mem_pool = ( | ||
| custom_mem_pool_type in SUPPORTED_CUSTOM_MEM_POOL_TYPES | ||
| ) | ||
| else: | ||
| self.enable_custom_mem_pool = False |
There was a problem hiding this comment.
There's significant code duplication in handling the SGLANG_MOONCAKE_CUSTOM_MEM_POOL environment variable and creating the custom memory pool. This logic is repeated in:
python/sglang/srt/disaggregation/mooncake/conn.py(L206-215)python/sglang/srt/mem_cache/memory_pool.pyinMambaPool(L152-181)python/sglang/srt/mem_cache/memory_pool.pyinKVCache(L454-481)python/sglang/srt/mem_cache/memory_pool.pyinSWAKVPool(L980-1007)
To improve maintainability, I recommend refactoring this into utility functions in a shared module like sglang.srt.utils.
One function to get the pool info:
def get_custom_mem_pool_info() -> (bool, Optional[str]):
custom_mem_pool_type = os.getenv("SGLANG_MOONCAKE_CUSTOM_MEM_POOL")
if custom_mem_pool_type is None:
return False, None
if custom_mem_pool_type.lower() == "true":
custom_mem_pool_type = "NVLINK"
if custom_mem_pool_type not in SUPPORTED_CUSTOM_MEM_POOL_TYPES:
return False, None
return True, custom_mem_pool_typeAnd another to create the pool:
def create_custom_mem_pool(pool_type: str, device: str) -> torch.cuda.MemPool:
if pool_type == "NVLINK":
from mooncake.allocator import NVLinkAllocator
allocator = NVLinkAllocator.get_allocator(device)
elif pool_type == "BAREX":
from mooncake.allocator import BarexAllocator
allocator = BarexAllocator.get_allocator(device)
else:
raise ValueError(f"Unsupported custom mem pool type: {pool_type}")
return torch.cuda.MemPool(allocator.allocator())This would centralize the logic and make future changes much easier.
|
|
||
| GB = 1024 * 1024 * 1024 | ||
|
|
||
| SUPPORTED_CUSTOM_MEM_POOL_TYPES = ["NVLINK", "BAREX"] |
There was a problem hiding this comment.
These lines should not be placed in memory_pool this file.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
@ShangmingCai All these changes are similar to NVLinkAllocator. PTAL. |
| custom_mem_pool_type = os.getenv("SGLANG_MOONCAKE_CUSTOM_MEM_POOL") | ||
| if custom_mem_pool_type is not None: | ||
| # Handle boolean True as NVLINK | ||
| if custom_mem_pool_type.lower() == "true": | ||
| custom_mem_pool_type = "NVLINK" | ||
| self.enable_custom_mem_pool = ( | ||
| custom_mem_pool_type in SUPPORTED_CUSTOM_MEM_POOL_TYPES | ||
| ) | ||
| else: | ||
| self.enable_custom_mem_pool = False | ||
|
|
There was a problem hiding this comment.
Try using the attributes of Envs.SGLANG_MOONCAKE_CUSTOM_MEM_POOL instead of os.getenv.
| self.enable_custom_mem_pool = get_bool_env_var( | ||
| "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" | ||
| ) | ||
| custom_mem_pool_type = os.getenv("SGLANG_MOONCAKE_CUSTOM_MEM_POOL") | ||
| if custom_mem_pool_type is not None: | ||
| # Handle boolean True as NVLINK | ||
| if custom_mem_pool_type.lower() == "true": | ||
| custom_mem_pool_type = "NVLINK" | ||
| self.enable_custom_mem_pool = custom_mem_pool_type in SUPPORTED_CUSTOM_MEM_POOL_TYPES | ||
| else: | ||
| self.enable_custom_mem_pool = False | ||
|
|
| self.enable_custom_mem_pool = get_bool_env_var( | ||
| "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" | ||
| ) | ||
| custom_mem_pool_type = os.getenv("SGLANG_MOONCAKE_CUSTOM_MEM_POOL") | ||
| if custom_mem_pool_type is not None: | ||
| # Handle boolean True as NVLINK | ||
| if custom_mem_pool_type.lower() == "true": | ||
| custom_mem_pool_type = "NVLINK" | ||
| self.enable_custom_mem_pool = custom_mem_pool_type in SUPPORTED_CUSTOM_MEM_POOL_TYPES | ||
| else: | ||
| self.enable_custom_mem_pool = False | ||
|
|
ShangmingCai
left a comment
There was a problem hiding this comment.
Maybe I should abstract custom allocator class first, so that we don't need to replicate the code if custom_mem_pool_type == xxxx: in every memory pool.
Sure. We can abstract the allocator class to support more in-house GPU link protocols. Should this be implemented in Mooncake or SGLang? |
We better put it in the Mooncake side, and use env var to switch allocator at the mooncake side with the default value set to nvlink. Then we can unified the code in the sglang. |
I believe some logic isn't common on the mooncake side. Therefore, I've extracted the switching allocator's logic into |
Co-authored-by: Shangming Cai <csmthu@gmail.com>
| import triton.language as tl | ||
|
|
||
| from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE | ||
| from sglang.srt.disaggregation.mooncake.utils import init_mooncake_custom_mem_pool |
There was a problem hiding this comment.
| from sglang.srt.disaggregation.mooncake.utils import init_mooncake_custom_mem_pool | |
| from sglang.srt.utils import maybe_init_custom_mem_pool |
Then we define a function in sglang.srt.utils:
def maybe_init_custom_mem_pool(
device: str,
) -> Tuple[bool, Optional[Any], Optional[str]]:
# This function can be modified to support more features that require a custom memory pool.
enable_custom_mem_pool = True if envs.SGLANG_MOONCAKE_CUSTOM_MEM_POOL.get() is not None else False
if enable_custom_mem_pool:
# Currently, only mooncake requires a custom mem pool for MNNVL PD disaggregation
from sglang.srt.disaggregation.mooncake.utils import init_mooncake_custom_mem_pool
return init_mooncake_custom_mem_pool(device)
else:
return False, None, NoneThere was a problem hiding this comment.
Just this minor suggestion, others LGTM.
There was a problem hiding this comment.
Remember to fix lint.
|
PTAL when it's convenient. Thanks! @xiezhq-hermann |
| Tuple of (enable_custom_mem_pool, custom_mem_pool, custom_mem_pool_type) | ||
| """ | ||
| enable_custom_mem_pool = ( | ||
| True if envs.SGLANG_MOONCAKE_CUSTOM_MEM_POOL.get() is not None else False |
There was a problem hiding this comment.
does this support other transfer engine as well? if so, might be a good practice to have a more general set of statements.
There was a problem hiding this comment.
Unsure if other transfer engines support these transports; if so, add them to utils.py.
Motivation
Usage
Modifications
SGLANG_MOONCAKE_CUSTOM_MEM_POOLfromBooltoStringtype, and keep forward compatibility (still supportSGLANG_MOONCAKE_CUSTOM_MEM_POOL=true|false).BarexAllocatorfrom mooncake, and make it as defaulttorch.cuda.mempool.Accuracy Tests
Benchmarking and Profiling
Checklist