diff --git a/sgl-router/README.md b/sgl-router/README.md index baa894e1fc97..6cc10c159117 100644 --- a/sgl-router/README.md +++ b/sgl-router/README.md @@ -93,6 +93,19 @@ python -m sglang_router.launch_router \ --prometheus-port 9000 ``` +### Request ID Tracking + +Track requests across distributed systems with configurable headers: + +```bash +# Use custom request ID headers +python -m sglang_router.launch_router \ + --worker-urls http://localhost:8080 \ + --request-id-headers x-trace-id x-request-id +``` + +Default headers: `x-request-id`, `x-correlation-id`, `x-trace-id`, `request-id` + ## Advanced Features ### Kubernetes Service Discovery diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py index af1ce392c0b6..9337c4eaa0ac 100644 --- a/sgl-router/py_src/sglang_router/launch_router.py +++ b/sgl-router/py_src/sglang_router/launch_router.py @@ -64,6 +64,8 @@ class RouterArgs: # Prometheus configuration prometheus_port: Optional[int] = None prometheus_host: Optional[str] = None + # Request ID headers configuration + request_id_headers: Optional[List[str]] = None @staticmethod def add_cli_args( @@ -255,6 +257,12 @@ def add_cli_args( default="127.0.0.1", help="Host address to bind the Prometheus metrics server", ) + parser.add_argument( + f"--{prefix}request-id-headers", + type=str, + nargs="*", + help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.", + ) @classmethod def from_cli_args( @@ -313,6 +321,7 @@ def from_cli_args( bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation prometheus_port=getattr(args, f"{prefix}prometheus_port", None), prometheus_host=getattr(args, f"{prefix}prometheus_host", None), + request_id_headers=getattr(args, f"{prefix}request_id_headers", None), ) @staticmethod @@ -481,6 +490,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: if router_args.decode_policy else None ), + request_id_headers=router_args.request_id_headers, ) router.start() diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py index cd10e8e69e3f..7b85f77673a7 100644 --- a/sgl-router/py_src/sglang_router/router.py +++ b/sgl-router/py_src/sglang_router/router.py @@ -54,6 +54,9 @@ class Router: If not specified, uses the main policy. Default: None decode_policy: Specific load balancing policy for decode nodes (PD mode only). If not specified, uses the main policy. Default: None + request_id_headers: List of HTTP headers to check for request IDs. If not specified, + uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id']. + Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None """ def __init__( @@ -85,6 +88,7 @@ def __init__( decode_urls: Optional[List[str]] = None, prefill_policy: Optional[PolicyType] = None, decode_policy: Optional[PolicyType] = None, + request_id_headers: Optional[List[str]] = None, ): if selector is None: selector = {} @@ -121,6 +125,7 @@ def __init__( decode_urls=decode_urls, prefill_policy=prefill_policy, decode_policy=decode_policy, + request_id_headers=request_id_headers, ) def start(self) -> None: diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index 84075de4c991..537e2a11997a 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -29,6 +29,8 @@ pub struct RouterConfig { pub log_dir: Option, /// Log level (None = info) pub log_level: Option, + /// Custom request ID headers to check (defaults to common headers) + pub request_id_headers: Option>, } /// Routing mode configuration @@ -207,6 +209,7 @@ impl Default for RouterConfig { metrics: None, log_dir: None, log_level: None, + request_id_headers: None, } } } @@ -312,6 +315,7 @@ mod tests { metrics: Some(MetricsConfig::default()), log_dir: Some("/var/log".to_string()), log_level: Some("debug".to_string()), + request_id_headers: None, }; let json = serde_json::to_string(&config).unwrap(); @@ -734,6 +738,7 @@ mod tests { }), log_dir: Some("/var/log/sglang".to_string()), log_level: Some("info".to_string()), + request_id_headers: None, }; assert!(config.mode.is_pd_mode()); @@ -780,6 +785,7 @@ mod tests { metrics: Some(MetricsConfig::default()), log_dir: None, log_level: Some("debug".to_string()), + request_id_headers: None, }; assert!(!config.mode.is_pd_mode()); @@ -822,6 +828,7 @@ mod tests { }), log_dir: Some("/opt/logs/sglang".to_string()), log_level: Some("trace".to_string()), + request_id_headers: None, }; assert!(config.has_service_discovery()); diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 1aa6766c1886..fc91b1f5e6ce 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -411,7 +411,7 @@ pub fn start_health_checker( // Check for shutdown signal if shutdown_clone.load(Ordering::Acquire) { - tracing::info!("Health checker shutting down"); + tracing::debug!("Health checker shutting down"); break; } @@ -439,6 +439,9 @@ pub fn start_health_checker( Err(e) => { if was_healthy { tracing::warn!("Worker {} health check failed: {}", worker_url, e); + } else { + // Worker was already unhealthy, log at debug level + tracing::debug!("Worker {} remains unhealthy: {}", worker_url, e); } } } diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 0c03bd497bc7..ede058f8731c 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -4,6 +4,7 @@ pub mod logging; use std::collections::HashMap; pub mod core; pub mod metrics; +pub mod middleware; pub mod openai_api_types; pub mod policies; pub mod routers; @@ -49,6 +50,7 @@ struct Router { prometheus_port: Option, prometheus_host: Option, request_timeout_secs: u64, + request_id_headers: Option>, // PD mode flag pd_disaggregation: bool, // PD-specific fields (only used when pd_disaggregation is true) @@ -138,6 +140,7 @@ impl Router { metrics, log_dir: self.log_dir.clone(), log_level: self.log_level.clone(), + request_id_headers: self.request_id_headers.clone(), }) } } @@ -170,6 +173,7 @@ impl Router { prometheus_port = None, prometheus_host = None, request_timeout_secs = 600, // Add configurable request timeout + request_id_headers = None, // Custom request ID headers pd_disaggregation = false, // New flag for PD mode prefill_urls = None, decode_urls = None, @@ -201,6 +205,7 @@ impl Router { prometheus_port: Option, prometheus_host: Option, request_timeout_secs: u64, + request_id_headers: Option>, pd_disaggregation: bool, prefill_urls: Option)>>, decode_urls: Option>, @@ -232,6 +237,7 @@ impl Router { prometheus_port, prometheus_host, request_timeout_secs, + request_id_headers, pd_disaggregation, prefill_urls, decode_urls, @@ -297,6 +303,7 @@ impl Router { service_discovery_config, prometheus_config, request_timeout_secs: self.request_timeout_secs, + request_id_headers: self.request_id_headers.clone(), }) .await .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string())) diff --git a/sgl-router/src/middleware.rs b/sgl-router/src/middleware.rs new file mode 100644 index 000000000000..76c48f413654 --- /dev/null +++ b/sgl-router/src/middleware.rs @@ -0,0 +1,111 @@ +use actix_web::{ + dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, + Error, HttpMessage, HttpRequest, +}; +use futures_util::future::LocalBoxFuture; +use std::future::{ready, Ready}; + +/// Generate OpenAI-compatible request ID based on endpoint +fn generate_request_id(path: &str) -> String { + let prefix = if path.contains("/chat/completions") { + "chatcmpl-" + } else if path.contains("/completions") { + "cmpl-" + } else if path.contains("/generate") { + "gnt-" + } else { + "req-" + }; + + // Generate a random string similar to OpenAI's format + let random_part: String = (0..24) + .map(|_| { + let chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + chars + .chars() + .nth(rand::random::() % chars.len()) + .unwrap() + }) + .collect(); + + format!("{}{}", prefix, random_part) +} + +/// Extract request ID from request extensions or generate a new one +pub fn get_request_id(req: &HttpRequest) -> String { + req.extensions() + .get::() + .cloned() + .unwrap_or_else(|| generate_request_id(req.path())) +} + +/// Middleware for injecting request ID into request extensions +pub struct RequestIdMiddleware { + headers: Vec, +} + +impl RequestIdMiddleware { + pub fn new(headers: Vec) -> Self { + Self { headers } + } +} + +impl Transform for RequestIdMiddleware +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Transform = RequestIdMiddlewareService; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ready(Ok(RequestIdMiddlewareService { + service, + headers: self.headers.clone(), + })) + } +} + +pub struct RequestIdMiddlewareService { + service: S, + headers: Vec, +} + +impl Service for RequestIdMiddlewareService +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Response = ServiceResponse; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + forward_ready!(service); + + fn call(&self, req: ServiceRequest) -> Self::Future { + // Extract request ID from headers or generate new one + let mut request_id = None; + + for header_name in &self.headers { + if let Some(header_value) = req.headers().get(header_name) { + if let Ok(value) = header_value.to_str() { + request_id = Some(value.to_string()); + break; + } + } + } + + let request_id = request_id.unwrap_or_else(|| generate_request_id(req.path())); + + // Insert request ID into request extensions + req.extensions_mut().insert(request_id); + + let fut = self.service.call(req); + Box::pin(async move { fut.await }) + } +} diff --git a/sgl-router/src/policies/cache_aware.rs b/sgl-router/src/policies/cache_aware.rs index bfbe4b93a003..8d83505f6cae 100644 --- a/sgl-router/src/policies/cache_aware.rs +++ b/sgl-router/src/policies/cache_aware.rs @@ -66,7 +66,7 @@ use crate::tree::Tree; use std::sync::{Arc, Mutex}; use std::thread; use std::time::Duration; -use tracing::{debug, info}; +use tracing::debug; /// Cache-aware routing policy /// @@ -164,10 +164,8 @@ impl LoadBalancingPolicy for CacheAwarePolicy { .map(|w| (w.url().to_string(), w.load())) .collect(); - info!( - "Load balancing triggered due to workload imbalance:\n\ - Max load: {}, Min load: {}\n\ - Current worker loads: {:?}", + debug!( + "Load balancing triggered | max: {} | min: {} | workers: {:?}", max_load, min_load, worker_loads ); diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index 507ac1f4250f..4bc224fcf1bb 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -5,6 +5,7 @@ use super::pd_types::{api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRou use super::request_adapter::ToPdRequest; use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard}; use crate::metrics::RouterMetrics; +use crate::middleware::get_request_id; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::policies::LoadBalancingPolicy; use crate::tree::Tree; @@ -16,7 +17,6 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex, RwLock}; use std::time::{Duration, Instant}; use tracing::{debug, error, info, warn}; -use uuid::Uuid; #[derive(Debug)] pub struct PDRouter { @@ -307,8 +307,8 @@ impl PDRouter { mut typed_req: GenerateReqInput, route: &str, ) -> HttpResponse { + let request_id = get_request_id(req); let start = Instant::now(); - let _request_id = Uuid::new_v4(); // Get stream flag and return_logprob flag before moving the request let is_stream = typed_req.stream; @@ -328,7 +328,10 @@ impl PDRouter { let (prefill, decode) = match self.select_pd_pair(client, request_text).await { Ok(pair) => pair, Err(e) => { - error!("Failed to select PD pair: {}", e); + error!( + request_id = %request_id, + "Failed to select PD pair error={}", e + ); RouterMetrics::record_pd_error("server_selection"); return HttpResponse::ServiceUnavailable() .body(format!("No available servers: {}", e)); @@ -337,15 +340,17 @@ impl PDRouter { // Log routing decision info!( - "PD routing: {} -> prefill={}, decode={}", - route, - prefill.url(), - decode.url() + request_id = %request_id, + "PD routing decision route={} prefill_url={} decode_url={}", + route, prefill.url(), decode.url() ); // Add bootstrap info using the trait method if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { - error!("Failed to add bootstrap info: {}", e); + error!( + request_id = %request_id, + "Failed to add bootstrap info error={}", e + ); RouterMetrics::record_pd_error("bootstrap_injection"); return HttpResponse::InternalServerError() .body(format!("Bootstrap injection failed: {}", e)); @@ -355,7 +360,10 @@ impl PDRouter { let json_with_bootstrap = match serde_json::to_value(&typed_req) { Ok(json) => json, Err(e) => { - error!("Failed to serialize request: {}", e); + error!( + request_id = %request_id, + "Failed to serialize request error={}", e + ); return HttpResponse::InternalServerError().body("Failed to serialize request"); } }; @@ -383,6 +391,7 @@ impl PDRouter { mut typed_req: ChatReqInput, route: &str, ) -> HttpResponse { + let request_id = get_request_id(req); let start = Instant::now(); // Get stream flag and return_logprob flag before moving the request @@ -406,7 +415,10 @@ impl PDRouter { let (prefill, decode) = match self.select_pd_pair(client, request_text).await { Ok(pair) => pair, Err(e) => { - error!("Failed to select PD pair: {}", e); + error!( + request_id = %request_id, + "Failed to select PD pair error={}", e + ); RouterMetrics::record_pd_error("server_selection"); return HttpResponse::ServiceUnavailable() .body(format!("No available servers: {}", e)); @@ -415,15 +427,17 @@ impl PDRouter { // Log routing decision info!( - "PD routing: {} -> prefill={}, decode={}", - route, - prefill.url(), - decode.url() + request_id = %request_id, + "PD routing decision route={} prefill_url={} decode_url={}", + route, prefill.url(), decode.url() ); // Add bootstrap info using the trait method if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { - error!("Failed to add bootstrap info: {}", e); + error!( + request_id = %request_id, + "Failed to add bootstrap info error={}", e + ); RouterMetrics::record_pd_error("bootstrap_injection"); return HttpResponse::InternalServerError() .body(format!("Bootstrap injection failed: {}", e)); @@ -433,7 +447,10 @@ impl PDRouter { let json_with_bootstrap = match serde_json::to_value(&typed_req) { Ok(json) => json, Err(e) => { - error!("Failed to serialize request: {}", e); + error!( + request_id = %request_id, + "Failed to serialize request error={}", e + ); return HttpResponse::InternalServerError().body("Failed to serialize request"); } }; @@ -461,6 +478,7 @@ impl PDRouter { mut typed_req: CompletionRequest, route: &str, ) -> HttpResponse { + let request_id = get_request_id(req); let start = Instant::now(); // Get stream flag and return_logprob flag before moving the request @@ -477,7 +495,10 @@ impl PDRouter { let (prefill, decode) = match self.select_pd_pair(client, request_text).await { Ok(pair) => pair, Err(e) => { - error!("Failed to select PD pair: {}", e); + error!( + request_id = %request_id, + "Failed to select PD pair error={}", e + ); RouterMetrics::record_pd_error("server_selection"); return HttpResponse::ServiceUnavailable() .body(format!("No available servers: {}", e)); @@ -486,15 +507,17 @@ impl PDRouter { // Log routing decision info!( - "PD routing: {} -> prefill={}, decode={}", - route, - prefill.url(), - decode.url() + request_id = %request_id, + "PD routing decision route={} prefill_url={} decode_url={}", + route, prefill.url(), decode.url() ); // Add bootstrap info using the trait method if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { - error!("Failed to add bootstrap info: {}", e); + error!( + request_id = %request_id, + "Failed to add bootstrap info error={}", e + ); RouterMetrics::record_pd_error("bootstrap_injection"); return HttpResponse::InternalServerError() .body(format!("Bootstrap injection failed: {}", e)); @@ -504,7 +527,10 @@ impl PDRouter { let json_with_bootstrap = match serde_json::to_value(&typed_req) { Ok(json) => json, Err(e) => { - error!("Failed to serialize request: {}", e); + error!( + request_id = %request_id, + "Failed to serialize request error={}", e + ); return HttpResponse::InternalServerError().body("Failed to serialize request"); } }; @@ -538,6 +564,7 @@ impl PDRouter { return_logprob: bool, start_time: Instant, ) -> HttpResponse { + let request_id = get_request_id(req); // Update load tracking for both workers let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]); @@ -578,9 +605,9 @@ impl PDRouter { if !status.is_success() { RouterMetrics::record_pd_decode_error(decode.url()); error!( - "Decode server {} returned error status: {}", - decode.url(), - status + request_id = %request_id, + "Decode server returned error status decode_url={} status={}", + decode.url(), status ); // Return the error response from decode server @@ -598,9 +625,9 @@ impl PDRouter { // Log prefill errors for debugging if let Err(e) = &prefill_result { error!( - "Prefill server {} failed (non-critical): {}", - prefill.url(), - e + request_id = %request_id, + "Prefill server failed (non-critical) prefill_url={} error={}", + prefill.url(), e ); RouterMetrics::record_pd_prefill_error(prefill.url()); } @@ -684,7 +711,12 @@ impl PDRouter { } } Err(e) => { - error!("Decode request failed: {}", e); + error!( + request_id = %request_id, + decode_url = %decode.url(), + error = %e, + "Decode request failed" + ); RouterMetrics::record_pd_decode_error(decode.url()); HttpResponse::BadGateway().body(format!("Decode server error: {}", e)) } diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs index 84bb28fb58e8..b065afafed9c 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/router.rs @@ -1,5 +1,6 @@ use crate::core::{HealthChecker, Worker, WorkerFactory}; use crate::metrics::RouterMetrics; +use crate::middleware::get_request_id; use crate::policies::LoadBalancingPolicy; use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; @@ -134,32 +135,26 @@ impl Router { match sync_client.get(&format!("{}/health", url)).send() { Ok(res) => { if !res.status().is_success() { - let msg = format!( - "Worker heatlh check is pending with status {}", - res.status() - ); - info!("{}", msg); all_healthy = false; - unhealthy_workers.push((url, msg)); + unhealthy_workers.push((url, format!("status: {}", res.status()))); } } Err(_) => { - let msg = format!("Worker is not ready yet"); - info!("{}", msg); all_healthy = false; - unhealthy_workers.push((url, msg)); + unhealthy_workers.push((url, "not ready".to_string())); } } } if all_healthy { - info!("All workers are healthy"); + info!("All {} workers are healthy", worker_urls.len()); return Ok(()); } else { - info!("Initializing workers:"); - for (url, reason) in &unhealthy_workers { - info!(" {} - {}", url, reason); - } + debug!( + "Waiting for {} workers to become healthy ({} unhealthy)", + worker_urls.len(), + unhealthy_workers.len() + ); thread::sleep(Duration::from_secs(interval_secs)); } } @@ -181,6 +176,7 @@ impl Router { route: &str, req: &HttpRequest, ) -> HttpResponse { + let request_id = get_request_id(req); let start = Instant::now(); let mut request_builder = client.get(format!("{}{}", worker_url, route)); @@ -202,14 +198,32 @@ impl Router { match res.bytes().await { Ok(body) => HttpResponse::build(status).body(body.to_vec()), - Err(e) => HttpResponse::InternalServerError() - .body(format!("Failed to read response body: {}", e)), + Err(e) => { + error!( + request_id = %request_id, + worker_url = %worker_url, + route = %route, + error = %e, + "Failed to read response body" + ); + HttpResponse::InternalServerError() + .body(format!("Failed to read response body: {}", e)) + } } } - Err(e) => HttpResponse::InternalServerError().body(format!( - "Failed to send request to worker {}: {}", - worker_url, e - )), + Err(e) => { + error!( + request_id = %request_id, + worker_url = %worker_url, + route = %route, + error = %e, + "Failed to send request to worker" + ); + HttpResponse::InternalServerError().body(format!( + "Failed to send request to worker {}: {}", + worker_url, e + )) + } }; // Record request metrics @@ -231,6 +245,7 @@ impl Router { route: &str, req: &HttpRequest, ) -> HttpResponse { + let request_id = get_request_id(req); const MAX_REQUEST_RETRIES: u32 = 3; const MAX_TOTAL_RETRIES: u32 = 6; let mut total_retries = 0; @@ -260,17 +275,23 @@ impl Router { } warn!( - "Request to {} failed (attempt {}/{})", - worker_url, - request_retries + 1, - MAX_REQUEST_RETRIES + request_id = %request_id, + route = %route, + worker_url = %worker_url, + attempt = request_retries + 1, + max_attempts = MAX_REQUEST_RETRIES, + "Request failed" ); request_retries += 1; total_retries += 1; if request_retries == MAX_REQUEST_RETRIES { - warn!("Removing failed worker: {}", worker_url); + warn!( + request_id = %request_id, + worker_url = %worker_url, + "Removing failed worker" + ); self.remove_worker(&worker_url); break; } @@ -293,6 +314,7 @@ impl Router { typed_req: &T, route: &str, ) -> HttpResponse { + let request_id = get_request_id(req); // Handle retries like the original implementation let start = Instant::now(); const MAX_REQUEST_RETRIES: u32 = 3; @@ -357,17 +379,19 @@ impl Router { } warn!( - "Generate request to {} failed (attempt {}/{})", - worker_url, - request_retries + 1, - MAX_REQUEST_RETRIES + request_id = %request_id, + "Generate request failed route={} worker_url={} attempt={} max_attempts={}", + route, worker_url, request_retries + 1, MAX_REQUEST_RETRIES ); request_retries += 1; total_retries += 1; if request_retries == MAX_REQUEST_RETRIES { - warn!("Removing failed worker: {}", worker_url); + warn!( + request_id = %request_id, + "Removing failed worker after typed request failures worker_url={}", worker_url + ); self.remove_worker(&worker_url); break; } @@ -402,13 +426,9 @@ impl Router { is_stream: bool, load_incremented: bool, // Whether load was incremented for this request ) -> HttpResponse { + let request_id = get_request_id(req); let start = Instant::now(); - // Debug: Log what we're sending - if let Ok(json_str) = serde_json::to_string_pretty(typed_req) { - debug!("Sending request to {}: {}", route, json_str); - } - let mut request_builder = client .post(format!("{}{}", worker_url, route)) .json(typed_req); // Use json() directly with typed request @@ -424,7 +444,11 @@ impl Router { let res = match request_builder.send().await { Ok(res) => res, Err(e) => { - error!("Failed to send request to {}: {}", worker_url, e); + error!( + request_id = %request_id, + "Failed to send typed request worker_url={} route={} error={}", + worker_url, route, e + ); // Decrement load on error if it was incremented if load_incremented { @@ -497,7 +521,6 @@ impl Router { &worker_url, worker.load(), ); - debug!("Streaming is done!!") } } } @@ -536,7 +559,6 @@ impl Router { match client.get(&format!("{}/health", worker_url)).send().await { Ok(res) => { if res.status().is_success() { - info!("Worker {} health check passed", worker_url); let mut workers_guard = self.workers.write().unwrap(); if workers_guard.iter().any(|w| w.url() == worker_url) { return Err(format!("Worker {} already exists", worker_url)); @@ -560,8 +582,8 @@ impl Router { return Ok(format!("Successfully added worker: {}", worker_url)); } else { - info!( - "Worker {} health check is pending with status: {}.", + debug!( + "Worker {} health check pending - status: {}", worker_url, res.status() ); @@ -576,10 +598,7 @@ impl Router { } } Err(e) => { - info!( - "Worker {} health check is pending with error: {}", - worker_url, e - ); + debug!("Worker {} health check pending - error: {}", worker_url, e); // if the url does not have http or https prefix, warn users if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") { @@ -611,7 +630,6 @@ impl Router { .downcast_ref::() { cache_aware.remove_worker(worker_url); - info!("Removed worker from tree: {}", worker_url); } } @@ -675,7 +693,6 @@ impl Router { for url in &worker_urls { if let Some(load) = Self::get_worker_load_static(&client, url).await { loads.insert(url.clone(), load); - debug!("Worker {} load: {}", url, load); } } diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 83774f172a35..acbc9d9e9e14 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,6 +1,7 @@ use crate::config::RouterConfig; use crate::logging::{self, LoggingConfig}; use crate::metrics::{self, PrometheusConfig}; +use crate::middleware::{get_request_id, RequestIdMiddleware}; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::routers::{RouterFactory, RouterTrait}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; @@ -46,13 +47,13 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result Error { - error!("JSON payload error: {:?}", err); +fn json_error_handler(err: error::JsonPayloadError, req: &HttpRequest) -> Error { + let request_id = get_request_id(req); match &err { error::JsonPayloadError::OverflowKnownLength { length, limit } => { error!( - "Payload too large: {} bytes exceeds limit of {} bytes", - length, limit + request_id = %request_id, + "Payload too large length={} limit={}", length, limit ); error::ErrorPayloadTooLarge(format!( "Payload too large: {} bytes exceeds limit of {} bytes", @@ -60,10 +61,19 @@ fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error )) } error::JsonPayloadError::Overflow { limit } => { - error!("Payload overflow: exceeds limit of {} bytes", limit); + error!( + request_id = %request_id, + "Payload overflow limit={}", limit + ); error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit)) } - _ => error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)), + _ => { + error!( + request_id = %request_id, + "Invalid JSON payload error={}", err + ); + error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)) + } } } @@ -108,8 +118,20 @@ async fn generate( body: web::Json, state: web::Data, ) -> Result { - let json_body = serde_json::to_value(body.into_inner()) - .map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?; + let request_id = get_request_id(&req); + info!( + request_id = %request_id, + "Received generate request method=\"POST\" path=\"/generate\"" + ); + + let json_body = serde_json::to_value(body.into_inner()).map_err(|e| { + error!( + request_id = %request_id, + "Failed to parse generate request body error={}", e + ); + error::ErrorBadRequest(format!("Invalid JSON: {}", e)) + })?; + Ok(state .router .route_generate(&state.client, &req, json_body) @@ -122,8 +144,20 @@ async fn v1_chat_completions( body: web::Json, state: web::Data, ) -> Result { - let json_body = serde_json::to_value(body.into_inner()) - .map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?; + let request_id = get_request_id(&req); + info!( + request_id = %request_id, + "Received chat completion request method=\"POST\" path=\"/v1/chat/completions\"" + ); + + let json_body = serde_json::to_value(body.into_inner()).map_err(|e| { + error!( + request_id = %request_id, + "Failed to parse chat completion request body error={}", e + ); + error::ErrorBadRequest(format!("Invalid JSON: {}", e)) + })?; + Ok(state .router .route_chat(&state.client, &req, json_body) @@ -136,8 +170,20 @@ async fn v1_completions( body: web::Json, state: web::Data, ) -> Result { - let json_body = serde_json::to_value(body.into_inner()) - .map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?; + let request_id = get_request_id(&req); + info!( + request_id = %request_id, + "Received completion request method=\"POST\" path=\"/v1/completions\"" + ); + + let json_body = serde_json::to_value(body.into_inner()).map_err(|e| { + error!( + request_id = %request_id, + "Failed to parse completion request body error={}", e + ); + error::ErrorBadRequest(format!("Invalid JSON: {}", e)) + })?; + Ok(state .router .route_completion(&state.client, &req, json_body) @@ -146,20 +192,48 @@ async fn v1_completions( #[post("/add_worker")] async fn add_worker( + req: HttpRequest, query: web::Query>, data: web::Data, ) -> impl Responder { + let request_id = get_request_id(&req); + let worker_url = match query.get("url") { Some(url) => url.to_string(), None => { + warn!( + request_id = %request_id, + "Add worker request missing URL parameter" + ); return HttpResponse::BadRequest() - .body("Worker URL required. Provide 'url' query parameter") + .body("Worker URL required. Provide 'url' query parameter"); } }; + info!( + request_id = %request_id, + worker_url = %worker_url, + "Adding worker" + ); + match data.router.add_worker(&worker_url).await { - Ok(message) => HttpResponse::Ok().body(message), - Err(error) => HttpResponse::BadRequest().body(error), + Ok(message) => { + info!( + request_id = %request_id, + worker_url = %worker_url, + "Successfully added worker" + ); + HttpResponse::Ok().body(message) + } + Err(error) => { + error!( + request_id = %request_id, + worker_url = %worker_url, + error = %error, + "Failed to add worker" + ); + HttpResponse::BadRequest().body(error) + } } } @@ -171,13 +245,29 @@ async fn list_workers(data: web::Data) -> impl Responder { #[post("/remove_worker")] async fn remove_worker( + req: HttpRequest, query: web::Query>, data: web::Data, ) -> impl Responder { + let request_id = get_request_id(&req); + let worker_url = match query.get("url") { Some(url) => url.to_string(), - None => return HttpResponse::BadRequest().finish(), + None => { + warn!( + request_id = %request_id, + "Remove worker request missing URL parameter" + ); + return HttpResponse::BadRequest().finish(); + } }; + + info!( + request_id = %request_id, + worker_url = %worker_url, + "Removing worker" + ); + data.router.remove_worker(&worker_url); HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url)) } @@ -202,6 +292,7 @@ pub struct ServerConfig { pub service_discovery_config: Option, pub prometheus_config: Option, pub request_timeout_secs: u64, + pub request_id_headers: Option>, } pub async fn startup(config: ServerConfig) -> std::io::Result<()> { @@ -233,31 +324,18 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { // Initialize prometheus metrics exporter if let Some(prometheus_config) = config.prometheus_config { - info!( - "🚧 Initializing Prometheus metrics on {}:{}", - prometheus_config.host, prometheus_config.port - ); metrics::start_prometheus(prometheus_config); - } else { - info!("🚧 Prometheus metrics disabled"); } - info!("🚧 Initializing router on {}:{}", config.host, config.port); - info!("🚧 Router mode: {:?}", config.router_config.mode); - info!("🚧 Policy: {:?}", config.router_config.policy); info!( - "🚧 Max payload size: {} MB", + "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB", + config.host, + config.port, + config.router_config.mode, + config.router_config.policy, config.max_payload_size / (1024 * 1024) ); - // Log service discovery status - if let Some(service_discovery_config) = &config.service_discovery_config { - info!("🚧 Service discovery enabled"); - info!("🚧 Selector: {:?}", service_discovery_config.selector); - } else { - info!("🚧 Service discovery disabled"); - } - let client = Client::builder() .pool_idle_timeout(Some(Duration::from_secs(50))) .timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout @@ -272,11 +350,9 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { // Start the service discovery if enabled if let Some(service_discovery_config) = config.service_discovery_config { if service_discovery_config.enabled { - info!("🚧 Initializing Kubernetes service discovery"); - // Pass the Arc directly match start_service_discovery(service_discovery_config, router_arc).await { Ok(handle) => { - info!("✅ Service discovery started successfully"); + info!("Service discovery started"); // Spawn a task to handle the service discovery thread spawn(async move { if let Err(e) = handle.await { @@ -292,14 +368,26 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { } } - info!("✅ Serving router on {}:{}", config.host, config.port); info!( - "✅ Serving workers on {:?}", + "Router ready | workers: {:?}", app_state.router.get_worker_urls() ); + // Configure request ID headers + let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| { + vec![ + "x-request-id".to_string(), + "x-correlation-id".to_string(), + "x-trace-id".to_string(), + "request-id".to_string(), + ] + }); + HttpServer::new(move || { + let request_id_middleware = RequestIdMiddleware::new(request_id_headers.clone()); + App::new() + .wrap(request_id_middleware) .app_data(app_state.clone()) .app_data( web::JsonConfig::default() diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index 72d78b490951..fae09896d432 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -209,7 +209,7 @@ pub async fn start_service_discovery( .join(","); info!( - "Starting Kubernetes service discovery in PD mode with prefill_selector: '{}', decode_selector: '{}'", + "Starting K8s service discovery | PD mode | prefill: '{}' | decode: '{}'", prefill_selector, decode_selector ); } else { @@ -221,7 +221,7 @@ pub async fn start_service_discovery( .join(","); info!( - "Starting Kubernetes service discovery with selector: '{}'", + "Starting K8s service discovery | selector: '{}'", label_selector ); } @@ -238,7 +238,7 @@ pub async fn start_service_discovery( Api::all(client) }; - info!("Kubernetes service discovery initialized successfully"); + debug!("K8s service discovery initialized"); // Create Arcs for configuration data let config_arc = Arc::new(config.clone()); @@ -375,7 +375,7 @@ async fn handle_pod_event( if should_add { info!( - "Healthy pod found: {} (type: {:?}). Adding worker: {}", + "Adding pod: {} | type: {:?} | url: {}", pod_info.name, pod_info.pod_type, worker_url ); @@ -409,8 +409,8 @@ async fn handle_pod_event( }; match result { - Ok(msg) => { - info!("Successfully added worker: {}", msg); + Ok(_) => { + debug!("Worker added: {}", worker_url); } Err(e) => { error!("Failed to add worker {} to router: {}", worker_url, e); @@ -446,7 +446,7 @@ async fn handle_pod_deletion( if was_tracked { info!( - "Pod deleted: {} (type: {:?}). Removing worker: {}", + "Removing pod: {} | type: {:?} | url: {}", pod_info.name, pod_info.pod_type, worker_url ); diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index 12e8dd2d2b88..bf86d776b1e0 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -35,6 +35,7 @@ impl TestContext { metrics: None, log_dir: None, log_level: None, + request_id_headers: None, }; Self::new_with_config(config, worker_configs).await @@ -953,6 +954,7 @@ mod error_tests { metrics: None, log_dir: None, log_level: None, + request_id_headers: None, }; let ctx = TestContext::new_with_config( diff --git a/sgl-router/tests/common/mod.rs b/sgl-router/tests/common/mod.rs index 34467cd0885a..62c99a46bbae 100644 --- a/sgl-router/tests/common/mod.rs +++ b/sgl-router/tests/common/mod.rs @@ -20,6 +20,7 @@ pub fn create_test_config(worker_urls: Vec) -> RouterConfig { metrics: None, log_dir: None, log_level: None, + request_id_headers: None, } } @@ -40,6 +41,7 @@ pub fn create_test_config_no_workers() -> RouterConfig { metrics: None, log_dir: None, log_level: None, + request_id_headers: None, } } diff --git a/sgl-router/tests/request_formats_test.rs b/sgl-router/tests/request_formats_test.rs index 40045a0f7b15..d265d10309e2 100644 --- a/sgl-router/tests/request_formats_test.rs +++ b/sgl-router/tests/request_formats_test.rs @@ -46,6 +46,7 @@ impl RequestTestContext { metrics: None, log_dir: None, log_level: None, + request_id_headers: None, }; let client = Client::builder() diff --git a/sgl-router/tests/streaming_tests.rs b/sgl-router/tests/streaming_tests.rs index 47a1326ae575..ada8b7e4554e 100644 --- a/sgl-router/tests/streaming_tests.rs +++ b/sgl-router/tests/streaming_tests.rs @@ -50,6 +50,7 @@ impl StreamingTestContext { metrics: None, log_dir: None, log_level: None, + request_id_headers: None, }; let client = Client::builder() diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 24571eb243f6..a6cb8d02d572 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -173,6 +173,7 @@ mod test_pd_routing { metrics: None, log_dir: None, log_level: None, + request_id_headers: None, }; // Router creation will fail due to health checks, but config should be valid