diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 3353aa2ea471..3d7e3a56b795 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -4,29 +4,42 @@ import torch -from sglang.srt.utils import is_hip, is_hpu, is_npu +from sglang.srt.utils import is_cuda, is_hip logger = logging.getLogger(__name__) +_is_cuda = is_cuda() +_is_hip = is_hip() -if not is_hpu(): - try: - import sgl_kernel - except ImportError as e: +IS_CUSTOM_AR_AVAILABLE = _is_cuda or _is_hip +IS_QUICK_AR_AVAILABLE = _is_hip +# TODO(zyksir): mscclpp is untested on AMD and therefore disabled. +IS_MSCCLPP_AR_AVAILABLE = _is_cuda + +try: + import sgl_kernel.allreduce as _custom_ar +except ImportError as e: + if _is_cuda or _is_hip: logger.warning("Failed to import from custom_ar with %r", e) + IS_CUSTOM_AR_AVAILABLE = False + IS_QUICK_AR_AVAILABLE = False + IS_MSCCLPP_AR_AVAILABLE = False + +# region IS_CUSTOM_AR_AVAILABLE +if not IS_CUSTOM_AR_AVAILABLE: + pass -if not is_hip() and not is_npu(): - custom_op = sgl_kernel.allreduce +elif _is_cuda: + # CUDA custom allreduce - # custom allreduce def init_custom_ar( ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor, rank: int, full_nvlink: bool, ) -> int: - return custom_op.init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink) + return _custom_ar.init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink) def all_reduce( fa: int, @@ -35,26 +48,26 @@ def all_reduce( reg_buffer: int, reg_buffer_sz_bytes: int, ) -> None: - custom_op.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) + _custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) def dispose(fa: int) -> None: - custom_op.dispose(fa) + _custom_ar.dispose(fa) def meta_size() -> int: - return custom_op.meta_size() + return _custom_ar.meta_size() def register_buffer(fa: int, ipc_tensors: List[int]) -> None: - return custom_op.register_buffer(fa, ipc_tensors) + return _custom_ar.register_buffer(fa, ipc_tensors) def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: - return custom_op.get_graph_buffer_ipc_meta(fa) + return _custom_ar.get_graph_buffer_ipc_meta(fa) def register_graph_buffers( fa: int, handles: List[List[int]], offsets: List[List[int]] ) -> None: - custom_op.register_graph_buffers(fa, handles, offsets) + _custom_ar.register_graph_buffers(fa, handles, offsets) -else: +elif _is_hip: # ROCM custom allreduce def init_custom_ar( @@ -65,55 +78,64 @@ def init_custom_ar( rank: int, full_nvlink: bool, ) -> int: - return sgl_kernel.allreduce.init_custom_ar( + return _custom_ar.init_custom_ar( meta, rank_data, handles, offsets, rank, full_nvlink ) def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - sgl_kernel.allreduce.all_reduce_reg(fa, inp, out) + _custom_ar.all_reduce_reg(fa, inp, out) def all_reduce_unreg( fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor ) -> None: - sgl_kernel.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out) + _custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out) def dispose(fa: int) -> None: - sgl_kernel.allreduce.dispose(fa) + _custom_ar.dispose(fa) def meta_size() -> int: - return sgl_kernel.allreduce.meta_size() + return _custom_ar.meta_size() def register_buffer( fa: int, t: torch.Tensor, handles: List[str], offsets: List[int] ) -> None: - return sgl_kernel.allreduce.register_buffer(fa, t, handles, offsets) + return _custom_ar.register_buffer(fa, t, handles, offsets) def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: - return sgl_kernel.allreduce.get_graph_buffer_ipc_meta(fa) + return _custom_ar.get_graph_buffer_ipc_meta(fa) def register_graph_buffers( fa: int, handles: List[str], offsets: List[List[int]] ) -> None: - sgl_kernel.allreduce.register_graph_buffers(fa, handles, offsets) + _custom_ar.register_graph_buffers(fa, handles, offsets) def allocate_meta_buffer(size: int) -> torch.Tensor: - return sgl_kernel.allreduce.allocate_meta_buffer(size) + return _custom_ar.allocate_meta_buffer(size) def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: - return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp) + return _custom_ar.get_meta_buffer_ipc_handle(inp) + + +# endregion + +# region IS_QUICK_AR_AVAILABLE +if not IS_QUICK_AR_AVAILABLE: + pass + +elif _is_hip: # ROCM custom quick allreduce def init_custom_qr( rank: int, world_size: int, qr_max_size: Optional[int] = None ) -> int: - return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size) + return _custom_ar.init_custom_qr(world_size, rank, qr_max_size) def qr_get_handle(fa: int) -> torch.Tensor: - return sgl_kernel.allreduce.qr_get_handle(fa) + return _custom_ar.qr_get_handle(fa) def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None: - sgl_kernel.allreduce.qr_open_handles(fa, handles) + _custom_ar.qr_open_handles(fa, handles) def qr_all_reduce( fa: int, @@ -122,44 +144,54 @@ def qr_all_reduce( quant_level: int, cast_bf2half: bool, ) -> None: - sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half) + _custom_ar.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half) def qr_destroy(fa: int) -> None: - sgl_kernel.allreduce.qr_destroy(fa) + _custom_ar.qr_destroy(fa) def qr_max_size() -> int: - return sgl_kernel.allreduce.qr_max_size() - - -def mscclpp_generate_unique_id() -> bytes: - return sgl_kernel.allreduce.mscclpp_generate_unique_id() - - -def mscclpp_init_context( - unique_id: bytes, - rank: int, - world_size: int, - scratch: torch.Tensor, - put_buffer: torch.Tensor, - nranks_per_node: int, - rank_to_node: List[int], - rank_to_ib: List[int], - context_selection: int, -) -> int: - return sgl_kernel.allreduce.mscclpp_init_context( - unique_id, - rank, - world_size, - scratch, - put_buffer, - nranks_per_node, - rank_to_node, - rank_to_ib, - context_selection, - ) - - -def mscclpp_allreduce( - context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int -) -> None: - return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks) + return _custom_ar.qr_max_size() + + +# endregion + +# region IS_MSCCLPP_AR_AVAILABLE + +if not IS_MSCCLPP_AR_AVAILABLE: + pass + +elif _is_cuda: + + def mscclpp_generate_unique_id() -> bytes: + return _custom_ar.mscclpp_generate_unique_id() + + def mscclpp_init_context( + unique_id: bytes, + rank: int, + world_size: int, + scratch: torch.Tensor, + put_buffer: torch.Tensor, + nranks_per_node: int, + rank_to_node: List[int], + rank_to_ib: List[int], + context_selection: int, + ) -> int: + return _custom_ar.mscclpp_init_context( + unique_id, + rank, + world_size, + scratch, + put_buffer, + nranks_per_node, + rank_to_node, + rank_to_ib, + context_selection, + ) + + def mscclpp_allreduce( + context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int + ) -> None: + return _custom_ar.mscclpp_allreduce(context, inp, out, nthreads, nblocks) + + +# endregion diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index b768d46bcea0..5fecab5e6087 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -21,16 +21,6 @@ from sglang.srt.environ import envs from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, log_info_on_rank0 -try: - # Use custom allreduce from sgl kernel (ROCM and TRT-LLM) - import sgl_kernel # noqa: F401 - - custom_ar = True -except ImportError: - # For CPUs - custom_ar = False - - _is_cuda = is_cuda() _is_hip = is_hip() @@ -79,7 +69,7 @@ def __init__( self.disabled = True # This can be modified in-place by context manager in piecewise cuda graph runner self.original_disabled = True # To store the original state - if not custom_ar: + if not ops.IS_CUSTOM_AR_AVAILABLE: # disable because of missing custom allreduce library # e.g. in a non-cuda environment return diff --git a/python/sglang/srt/distributed/device_communicators/pymscclpp.py b/python/sglang/srt/distributed/device_communicators/pymscclpp.py index 5d7511c2c2a9..78e1318fad14 100644 --- a/python/sglang/srt/distributed/device_communicators/pymscclpp.py +++ b/python/sglang/srt/distributed/device_communicators/pymscclpp.py @@ -11,25 +11,12 @@ from torch.distributed import ProcessGroup, ReduceOp from sglang.srt import _custom_ops as ops -from sglang.srt.utils import is_cuda, is_hip +from sglang.srt.utils import is_hip logger = logging.getLogger(__name__) -_is_cuda = is_cuda() _is_hip = is_hip() -mscclpp_is_available = False -if _is_hip: - # TODO(zyksir): mscclpp is untested on AMD and therefore disabled. - mscclpp_is_available = False -if _is_cuda: - try: - import sgl_kernel # noqa: F401 - - mscclpp_is_available = True - except: - mscclpp_is_available = False - class MscclContextSelection(IntEnum): MSCCL1SHOT1NODELL = 1 @@ -127,7 +114,7 @@ def __init__( self._IS_CAPTURING = False self.disabled = True - if not mscclpp_is_available: + if not ops.IS_MSCCLPP_AR_AVAILABLE: # disable because of missing mscclpp library # e.g. in a non-cuda environment return diff --git a/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py b/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py index de97af8168a5..0113f02c3a30 100644 --- a/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/quick_all_reduce.py @@ -24,14 +24,6 @@ _is_hip = is_hip() -try: - ops.qr_max_size() - quick_ar = True -except Exception: - # For CPUs and CUDA - quick_ar = False - - @cache def qr_rocm_arch_available(): if not _is_hip: @@ -101,7 +93,7 @@ def __init__( ) return - if not quick_ar: + if not ops.IS_QUICK_AR_AVAILABLE: # disable because of missing quick reduce library # e.g. in a cuda environment logger.info( diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index bd6e5b332cb1..e759c386d91f 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -122,6 +122,7 @@ def get_or_create_event_loop(): # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip +@lru_cache(maxsize=1) def is_hip() -> bool: return torch.version.hip is not None @@ -137,18 +138,22 @@ def is_hip() -> bool: builtins.FP8_E4M3_MIN = FP8_E4M3_MIN +@lru_cache(maxsize=1) def is_cuda(): return torch.cuda.is_available() and torch.version.cuda +@lru_cache(maxsize=1) def is_cuda_alike(): return is_cuda() or is_hip() +@lru_cache(maxsize=1) def is_hpu() -> bool: return hasattr(torch, "hpu") and torch.hpu.is_available() +@lru_cache(maxsize=1) def is_xpu() -> bool: return hasattr(torch, "xpu") and torch.xpu.is_available() @@ -158,6 +163,7 @@ def is_npu() -> bool: return hasattr(torch, "npu") and torch.npu.is_available() +@lru_cache(maxsize=1) def is_host_cpu_x86() -> bool: machine = platform.machine().lower() return ( @@ -167,6 +173,7 @@ def is_host_cpu_x86() -> bool: ) +@lru_cache(maxsize=1) def is_cpu() -> bool: return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86()