Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions python/sglang/srt/managers/overlap_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from dataclasses import dataclass
from typing import Optional

import torch

from sglang.srt.managers.schedule_batch import ModelWorkerBatch
Expand All @@ -13,6 +16,12 @@ def _resolve_future_token_ids(input_ids, future_token_ids_map):
)


@dataclass
class FutureIndices:
indices: torch.Tensor
interval: Optional[slice] = None


class FutureMap:
def __init__(
self,
Expand All @@ -30,23 +39,17 @@ def __init__(
(self.future_buffer_len,), dtype=torch.int64, device=self.device
)

def update_ct(self, bs: int) -> int:
"""Update the circular buffer pointer and return the current pointer."""
def alloc_future_indices(self, bs: int) -> FutureIndices:
"""Update the circular buffer pointer and allocate future indices."""
cur_future_ct = self.future_ct
self.future_ct = (cur_future_ct + bs) % self.future_limit
return cur_future_ct
start = cur_future_ct + 1
end = cur_future_ct + 1 + bs
indices = torch.arange(start, end, dtype=torch.int64, device=self.device)
return FutureIndices(indices=indices, interval=slice(start, end))

def resolve_future(self, model_worker_batch: ModelWorkerBatch):
_resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)

def update_next_future(self, future_ct: int, bs: int):
return torch.arange(
-(future_ct + 1),
-(future_ct + 1 + bs),
-1,
dtype=torch.int64,
device=self.device,
)

def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor):
self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids
def store_to_map(self, future_indices: FutureIndices, next_token_ids: torch.Tensor):
self.token_ids_buf[future_indices.interval] = next_token_ids
19 changes: 8 additions & 11 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.mm_utils import init_embedding_cache
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.overlap_utils import FutureIndices, FutureMap
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
ModelWorkerBatch,
Expand Down Expand Up @@ -217,7 +217,7 @@ class GenerationBatchResult:
copy_done: Optional[torch.cuda.Event] = None
delay_sample_launch: bool = False
forward_batch: Optional[ForwardBatch] = None
future_map_ct: Optional[int] = None
future_indices: Optional[FutureIndices] = None

def copy_to_cpu(self, return_logprob: bool = False):
"""Copy tensors to CPU in overlap scheduling.
Expand Down Expand Up @@ -2092,7 +2092,7 @@ def run_batch(
)

bs = len(model_worker_batch.seq_lens)
cur_future_map_ct = self.future_map.update_ct(bs)
future_indices = self.future_map.alloc_future_indices(bs)

with self.forward_stream_ctx:
self.forward_stream.wait_stream(self.default_stream)
Expand All @@ -2108,22 +2108,19 @@ def run_batch(
).Event()
if not model_worker_batch.delay_sample_launch:
self.future_map.store_to_map(
cur_future_map_ct, bs, batch_result.next_token_ids
future_indices, batch_result.next_token_ids
)
batch_result.copy_to_cpu()
else:
batch_result.future_map_ct = cur_future_map_ct
batch_result.future_indices = future_indices

# FIXME(lsyin): move this assignment elsewhere
maybe_future_next_token_ids = self.future_map.update_next_future(
cur_future_map_ct, bs
)
maybe_future_next_token_ids = -future_indices.indices
else:
batch_result = self.model_worker.forward_batch_generation(
batch_or_worker_batch
)
maybe_future_next_token_ids = batch_result.next_token_ids
copy_done = None

if not self.spec_algorithm.is_none():
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
Expand Down Expand Up @@ -2182,8 +2179,8 @@ def launch_last_batch_sample_if_needed(
tmp_result.logits_output,
tmp_result.forward_batch,
)
ct, bs = tmp_result.future_map_ct, len(tmp_batch.reqs)
self.future_map.store_to_map(ct, bs, tmp_result.next_token_ids)
future_indices = tmp_result.future_indices
self.future_map.store_to_map(future_indices, tmp_result.next_token_ids)
tmp_result.copy_to_cpu()
self.result_queue.appendleft((tmp_batch, tmp_result))

Expand Down
Loading