diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 4738b032f8dd..d994bb867740 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -836,6 +836,7 @@ def event_loop_normal_disagg_decode(self: Scheduler): while True: # Receive requests + self.iter_start_time = time.perf_counter() recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) # polling and allocating kv cache @@ -863,6 +864,7 @@ def event_loop_overlap_disagg_decode(self: Scheduler): while True: # Receive requests + self.iter_start_time = time.perf_counter() recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) # polling and allocating kv cache @@ -980,6 +982,10 @@ def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: for req in can_run_list: req.time_stats.forward_entry_time = time.perf_counter() + if self.enable_metrics: + self.metrics_collector.observe_request_waiting_time( + req.time_stats.get_request_waiting_time(), + ) # construct a schedule batch with those requests and mark as decode new_batch = ScheduleBatch.init_new( diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 39a824c3a81b..27a90d6d13bd 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -439,6 +439,15 @@ def process_batch_result_disagg_prefill( logits_output.input_token_logprobs.tolist() ) + if self.enable_metrics: + self.iter_forward_finish_time = time.time() + run_batch_time = ( + self.iter_forward_finish_time - self.iter_forward_start_time + ) + self.stats.run_batch_time = run_batch_time + self.metrics_collector.log_stats(self.stats) + + hidden_state_offset = 0 for i, (req, next_token_id) in enumerate( zip(batch.reqs, next_token_ids, strict=True) ): @@ -519,6 +528,9 @@ def process_batch_result_disagg_prefill( RequestStage.PREFILL_CHUNKED_FORWARD, req.rid, auto_next_anon=True ) + # Log DP-level prefill load-balancing metrics + if self.current_scheduler_metrics_enabled: + self.log_prefill_dp_balance_stats(batch) self.maybe_send_health_check_signal() def process_disagg_prefill_inflight_queue( @@ -577,6 +589,10 @@ def process_disagg_prefill_inflight_queue( for req in done_reqs: req.time_stats.completion_time = time.perf_counter() + if self.enable_metrics: + self.metrics_collector.observe_request_first_token_forward_time( + req.time_stats.get_request_first_token_forward_time() + ) # Stream requests which have finished transfer self.stream_output( diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 0d9e4e43352b..3cb8c4874472 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -1206,8 +1206,10 @@ async def continue_generation(obj: ContinueGenerationReqInput, request: Request) @app.post("/v1/completions", dependencies=[Depends(validate_json_request)]) async def openai_v1_completions(request: CompletionRequest, raw_request: Request): """OpenAI-compatible text completion endpoint.""" + # Timestamp when the HTTP request is received and handed off to the tokenizer + tokenizer_rev_request_time = time.time() return await raw_request.app.state.openai_serving_completion.handle_request( - request, raw_request + request, raw_request, tokenizer_rev_request_time ) @@ -1216,8 +1218,10 @@ async def openai_v1_chat_completions( request: ChatCompletionRequest, raw_request: Request ): """OpenAI-compatible chat completion endpoint.""" + # Timestamp when the HTTP request is received and handed off to the tokenizer + tokenizer_rev_request_time = time.time() return await raw_request.app.state.openai_serving_chat.handle_request( - request, raw_request + request, raw_request, tokenizer_rev_request_time ) diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index 6e01d2fd053f..46ab39b8cb6c 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -84,7 +84,10 @@ def _validate_lora_enabled(self, adapter_name: str) -> None: ) async def handle_request( - self, request: OpenAIServingRequest, raw_request: Request + self, + request: OpenAIServingRequest, + raw_request: Request, + tokenizer_rev_request_time: Optional[float] = None, ) -> Union[Any, StreamingResponse, ErrorResponse]: """Handle the specific request type with common pattern If you want to override this method, you should be careful to record the validation time. @@ -114,11 +117,17 @@ async def handle_request( # Note(Xinyuan): raw_request below is only used for detecting the connection of the client if hasattr(request, "stream") and request.stream: return await self._handle_streaming_request( - adapted_request, processed_request, raw_request + adapted_request, + processed_request, + raw_request, + tokenizer_rev_request_time, ) else: return await self._handle_non_streaming_request( - adapted_request, processed_request, raw_request + adapted_request, + processed_request, + raw_request, + tokenizer_rev_request_time, ) except HTTPException as e: return self.create_error_response( diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 73dbc6d942ab..e57ed873aba3 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -517,10 +517,13 @@ async def _handle_streaming_request( adapted_request: GenerateReqInput, request: ChatCompletionRequest, raw_request: Request, + tokenizer_rev_request_time: Optional[float] = None, ) -> StreamingResponse: """Handle streaming chat completion request""" return StreamingResponse( - self._generate_chat_stream(adapted_request, request, raw_request), + self._generate_chat_stream( + adapted_request, request, raw_request, tokenizer_rev_request_time + ), media_type="text/event-stream", background=self.tokenizer_manager.create_abort_task(adapted_request), ) @@ -530,6 +533,7 @@ async def _generate_chat_stream( adapted_request: GenerateReqInput, request: ChatCompletionRequest, raw_request: Request, + tokenizer_rev_request_time: Optional[float] = None, ) -> AsyncGenerator[str, None]: """Generate streaming chat completion response""" # Parsers for tool calls and reasoning @@ -551,7 +555,7 @@ async def _generate_chat_stream( try: async for content in self.tokenizer_manager.generate_request( - adapted_request, raw_request + adapted_request, raw_request, tokenizer_rev_request_time ): index = content.get("index", 0) @@ -769,11 +773,12 @@ async def _handle_non_streaming_request( adapted_request: GenerateReqInput, request: ChatCompletionRequest, raw_request: Request, + tokenizer_rev_request_time: Optional[float] = None, ) -> Union[ChatCompletionResponse, ErrorResponse, ORJSONResponse]: """Handle non-streaming chat completion request""" try: ret = await self.tokenizer_manager.generate_request( - adapted_request, raw_request + adapted_request, raw_request, tokenizer_rev_request_time ).__anext__() except ValueError as e: return self.create_error_response(str(e)) diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 8229de122ba6..038155dd7fc1 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -177,10 +177,13 @@ async def _handle_streaming_request( adapted_request: GenerateReqInput, request: CompletionRequest, raw_request: Request, + tokenizer_rev_request_time: Optional[float] = None, ) -> StreamingResponse: """Handle streaming completion request""" return StreamingResponse( - self._generate_completion_stream(adapted_request, request, raw_request), + self._generate_completion_stream( + adapted_request, request, raw_request, tokenizer_rev_request_time + ), media_type="text/event-stream", background=self.tokenizer_manager.create_abort_task(adapted_request), ) @@ -190,6 +193,7 @@ async def _generate_completion_stream( adapted_request: GenerateReqInput, request: CompletionRequest, raw_request: Request, + tokenizer_rev_request_time: Optional[float] = None, ) -> AsyncGenerator[str, None]: """Generate streaming completion response""" created = int(time.time()) @@ -206,7 +210,7 @@ async def _generate_completion_stream( try: async for content in self.tokenizer_manager.generate_request( - adapted_request, raw_request + adapted_request, raw_request, tokenizer_rev_request_time ): index = content.get("index", 0) @@ -341,11 +345,12 @@ async def _handle_non_streaming_request( adapted_request: GenerateReqInput, request: CompletionRequest, raw_request: Request, + tokenizer_rev_request_time: Optional[float] = None, ) -> Union[CompletionResponse, ErrorResponse, ORJSONResponse]: """Handle non-streaming completion request""" try: generator = self.tokenizer_manager.generate_request( - adapted_request, raw_request + adapted_request, raw_request, tokenizer_rev_request_time ) ret = await generator.__anext__() except ValueError as e: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 18cc3d2aa636..8a93f477f9d3 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -716,6 +716,9 @@ class TokenizedGenerateReqInput(BaseReq): # Session info for continual prompting session_params: Optional[SessionParams] = None + # Timestamp when tokenizer dispatches the request to the scheduler + dispatch_to_scheduler_time: Optional[float] = None + # LoRA related lora_id: Optional[str] = None # None means just use the base model @@ -924,6 +927,8 @@ class TokenizedEmbeddingReqInput(BaseReq): priority: Optional[int] = None # The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings. dimensions: Optional[int] = None + # Timestamp when tokenizer dispatches the request to the scheduler + dispatch_to_scheduler_time: Optional[float] = None @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 22e173192d68..717a18ff8ec7 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -512,6 +512,7 @@ def __init__( return_hidden_states: bool = False, return_routed_experts: bool = False, eos_token_ids: Optional[Set[int]] = None, + dispatch_to_scheduler_time: Optional[float] = None, bootstrap_host: Optional[str] = None, bootstrap_port: Optional[int] = None, bootstrap_room: Optional[int] = None, @@ -745,6 +746,12 @@ def __init__( self.has_log_time_stats: bool = False self.last_tic = time.monotonic() + # Timestamp when tokenizer dispatches the request to the scheduler + self.dispatch_to_scheduler_time = dispatch_to_scheduler_time + # TODO (suhang): Move the dispatch_to_scheduler_time synchronization into Req’s own initializer: + # Once dispatch_to_scheduler_time is passed into Req + # TimeStats can synchronize it automatically, so the scheduler no longer needs that extra getattr check. + # For disaggregation self.bootstrap_host: str = bootstrap_host self.bootstrap_port: Optional[int] = bootstrap_port @@ -1214,11 +1221,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): inner_idle_batch: Optional[ScheduleBatch] = None global_num_tokens: Optional[List[int]] = None global_num_tokens_for_logprob: Optional[List[int]] = None + dp_global_num_tokens_for_metric: Optional[List[int]] = None is_extend_in_batch: bool = False can_run_dp_cuda_graph: bool = False tbo_split_seq_index: Optional[int] = None global_forward_mode: Optional[ForwardMode] = None + # DP all_gather latency for this batch + all_gather_latency: float = 0.0 + # For processing logprobs return_logprob: bool = False top_logprobs_nums: Optional[List[int]] = None @@ -2195,6 +2206,8 @@ def copy(self): spec_algorithm=self.spec_algorithm, global_num_tokens=self.global_num_tokens, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, + dp_global_num_tokens_for_metric=self.dp_global_num_tokens_for_metric, + all_gather_latency=self.all_gather_latency, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, is_extend_in_batch=self.is_extend_in_batch, is_prefill_only=self.is_prefill_only, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1bf294973df1..48026e02656b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -319,12 +319,6 @@ def __init__( # Init model configs self.init_model_config() - # Init metrics stats - self.init_metrics(tp_rank, pp_rank, dp_rank) - - # Init inter-process communication - self.init_ipc_channels(port_args) - # Init PD-multiplexing context if self.enable_pdmux: self.init_pdmux() @@ -338,6 +332,12 @@ def __init__( # Launch a model worker and draft model worker if using speculative decoding self.init_model_worker() + # Init metrics stats + self.init_metrics(tp_rank, pp_rank, dp_rank) + + # Init inter-process communication + self.init_ipc_channels(port_args) + if (t := envs.SGLANG_TEST_STUCK_SCHEDULER_INIT.get()) > 0: time.sleep(t) @@ -1056,6 +1056,7 @@ def event_loop_normal(self): """A normal scheduler loop.""" while True: # Receive requests + self.iter_start_time = time.perf_counter() recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) if self._engine_paused: @@ -1092,6 +1093,7 @@ def pop_and_process(): while True: # Receive requests + self.iter_start_time = time.perf_counter() recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) if self._engine_paused: @@ -1429,6 +1431,7 @@ def handle_generate_request( return_hidden_states=recv_req.return_hidden_states, return_routed_experts=recv_req.return_routed_experts, eos_token_ids=self.model_config.hf_eos_token_id, + dispatch_to_scheduler_time=recv_req.dispatch_to_scheduler_time, bootstrap_host=recv_req.bootstrap_host, bootstrap_port=recv_req.bootstrap_port, bootstrap_room=recv_req.bootstrap_room, @@ -1444,6 +1447,12 @@ def handle_generate_request( dllm_config=self.dllm_config, ) req.tokenizer = self.tokenizer + if getattr(recv_req, "dispatch_to_scheduler_time", 0.0): + # Keep dispatch timestamp only when present, clamp to zero to avoid negative values + req.time_stats.dispatch_to_scheduler_time = max( + 0.0, recv_req.dispatch_to_scheduler_time + ) + req.time_stats.arrive_scheduler_time = time.perf_counter() if self.disaggregation_mode != DisaggregationMode.NULL: # Invalid request for disaggregated mode @@ -1674,6 +1683,7 @@ def handle_embedding_request( recv_req.input_text, recv_req.input_ids, recv_req.sampling_params, + dispatch_to_scheduler_time=recv_req.dispatch_to_scheduler_time, token_type_ids=recv_req.token_type_ids, priority=recv_req.priority, dimensions=recv_req.dimensions, @@ -2001,9 +2011,15 @@ def _get_new_batch_prefill_raw( if req.time_stats.forward_entry_time == 0: req.time_stats.forward_entry_time = time.perf_counter() if self.enable_metrics: + self.metrics_collector.observe_request_zmq_time( + req.time_stats.get_request_zmq_time(), + ) self.metrics_collector.observe_queue_time( req.time_stats.get_queueing_time(), ) + self.metrics_collector.observe_request_waiting_time( + req.time_stats.get_request_waiting_time(), + ) # Create a new batch new_batch = ScheduleBatch.init_new( @@ -2141,6 +2157,7 @@ def run_batch( ) -> Union[GenerationBatchResult, EmbeddingBatchResult]: """Run a batch.""" self.forward_ct += 1 + self.iter_forward_start_time = time.time() # Whether to run the profiler self._profile_batch_predicate(batch) diff --git a/python/sglang/srt/managers/scheduler_dp_attn_mixin.py b/python/sglang/srt/managers/scheduler_dp_attn_mixin.py index 9c92cd9c383b..60a0d72b4027 100644 --- a/python/sglang/srt/managers/scheduler_dp_attn_mixin.py +++ b/python/sglang/srt/managers/scheduler_dp_attn_mixin.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Optional @@ -102,7 +103,7 @@ def _update_gather_batch( def prepare_mlp_sync_batch_raw( - local_batch: ScheduleBatch, + local_batch: Optional[ScheduleBatch], dp_size: int, attn_tp_size: int, tp_group: GroupCoordinator, @@ -153,6 +154,9 @@ def prepare_mlp_sync_batch_raw( group = tp_group.cpu_group device = "cpu" + # Start timing DP all_gather + start_time = time.perf_counter() + local_can_run_tbo, local_forward_mode = tbo_preparer.prepare_all_gather(local_batch) mlp_sync_info = MLPSyncBatchInfo( @@ -169,11 +173,18 @@ def prepare_mlp_sync_batch_raw( if not skip_all_gather: mlp_sync_info.all_gather(device=device, group=group) - mlp_sync_info.tbo_split_seq_index, mlp_sync_info.global_forward_mode = ( - tbo_preparer.compute_output( - mlp_sync_info.tp0_info[:, 4:6], - ) + # DP all_gather latency (seconds) + all_gather_latency = time.perf_counter() - start_time + + if local_batch is not None: + local_batch.dp_global_num_tokens_for_metric = mlp_sync_info.global_num_tokens + local_batch.all_gather_latency = all_gather_latency + + mlp_sync_info.tbo_split_seq_index, mlp_sync_info.global_forward_mode = ( + tbo_preparer.compute_output( + mlp_sync_info.tp0_info[:, 4:6], ) + ) need_idle_batch = skip_all_gather or max(mlp_sync_info.global_num_tokens) > 0 if need_idle_batch: diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index c943d1886cce..4bb6350eb08e 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -113,6 +113,20 @@ def init_metrics( if self.enable_kv_cache_events: self.init_kv_events(self.server_args.kv_events_config) + # reocord iter time + self.iter_start_time = 0 + self.iter_forward_start_time = 0 + self.iter_forward_finish_time = 0 + self.iter_finish_time = 0 + + if self.enable_metrics: + self.stats.num_max_batchs = ( + self.server_args.max_running_requests + if self.server_args.max_running_requests is not None + else 0 + ) + self.stats.max_total_num_tokens = self.max_total_num_tokens + self.scheduler_status_logger = SchedulerStatusLogger.maybe_create() def init_kv_events(self: Scheduler, kv_events_config: Optional[str]): @@ -145,6 +159,11 @@ def log_prefill_stats( self.last_prefill_stats_tic = time.perf_counter() self.last_input_throughput = self.last_prefill_tokens / gap_latency self.last_prefill_tokens = adder.log_input_tokens + self.last_prefill_cache_tokens = adder.log_hit_tokens + # In PREFILL disaggregation, `self.running_batch` is decode-only; + # use the current prefill batch size to compute running_bs. + if self.disaggregation_mode == DisaggregationMode.PREFILL: + running_bs = len(can_run_list) assert self.temp_prefill_info is None self.temp_prefill_info = dict( @@ -278,6 +297,62 @@ def log_prefill_stats_late(self: Scheduler, batch: Optional[ScheduleBatch]): dp_cooperation_info=batch.dp_cooperation_info, ) + def log_prefill_dp_balance_stats(self: Scheduler, batch: ScheduleBatch) -> None: + """Log DP-level load-balancing metrics for the prefill stage.""" + tokens_list = None + total_dp = None + total_tokens = None + dp_balance = 0.0 + idle_batch_ratio = 1.0 + prefill_chunk_util = 0.0 + + if ( + batch is not None + and self.dp_rank == 0 + and self.chunked_prefill_size is not None + ): + # Prepare per-DP worker token counts + tokens_list = batch.dp_global_num_tokens_for_metric + + if tokens_list: + total_dp = len(tokens_list) + total_tokens = sum(tokens_list) + token_sorted = sorted(tokens_list) + + # Compute idle ratio and utilization metrics in a unified way + idle_batch_ratio = tokens_list.count(0) / total_dp + prefill_chunk_util = total_tokens / total_dp / self.chunked_prefill_size + + if total_dp > 1 and total_tokens > 0: + # Compute Gini coefficient and DP balance in the general case + acc = 0 + for i, val in enumerate(token_sorted, start=1): + acc += (2 * i - total_dp - 1) * val + gini = acc / (total_dp * total_tokens) + # Derive DP balance score from Gini coefficient + dp_balance = 1.0 - gini + else: + # When there is only one DP or no tokens, use a fixed DP balance value + # while keeping idle_batch_ratio and prefill_chunk_util as computed above + dp_balance = 1.0 if total_dp == 1 else 0.0 + + logger.info( + f"Prefill tokens_list: {tokens_list}, " + f"#total_dp: {total_dp}, " + f"total_tokens: {total_tokens}, " + f"#dp_balance: {dp_balance:.2f}, " + f"#idle_batch_ratio: {idle_batch_ratio:.2f}, " + f"#prefill_chunk_util: {prefill_chunk_util:.2f}, " + ) + + if self.enable_metrics: + # DP balance + self.stats.dp_balance = dp_balance + self.stats.idle_batch_ratio = idle_batch_ratio + self.stats.prefill_chunk_util = prefill_chunk_util + # Others + self.metrics_collector.log_stats(self.stats) + def log_decode_stats( self: Scheduler, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None ): @@ -291,6 +366,97 @@ def log_decode_stats( num_running_reqs = len(batch.reqs) num_running_reqs_offline_batch = 0 + if RECORD_STEP_TIME: + self.step_time_dict[num_running_reqs].append( + gap_latency / self.server_args.decode_log_interval + ) + + iter_msg = f" [{self.forward_ct}]" if LOG_FORWARD_ITERS else "" + + _ = self.log_decode_run_batch_stats(batch) + + dp_balance_msg = self.log_decode_dp_balance_stats(batch) + token_usage_msg = self.log_decode_token_usage(batch) + spec_msg = self.log_decode_spec_stats(batch) + disagg_queue_msg = self.log_decode_disagg_queue_stats(batch) + + msg = f"Decode batch{iter_msg}, #running-req: {num_running_reqs}, {token_usage_msg}" + msg += spec_msg + msg += disagg_queue_msg + msg += ( + f"{'cuda graph' if self.device == 'cuda' else 'cpu graph'}: {can_run_cuda_graph}, " + f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " + f"#queue-req: {len(self.waiting_queue)}, " + ) + msg += dp_balance_msg + + logger.info(msg) + if self.enable_metrics: + # Basics + self.stats.num_running_reqs = num_running_reqs + self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch + + self.stats.decode_sum_seq_lens = batch.seq_lens_cpu.sum().item() + self.stats.gen_throughput = self.last_gen_throughput + self.stats.num_queue_reqs = len(self.waiting_queue) + self.stats.num_grammar_queue_reqs = len(self.grammar_manager) + self.stats.cache_hit_rate = 0 + + self.stats.max_total_num_tokens = self.max_total_num_tokens + + # Retract + self.stats.num_retracted_reqs = self.num_retracted_reqs + self.stats.num_paused_reqs = self.num_paused_reqs + self.num_retracted_reqs = self.num_paused_reqs = 0 + + # Others + self.calculate_utilization() + self.update_lora_metrics() + self.metrics_collector.log_stats(self.stats) + self._emit_kv_metrics() + self._publish_kv_events() + + def log_decode_spec_stats(self: Scheduler, _: ScheduleBatch) -> str: + """Log speculative decoding metrics.""" + if self.spec_algorithm.is_none(): + spec_accept_length = 0 + spec_accept_rate = 0 + msg = "" + else: + spec_accept_length = ( + self.spec_num_accepted_tokens / self.spec_num_forward_ct + if self.spec_num_forward_ct > 0 + else 0 + ) + # Calculate acceptance rate: accepted tokens / total draft tokens + draft_tokens_fallback = (self.server_args.speculative_num_steps or 0) + 1 + num_draft_tokens = ( + self.server_args.speculative_num_draft_tokens or draft_tokens_fallback + ) + total_draft_tokens = self.spec_num_forward_ct * num_draft_tokens + + spec_accept_rate = ( + self.spec_num_accepted_tokens / total_draft_tokens + if total_draft_tokens > 0 + else 0 + ) + self.spec_total_num_accepted_tokens += self.spec_num_accepted_tokens + self.spec_total_num_forward_ct += self.spec_num_forward_ct + self.spec_num_accepted_tokens = self.spec_num_forward_ct = 0 + msg = f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, " + + if self.enable_metrics: + # Speculative decoding + self.stats.spec_accept_rate = spec_accept_rate + self.stats.spec_accept_length = spec_accept_length + + return msg + + def log_decode_token_usage(self: Scheduler, _: ScheduleBatch) -> str: + """Log token usage statistics during decode stage.""" + mamba_usage = 0 + swa_token_usage = 0 + # TODO: generalize this for various memory pools if self.is_hybrid_swa: ( @@ -334,79 +500,26 @@ def log_decode_stats( num_used, token_usage, _, _ = self._get_token_info() token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, " - if RECORD_STEP_TIME: - self.step_time_dict[num_running_reqs].append( - gap_latency / self.server_args.decode_log_interval - ) - - iter_msg = f" [{self.forward_ct}]" if LOG_FORWARD_ITERS else "" - msg = f"Decode batch{iter_msg}, #running-req: {num_running_reqs}, {token_usage_msg}" - - if self.spec_algorithm.is_none(): - spec_accept_length = 0 - spec_accept_rate = 0 - else: - spec_accept_length = ( - self.spec_num_accepted_tokens / self.spec_num_forward_ct - ) - # Calculate acceptance rate: accepted tokens / total draft tokens - draft_tokens_fallback = (self.server_args.speculative_num_steps or 0) + 1 - num_draft_tokens = ( - self.server_args.speculative_num_draft_tokens or draft_tokens_fallback - ) - total_draft_tokens = self.spec_num_forward_ct * num_draft_tokens - - spec_accept_rate = ( - self.spec_num_accepted_tokens / total_draft_tokens - if total_draft_tokens > 0 - else 0 - ) - self.spec_total_num_accepted_tokens += self.spec_num_accepted_tokens - self.spec_total_num_forward_ct += self.spec_num_forward_ct - self.spec_num_accepted_tokens = self.spec_num_forward_ct = 0 - msg += f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, " - cache_hit_rate = 0.0 - - if self.disaggregation_mode == DisaggregationMode.DECODE: - msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, " - msg += f"#prealloc-req: {len(self.disagg_decode_prealloc_queue.queue)}, " - msg += f"#transfer-req: {len(self.disagg_decode_transfer_queue.queue)}, " - msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, " - - msg += ( - f"{'cuda graph' if self.device == 'cuda' else 'cpu graph'}: {can_run_cuda_graph}, " - f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " - f"#queue-req: {len(self.waiting_queue)}, " - ) - - logger.info(msg) if self.enable_metrics: - # Basics - self.stats.num_running_reqs = num_running_reqs - self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch self.stats.num_used_tokens = num_used self.stats.token_usage = token_usage if self.is_hybrid_swa: self.stats.swa_token_usage = swa_token_usage if self.is_hybrid_ssm: self.stats.mamba_usage = mamba_usage - self.stats.decode_sum_seq_lens = batch.seq_lens_cpu.sum().item() - self.stats.gen_throughput = self.last_gen_throughput - self.stats.num_queue_reqs = len(self.waiting_queue) - self.stats.num_grammar_queue_reqs = len(self.grammar_manager) - self.stats.cache_hit_rate = cache_hit_rate - self.stats.max_total_num_tokens = self.max_total_num_tokens + return token_usage_msg - # Speculative decoding - self.stats.spec_accept_rate = spec_accept_rate - self.stats.spec_accept_length = spec_accept_length - - # Retract - self.stats.num_retracted_reqs = self.num_retracted_reqs - self.stats.num_paused_reqs = self.num_paused_reqs - self.num_retracted_reqs = self.num_paused_reqs = 0 + def log_decode_disagg_queue_stats(self: Scheduler, batch: ScheduleBatch) -> str: + """Log disaggregation queue statistics during decode stage.""" + msg = "" + if self.disaggregation_mode == DisaggregationMode.DECODE: + msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, " + msg += f"#prealloc-req: {len(self.disagg_decode_prealloc_queue.queue)}, " + msg += f"#transfer-req: {len(self.disagg_decode_transfer_queue.queue)}, " + msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, " + if self.enable_metrics: # PD disaggregation if self.disaggregation_mode == DisaggregationMode.PREFILL: self.stats.num_prefill_prealloc_queue_reqs = len( @@ -432,13 +545,84 @@ def log_decode_stats( _, self.stats.routing_key_all_req_counts = compute_routing_key_stats( running_routing_keys + waiting_routing_keys ) + return msg - # Others - self.calculate_utilization() - self.update_lora_metrics() - self.metrics_collector.log_stats(self.stats) - self._emit_kv_metrics() - self._publish_kv_events() + def log_decode_run_batch_stats(self: Scheduler, _: ScheduleBatch) -> str: + """Log runtime stats of running a single iteration""" + self.iter_finish_time = time.perf_counter() + generation_time = self.iter_finish_time - self.iter_start_time + run_batch_time = self.iter_forward_finish_time - self.iter_forward_start_time + iter_token_process_time = generation_time - run_batch_time + + if self.enable_metrics: + # Run batch + self.stats.generation_time = generation_time + self.stats.run_batch_time = run_batch_time + self.stats.iter_token_process_time = iter_token_process_time + + return "" + + def log_decode_dp_balance_stats(self: Scheduler, batch: ScheduleBatch) -> str: + """Log DP-level load-balancing metrics for the decode stage.""" + num_tokens_list: List[int] = [] + + num_total_tokens: int = 0 + # DP all_gather latency + all_gather_latency_us = 0.0 + + if batch is None or self.dp_rank != 0: + return "" + + assert batch.dp_global_num_tokens_for_metric is not None + num_tokens_list = batch.dp_global_num_tokens_for_metric + all_gather_latency_us = batch.all_gather_latency + + if num_tokens_list: + total_dp = len(num_tokens_list) + num_total_tokens = sum(num_tokens_list) + + # Compute idle ratio and utilization metrics in a unified way. + idle_batch_ratio = num_tokens_list.count(0) / total_dp + decode_bs_util: float = ( + num_total_tokens / total_dp / self.max_running_requests + ) + + if total_dp > 1 and num_total_tokens > 0: + # Compute Gini coefficient and DP balance in the general case. + acc = 0 + for i, val in enumerate(sorted(num_tokens_list), start=1): + acc += (2 * i - total_dp - 1) * val + gini = acc / (total_dp * num_total_tokens) + # Derive DP balance score from Gini coefficient + dp_balance = 1.0 - gini + else: + # When there is only one DP or no tokens, use a fixed DP balance value + # while keeping idle_batch_ratio and decode_bs_util as computed above. + dp_balance = 1.0 if total_dp == 1 else 0.0 + else: + total_dp = 0 + dp_balance = 0.0 + idle_batch_ratio = 1.0 + decode_bs_util = 0.0 + + msg = ( + f"Decode tokens_list: {num_tokens_list}, " + f"#total_dp: {total_dp}, " + f"total_tokens: {num_total_tokens}, " + f"#dp_balance: {dp_balance:.2f}, " + f"#idle_batch_ratio: {idle_batch_ratio:.2f}, " + f"#decode_bs_util: {decode_bs_util:.2f}, " + ) + + if self.enable_metrics: + self.stats.dp_balance = dp_balance + self.stats.idle_batch_ratio = idle_batch_ratio + self.stats.decode_bs_util = decode_bs_util + + # DP all_gather latency + self.stats.all_gather_latency_us = all_gather_latency_us + + return msg def log_decode_stats_every_iteration( self: Scheduler, batch: ScheduleBatch, num_accepted_tokens: int diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index ed614fea9d35..e124cf8063d9 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -117,6 +117,14 @@ def process_batch_result_prefill( logits_output.input_token_logprobs.tolist() ) + if self.enable_metrics: + self.iter_forward_finish_time = time.time() + run_batch_time = ( + self.iter_forward_finish_time - self.iter_forward_start_time + ) + self.stats.run_batch_time = run_batch_time + self.metrics_collector.log_stats(self.stats) + hidden_state_offset = 0 # Check finish conditions @@ -139,6 +147,10 @@ def process_batch_result_prefill( self.maybe_collect_routed_experts(req) release_kv_cache(req, self.tree_cache) req.time_stats.completion_time = time.perf_counter() + if self.enable_metrics: + self.metrics_collector.observe_request_first_token_forward_time( + req.time_stats.get_request_first_token_forward_time() + ) elif not batch.decoding_reqs or req not in batch.decoding_reqs: # This updates radix so others can match self.tree_cache.cache_unfinished_req(req) @@ -284,6 +296,9 @@ def process_batch_result_prefill( auto_next_anon=not req.finished(), thread_finish_flag=req.finished(), ) + # Log DP-level prefill load-balancing metrics + if self.current_scheduler_metrics_enabled: + self.log_prefill_dp_balance_stats(batch) self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) @@ -381,6 +396,7 @@ def process_batch_result_decode( if self.enable_metrics: self.metrics_collector.increment_cuda_graph_pass(value=can_run_cuda_graph) + self.iter_forward_finish_time = time.time() self.token_to_kv_pool_allocator.free_group_begin() # NOTE: in any case, we should check finish here diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a433a0597bf1..2a605bf16e6b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -476,6 +476,8 @@ async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], request: Optional[fastapi.Request] = None, + tokenizer_rev_request_time: Optional[float] = None, + trace_parent: Optional[str] = None, ): created_time = obj.received_time if obj.received_time else time.time() self.auto_create_handle_loop() @@ -501,7 +503,9 @@ async def generate_request( # Tokenize the request and send it to the scheduler if obj.is_single: - tokenized_obj = await self._tokenize_one_request(obj) + tokenized_obj = await self._tokenize_one_request( + obj, tokenizer_rev_request_time + ) state = self._send_one_request(obj, tokenized_obj, created_time) async for response in self._wait_one_response(obj, state, request): yield response @@ -652,6 +656,7 @@ async def _tokenize_texts( async def _tokenize_one_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], + tokenizer_rev_request_time: Optional[float] = None, ): """Tokenize one request.""" # Tokenize @@ -684,6 +689,17 @@ async def _tokenize_one_request( input_text, is_cross_encoder_request ) + # Record tokenize request ready time + tokenizer_ready_time = time.time() + if ( + tokenizer_rev_request_time is not None + and tokenizer_rev_request_time > 0 + and self.enable_metrics + ): + self.metrics_collector.observe_request_time_to_tokenizer_ready( + tokenizer_ready_time - tokenizer_rev_request_time + ) + if self.mm_processor and obj.contains_mm_input(): if obj.image_data is not None and not isinstance(obj.image_data, list): obj.image_data = [obj.image_data] @@ -1047,6 +1063,7 @@ def _send_one_request( ): trace_slice_start(RequestStage.TOKENIZER_DISPATCH, obj.rid) tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid) + tokenized_obj.dispatch_to_scheduler_time = time.perf_counter() self.send_to_scheduler.send_pyobj(tokenized_obj) state = ReqState([], False, asyncio.Event(), obj, created_time=created_time) state.request_sent_to_scheduler_ts = time.time() diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index 316cc177a627..f1b44244fbff 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -59,6 +59,8 @@ class TimeStats: disagg_mode: DisaggregationMode = DisaggregationMode.NULL lb_entry_time: float = 0.0 + dispatch_to_scheduler_time: float = 0.0 + arrive_scheduler_time: float = 0.0 wait_queue_entry_time: float = 0.0 forward_entry_time: float = 0.0 completion_time: float = 0.0 @@ -81,6 +83,13 @@ class TimeStats: # maintain unit consistency with other timestamp fields tracked by the `ReqState` class. prefill_finished_ts: float = 0.0 + def get_request_zmq_time(self) -> float: + """get_request_zmq_time""" + # Avoid pushing uninitialized values into metrics + if self.dispatch_to_scheduler_time <= 0.0 or self.arrive_scheduler_time <= 0.0: + return 0.0 + return max(0.0, self.arrive_scheduler_time - self.dispatch_to_scheduler_time) + def get_queueing_time(self) -> float: return self.forward_entry_time - self.wait_queue_entry_time @@ -99,6 +108,22 @@ def get_prefill_finished_ts(self) -> Optional[float]: return self.prefill_finished_ts return None + def get_request_waiting_time(self) -> float: + """get_request_waiting_time""" + if self.disagg_mode == DisaggregationMode.NULL: + return self.forward_entry_time - self.wait_queue_entry_time + elif self.disagg_mode == DisaggregationMode.PREFILL: + return self.forward_entry_time - self.prefill_bootstrap_queue_entry_time + elif self.disagg_mode == DisaggregationMode.DECODE: + return self.forward_entry_time - self.decode_transfer_queue_entry_time + + def get_request_first_token_forward_time(self) -> float: + """get_request_first_token_forward_time""" + if self.disagg_mode == DisaggregationMode.NULL: + return self.completion_time - self.wait_queue_entry_time + elif self.disagg_mode == DisaggregationMode.PREFILL: + return self.completion_time - self.prefill_bootstrap_queue_entry_time + def convert_to_duration(self) -> str: if self.disagg_mode == DisaggregationMode.NULL: queue_duration = self.forward_entry_time - self.wait_queue_entry_time @@ -207,10 +232,25 @@ class SchedulerStats: num_queue_reqs: int = 0 num_grammar_queue_reqs: int = 0 num_running_reqs_offline_batch: int = 0 + num_max_batchs: int = 0 cache_hit_rate: float = 0.0 max_total_num_tokens: int = 0 + # Run batch + generation_time: float = 0.0 + run_batch_time: float = 0.0 + iter_token_process_time: float = 0.0 + + # DP balance + dp_balance: float = 0.0 + idle_batch_ratio: float = 0.0 + decode_bs_util: float = 0.0 + prefill_chunk_util: float = 0.0 + + # DP all_gather latency + all_gather_latency_us: float = 0.0 + # Speculative decoding spec_accept_length: float = 0.0 spec_accept_rate: float = 0.0 @@ -363,6 +403,12 @@ def __init__( labelnames=labels.keys(), multiprocess_mode="mostrecent", ) + self.num_max_batchs = Gauge( + name="sglang:num_max_batchs", + documentation="The number of max running requests", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) self.cache_hit_rate = Gauge( name="sglang:cache_hit_rate", documentation="The prefix cache hit rate.", @@ -377,6 +423,56 @@ def __init__( multiprocess_mode="mostrecent", ) + # Run batch + self.generation_time = Gauge( + name="sglang:generation_time", + documentation="The generation time", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + self.run_batch_time = Gauge( + name="sglang:run_batch_time", + documentation="The run batch time", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + self.iter_token_process_time = Gauge( + name="sglang:iter_token_process_time", + documentation="The time between token", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + + # DP balance + self.dp_balance = Gauge( + name="sglang:dp_balance", + documentation="Variance of token counts across all DPs, indicating load balance", + labelnames=labels.keys(), + ) + self.idle_batch_ratio = Gauge( + name="sglang:idle_batch_ratio", + documentation="Idle batch ratio across all DPs", + labelnames=labels.keys(), + ) + self.decode_bs_util = Gauge( + name="sglang:decode_bs_util", + documentation="Active decode requests ratio per DP worker.", + labelnames=labels.keys(), + ) + self.prefill_chunk_util = Gauge( + name="sglang:prefill_chunk_util", + documentation="Used prefill tokens ratio per DP worker.", + labelnames=labels.keys(), + ) + + # DP all_gather latency + self.all_gather_latency_us = Gauge( + name="sglang:all_gather_latency_us", + documentation="The dp prepare allgather time", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + # Speculative decoding self.spec_accept_length = Gauge( name="sglang:spec_accept_length", @@ -514,6 +610,30 @@ def __init__( multiprocess_mode="mostrecent", ) + # Additional zmq time histogram + self.request_zmq_time = Histogram( + name="sglang:request_zmq_time_seconds", + documentation="Histogram of zmq time in seconds.", + labelnames=labels.keys(), + buckets=[ + 0.01, + 0.05, + 0.1, + 0.25, + 0.5, + 0.75, + 1, + 2, + 3, + 5, + 10, + 20, + 40, + 60, + 80, + 120, + ], + ) # Additional queueing time histogram self.queue_time = Histogram( name="sglang:queue_time_seconds", @@ -558,6 +678,52 @@ def __init__( 3000, ], ) + self.request_waiting_time = Histogram( + name="sglang:request_waiting_time", + documentation="Histogram of request waiting time in seconds", + labelnames=labels.keys(), + buckets=[ + 0.01, + 0.05, + 0.1, + 0.25, + 0.5, + 0.75, + 1, + 2, + 5, + 10, + 20, + 40, + 60, + 80, + 120, + 160, + ], + ) + self.histogram_request_first_token_forward_time = Histogram( + name="sglang:request_first_token_forward_time", + documentation="Histogram of request first token forward time in seconds", + labelnames=labels.keys(), + buckets=[ + 0.01, + 0.05, + 0.1, + 0.25, + 0.5, + 0.75, + 1, + 2, + 3, + 5, + 10, + 20, + 40, + 60, + 80, + 120, + ], + ) # Grammar metrics self.grammar_compilation_time = Histogram( @@ -858,6 +1024,14 @@ def observe_per_stage_req_latency(self, stage: str, latency: float) -> None: labels_with_stage = {**self.labels, "stage": stage} self.per_stage_req_latency_seconds.labels(**labels_with_stage).observe(latency) + def observe_request_first_token_forward_time(self, value: Union[float, int]): + """observe_request_first_token_forward_time""" + self._log_histogram(self.histogram_request_first_token_forward_time, value) + + def observe_request_zmq_time(self, latency: float) -> None: + """observe_request_zmq_time""" + self._log_histogram(self.request_zmq_time, latency) + def observe_queue_time(self, latency: float) -> None: self._log_histogram(self.queue_time, latency) @@ -945,6 +1119,10 @@ def increment_gpu_execution_seconds( **dp_cooperation_info.to_labels(), ).inc(t) + def observe_request_waiting_time(self, latency: Union[float, int]) -> None: + """observe_request_waiting_time""" + self._log_histogram(self.request_waiting_time, latency) + def log_stats(self, stats: SchedulerStats) -> None: self._log_gauge(self.num_running_reqs, stats.num_running_reqs) self._log_gauge(self.num_used_tokens, stats.num_used_tokens) @@ -961,10 +1139,25 @@ def log_stats(self, stats: SchedulerStats) -> None: self._log_gauge( self.num_running_reqs_offline_batch, stats.num_running_reqs_offline_batch ) + self._log_gauge(self.num_max_batchs, stats.num_max_batchs) self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate) self._log_gauge(self.max_total_num_tokens, stats.max_total_num_tokens) + # Run batch + self._log_gauge(self.generation_time, stats.generation_time) + self._log_gauge(self.run_batch_time, stats.run_batch_time) + self._log_gauge(self.iter_token_process_time, stats.iter_token_process_time) + + # DP balance + self._log_gauge(self.dp_balance, stats.dp_balance) + self._log_gauge(self.idle_batch_ratio, stats.idle_batch_ratio) + self._log_gauge(self.prefill_chunk_util, stats.prefill_chunk_util) + self._log_gauge(self.decode_bs_util, stats.decode_bs_util) + + # DP all_gather latency + self._log_gauge(self.all_gather_latency_us, stats.all_gather_latency_us) + # Speculative decoding self._log_gauge(self.spec_accept_length, stats.spec_accept_length) self._log_gauge(self.spec_accept_rate, stats.spec_accept_rate) @@ -1282,6 +1475,28 @@ def __init__( ], ) + self.histogram_request_time_to_tokenizer_ready = Histogram( + name="sglang:request_time_to_tokenizer_ready_seconds", + documentation="Histogram of time to tokenizer request in seconds", + labelnames=labels.keys(), + buckets=[ + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1, + 2, + 5, + 10, + 20, + 40, + 80, + 120, + ], + ) + def observe_one_finished_request( self, labels: Dict[str, str], @@ -1341,6 +1556,12 @@ def observe_inter_token_latency( def observe_one_aborted_request(self, labels: Dict[str, str]): self.num_aborted_requests_total.labels(**labels).inc(1) + def observe_request_time_to_tokenizer_ready(self, latency: float) -> None: + """Observe the time it took to generate a request from HTTP receipt to tokenizer-ready input.""" + self.histogram_request_time_to_tokenizer_ready.labels(**self.labels).observe( + latency + ) + @dataclass class StorageMetrics: