diff --git a/python/sglang/srt/managers/overlap_utils.py b/python/sglang/srt/managers/overlap_utils.py index 1a011717ac1b..f91212b75160 100644 --- a/python/sglang/srt/managers/overlap_utils.py +++ b/python/sglang/srt/managers/overlap_utils.py @@ -1,3 +1,6 @@ +from dataclasses import dataclass +from typing import Optional + import torch from sglang.srt.managers.schedule_batch import ModelWorkerBatch @@ -13,6 +16,12 @@ def _resolve_future_token_ids(input_ids, future_token_ids_map): ) +@dataclass +class FutureIndices: + indices: torch.Tensor + interval: Optional[slice] = None + + class FutureMap: def __init__( self, @@ -30,23 +39,17 @@ def __init__( (self.future_buffer_len,), dtype=torch.int64, device=self.device ) - def update_ct(self, bs: int) -> int: - """Update the circular buffer pointer and return the current pointer.""" + def alloc_future_indices(self, bs: int) -> FutureIndices: + """Update the circular buffer pointer and allocate future indices.""" cur_future_ct = self.future_ct self.future_ct = (cur_future_ct + bs) % self.future_limit - return cur_future_ct + start = cur_future_ct + 1 + end = cur_future_ct + 1 + bs + indices = torch.arange(start, end, dtype=torch.int64, device=self.device) + return FutureIndices(indices=indices, interval=slice(start, end)) def resolve_future(self, model_worker_batch: ModelWorkerBatch): _resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf) - def update_next_future(self, future_ct: int, bs: int): - return torch.arange( - -(future_ct + 1), - -(future_ct + 1 + bs), - -1, - dtype=torch.int64, - device=self.device, - ) - - def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor): - self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids + def store_to_map(self, future_indices: FutureIndices, next_token_ids: torch.Tensor): + self.token_ids_buf[future_indices.interval] = next_token_ids diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 68203d51e988..eedb28e790af 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -114,7 +114,7 @@ UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.mm_utils import init_embedding_cache -from sglang.srt.managers.overlap_utils import FutureMap +from sglang.srt.managers.overlap_utils import FutureIndices, FutureMap from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, ModelWorkerBatch, @@ -217,7 +217,7 @@ class GenerationBatchResult: copy_done: Optional[torch.cuda.Event] = None delay_sample_launch: bool = False forward_batch: Optional[ForwardBatch] = None - future_map_ct: Optional[int] = None + future_indices: Optional[FutureIndices] = None def copy_to_cpu(self, return_logprob: bool = False): """Copy tensors to CPU in overlap scheduling. @@ -2092,7 +2092,7 @@ def run_batch( ) bs = len(model_worker_batch.seq_lens) - cur_future_map_ct = self.future_map.update_ct(bs) + future_indices = self.future_map.alloc_future_indices(bs) with self.forward_stream_ctx: self.forward_stream.wait_stream(self.default_stream) @@ -2108,22 +2108,19 @@ def run_batch( ).Event() if not model_worker_batch.delay_sample_launch: self.future_map.store_to_map( - cur_future_map_ct, bs, batch_result.next_token_ids + future_indices, batch_result.next_token_ids ) batch_result.copy_to_cpu() else: - batch_result.future_map_ct = cur_future_map_ct + batch_result.future_indices = future_indices # FIXME(lsyin): move this assignment elsewhere - maybe_future_next_token_ids = self.future_map.update_next_future( - cur_future_map_ct, bs - ) + maybe_future_next_token_ids = -future_indices.indices else: batch_result = self.model_worker.forward_batch_generation( batch_or_worker_batch ) maybe_future_next_token_ids = batch_result.next_token_ids - copy_done = None if not self.spec_algorithm.is_none(): # TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing @@ -2182,8 +2179,8 @@ def launch_last_batch_sample_if_needed( tmp_result.logits_output, tmp_result.forward_batch, ) - ct, bs = tmp_result.future_map_ct, len(tmp_batch.reqs) - self.future_map.store_to_map(ct, bs, tmp_result.next_token_ids) + future_indices = tmp_result.future_indices + self.future_map.store_to_map(future_indices, tmp_result.next_token_ids) tmp_result.copy_to_cpu() self.result_queue.appendleft((tmp_batch, tmp_result))