Skip to content
26 changes: 21 additions & 5 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def __init__(

# Draft workers are looked up via `SpeculativeAlgorithm` registry; new
# algorithms should register their factory instead of patching this code.
if self.spec_algorithm.name in {"EAGLE", "EAGLE3"}:
if self.spec_algorithm.is_eagle():
draft_worker_kwargs["enable_overlap"] = self.enable_overlap
self.draft_worker = self.spec_algorithm.create_draft_worker(
**draft_worker_kwargs
Expand Down Expand Up @@ -864,8 +864,16 @@ def init_disaggregation(self):
)
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size,
hidden_states_dtype=self.model_config.dtype,
hidden_size=(
self.draft_worker.model_config.hidden_size
if self.spec_algorithm.is_eagle()
else 64 # For safety reasons, won't be used
),
hidden_states_dtype=(
self.draft_worker.model_config.dtype
if self.spec_algorithm.is_eagle()
else torch.float32
),
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)

Expand Down Expand Up @@ -909,8 +917,16 @@ def init_disaggregation(self):
)
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size,
hidden_states_dtype=self.model_config.dtype,
hidden_size=(
self.draft_worker.model_config.hidden_size
if self.spec_algorithm.is_eagle()
else 64 # For safety reasons, won't be used
),
hidden_states_dtype=(
self.draft_worker.model_config.dtype
if self.spec_algorithm.is_eagle()
else torch.float32
),
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)

Expand Down
Loading