Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
009db01
Do eagle_worker_v2.py
merrymercy Jul 22, 2025
93c09d9
Simplify
merrymercy Jul 22, 2025
97765bc
simplify eagle worker v2
merrymercy Jul 22, 2025
9b0bc9e
Simplify
merrymercy Jul 22, 2025
700f16d
move spec_info.prepare_for_verify
merrymercy Jul 22, 2025
f030b93
Split sample
merrymercy Jul 22, 2025
18552d6
Refactor verify
merrymercy Jul 22, 2025
5b78345
draft_extend_v2
merrymercy Jul 22, 2025
5a70cd6
simplify
merrymercy Jul 22, 2025
c3b7dde
Fix eagle worker v2
merrymercy Jul 22, 2025
393a442
Verify done
merrymercy Jul 22, 2025
96634f4
Split draft_cuda_graph prepartion
merrymercy Jul 22, 2025
9040f1f
Move plan_stream_ctx
merrymercy Jul 22, 2025
5171923
Faster multi step triton draft plan
merrymercy Jul 22, 2025
86fa056
Ready ckpt: simplify build_tree_kernel_efficient_preprocess
merrymercy Jul 22, 2025
12d96e6
support deepseek
merrymercy Jul 23, 2025
24511f7
update triton backend
merrymercy Jul 23, 2025
e6e7f5c
Move move_kv_cache to an earlier point
merrymercy Jul 23, 2025
dd3900a
Ready checkpoint: works for llama, deepseek
merrymercy Jul 24, 2025
4bc2cc7
bs1 overlap profile works
hanming-lu Aug 11, 2025
4f12127
bs>1 overlap profile work for send one
hanming-lu Aug 11, 2025
197766e
minor
hanming-lu Aug 11, 2025
5826cf8
minor
hanming-lu Aug 11, 2025
91e0ab4
keep a reference to avoid memory cleanup
hanming-lu Aug 12, 2025
4552d90
minor cleanup
hanming-lu Aug 12, 2025
9f71d1d
proper fix
hanming-lu Aug 13, 2025
d249e4e
bs4 send one passes; bs1 race
hanming-lu Aug 13, 2025
0a7c8d6
wait for copy_done before resolve
hanming-lu Aug 13, 2025
861c6b7
bs64 gsm8k passing; TODO: optimize resolve future
hanming-lu Aug 14, 2025
907540c
working version but needs to hide resolve latency
hanming-lu Aug 14, 2025
08bc92c
minor opt
hanming-lu Aug 14, 2025
f079e0e
minor cleanup; overlap spec dec works; TODO: non-overlap sd + overlap…
hanming-lu Aug 14, 2025
996de63
non-overlap sd done
hanming-lu Aug 14, 2025
b68219d
non-overlap no-sd passes; overlap no-sd runs but low acc
hanming-lu Aug 14, 2025
73cc358
overlap non-sd gsm8k passing
hanming-lu Aug 15, 2025
a496885
cleanup; ready for review
hanming-lu Aug 15, 2025
509e5ec
lint
hanming-lu Aug 15, 2025
2877ccb
Merge branch 'main' into lsyin/poc-overlap-spec-fix
hnyls2002 Sep 20, 2025
e60f45a
fix conflicts
hnyls2002 Sep 20, 2025
89ae3ca
fix missing key `is_prefill_only` in forward batch
hnyls2002 Sep 20, 2025
1fa196f
completely remove `launch_done`
hnyls2002 Sep 20, 2025
27ab91f
minor comment TODO
hnyls2002 Sep 20, 2025
7b34ab3
fix triton launch sync
hnyls2002 Sep 20, 2025
540b926
fix non radix cache
hnyls2002 Sep 20, 2025
8b93864
add todo here
hnyls2002 Sep 20, 2025
da4ba09
move future map
hnyls2002 Sep 20, 2025
deea418
remove sync in future && fix kv loc data race
hnyls2002 Sep 20, 2025
3b911d7
upadte comments
hnyls2002 Sep 21, 2025
793393a
Merge branch 'main' into lsyin/poc-overlap-spec
hnyls2002 Sep 21, 2025
adc453d
fix import error
hnyls2002 Sep 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ def process_batch_result_disagg_prefill(
self: Scheduler,
batch: ScheduleBatch,
result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
) -> None:
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
Expand All @@ -388,8 +387,8 @@ def process_batch_result_disagg_prefill(
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
if self.enable_overlap:
# wait
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
launch_done
logits_output, next_token_ids, _ = (
self.tp_worker.resolve_last_batch_result()
)
else:
next_token_ids = result.next_token_ids.tolist()
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,7 +1428,7 @@ def _execute_server_warmup(
return False

# Debug print
# logger.info(f"warmup request returns: {res.json()=}")
logger.info(f"warmup request returns: {res.json()=}")
return success


Expand Down
20 changes: 20 additions & 0 deletions python/sglang/srt/layers/attention/base_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInfo


class AttentionBackend(ABC):
Expand Down Expand Up @@ -54,6 +55,25 @@ def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise NotImplementedError()

def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers of verify attention kernels that needs to be filled after draft.

Typically, these are tree mask and position buffers.
"""
return [None, None]

def update_verify_buffers_to_fill_after_draft(
self, spec_info: SpecInfo, cuda_graph_bs: Optional[int]
):
"""
Update the buffers returned by get_verify_fill_after_draft_buffers if needed.

Here, we need to redo the computation of all metadata of the attention backend
that depends on tree mask and position buffers.
"""
raise NotImplementedError()

def forward(
self,
q: torch.Tensor,
Expand Down
95 changes: 76 additions & 19 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInfo


def logit_capping_mod(logit_capping_method, logit_cap):
Expand Down Expand Up @@ -139,6 +140,13 @@ def __init__(
# Initialize forward metadata
self.forward_metadata: ForwardMetadata = None

self.max_context_len = model_runner.model_config.context_len

self.cuda_graph_custom_mask = None

self.device = model_runner.device
self.device_core_count = get_device_core_count(model_runner.gpu_id)

def get_num_kv_splits(
self,
num_kv_splits: torch.Tensor,
Expand Down Expand Up @@ -251,6 +259,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
mask_indptr = None
max_extend_len = None
elif forward_batch.forward_mode.is_target_verify():
# TODO: Support sliding window in spec inference
bs = len(forward_batch.req_pool_indices)
qo_indptr = torch.arange(
0,
Expand Down Expand Up @@ -329,7 +338,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
forward_batch.extend_prefix_lens.sum().item(),
kv_indptr[-1].item(),
dtype=torch.int64,
device=self.device,
)
Expand Down Expand Up @@ -388,6 +397,7 @@ def init_cuda_graph_state(
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
cuda_graph_num_kv_splits_buf: Optional[torch.Tensor] = None,
):
self.cuda_graph_attn_logits = torch.zeros(
(max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
Expand All @@ -399,9 +409,17 @@ def init_cuda_graph_state(
dtype=torch.float32,
device=self.device,
)
self.cuda_graph_num_kv_splits = torch.full(
(max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
)

if cuda_graph_num_kv_splits_buf is None:
self.cuda_graph_num_kv_splits = torch.full(
(max_num_tokens,),
self.max_kv_splits,
dtype=torch.int32,
device=self.device,
)
else:
self.cuda_graph_num_kv_splits = cuda_graph_num_kv_splits_buf

if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros(
(max_num_tokens * self.max_context_len),
Expand Down Expand Up @@ -693,7 +711,7 @@ def init_forward_metadata_replay_cuda_graph(
)
)
custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
# custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
mask_indptr = self.mask_indptr[: bs + 1]
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
Expand Down Expand Up @@ -722,6 +740,19 @@ def init_forward_metadata_replay_cuda_graph(
def get_cuda_graph_seq_len_fill_value(self):
return 1

def get_verify_buffers_to_fill_after_draft(self):
"""
Return buffers for verify attention kernels that needs to be filled after draft.

Typically, these are tree mask and position buffers.
"""
return [self.cuda_graph_custom_mask, None]

def update_verify_buffers_to_fill_after_draft(
self, spec_info: SpecInfo, cuda_graph_bs: Optional[int]
):
pass

def forward_extend(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -907,7 +938,7 @@ def common_template(
self.page_size,
)

for i in range(self.speculative_num_steps):
for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1)
Expand Down Expand Up @@ -941,9 +972,18 @@ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
dtype=torch.int64,
device=self.device,
)
self.cuda_graph_num_kv_splits = torch.full(
(max_num_tokens,),
self.attn_backends[0].max_kv_splits,
dtype=torch.int32,
device=self.device,
)
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
max_bs,
max_num_tokens,
kv_indices_buf=self.cuda_graph_kv_indices[i],
cuda_graph_num_kv_splits_buf=self.cuda_graph_num_kv_splits,
)

def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
Expand All @@ -963,19 +1003,36 @@ def call_fn(i, forward_batch):
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=None,
)
old_bs = bs
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs

self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
# Generate kv indices
# Directly write to cuda graph buffers
self.generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk)
](
forward_batch.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.seq_lens,
self.cuda_graph_kv_indices,
self.kv_indptr,
forward_batch.positions,
self.pool_len,
self.cuda_graph_kv_indices.shape[1],
self.kv_indptr.shape[1],
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
next_power_of_2(bs),
self.page_size,
)

# Set num_kv_split only once because cuda_graph_num_kv_splits is shared across steps
num_token = bs
self.attn_backends[-1].get_num_kv_splits(
self.attn_backends[-1].cuda_graph_num_kv_splits[:num_token],
forward_batch.seq_lens[:old_bs],
)


@triton.jit
Expand Down
28 changes: 6 additions & 22 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,7 @@ class LogitsMetadata:

@classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch):
if (
forward_batch.forward_mode.is_extend()
and forward_batch.return_logprob
and not forward_batch.forward_mode.is_target_verify()
):
if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
extend_return_top_logprob = any(
x > 0 for x in forward_batch.top_logprobs_nums
)
Expand Down Expand Up @@ -260,6 +256,7 @@ def forward(
if (
logits_metadata.forward_mode.is_decode_or_idle()
or logits_metadata.forward_mode.is_target_verify()
or logits_metadata.forward_mode.is_draft_extend_v2()
):
pruned_states = hidden_states
if aux_hidden_states is not None:
Expand All @@ -270,22 +267,7 @@ def forward(
logits_metadata.forward_mode.is_extend()
and not logits_metadata.extend_return_logprob
):
# Prefill without input logprobs.
if logits_metadata.padded_static_len < 0:
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
else:
# If padding_static length is 5 and extended_seq_lens is [2, 3],
# then our batch looks like [t00, t01, p, p, p, t10, t11, t12, p, p]
# and this retrieves t01 and t12, which are the valid last tokens
idx = torch.arange(
len(logits_metadata.extend_seq_lens),
device=logits_metadata.extend_seq_lens.device,
)
last_index = (
idx * logits_metadata.padded_static_len
+ logits_metadata.extend_seq_lens
- 1
)
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
pruned_states = hidden_states[last_index]
if aux_hidden_states is not None:
aux_pruned_states = [hidden[last_index] for hidden in aux_hidden_states]
Expand Down Expand Up @@ -375,7 +357,9 @@ def forward(
else pruned_states
)
else:
assert False, "Should never reach"
raise ValueError(
f"Invalid capture hidden mode: {logits_metadata.capture_hidden_mode=}"
)

if not logits_metadata.extend_return_logprob:
# Decode mode or extend mode without return_logprob.
Expand Down
Loading
Loading