Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions sgl-router/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ python -m sglang_router.launch_router \
--prometheus-port 9000
```

### Request ID Tracking

Track requests across distributed systems with configurable headers:

```bash
# Use custom request ID headers
python -m sglang_router.launch_router \
--worker-urls http://localhost:8080 \
--request-id-headers x-trace-id x-request-id
```

Default headers: `x-request-id`, `x-correlation-id`, `x-trace-id`, `request-id`

## Advanced Features

### Kubernetes Service Discovery
Expand Down
10 changes: 10 additions & 0 deletions sgl-router/py_src/sglang_router/launch_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class RouterArgs:
# Prometheus configuration
prometheus_port: Optional[int] = None
prometheus_host: Optional[str] = None
# Request ID headers configuration
request_id_headers: Optional[List[str]] = None

@staticmethod
def add_cli_args(
Expand Down Expand Up @@ -255,6 +257,12 @@ def add_cli_args(
default="127.0.0.1",
help="Host address to bind the Prometheus metrics server",
)
parser.add_argument(
f"--{prefix}request-id-headers",
type=str,
nargs="*",
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
)

@classmethod
def from_cli_args(
Expand Down Expand Up @@ -313,6 +321,7 @@ def from_cli_args(
bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
request_id_headers=getattr(args, f"{prefix}request_id_headers", None),
)

@staticmethod
Expand Down Expand Up @@ -481,6 +490,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
if router_args.decode_policy
else None
),
request_id_headers=router_args.request_id_headers,
)

router.start()
Expand Down
5 changes: 5 additions & 0 deletions sgl-router/py_src/sglang_router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class Router:
If not specified, uses the main policy. Default: None
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
If not specified, uses the main policy. Default: None
request_id_headers: List of HTTP headers to check for request IDs. If not specified,
uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None
"""

def __init__(
Expand Down Expand Up @@ -85,6 +88,7 @@ def __init__(
decode_urls: Optional[List[str]] = None,
prefill_policy: Optional[PolicyType] = None,
decode_policy: Optional[PolicyType] = None,
request_id_headers: Optional[List[str]] = None,
):
if selector is None:
selector = {}
Expand Down Expand Up @@ -121,6 +125,7 @@ def __init__(
decode_urls=decode_urls,
prefill_policy=prefill_policy,
decode_policy=decode_policy,
request_id_headers=request_id_headers,
)

def start(self) -> None:
Expand Down
7 changes: 7 additions & 0 deletions sgl-router/src/config/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ pub struct RouterConfig {
pub log_dir: Option<String>,
/// Log level (None = info)
pub log_level: Option<String>,
/// Custom request ID headers to check (defaults to common headers)
pub request_id_headers: Option<Vec<String>>,
}

/// Routing mode configuration
Expand Down Expand Up @@ -207,6 +209,7 @@ impl Default for RouterConfig {
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
}
}
}
Expand Down Expand Up @@ -312,6 +315,7 @@ mod tests {
metrics: Some(MetricsConfig::default()),
log_dir: Some("/var/log".to_string()),
log_level: Some("debug".to_string()),
request_id_headers: None,
};

let json = serde_json::to_string(&config).unwrap();
Expand Down Expand Up @@ -734,6 +738,7 @@ mod tests {
}),
log_dir: Some("/var/log/sglang".to_string()),
log_level: Some("info".to_string()),
request_id_headers: None,
};

assert!(config.mode.is_pd_mode());
Expand Down Expand Up @@ -780,6 +785,7 @@ mod tests {
metrics: Some(MetricsConfig::default()),
log_dir: None,
log_level: Some("debug".to_string()),
request_id_headers: None,
};

assert!(!config.mode.is_pd_mode());
Expand Down Expand Up @@ -822,6 +828,7 @@ mod tests {
}),
log_dir: Some("/opt/logs/sglang".to_string()),
log_level: Some("trace".to_string()),
request_id_headers: None,
};

assert!(config.has_service_discovery());
Expand Down
5 changes: 4 additions & 1 deletion sgl-router/src/core/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ pub fn start_health_checker(

// Check for shutdown signal
if shutdown_clone.load(Ordering::Acquire) {
tracing::info!("Health checker shutting down");
tracing::debug!("Health checker shutting down");
break;
}

Expand Down Expand Up @@ -439,6 +439,9 @@ pub fn start_health_checker(
Err(e) => {
if was_healthy {
tracing::warn!("Worker {} health check failed: {}", worker_url, e);
} else {
// Worker was already unhealthy, log at debug level
tracing::debug!("Worker {} remains unhealthy: {}", worker_url, e);
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions sgl-router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod logging;
use std::collections::HashMap;
pub mod core;
pub mod metrics;
pub mod middleware;
pub mod openai_api_types;
pub mod policies;
pub mod routers;
Expand Down Expand Up @@ -49,6 +50,7 @@ struct Router {
prometheus_port: Option<u16>,
prometheus_host: Option<String>,
request_timeout_secs: u64,
request_id_headers: Option<Vec<String>>,
// PD mode flag
pd_disaggregation: bool,
// PD-specific fields (only used when pd_disaggregation is true)
Expand Down Expand Up @@ -138,6 +140,7 @@ impl Router {
metrics,
log_dir: self.log_dir.clone(),
log_level: self.log_level.clone(),
request_id_headers: self.request_id_headers.clone(),
})
}
}
Expand Down Expand Up @@ -170,6 +173,7 @@ impl Router {
prometheus_port = None,
prometheus_host = None,
request_timeout_secs = 600, // Add configurable request timeout
request_id_headers = None, // Custom request ID headers
pd_disaggregation = false, // New flag for PD mode
prefill_urls = None,
decode_urls = None,
Expand Down Expand Up @@ -201,6 +205,7 @@ impl Router {
prometheus_port: Option<u16>,
prometheus_host: Option<String>,
request_timeout_secs: u64,
request_id_headers: Option<Vec<String>>,
pd_disaggregation: bool,
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
Expand Down Expand Up @@ -232,6 +237,7 @@ impl Router {
prometheus_port,
prometheus_host,
request_timeout_secs,
request_id_headers,
pd_disaggregation,
prefill_urls,
decode_urls,
Expand Down Expand Up @@ -297,6 +303,7 @@ impl Router {
service_discovery_config,
prometheus_config,
request_timeout_secs: self.request_timeout_secs,
request_id_headers: self.request_id_headers.clone(),
})
.await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
Expand Down
111 changes: 111 additions & 0 deletions sgl-router/src/middleware.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use actix_web::{
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
Error, HttpMessage, HttpRequest,
};
use futures_util::future::LocalBoxFuture;
use std::future::{ready, Ready};

/// Generate OpenAI-compatible request ID based on endpoint
fn generate_request_id(path: &str) -> String {
let prefix = if path.contains("/chat/completions") {
"chatcmpl-"
} else if path.contains("/completions") {
"cmpl-"
} else if path.contains("/generate") {
"gnt-"
} else {
"req-"
};

// Generate a random string similar to OpenAI's format
let random_part: String = (0..24)
.map(|_| {
let chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
chars
.chars()
.nth(rand::random::<usize>() % chars.len())
.unwrap()
})
.collect();

format!("{}{}", prefix, random_part)
}

/// Extract request ID from request extensions or generate a new one
pub fn get_request_id(req: &HttpRequest) -> String {
req.extensions()
.get::<String>()
.cloned()
.unwrap_or_else(|| generate_request_id(req.path()))
}

/// Middleware for injecting request ID into request extensions
pub struct RequestIdMiddleware {
headers: Vec<String>,
}

impl RequestIdMiddleware {
pub fn new(headers: Vec<String>) -> Self {
Self { headers }
}
}

impl<S, B> Transform<S, ServiceRequest> for RequestIdMiddleware
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type InitError = ();
type Transform = RequestIdMiddlewareService<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;

fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(RequestIdMiddlewareService {
service,
headers: self.headers.clone(),
}))
}
}

pub struct RequestIdMiddlewareService<S> {
service: S,
headers: Vec<String>,
}

impl<S, B> Service<ServiceRequest> for RequestIdMiddlewareService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

forward_ready!(service);

fn call(&self, req: ServiceRequest) -> Self::Future {
// Extract request ID from headers or generate new one
let mut request_id = None;

for header_name in &self.headers {
if let Some(header_value) = req.headers().get(header_name) {
if let Ok(value) = header_value.to_str() {
request_id = Some(value.to_string());
break;
}
}
}

let request_id = request_id.unwrap_or_else(|| generate_request_id(req.path()));

// Insert request ID into request extensions
req.extensions_mut().insert(request_id);

let fut = self.service.call(req);
Box::pin(async move { fut.await })
}
}
8 changes: 3 additions & 5 deletions sgl-router/src/policies/cache_aware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ use crate::tree::Tree;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
use tracing::{debug, info};
use tracing::debug;

/// Cache-aware routing policy
///
Expand Down Expand Up @@ -164,10 +164,8 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
.map(|w| (w.url().to_string(), w.load()))
.collect();

info!(
"Load balancing triggered due to workload imbalance:\n\
Max load: {}, Min load: {}\n\
Current worker loads: {:?}",
debug!(
"Load balancing triggered | max: {} | min: {} | workers: {:?}",
max_load, min_load, worker_loads
);

Expand Down
Loading
Loading