Skip to content

Commit cf1da38

Browse files
alisonshaotonyluj
authored andcommitted
Revert PR sgl-project#14044: Restore separate memory pool for piecewise CUDA graph (sgl-project#14278)
1 parent 7ac73b8 commit cf1da38

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

python/sglang/srt/distributed/device_communicators/pynccl_allocator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ def is_symmetric_memory_enabled():
7171

7272
def set_graph_pool_id(graph_pool_id):
7373
global _graph_pool_id
74-
if _graph_pool_id is not None:
75-
_graph_pool_id = graph_pool_id
74+
_graph_pool_id = graph_pool_id
7675

7776

7877
def disable_symmetric_memory_context():

python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,6 @@
4545
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
4646
from sglang.srt.layers.pooler import EmbeddingPoolerOutput
4747
from sglang.srt.layers.torchao_utils import save_gemlite_cache
48-
from sglang.srt.model_executor.cuda_graph_runner import (
49-
get_global_graph_memory_pool,
50-
set_global_graph_memory_pool,
51-
)
5248
from sglang.srt.model_executor.forward_batch_info import (
5349
CaptureHiddenMode,
5450
ForwardBatch,
@@ -147,6 +143,19 @@ def patch_model(model: torch.nn.Module, compiler: str):
147143
_to_torch(model, reverse=True, num_tokens=16)
148144

149145

146+
# Reuse this memory pool across all cuda graph runners.
147+
global_graph_memory_pool = None
148+
149+
150+
def get_global_graph_memory_pool():
151+
return global_graph_memory_pool
152+
153+
154+
def set_global_graph_memory_pool(val):
155+
global global_graph_memory_pool
156+
global_graph_memory_pool = val
157+
158+
150159
def set_torch_compile_config():
151160
import torch._dynamo.config
152161

0 commit comments

Comments
 (0)