Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ srt_hpu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
# To install vllm for CPU, please follow the instruction here:
# https://docs.vllm.ai/en/latest/getting_started/installation/cpu/index.html
srt_cpu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11", "torch"]
# https://vllm-ascend.readthedocs.io/en/latest/installation.html
srt_npu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]

openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
Expand All @@ -107,6 +109,7 @@ all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[lit
all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_cpu = ["sglang[srt_cpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_npu = ["sglang[srt_npu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]

dev = ["sglang[all]", "sglang[test]"]
dev_hip = ["sglang[all_hip]", "sglang[test]"]
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/configs/device_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class DeviceConfig:
device: Optional[torch.device]

def __init__(self, device: str = "cuda") -> None:
if device in ["cuda", "xpu", "hpu", "cpu"]:
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
self.device_type = device
else:
raise RuntimeError(f"Not supported device type: {device}")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from sglang.srt.utils import is_npu


class NpuCommunicator:

def __init__(self, group: ProcessGroup):
if not is_npu():
self.disabled = True
return
self.disabled = False
self.group = group
self.world_size = dist.get_world_size(self.group)

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
dist.all_reduce(x, group=self.group)
return x

def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
if dim < 0:
# Convert negative dim to positive.
dim += x.dim()
input_size = x.size()
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size, dtype=x.dtype, device=x.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
# Reshape
output_tensor = output_tensor.reshape((world_size,) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
)
return output_tensor
23 changes: 22 additions & 1 deletion python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from sglang.srt.utils import (
direct_register_custom_op,
is_cuda_alike,
is_npu,
supports_custom_op,
)

Expand Down Expand Up @@ -206,6 +207,7 @@ def __init__(
use_custom_allreduce: bool,
use_hpu_communicator: bool,
use_xpu_communicator: bool,
use_npu_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
Expand Down Expand Up @@ -244,6 +246,7 @@ def __init__(
self.use_custom_allreduce = use_custom_allreduce
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator
self.use_npu_communicator = use_npu_communicator
self.use_message_queue_broadcaster = use_message_queue_broadcaster

# lazy import to avoid documentation build error
Expand Down Expand Up @@ -291,6 +294,14 @@ def __init__(
if use_xpu_communicator and self.world_size > 1:
self.xpu_communicator = XpuCommunicator(group=self.device_group)

from sglang.srt.distributed.device_communicators.npu_communicator import (
NpuCommunicator,
)

self.npu_communicator: Optional[NpuCommunicator] = None
if use_npu_communicator and self.world_size > 1:
self.npu_communicator = NpuCommunicator(group=self.device_group)

from sglang.srt.distributed.device_communicators.shm_broadcast import (
MessageQueue,
)
Expand Down Expand Up @@ -418,6 +429,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
if self.xpu_communicator is not None and not self.xpu_communicator.disabled:
return self.xpu_communicator.all_reduce(input_)

if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_)

if (
self.ca_comm is not None
and not self.ca_comm.disabled
Expand Down Expand Up @@ -497,6 +511,11 @@ def all_gather(
if hpu_comm is not None and not hpu_comm.disabled:
return hpu_comm.all_gather(input_, dim)

# For NPUs, use NPU communicator.
npu_comm = self.npu_communicator
if npu_comm is not None and not npu_comm.disabled:
return npu_comm.all_gather(input_, dim)

if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
Expand Down Expand Up @@ -941,6 +960,7 @@ def init_world_group(
use_custom_allreduce=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
use_npu_communicator=False,
group_name="world",
)

Expand All @@ -959,10 +979,11 @@ def init_model_parallel_group(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
use_pynccl=not is_npu(),
use_custom_allreduce=use_custom_allreduce,
use_hpu_communicator=True,
use_xpu_communicator=True,
use_npu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ def init_torch_distributed(self):
backend = "hccl"
elif self.device == "cpu":
backend = "gloo"
elif self.device == "npu":
backend = "hccl"

before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
if not self.server_args.enable_p2p_check:
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"--device",
type=str,
default=ServerArgs.device,
help="The device to use ('cuda', 'xpu', 'hpu', 'cpu'). Defaults to auto-detection if not specified.",
help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
)
parser.add_argument(
"--served-model-name",
Expand Down
24 changes: 24 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def is_xpu() -> bool:
return hasattr(torch, "xpu") and torch.xpu.is_available()


def is_npu() -> bool:
return hasattr(torch, "npu") and torch.npu.is_available()


def is_flashinfer_available():
"""
Check whether flashinfer is available.
Expand Down Expand Up @@ -328,6 +332,16 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
elif device == "cpu":
# TODO: rename the variables in the current function to be not GPU specific
free_gpu_memory = psutil.virtual_memory().available
elif device == "npu":
num_gpus = torch.npu.device_count()
assert gpu_id < num_gpus

if torch.npu.current_device() != gpu_id:
print(
f"WARNING: current device is not {gpu_id}, but {torch.npu.current_device()}, ",
"which may cause useless memory allocation for torch NPU context.",
)
free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()

if distributed:
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
Expand Down Expand Up @@ -1348,6 +1362,9 @@ def get_device_name(device_id: int = 0) -> str:
if hasattr(torch, "hpu") and torch.hpu.is_available():
return torch.hpu.get_device_name(device_id)

if hasattr(torch, "npu") and torch.npu.is_available():
return torch.npu.get_device_name(device_id)


@lru_cache(maxsize=1)
def is_habana_available() -> bool:
Expand Down Expand Up @@ -1444,6 +1461,13 @@ def get_compiler_backend() -> str:
if hasattr(torch, "hpu") and torch.hpu.is_available():
return "hpu_backend"

if hasattr(torch, "npu") and torch.npu.is_available():
import torchair

config = torchair.CompilerConfig()
npu_backend = torchair.get_npu_backend(compiler_config=config)
return npu_backend

return "inductor"


Expand Down
Loading