Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
054ed81
feat: mtp support dp-attention with cuda-graph (#6080)
May 12, 2025
a602a29
fix dp+mtp bugs
May 27, 2025
ed6b060
Merge branch 'main' into feature_mtp_support_dp_attention
u4lr451 May 28, 2025
6cc38e7
fix: MTP+cudagraph+DPAtten and fa3
TianQiLin666666 May 30, 2025
a526032
Merge remote-tracking branch 'github/main' into feature_mtp_support_d…
May 31, 2025
672d6be
feat:Enable CUDA Graph for draft_extend while supporting dp-attention…
May 31, 2025
b130867
fix: Adjust the init_cuda_graph_state and fixbug (#6081)
May 31, 2025
35fe3df
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 1, 2025
3ceedbe
Performance: Eliminate performance impact in non-dp-attention+mtp sce…
Jun 3, 2025
04ede24
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 3, 2025
96b7209
fix bugs for mtp (#6081)
Jun 4, 2025
54dd1f7
fix enable cuda graph for draft_extend stage while supporting dp-atte…
Jun 6, 2025
b01de94
Merge branch 'main' into feature_mtp_support_dp_attention
u4lr451 Jun 6, 2025
5805662
Merge branch 'main' into feature_mtp_support_dp_attention
ch-wan Jun 6, 2025
990fe38
Merge branch 'main' into feature_mtp_support_dp_attention
Qiaolin-Yu Jun 7, 2025
658fd39
Added test cases for dp-attention + mtp (#6081)
Jun 7, 2025
8e47432
Merge commit '60fdad7cf343333e956a3889c12956396a1516bf' into u4lr451:…
Jun 9, 2025
57e8f1c
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 9, 2025
e15db54
Update mtp+dp-attention test cases (#6081)
Jun 9, 2025
64cc457
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 9, 2025
ed7d4e2
Merge branch 'main' into feature_mtp_support_dp_attention
u4lr451 Jun 9, 2025
5cba657
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 10, 2025
b54f934
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 10, 2025
c336c53
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 10, 2025
76f6cde
compatibility for fa3 (#6081)
Jun 10, 2025
23f82db
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 10, 2025
9dff016
fix
Qiaolin-Yu Jun 11, 2025
cc124fb
fix
Qiaolin-Yu Jun 11, 2025
55aefb7
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 11, 2025
7d44df1
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 11, 2025
767ff45
Merge branch 'main' into feature_mtp_support_dp_attention
ch-wan Jun 11, 2025
4982404
Merge branch 'main' into feature_mtp_support_dp_attention
u4lr451 Jun 11, 2025
9be85b7
Remove redundant code (#6081)
Jun 11, 2025
6690410
Merge branch 'main' into feature_mtp_support_dp_attention
Qiaolin-Yu Jun 11, 2025
d4ec8c8
nit update
ch-wan Jun 12, 2025
1218312
nit fix (#6081)
Jun 12, 2025
42d2403
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 12, 2025
6f9478a
Merge branch 'main' into feature_mtp_support_dp_attention
ch-wan Jun 12, 2025
4e54751
Merge branch 'main' into feature_mtp_support_dp_attention
zhyncs Jun 13, 2025
ec987fc
update scheduler and eagle worker
ch-wan Jun 13, 2025
9c86afe
update eagle_worker (#6081)
Jun 13, 2025
b0cb235
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 13, 2025
973edde
update forward_batch_speculative_generation
Jun 14, 2025
4f299ae
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 14, 2025
d2a162f
Merge commit '55561e25533f195e6d6b11e1c3d2449bc9908495' into pr/u4lr4…
ch-wan Jun 15, 2025
6e7c69e
polish global sync
ch-wan Jun 15, 2025
37af1a2
refactor eagle_worker.py
ch-wan Jun 15, 2025
64cc292
fix
ch-wan Jun 15, 2025
5c6b93e
Merge branch 'main' into feature_mtp_support_dp_attention
ch-wan Jun 15, 2025
3744a72
Merge remote-tracking branch 'origin/HEAD' into pr/u4lr451/6081
ch-wan Jun 15, 2025
ab26c11
format
ch-wan Jun 15, 2025
c07ba77
fix refactor bug
Jun 15, 2025
ff07187
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 15, 2025
97f531b
fix enable_dp_lm_head when dp-size == tp-size
Jun 16, 2025
3f686b1
Performance: Support enabling CUDA graph when idle batches exist
Jun 16, 2025
5ae3c3d
Merge remote-tracking branch 'github/main' into u4lrssh.feature_mtp_s…
Jun 16, 2025
c27fa92
transfer hidden_states but bug
Atream Jun 16, 2025
f3854ee
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 16, 2025
841defa
refine code for dp lm head
ch-wan Jun 16, 2025
2f64ad7
Merge branch 'main' into feature_mtp_support_dp_attention
zhyncs Jun 16, 2025
a279680
Revert "Performance: Support enabling CUDA graph when idle batches ex…
ch-wan Jun 16, 2025
038ca0f
add a note
ch-wan Jun 17, 2025
3bc16e4
Merge commit '873ae12cee348dcb579a4c7456d789ef4441f3bf' into pr/u4lr4…
ch-wan Jun 17, 2025
16f8a63
Merge branch 'main' into feature_mtp_support_dp_attention
zhyncs Jun 17, 2025
e4bf571
fix merge error
ch-wan Jun 17, 2025
3a5b9d5
clean code and add comments
ch-wan Jun 17, 2025
f40bdb2
Merge branch 'pr-6081' into mtp-dp-pd
Atream Jun 17, 2025
a5feca1
remove hard code, some code is copied from bytedance-iaas/ayrnb
Atream Jun 17, 2025
d30108f
Merge branch 'main' into mtp-dp-pd
zhyncs Jun 17, 2025
5c49666
remove hard code
Atream Jun 18, 2025
b9509bd
Merge branch 'main' into mtp-dp-pd
ShangmingCai Jun 18, 2025
decdd3a
fix None
Atream Jun 18, 2025
b7eeef4
Merge branch 'zbx' into mtp-dp-pd
Atream Jun 18, 2025
34f5019
get spec algorithm from model runner instead of import
Atream Jun 18, 2025
efbbb50
format
Atream Jun 18, 2025
ddbf506
fix getting spec_algorithm from scheduler
Atream Jun 18, 2025
726d64c
Merge branch 'main' into mtp-dp-pd
zhyncs Jun 18, 2025
91273ba
format
Atream Jun 19, 2025
064f518
fix bug
Atream Jun 19, 2025
40f6a51
fix no mtp
Atream Jun 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def __init__(
self.metadata_buffers = metadata_buffers
self.scheduler = scheduler
self.tree_cache = tree_cache
self.spec_algorithm = scheduler.spec_algorithm

def add(self, decode_req: DecodeRequest) -> None:
self.queue.append(decode_req)
Expand Down Expand Up @@ -581,14 +582,16 @@ def pop_transferred(self) -> List[Req]:
idx = decode_req.metadata_buffer_index
(
output_id,
output_hidden_states,
output_token_logprobs_val,
output_token_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
) = self.metadata_buffers.get_buf(idx)

decode_req.req.output_ids.append(output_id[0].item())

if not self.spec_algorithm.is_none():
decode_req.req.hidden_states_tensor = output_hidden_states
if decode_req.req.return_logprob:
decode_req.req.output_token_logprobs_val.append(
output_token_logprobs_val[0].item()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,16 @@ def process_prebuilt_extend(
)
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)

hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)

# local import to avoid circular import
from sglang.srt.speculative.eagle_utils import EagleDraftInput

spec_info = EagleDraftInput(
topk_p=topk_p,
topk_index=topk_index,
hidden_states=torch.ones(
(b, model_config.hidden_size), device=self.device
),
hidden_states=hidden_states,
verified_id=self.output_ids,
)
spec_info.prepare_for_extend(self)
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def process_batch_result_disagg_prefill(
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)

hidden_state_offset = 0
for i, (req, next_token_id) in enumerate(
zip(batch.reqs, next_token_ids, strict=True)
):
Expand All @@ -402,6 +404,16 @@ def process_batch_result_disagg_prefill(
req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
self.disagg_prefill_inflight_queue.append(req)
if logits_output.hidden_states is not None:
last_hidden_index = (
hidden_state_offset + extend_input_len_per_req[i] - 1
)
req.hidden_states_tensor = (
logits_output.hidden_states[last_hidden_index].cpu().clone()
)
hidden_state_offset += extend_input_len_per_req[i]
else:
req.hidden_states_tensor = None
if req.return_logprob:
assert extend_logprob_start_len_per_req is not None
assert extend_input_len_per_req is not None
Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/disaggregation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class MetadataBuffers:
def __init__(
self,
size: int,
hidden_size: int,
dtype: torch.dtype,
max_top_logprobs_num: int = 128,
custom_mem_pool: torch.cuda.MemPool = None,
):
Expand All @@ -104,6 +106,10 @@ def __init__(
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)

self.output_hidden_states = torch.zeros(
(size, hidden_size), dtype=dtype, device=device
)
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
Expand All @@ -120,20 +126,23 @@ def __init__(
def get_buf_infos(self):
ptrs = [
self.output_ids.data_ptr(),
self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(),
self.output_top_logprobs_idx.data_ptr(),
]
data_lens = [
self.output_ids.nbytes,
self.output_hidden_states.nbytes,
self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes,
self.output_top_logprobs_idx.nbytes,
]
item_lens = [
self.output_ids[0].nbytes,
self.output_hidden_states[0].nbytes,
self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes,
Expand All @@ -144,6 +153,7 @@ def get_buf_infos(self):
def get_buf(self, idx: int):
return (
self.output_ids[idx],
self.output_hidden_states[idx],
self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx],
Expand All @@ -153,6 +163,10 @@ def get_buf(self, idx: int):
def set_buf(self, req: Req):

self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
if req.hidden_states_tensor is not None:
self.output_hidden_states[req.metadata_buffer_index].copy_(
req.hidden_states_tensor
)
if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,7 @@ def __init__(
self.output_token_ids_logprobs_idx
) = None
self.hidden_states: List[List[float]] = []
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP

# Embedding (return values)
self.embedding = None
Expand Down
10 changes: 8 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,8 @@ def init_disaggregation(self):
)
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size,
dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)

Expand Down Expand Up @@ -678,6 +680,8 @@ def init_disaggregation(self):
)
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size,
dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)

Expand Down Expand Up @@ -1692,13 +1696,15 @@ def run_batch(
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing.
if batch.return_logprob:
if batch.return_logprob or self.spec_algorithm.is_eagle():
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
else:
extend_input_len_per_req = None
if batch.return_logprob:
extend_logprob_start_len_per_req = [
req.extend_logprob_start_len for req in batch.reqs
]
else:
extend_input_len_per_req = None
extend_logprob_start_len_per_req = None

ret = GenerationBatchResult(
Expand Down
Loading