diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index afc558f1cc8d..17d8c87fd2a8 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -35,6 +35,7 @@ metrics = "0.24.2" metrics-exporter-prometheus = "0.17.0" # Added for request tracing uuid = { version = "1.10", features = ["v4", "serde"] } +thiserror = "2.0.12" [profile.release] lto = "thin" codegen-units = 1 diff --git a/sgl-router/README.md b/sgl-router/README.md index 88637bd0542f..47349a380327 100644 --- a/sgl-router/README.md +++ b/sgl-router/README.md @@ -95,38 +95,217 @@ python -m sglang_router.launch_router \ ### Kubernetes Service Discovery -SGL Router supports automatic service discovery for worker nodes in Kubernetes environments. When enabled, the router will automatically: +SGL Router supports automatic service discovery for worker nodes in Kubernetes environments. This feature works with both regular (single-server) routing and PD (Prefill-Decode) routing modes. When enabled, the router will automatically: - Discover and add worker pods with matching labels - Remove unhealthy or deleted worker pods - Dynamically adjust the worker pool based on pod health and availability +- For PD mode: distinguish between prefill and decode servers based on labels -#### Command Line Usage +#### Regular Mode Service Discovery + +For traditional single-server routing: ```bash python -m sglang_router.launch_router \ --service-discovery \ --selector app=sglang-worker role=inference \ - --service-discovery-port 8000 \ --service-discovery-namespace default ``` +#### PD Mode Service Discovery + +For PD (Prefill-Decode) disaggregated routing, service discovery can automatically discover and classify pods as either prefill or decode servers based on their labels: + +```bash +python -m sglang_router.launch_router \ + --pd-disaggregation \ + --policy cache_aware \ + --service-discovery \ + --prefill-selector app=sglang component=prefill \ + --decode-selector app=sglang component=decode \ + --service-discovery-namespace sglang-system +``` + +You can also specify initial prefill and decode servers and let service discovery add more: + +```bash +python -m sglang_router.launch_router \ + --pd-disaggregation \ + --policy cache_aware \ + --prefill http://prefill-1:8000 8001 \ + --decode http://decode-1:8000 \ + --service-discovery \ + --prefill-selector app=sglang component=prefill \ + --decode-selector app=sglang component=decode \ + --service-discovery-namespace sglang-system +``` + +#### Kubernetes Pod Configuration for PD Mode + +When using PD service discovery, your Kubernetes pods need specific labels to be classified as prefill or decode servers: + +**Prefill Server Pod:** +```yaml +apiVersion: v1 +kind: Pod +metadata: + name: sglang-prefill-1 + labels: + app: sglang + component: prefill + annotations: + sglang.ai/bootstrap-port: "9001" # Optional: Bootstrap port for Mooncake prefill coordination +spec: + containers: + - name: sglang + image: lmsys/sglang:latest + ports: + - containerPort: 8000 # Main API port + - containerPort: 9001 # Optional: Bootstrap coordination port + # ... rest of configuration +``` + +**Decode Server Pod:** +```yaml +apiVersion: v1 +kind: Pod +metadata: + name: sglang-decode-1 + labels: + app: sglang + component: decode +spec: + containers: + - name: sglang + image: lmsys/sglang:latest + ports: + - containerPort: 8000 # Main API port + # ... rest of configuration +``` + +**Key Requirements:** +- Prefill pods must have labels matching your `--prefill-selector` +- Decode pods must have labels matching your `--decode-selector` +- Prefill pods can optionally include bootstrap port in annotations using `sglang.ai/bootstrap-port` (defaults to None if not specified) + #### Service Discovery Arguments +**General Arguments:** - `--service-discovery`: Enable Kubernetes service discovery feature -- `--selector`: One or more label key-value pairs for pod selection (format: key1=value1 key2=value2) -- `--service-discovery-port`: Port to use when generating worker URLs (default: 80) +- `--service-discovery-port`: Port to use when generating worker URLs (default: 8000) - `--service-discovery-namespace`: Optional. Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions) +- `--selector`: One or more label key-value pairs for pod selection in regular mode (format: key1=value1 key2=value2) + +**PD Mode Arguments:** +- `--pd-disaggregation`: Enable PD (Prefill-Decode) disaggregated mode +- `--prefill`: Specify initial prefill server URL and bootstrap port (format: URL BOOTSTRAP_PORT, can be used multiple times) +- `--decode`: Specify initial decode server URL (can be used multiple times) +- `--prefill-selector`: Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2) +- `--decode-selector`: Label selector for decode server pods in PD mode (format: key1=value1 key2=value2) +- `--policy`: Routing policy (cache_aware, random, power_of_two - note: power_of_two only works in PD mode) + +**Notes:** +- Bootstrap port annotation is automatically set to `sglang.ai/bootstrap-port` for Mooncake deployments +- Advanced cache tuning parameters use sensible defaults and are not exposed via CLI #### RBAC Requirements When using service discovery, you must configure proper Kubernetes RBAC permissions: -- **If using namespace-scoped discovery** (with `--service-discovery-namespace`): - Set up a ServiceAccount, Role, and RoleBinding +**Namespace-scoped (recommended):** +```yaml +apiVersion: v1 +kind: ServiceAccount +metadata: + name: sglang-router + namespace: sglang-system +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + namespace: sglang-system + name: sglang-router +rules: +- apiGroups: [""] + resources: ["pods"] + verbs: ["get", "list", "watch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: sglang-router + namespace: sglang-system +subjects: +- kind: ServiceAccount + name: sglang-router + namespace: sglang-system +roleRef: + kind: Role + name: sglang-router + apiGroup: rbac.authorization.k8s.io +``` + +**Cluster-wide (if watching all namespaces):** +```yaml +apiVersion: v1 +kind: ServiceAccount +metadata: + name: sglang-router + namespace: sglang-system +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: sglang-router +rules: +- apiGroups: [""] + resources: ["pods"] + verbs: ["get", "list", "watch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: sglang-router +subjects: +- kind: ServiceAccount + name: sglang-router + namespace: sglang-system +roleRef: + kind: ClusterRole + name: sglang-router + apiGroup: rbac.authorization.k8s.io +``` + +#### Complete Example: PD Mode with Service Discovery + +Here's a complete example of running SGLang Router with PD mode and service discovery: + +```bash +# Start the router with PD mode and automatic prefill/decode discovery +python -m sglang_router.launch_router \ + --pd-disaggregation \ + --policy cache_aware \ + --service-discovery \ + --prefill-selector app=sglang component=prefill environment=production \ + --decode-selector app=sglang component=decode environment=production \ + --service-discovery-namespace production \ + --host 0.0.0.0 \ + --port 8080 \ + --prometheus-host 0.0.0.0 \ + --prometheus-port 9090 +``` + +This setup will: +1. Enable PD (Prefill-Decode) disaggregated routing mode with automatic pod classification +2. Watch for pods in the `production` namespace +3. Automatically add prefill servers with labels `app=sglang`, `component=prefill`, `environment=production` +4. Automatically add decode servers with labels `app=sglang`, `component=decode`, `environment=production` +5. Extract bootstrap ports from the `sglang.ai/bootstrap-port` annotation on prefill pods +6. Use cache-aware load balancing for optimal performance +7. Expose the router API on port 8080 and metrics on port 9090 -- **If watching all namespaces** (without specifying namespace): - Set up a ServiceAccount, ClusterRole, and ClusterRoleBinding with permissions to list/watch pods at the cluster level +**Note:** In PD mode with service discovery, pods MUST match either the prefill or decode selector to be added. Pods that don't match either selector are ignored. ### Troubleshooting diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py index 74000ccbe0c6..82a56eec9009 100644 --- a/sgl-router/py_src/sglang_router/launch_router.py +++ b/sgl-router/py_src/sglang_router/launch_router.py @@ -32,7 +32,7 @@ class RouterArgs: port: int = 30000 # PD-specific configuration - pd_disaggregated: bool = False # Enable PD disaggregated mode + pd_disaggregation: bool = False # Enable PD disaggregated mode prefill_urls: List[tuple] = dataclasses.field( default_factory=list ) # List of (url, bootstrap_port) @@ -55,6 +55,10 @@ class RouterArgs: selector: Dict[str, str] = dataclasses.field(default_factory=dict) service_discovery_port: int = 80 service_discovery_namespace: Optional[str] = None + # PD service discovery configuration + prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict) + decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict) + bootstrap_port_annotation: str = "sglang.ai/bootstrap-port" # Prometheus configuration prometheus_port: Optional[int] = None prometheus_host: Optional[str] = None @@ -108,7 +112,7 @@ def add_cli_args( # PD-specific arguments parser.add_argument( - f"--{prefix}pd-disaggregated", + f"--{prefix}pd-disaggregation", action="store_true", help="Enable PD (Prefill-Decode) disaggregated mode", ) @@ -207,6 +211,18 @@ def add_cli_args( type=str, help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)", ) + parser.add_argument( + f"--{prefix}prefill-selector", + type=str, + nargs="+", + help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)", + ) + parser.add_argument( + f"--{prefix}decode-selector", + type=str, + nargs="+", + help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)", + ) # Prometheus configuration parser.add_argument( f"--{prefix}prometheus-port", @@ -243,7 +259,7 @@ def from_cli_args( worker_urls=worker_urls, host=args.host, port=args.port, - pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False), + pd_disaggregation=getattr(args, f"{prefix}pd_disaggregation", False), prefill_urls=prefill_urls, decode_urls=decode_urls, policy=getattr(args, f"{prefix}policy"), @@ -267,6 +283,13 @@ def from_cli_args( service_discovery_namespace=getattr( args, f"{prefix}service_discovery_namespace", None ), + prefill_selector=cls._parse_selector( + getattr(args, f"{prefix}prefill_selector", None) + ), + decode_selector=cls._parse_selector( + getattr(args, f"{prefix}decode_selector", None) + ), + bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation prometheus_port=getattr(args, f"{prefix}prometheus_port", None), prometheus_host=getattr(args, f"{prefix}prometheus_host", None), ) @@ -355,17 +378,20 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: router_args = args # Validate configuration based on mode - if router_args.pd_disaggregated: - # Validate PD configuration - if not router_args.prefill_urls: - raise ValueError("PD disaggregated mode requires --prefill") - if not router_args.decode_urls: - raise ValueError("PD disaggregated mode requires --decode") + if router_args.pd_disaggregation: + # Validate PD configuration - skip URL requirements if using service discovery + if not router_args.service_discovery: + if not router_args.prefill_urls: + raise ValueError("PD disaggregation mode requires --prefill") + if not router_args.decode_urls: + raise ValueError("PD disaggregation mode requires --decode") # Create router with unified constructor router = Router( worker_urls=( - router_args.worker_urls if not router_args.pd_disaggregated else [] + [] + if router_args.service_discovery or router_args.pd_disaggregation + else router_args.worker_urls ), host=router_args.host, port=router_args.port, @@ -384,14 +410,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: selector=router_args.selector, service_discovery_port=router_args.service_discovery_port, service_discovery_namespace=router_args.service_discovery_namespace, + prefill_selector=router_args.prefill_selector, + decode_selector=router_args.decode_selector, prometheus_port=router_args.prometheus_port, prometheus_host=router_args.prometheus_host, - pd_disaggregated=router_args.pd_disaggregated, + pd_disaggregation=router_args.pd_disaggregation, prefill_urls=( - router_args.prefill_urls if router_args.pd_disaggregated else None + router_args.prefill_urls if router_args.pd_disaggregation else None ), decode_urls=( - router_args.decode_urls if router_args.pd_disaggregated else None + router_args.decode_urls if router_args.pd_disaggregation else None ), ) @@ -425,7 +453,7 @@ def parse_router_args(args: List[str]) -> RouterArgs: python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 # PD disaggregated mode - python -m sglang_router.launch_router --pd-disaggregated \\ + python -m sglang_router.launch_router --pd-disaggregation \\ --prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\ --decode http://decode1:8001 --decode http://decode2:8001 \\ --policy cache_aware diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py index 5fd5d878877c..7708e393aea5 100644 --- a/sgl-router/py_src/sglang_router/router.py +++ b/sgl-router/py_src/sglang_router/router.py @@ -41,9 +41,13 @@ class Router: worker URLs using this port. Default: 80 service_discovery_namespace: Kubernetes namespace to watch for pods. If not provided, watches pods across all namespaces (requires cluster-wide permissions). Default: None + prefill_selector: Dictionary mapping of label keys to values for Kubernetes pod selection + for prefill servers (PD mode only). Default: {} + decode_selector: Dictionary mapping of label keys to values for Kubernetes pod selection + for decode servers (PD mode only). Default: {} prometheus_port: Port to expose Prometheus metrics. Default: None prometheus_host: Host address to bind the Prometheus metrics server. Default: None - pd_disaggregated: Enable PD (Prefill-Decode) disaggregated mode. Default: False + pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only) decode_urls: List of URLs for decode servers (PD mode only) """ @@ -68,14 +72,20 @@ def __init__( selector: Dict[str, str] = None, service_discovery_port: int = 80, service_discovery_namespace: Optional[str] = None, + prefill_selector: Dict[str, str] = None, + decode_selector: Dict[str, str] = None, prometheus_port: Optional[int] = None, prometheus_host: Optional[str] = None, - pd_disaggregated: bool = False, + pd_disaggregation: bool = False, prefill_urls: Optional[List[tuple]] = None, decode_urls: Optional[List[str]] = None, ): if selector is None: selector = {} + if prefill_selector is None: + prefill_selector = {} + if decode_selector is None: + decode_selector = {} self._router = _Router( worker_urls=worker_urls, @@ -96,9 +106,11 @@ def __init__( selector=selector, service_discovery_port=service_discovery_port, service_discovery_namespace=service_discovery_namespace, + prefill_selector=prefill_selector, + decode_selector=decode_selector, prometheus_port=prometheus_port, prometheus_host=prometheus_host, - pd_disaggregated=pd_disaggregated, + pd_disaggregation=pd_disaggregation, prefill_urls=prefill_urls, decode_urls=decode_urls, ) diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py index 26b3c33d90ce..884109e67ecb 100644 --- a/sgl-router/py_test/test_launch_router.py +++ b/sgl-router/py_test/test_launch_router.py @@ -45,7 +45,7 @@ def setUp(self): prometheus_port=None, prometheus_host=None, # PD-specific attributes - pd_disaggregated=False, + pd_disaggregation=False, prefill=None, decode=None, # Keep worker_urls for regular mode @@ -119,7 +119,7 @@ def test_launch_router_pd_mode_basic(self): # Test RouterArgs parsing for PD mode # Simulate the parsed args structure from argparse with action="append" args = self.create_router_args( - pd_disaggregated=True, + pd_disaggregation=True, policy="power_of_two", # PowerOfTwo is only valid in PD mode prefill=[ ["http://prefill1:8080", "9000"], @@ -133,7 +133,7 @@ def test_launch_router_pd_mode_basic(self): ) router_args = RouterArgs.from_cli_args(args) - self.assertTrue(router_args.pd_disaggregated) + self.assertTrue(router_args.pd_disaggregation) self.assertEqual(router_args.policy, "power_of_two") self.assertEqual(len(router_args.prefill_urls), 2) self.assertEqual(len(router_args.decode_urls), 2) @@ -147,7 +147,7 @@ def test_launch_router_pd_mode_basic(self): # Test Router creation in PD mode router = Router( worker_urls=[], # Empty for PD mode - pd_disaggregated=True, + pd_disaggregation=True, prefill_urls=[ ("http://prefill1:8080", 9000), ("http://prefill2:8080", None), @@ -165,7 +165,7 @@ def test_policy_validation(self): # Test 1: PowerOfTwo is only valid in PD mode args = self.create_router_args( - pd_disaggregated=False, + pd_disaggregation=False, policy="power_of_two", worker_urls=["http://localhost:8000"], ) @@ -180,7 +180,7 @@ def test_policy_validation(self): # Test 2: RoundRobin is not valid in PD mode args = self.create_router_args( - pd_disaggregated=True, + pd_disaggregation=True, policy="round_robin", prefill=[["http://prefill1:8080", "9000"]], decode=[["http://decode1:8081"]], @@ -198,7 +198,7 @@ def test_policy_validation(self): # Test 3: Valid combinations should not raise errors # Regular mode with RoundRobin args = self.create_router_args( - pd_disaggregated=False, + pd_disaggregation=False, policy="round_robin", worker_urls=["http://localhost:8000"], ) @@ -206,7 +206,7 @@ def test_policy_validation(self): # PD mode with PowerOfTwo args = self.create_router_args( - pd_disaggregated=True, + pd_disaggregation=True, policy="power_of_two", prefill=[["http://prefill1:8080", "9000"]], decode=[["http://decode1:8081"]], @@ -214,6 +214,79 @@ def test_policy_validation(self): ) # This should not raise (though it may fail to connect) + def test_pd_service_discovery_args_parsing(self): + """Test PD service discovery CLI argument parsing.""" + import argparse + + from sglang_router.launch_router import RouterArgs + + parser = argparse.ArgumentParser() + RouterArgs.add_cli_args(parser) + + args = parser.parse_args( + [ + "--pd-disaggregation", + "--service-discovery", + "--prefill-selector", + "app=sglang", + "component=prefill", + "--decode-selector", + "app=sglang", + "component=decode", + "--service-discovery-port", + "8000", + "--service-discovery-namespace", + "production", + "--policy", + "cache_aware", + ] + ) + + router_args = RouterArgs.from_cli_args(args) + + self.assertTrue(router_args.pd_disaggregation) + self.assertTrue(router_args.service_discovery) + self.assertEqual( + router_args.prefill_selector, {"app": "sglang", "component": "prefill"} + ) + self.assertEqual( + router_args.decode_selector, {"app": "sglang", "component": "decode"} + ) + self.assertEqual(router_args.service_discovery_port, 8000) + self.assertEqual(router_args.service_discovery_namespace, "production") + + def test_regular_service_discovery_args_parsing(self): + """Test regular mode service discovery CLI argument parsing.""" + import argparse + + from sglang_router.launch_router import RouterArgs + + parser = argparse.ArgumentParser() + RouterArgs.add_cli_args(parser) + + args = parser.parse_args( + [ + "--service-discovery", + "--selector", + "app=sglang-worker", + "environment=staging", + "--service-discovery-port", + "8000", + "--policy", + "round_robin", + ] + ) + + router_args = RouterArgs.from_cli_args(args) + + self.assertFalse(router_args.pd_disaggregation) + self.assertTrue(router_args.service_discovery) + self.assertEqual( + router_args.selector, {"app": "sglang-worker", "environment": "staging"} + ) + self.assertEqual(router_args.prefill_selector, {}) + self.assertEqual(router_args.decode_selector, {}) + if __name__ == "__main__": unittest.main() diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 439db1c4f838..dfe114f650e1 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -42,12 +42,16 @@ struct Router { selector: HashMap, service_discovery_port: u16, service_discovery_namespace: Option, + // PD service discovery fields + prefill_selector: HashMap, + decode_selector: HashMap, + bootstrap_port_annotation: String, prometheus_port: Option, prometheus_host: Option, request_timeout_secs: u64, // PD mode flag - pd_disaggregated: bool, - // PD-specific fields (only used when pd_disaggregated is true) + pd_disaggregation: bool, + // PD-specific fields (only used when pd_disaggregation is true) prefill_urls: Option)>>, decode_urls: Option>, } @@ -74,10 +78,13 @@ impl Router { selector = HashMap::new(), service_discovery_port = 80, service_discovery_namespace = None, + prefill_selector = HashMap::new(), + decode_selector = HashMap::new(), + bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"), prometheus_port = None, prometheus_host = None, request_timeout_secs = 600, // Add configurable request timeout - pd_disaggregated = false, // New flag for PD mode + pd_disaggregation = false, // New flag for PD mode prefill_urls = None, decode_urls = None ))] @@ -100,10 +107,13 @@ impl Router { selector: HashMap, service_discovery_port: u16, service_discovery_namespace: Option, + prefill_selector: HashMap, + decode_selector: HashMap, + bootstrap_port_annotation: String, prometheus_port: Option, prometheus_host: Option, request_timeout_secs: u64, - pd_disaggregated: bool, + pd_disaggregation: bool, prefill_urls: Option)>>, decode_urls: Option>, ) -> PyResult { @@ -126,17 +136,20 @@ impl Router { selector, service_discovery_port, service_discovery_namespace, + prefill_selector, + decode_selector, + bootstrap_port_annotation, prometheus_port, prometheus_host, request_timeout_secs, - pd_disaggregated, + pd_disaggregation, prefill_urls, decode_urls, }) } fn start(&self) -> PyResult<()> { - let policy_config = if self.pd_disaggregated { + let policy_config = if self.pd_disaggregation { // PD mode - map PolicyType to PDSelectionPolicy let pd_selection_policy = match &self.policy { PolicyType::Random => pd_types::PDSelectionPolicy::Random, @@ -207,6 +220,11 @@ impl Router { check_interval: std::time::Duration::from_secs(60), port: self.service_discovery_port, namespace: self.service_discovery_namespace.clone(), + // PD mode configuration + pd_mode: self.pd_disaggregation, + prefill_selector: self.prefill_selector.clone(), + decode_selector: self.decode_selector.clone(), + bootstrap_port_annotation: self.bootstrap_port_annotation.clone(), }) } else { None diff --git a/sgl-router/src/pd_router.rs b/sgl-router/src/pd_router.rs index e06fa371a3d7..dc1f1d74cc37 100644 --- a/sgl-router/src/pd_router.rs +++ b/sgl-router/src/pd_router.rs @@ -1,7 +1,9 @@ // PD (Prefill-Decode) Router Implementation // This module handles routing for disaggregated prefill-decode systems -use crate::pd_types::{Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDSelectionPolicy}; +use crate::pd_types::{ + Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDRouterError, PDSelectionPolicy, +}; use crate::tree::Tree; use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; @@ -65,12 +67,145 @@ impl Drop for LoadGuard<'_> { } impl PDRouter { - // TODO: Add methods for dynamic worker management to support /register endpoint: - // - add_prefill_server(url: String, bootstrap_port: Option) - // - add_decode_server(url: String) - // - remove_prefill_server(url: &str) - // - remove_decode_server(url: &str) - // These methods will be used when service discovery is implemented for PD mode + // Dynamic worker management methods for service discovery + pub async fn add_prefill_server( + &self, + url: String, + bootstrap_port: Option, + ) -> Result { + // Create EngineInfo for the new prefill server + let engine_info = EngineInfo::new_prefill(url.clone(), bootstrap_port); + + // Wait for the new server to be healthy + crate::router::Router::wait_for_healthy_workers( + &[url.clone()], + self.timeout_secs, + self.interval_secs, + ) + .map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?; + + // Add to prefill workers list + let mut workers = self + .prefill_workers + .write() + .map_err(|_| PDRouterError::LockError { + operation: "prefill_workers write".to_string(), + })?; + + // Check if already exists + if workers.iter().any(|w| w.url == url) { + return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() }); + } + + workers.push(engine_info); + + // Initialize load tracking + self.load_tracking + .insert(url.clone(), Arc::new(AtomicUsize::new(0))); + + // Add to cache tree if using cache-aware policy + if let Some(ref tree) = self.prefill_tree { + tree.lock().unwrap().insert("", &url); + } + + info!("Added prefill server: {}", url); + Ok(format!("Successfully added prefill server: {}", url)) + } + + pub async fn add_decode_server(&self, url: String) -> Result { + // Create EngineInfo for the new decode server + let engine_info = EngineInfo::new_decode(url.clone()); + + // Wait for the new server to be healthy + crate::router::Router::wait_for_healthy_workers( + &[url.clone()], + self.timeout_secs, + self.interval_secs, + ) + .map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?; + + // Add to decode workers list + let mut workers = self + .decode_workers + .write() + .map_err(|_| PDRouterError::LockError { + operation: "decode_workers write".to_string(), + })?; + + // Check if already exists + if workers.iter().any(|w| w.url == url) { + return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() }); + } + + workers.push(engine_info); + + // Initialize load tracking + self.load_tracking + .insert(url.clone(), Arc::new(AtomicUsize::new(0))); + + info!("Added decode server: {}", url); + Ok(format!("Successfully added decode server: {}", url)) + } + + pub async fn remove_prefill_server(&self, url: &str) -> Result { + let mut workers = self + .prefill_workers + .write() + .map_err(|_| PDRouterError::LockError { + operation: "prefill_workers write".to_string(), + })?; + + // Find and remove the server + let initial_len = workers.len(); + workers.retain(|w| w.url != url); + + if workers.len() == initial_len { + return Err(PDRouterError::WorkerNotFound { + url: url.to_string(), + }); + } + + // Remove from load tracking + self.load_tracking.remove(url); + + // Remove from cache tree if using cache-aware policy + if let Some(ref tree) = self.prefill_tree { + // Note: Tree doesn't have a remove method, so we rebuild it + let mut tree_guard = tree.lock().unwrap(); + *tree_guard = Tree::new(); + for worker in workers.iter() { + tree_guard.insert("", &worker.url); + } + } + + info!("Removed prefill server: {}", url); + Ok(format!("Successfully removed prefill server: {}", url)) + } + + pub async fn remove_decode_server(&self, url: &str) -> Result { + let mut workers = self + .decode_workers + .write() + .map_err(|_| PDRouterError::LockError { + operation: "decode_workers write".to_string(), + })?; + + // Find and remove the server + let initial_len = workers.len(); + workers.retain(|w| w.url != url); + + if workers.len() == initial_len { + return Err(PDRouterError::WorkerNotFound { + url: url.to_string(), + }); + } + + // Remove from load tracking + self.load_tracking.remove(url); + + info!("Removed decode server: {}", url); + Ok(format!("Successfully removed decode server: {}", url)) + } pub fn new( prefill_urls: Vec<(String, Option)>, diff --git a/sgl-router/src/pd_types.rs b/sgl-router/src/pd_types.rs index 98b1043862dc..16dc18267de5 100644 --- a/sgl-router/src/pd_types.rs +++ b/sgl-router/src/pd_types.rs @@ -3,6 +3,31 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; +// Custom error type for PD router operations +#[derive(Debug, thiserror::Error)] +pub enum PDRouterError { + #[error("Worker already exists: {url}")] + WorkerAlreadyExists { url: String }, + + #[error("Worker not found: {url}")] + WorkerNotFound { url: String }, + + #[error("Lock acquisition failed: {operation}")] + LockError { operation: String }, + + #[error("Health check failed for worker: {url}")] + HealthCheckFailed { url: String }, + + #[error("Invalid worker configuration: {reason}")] + InvalidConfiguration { reason: String }, + + #[error("Network error: {message}")] + NetworkError { message: String }, + + #[error("Timeout waiting for worker: {url}")] + Timeout { url: String }, +} + #[derive(Debug, Clone)] pub enum EngineType { Prefill, diff --git a/sgl-router/src/router.rs b/sgl-router/src/router.rs index 9e40963113ae..3b45ecde7fa6 100644 --- a/sgl-router/src/router.rs +++ b/sgl-router/src/router.rs @@ -1045,6 +1045,55 @@ impl Router { } } + /// Add a worker with PD mode support + pub async fn add_pd_worker( + &self, + worker_url: &str, + pod_type: crate::service_discovery::PodType, + bootstrap_port: Option, + ) -> Result { + match self { + Router::PrefillDecode { pd_router } => match pod_type { + crate::service_discovery::PodType::Prefill => pd_router + .add_prefill_server(worker_url.to_string(), bootstrap_port) + .await + .map_err(|e| e.to_string()), + crate::service_discovery::PodType::Decode => pd_router + .add_decode_server(worker_url.to_string()) + .await + .map_err(|e| e.to_string()), + crate::service_discovery::PodType::Regular => { + Err("Regular pod type not supported in PD mode".to_string()) + } + }, + _ => Err("add_pd_worker only supported in PD mode".to_string()), + } + } + + /// Remove a worker with PD mode support + pub async fn remove_pd_worker( + &self, + worker_url: &str, + pod_type: crate::service_discovery::PodType, + ) -> Result { + match self { + Router::PrefillDecode { pd_router } => match pod_type { + crate::service_discovery::PodType::Prefill => pd_router + .remove_prefill_server(worker_url) + .await + .map_err(|e| e.to_string()), + crate::service_discovery::PodType::Decode => pd_router + .remove_decode_server(worker_url) + .await + .map_err(|e| e.to_string()), + crate::service_discovery::PodType::Regular => { + Err("Regular pod type not supported in PD mode".to_string()) + } + }, + _ => Err("remove_pd_worker only supported in PD mode".to_string()), + } + } + async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option { match client.get(&format!("{}/get_load", worker_url)).send().await { Ok(res) if res.status().is_success() => match res.bytes().await { @@ -1174,3 +1223,108 @@ impl Router { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::service_discovery::PodType; + + fn create_test_regular_router() -> Router { + Router::Random { + worker_urls: Arc::new(RwLock::new(vec![ + "http://worker1:8080".to_string(), + "http://worker2:8080".to_string(), + ])), + timeout_secs: 5, + interval_secs: 1, + } + } + + #[test] + fn test_router_get_worker_urls_regular() { + let router = create_test_regular_router(); + let worker_urls = router.get_worker_urls(); + let urls = worker_urls.read().unwrap(); + + assert_eq!(urls.len(), 2); + assert!(urls.contains(&"http://worker1:8080".to_string())); + assert!(urls.contains(&"http://worker2:8080".to_string())); + } + + // #[test] + // fn test_router_get_worker_urls_pd_mode() { + // // For PD mode, get_worker_urls returns empty list + // // Note: PDRouter::new requires health checks which fail in tests + // // This test would need a mock server or different test setup + // } + + #[tokio::test] + async fn test_add_pd_worker_with_regular_router() { + let router = create_test_regular_router(); + + let result = router + .add_pd_worker("http://new-worker:8080", PodType::Prefill, Some(8081)) + .await; + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .contains("add_pd_worker only supported in PD mode")); + } + + #[tokio::test] + async fn test_remove_pd_worker_with_regular_router() { + let router = create_test_regular_router(); + + let result = router + .remove_pd_worker("http://worker:8080", PodType::Decode) + .await; + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .contains("remove_pd_worker only supported in PD mode")); + } + + // #[tokio::test] + // async fn test_add_pd_worker_with_pd_router_regular_type() { + // // Note: PDRouter::new requires health checks which fail in tests + // // This test would need a mock server or different test setup + // } + + // #[tokio::test] + // async fn test_remove_pd_worker_with_pd_router_regular_type() { + // // Note: PDRouter::new requires health checks which fail in tests + // // This test would need a mock server or different test setup + // } + + #[test] + fn test_select_first_worker_regular() { + let router = create_test_regular_router(); + let result = router.select_first_worker(); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "http://worker1:8080"); + } + + // #[test] + // fn test_select_first_worker_pd_mode() { + // // Note: PDRouter::new requires health checks which fail in tests + // // This test would need a mock server or different test setup + // } + + #[test] + fn test_wait_for_healthy_workers_empty_list() { + let result = Router::wait_for_healthy_workers(&[], 1, 1); + assert!(result.is_ok()); + } + + #[test] + fn test_wait_for_healthy_workers_invalid_urls() { + // This test will timeout quickly since the URLs are invalid + let result = + Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Timeout")); + } +} diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index 103551891664..b7104de11084 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -24,6 +24,12 @@ pub struct ServiceDiscoveryConfig { pub check_interval: Duration, pub port: u16, pub namespace: Option, + // PD mode specific configuration + pub pd_mode: bool, + pub prefill_selector: HashMap, + pub decode_selector: HashMap, + // Bootstrap port annotation specific to mooncake implementation + pub bootstrap_port_annotation: String, } impl Default for ServiceDiscoveryConfig { @@ -32,12 +38,25 @@ impl Default for ServiceDiscoveryConfig { enabled: false, selector: HashMap::new(), check_interval: Duration::from_secs(60), - port: 80, // Default port to connect to pods + port: 8000, // Standard port for modern services namespace: None, // None means watch all namespaces + // PD mode defaults + pd_mode: false, + prefill_selector: HashMap::new(), + decode_selector: HashMap::new(), + bootstrap_port_annotation: "sglang.ai/bootstrap-port".to_string(), } } } +/// Pod type for PD mode service discovery +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum PodType { + Prefill, + Decode, + Regular, +} + /// Represents a Kubernetes pod's information used for worker management #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct PodInfo { @@ -45,10 +64,47 @@ pub struct PodInfo { pub ip: String, pub status: String, pub is_ready: bool, + pub pod_type: Option, + pub bootstrap_port: Option, } impl PodInfo { - pub fn from_pod(pod: &Pod) -> Option { + /// Check if a pod matches any of the given selectors + fn matches_selector(pod: &Pod, selector: &HashMap) -> bool { + if selector.is_empty() { + return false; + } + + pod.metadata.labels.as_ref().map_or(false, |labels| { + selector + .iter() + .all(|(k, v)| labels.get(k).map_or(false, |label_value| label_value == v)) + }) + } + + /// Check if a pod should be included in service discovery + pub fn should_include(pod: &Pod, config: &ServiceDiscoveryConfig) -> bool { + if config.pd_mode { + // In PD mode, at least one selector must be non-empty + if config.prefill_selector.is_empty() && config.decode_selector.is_empty() { + warn!("PD mode enabled but both prefill_selector and decode_selector are empty"); + return false; + } + // In PD mode, pod must match either prefill or decode selector + Self::matches_selector(pod, &config.prefill_selector) + || Self::matches_selector(pod, &config.decode_selector) + } else { + // In regular mode, pod must match the general selector + if config.selector.is_empty() { + warn!("Regular mode enabled but selector is empty"); + return false; + } + Self::matches_selector(pod, &config.selector) + } + } + + /// Unified PodInfo creation with optional PD configuration + pub fn from_pod(pod: &Pod, config: Option<&ServiceDiscoveryConfig>) -> Option { let name = pod.metadata.name.clone()?; let status = pod.status.clone()?; let pod_ip = status.pod_ip?; @@ -63,11 +119,47 @@ impl PodInfo { let pod_status = status.phase.unwrap_or_else(|| "Unknown".to_string()); + // Determine pod type based on labels if config is provided and in PD mode + let pod_type = if let Some(config) = config { + if config.pd_mode { + // Use simplified helper methods for cleaner logic + if Self::matches_selector(pod, &config.prefill_selector) { + Some(PodType::Prefill) + } else if Self::matches_selector(pod, &config.decode_selector) { + Some(PodType::Decode) + } else { + Some(PodType::Regular) + } + } else { + Some(PodType::Regular) + } + } else { + // No config provided, default to None (for backwards compatibility) + None + }; + + // Extract bootstrap port from annotations for prefill pods + let bootstrap_port = if matches!(pod_type, Some(PodType::Prefill)) { + if let Some(config) = config { + pod.metadata + .annotations + .as_ref() + .and_then(|annotations| annotations.get(&config.bootstrap_port_annotation)) + .and_then(|port_str| port_str.parse::().ok()) + } else { + None + } + } else { + None + }; + Some(PodInfo { name, ip: pod_ip, status: pod_status, is_ready, + pod_type, + bootstrap_port, }) } @@ -100,18 +192,39 @@ pub async fn start_service_discovery( // Initialize Kubernetes client let client = Client::try_default().await?; - // Construct label selector string from map - let label_selector = config - .selector - .iter() - .map(|(k, v)| format!("{}={}", k, v)) - .collect::>() - .join(","); + // Log the appropriate selectors based on mode + if config.pd_mode { + let prefill_selector = config + .prefill_selector + .iter() + .map(|(k, v)| format!("{}={}", k, v)) + .collect::>() + .join(","); + + let decode_selector = config + .decode_selector + .iter() + .map(|(k, v)| format!("{}={}", k, v)) + .collect::>() + .join(","); - info!( - "Starting Kubernetes service discovery with selector: {}", - label_selector - ); + info!( + "Starting Kubernetes service discovery in PD mode with prefill_selector: '{}', decode_selector: '{}'", + prefill_selector, decode_selector + ); + } else { + let label_selector = config + .selector + .iter() + .map(|(k, v)| format!("{}={}", k, v)) + .collect::>() + .join(","); + + info!( + "Starting Kubernetes service discovery with selector: '{}'", + label_selector + ); + } // Create the task that will run in the background let handle = task::spawn(async move { @@ -127,33 +240,30 @@ pub async fn start_service_discovery( info!("Kubernetes service discovery initialized successfully"); - // Create an Arc for the selector map - let selector = Arc::new(config.selector); + // Create Arcs for configuration data + let config_arc = Arc::new(config.clone()); let port = config.port; + let mut retry_delay = Duration::from_secs(1); + const MAX_RETRY_DELAY: Duration = Duration::from_secs(300); // 5 minutes max + loop { // Create a watcher with the proper parameters according to the kube-rs API let watcher_config = Config::default(); let watcher_stream = watcher(pods.clone(), watcher_config).applied_objects(); // Clone Arcs for the closures - let selector_clone = Arc::clone(&selector); + let config_clone = Arc::clone(&config_arc); let tracked_pods_clone = Arc::clone(&tracked_pods); - // Apply label selector filter separately since we can't do it directly with the watcher anymore + // Simplified label selector filter using helper method let filtered_stream = watcher_stream.filter_map(move |obj_res| { - let selector_inner = Arc::clone(&selector_clone); + let config_inner = Arc::clone(&config_clone); async move { match obj_res { Ok(pod) => { - // Only process pods matching our label selector - if pod.metadata.labels.as_ref().map_or(false, |labels| { - // Check if the pod has all the labels from our selector - selector_inner.iter().all(|(k, v)| { - labels.get(k).map_or(false, |label_value| label_value == v) - }) - }) { + if PodInfo::should_include(&pod, &config_inner) { Some(Ok(pod)) } else { None @@ -167,25 +277,36 @@ pub async fn start_service_discovery( // Clone again for the next closure let tracked_pods_clone2 = Arc::clone(&tracked_pods_clone); let router_clone = Arc::clone(&router); + let config_clone2 = Arc::clone(&config_arc); match filtered_stream .try_for_each(move |pod| { let tracked_pods_inner = Arc::clone(&tracked_pods_clone2); let router_inner = Arc::clone(&router_clone); + let config_inner = Arc::clone(&config_clone2); async move { - if let Some(pod_info) = PodInfo::from_pod(&pod) { + let pod_info = PodInfo::from_pod(&pod, Some(&config_inner)); + + if let Some(pod_info) = pod_info { if pod.metadata.deletion_timestamp.is_some() { handle_pod_deletion( &pod_info, tracked_pods_inner, router_inner, port, + config_inner.pd_mode, ) .await; } else { - handle_pod_event(&pod_info, tracked_pods_inner, router_inner, port) - .await; + handle_pod_event( + &pod_info, + tracked_pods_inner, + router_inner, + port, + config_inner.pd_mode, + ) + .await; } } Ok(()) @@ -193,20 +314,29 @@ pub async fn start_service_discovery( }) .await { - Ok(_) => {} + Ok(_) => { + // Reset retry delay on success + retry_delay = Duration::from_secs(1); + } Err(err) => { error!("Error in Kubernetes watcher: {}", err); - // Wait a bit before retrying - time::sleep(Duration::from_secs(5)).await; + warn!( + "Retrying in {} seconds with exponential backoff", + retry_delay.as_secs() + ); + time::sleep(retry_delay).await; + + // Exponential backoff with jitter + retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY); } } // If the watcher exits for some reason, wait a bit before restarting warn!( "Kubernetes watcher exited, restarting in {} seconds", - config.check_interval.as_secs() + config_arc.check_interval.as_secs() ); - time::sleep(config.check_interval).await; + time::sleep(config_arc.check_interval).await; } }); @@ -218,34 +348,64 @@ async fn handle_pod_event( tracked_pods: Arc>>, router: Arc, port: u16, + pd_mode: bool, ) { let worker_url = pod_info.worker_url(port); - // Check if pod is already tracked - let already_tracked = { - let tracker = tracked_pods.lock().unwrap(); - tracker.contains(pod_info) - }; - - // If pod is healthy and not already tracked, add it + // If pod is healthy, try to add it (with atomic check-and-insert) if pod_info.is_healthy() { - if !already_tracked { + // Atomic check-and-insert to prevent race conditions + let should_add = { + let mut tracker = match tracked_pods.lock() { + Ok(tracker) => tracker, + Err(e) => { + error!("Failed to acquire tracked_pods lock: {}", e); + return; + } + }; + + if tracker.contains(pod_info) { + false // Already tracked + } else { + // Reserve the spot to prevent other threads from adding the same pod + tracker.insert(pod_info.clone()); + true + } + }; + + if should_add { info!( - "Healthy pod found: {}. Adding worker: {}", - pod_info.name, worker_url + "Healthy pod found: {} (type: {:?}). Adding worker: {}", + pod_info.name, pod_info.pod_type, worker_url ); - match router.add_worker(&worker_url).await { + + let result = if pd_mode && pod_info.pod_type.is_some() { + // Use PD-aware worker management + if let Some(pod_type) = &pod_info.pod_type { + router + .add_pd_worker(&worker_url, pod_type.clone(), pod_info.bootstrap_port) + .await + } else { + Err("Pod type is None in PD mode".to_string()) + } + } else { + // Fallback to regular worker management + router.add_worker(&worker_url).await + }; + + match result { Ok(msg) => { - info!("Router add_worker: {}", msg); - let mut tracker = tracked_pods.lock().unwrap(); - tracker.insert(pod_info.clone()); + info!("Successfully added worker: {}", msg); + } + Err(e) => { + error!("Failed to add worker {} to router: {}", worker_url, e); + // Remove from tracking since addition failed + if let Ok(mut tracker) = tracked_pods.lock() { + tracker.remove(pod_info); + } } - Err(e) => error!("Failed to add worker {} to router: {}", worker_url, e), } } - } else if already_tracked { - // If pod was healthy before but not anymore, remove it - handle_pod_deletion(pod_info, tracked_pods, router, port).await; } } @@ -254,22 +414,47 @@ async fn handle_pod_deletion( tracked_pods: Arc>>, router: Arc, port: u16, + pd_mode: bool, ) { let worker_url = pod_info.worker_url(port); - let mut tracked = tracked_pods.lock().unwrap(); - if tracked.remove(pod_info) { + let was_tracked = { + let mut tracked = match tracked_pods.lock() { + Ok(tracked) => tracked, + Err(e) => { + error!("Failed to acquire tracked_pods lock during deletion: {}", e); + return; + } + }; + tracked.remove(pod_info) + }; + + if was_tracked { info!( - "Pod deleted: {}. Removing worker: {}", - pod_info.name, worker_url + "Pod deleted: {} (type: {:?}). Removing worker: {}", + pod_info.name, pod_info.pod_type, worker_url ); - router.remove_worker(&worker_url); + + if pd_mode && pod_info.pod_type.is_some() { + // Use PD-aware worker removal + if let Some(pod_type) = &pod_info.pod_type { + if let Err(e) = router.remove_pd_worker(&worker_url, pod_type.clone()).await { + error!( + "Failed to remove PD worker {} from router: {}", + worker_url, e + ); + } + } + } else { + // Fallback to regular worker removal + router.remove_worker(&worker_url); + } } else { // This case might occur if a pod is deleted before it was ever marked healthy and added. // Or if the event is duplicated. No action needed on the router if it wasn't tracked (and thus not added). debug!( - "Pod deletion event for untracked/already removed pod: {}. Worker URL: {}", - pod_info.name, worker_url + "Pod deletion event for untracked/already removed pod: {} (type: {:?}). Worker URL: {}", + pod_info.name, pod_info.pod_type, worker_url ); } } @@ -325,6 +510,41 @@ mod tests { pod } + // Helper function to create a Pod with PD-specific labels and annotations + fn create_pd_k8s_pod(name: &str, ip: &str, pod_type: &str, bootstrap_port: Option) -> Pod { + let mut labels = std::collections::BTreeMap::new(); + labels.insert("app".to_string(), "sglang".to_string()); + labels.insert("component".to_string(), pod_type.to_string()); + + let mut annotations = std::collections::BTreeMap::new(); + if let Some(port) = bootstrap_port { + annotations.insert("sglang.ai/bootstrap-port".to_string(), port.to_string()); + } + + Pod { + metadata: ObjectMeta { + name: Some(name.to_string()), + labels: Some(labels), + annotations: Some(annotations), + ..Default::default() + }, + spec: Some(PodSpec::default()), + status: Some(PodStatus { + pod_ip: Some(ip.to_string()), + phase: Some("Running".to_string()), + conditions: Some(vec![PodCondition { + type_: "Ready".to_string(), + status: "True".to_string(), + last_probe_time: None, + last_transition_time: None, + message: None, + reason: None, + }]), + ..Default::default() + }), + } + } + // Helper to create a Router instance for testing event handlers fn create_test_router() -> Arc { let worker_urls = Arc::new(RwLock::new(Vec::new())); @@ -335,14 +555,80 @@ mod tests { }) } + // Helper to create a PD config for testing + fn create_pd_config() -> ServiceDiscoveryConfig { + let mut prefill_selector = HashMap::new(); + prefill_selector.insert("app".to_string(), "sglang".to_string()); + prefill_selector.insert("component".to_string(), "prefill".to_string()); + + let mut decode_selector = HashMap::new(); + decode_selector.insert("app".to_string(), "sglang".to_string()); + decode_selector.insert("component".to_string(), "decode".to_string()); + + ServiceDiscoveryConfig { + enabled: true, + selector: HashMap::new(), + check_interval: Duration::from_secs(60), + port: 8080, + namespace: None, + pd_mode: true, + prefill_selector, + decode_selector, + bootstrap_port_annotation: "sglang.ai/bootstrap-port".to_string(), + } + } + + #[test] + fn test_pod_info_should_include() { + let config = create_pd_config(); + + // Test prefill pod should be included + let prefill_pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", Some(8081)); + assert!(PodInfo::should_include(&prefill_pod, &config)); + + // Test decode pod should be included + let decode_pod = create_pd_k8s_pod("decode-pod", "10.0.0.2", "decode", None); + assert!(PodInfo::should_include(&decode_pod, &config)); + + // Test unmatched pod should not be included + let unmatched_pod = create_pd_k8s_pod("other-pod", "10.0.0.3", "other", None); + assert!(!PodInfo::should_include(&unmatched_pod, &config)); + + // Test regular mode + let mut regular_config = ServiceDiscoveryConfig::default(); + regular_config + .selector + .insert("app".to_string(), "sglang".to_string()); + regular_config.pd_mode = false; + + let regular_pod = create_pd_k8s_pod("worker-pod", "10.0.0.4", "worker", None); + assert!(PodInfo::should_include(®ular_pod, ®ular_config)); + } + #[test] fn test_service_discovery_config_default() { let config = ServiceDiscoveryConfig::default(); assert!(!config.enabled); assert!(config.selector.is_empty()); assert_eq!(config.check_interval, Duration::from_secs(60)); - assert_eq!(config.port, 80); + assert_eq!(config.port, 8000); assert!(config.namespace.is_none()); + assert!(!config.pd_mode); + assert!(config.prefill_selector.is_empty()); + assert!(config.decode_selector.is_empty()); + assert_eq!(config.bootstrap_port_annotation, "sglang.ai/bootstrap-port"); + } + + #[test] + fn test_pod_type_enum() { + // Test that PodType enum has expected variants + let prefill = PodType::Prefill; + let decode = PodType::Decode; + let regular = PodType::Regular; + + assert_eq!(format!("{:?}", prefill), "Prefill"); + assert_eq!(format!("{:?}", decode), "Decode"); + assert_eq!(format!("{:?}", regular), "Regular"); } #[test] @@ -354,11 +640,85 @@ mod tests { Some("True"), None, ); - let pod_info = PodInfo::from_pod(&k8s_pod).unwrap(); + let pod_info = PodInfo::from_pod(&k8s_pod, None).unwrap(); assert_eq!(pod_info.name, "test-pod"); assert_eq!(pod_info.ip, "10.0.0.1"); assert_eq!(pod_info.status, "Running"); assert!(pod_info.is_ready); + assert!(pod_info.pod_type.is_none()); + assert!(pod_info.bootstrap_port.is_none()); + } + + #[test] + fn test_pod_info_from_pod_with_pd_config_prefill() { + let k8s_pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", Some(8081)); + let config = create_pd_config(); + + let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap(); + assert_eq!(pod_info.name, "prefill-pod"); + assert_eq!(pod_info.ip, "10.0.0.1"); + assert_eq!(pod_info.status, "Running"); + assert!(pod_info.is_ready); + assert_eq!(pod_info.pod_type, Some(PodType::Prefill)); + assert_eq!(pod_info.bootstrap_port, Some(8081)); + } + + #[test] + fn test_pod_info_from_pod_with_pd_config_decode() { + let k8s_pod = create_pd_k8s_pod("decode-pod", "10.0.0.2", "decode", None); + let config = create_pd_config(); + + let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap(); + assert_eq!(pod_info.name, "decode-pod"); + assert_eq!(pod_info.ip, "10.0.0.2"); + assert_eq!(pod_info.status, "Running"); + assert!(pod_info.is_ready); + assert_eq!(pod_info.pod_type, Some(PodType::Decode)); + assert!(pod_info.bootstrap_port.is_none()); + } + + #[test] + fn test_pod_info_from_pod_with_pd_config_regular_mode() { + let k8s_pod = create_pd_k8s_pod("regular-pod", "10.0.0.3", "worker", None); + let mut config = create_pd_config(); + config.pd_mode = false; // Set to regular mode + + let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap(); + assert_eq!(pod_info.name, "regular-pod"); + assert_eq!(pod_info.ip, "10.0.0.3"); + assert_eq!(pod_info.status, "Running"); + assert!(pod_info.is_ready); + assert_eq!(pod_info.pod_type, Some(PodType::Regular)); + assert!(pod_info.bootstrap_port.is_none()); + } + + #[test] + fn test_pod_info_from_pod_with_pd_config_unmatched_labels() { + let k8s_pod = create_pd_k8s_pod("unknown-pod", "10.0.0.4", "unknown", None); + let config = create_pd_config(); + + let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap(); + assert_eq!(pod_info.name, "unknown-pod"); + assert_eq!(pod_info.ip, "10.0.0.4"); + assert_eq!(pod_info.status, "Running"); + assert!(pod_info.is_ready); + assert_eq!(pod_info.pod_type, Some(PodType::Regular)); + assert!(pod_info.bootstrap_port.is_none()); + } + + #[test] + fn test_pod_info_from_pod_with_pd_config_invalid_bootstrap_port() { + let mut pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", None); + // Add invalid bootstrap port annotation + pod.metadata.annotations.as_mut().unwrap().insert( + "sglang.ai/bootstrap-port".to_string(), + "invalid".to_string(), + ); + let config = create_pd_config(); + + let pod_info = PodInfo::from_pod(&pod, Some(&config)).unwrap(); + assert_eq!(pod_info.pod_type, Some(PodType::Prefill)); + assert!(pod_info.bootstrap_port.is_none()); // Should be None for invalid port } #[test] @@ -370,7 +730,7 @@ mod tests { Some("False"), None, ); - let pod_info = PodInfo::from_pod(&k8s_pod).unwrap(); + let pod_info = PodInfo::from_pod(&k8s_pod, None).unwrap(); assert!(!pod_info.is_ready); } @@ -383,26 +743,26 @@ mod tests { None, None, ); - let pod_info = PodInfo::from_pod(&k8s_pod).unwrap(); + let pod_info = PodInfo::from_pod(&k8s_pod, None).unwrap(); assert!(!pod_info.is_ready); } #[test] fn test_pod_info_from_pod_missing_name() { let k8s_pod = create_k8s_pod(None, Some("10.0.0.1"), Some("Running"), Some("True"), None); - assert!(PodInfo::from_pod(&k8s_pod).is_none()); + assert!(PodInfo::from_pod(&k8s_pod, None).is_none()); } #[test] fn test_pod_info_from_pod_missing_ip() { let k8s_pod = create_k8s_pod(Some("test-pod"), None, Some("Running"), Some("True"), None); - assert!(PodInfo::from_pod(&k8s_pod).is_none()); + assert!(PodInfo::from_pod(&k8s_pod, None).is_none()); } #[test] fn test_pod_info_from_pod_missing_status_phase() { let k8s_pod = create_k8s_pod(Some("test-pod"), Some("10.0.0.1"), None, Some("True"), None); - let pod_info = PodInfo::from_pod(&k8s_pod).unwrap(); + let pod_info = PodInfo::from_pod(&k8s_pod, None).unwrap(); assert_eq!(pod_info.status, "Unknown"); } @@ -410,7 +770,7 @@ mod tests { fn test_pod_info_from_pod_no_status_object() { let mut k8s_pod = create_k8s_pod(Some("test-pod"), None, None, None, None); k8s_pod.status = None; - assert!(PodInfo::from_pod(&k8s_pod).is_none()); + assert!(PodInfo::from_pod(&k8s_pod, None).is_none()); } #[test] @@ -420,6 +780,8 @@ mod tests { ip: "1.1.1.1".into(), status: "Running".into(), is_ready: true, + pod_type: None, + bootstrap_port: None, }; assert!(healthy_pod.is_healthy()); @@ -428,6 +790,8 @@ mod tests { ip: "1.1.1.2".into(), status: "Running".into(), is_ready: false, + pod_type: None, + bootstrap_port: None, }; assert!(!not_ready_pod.is_healthy()); @@ -436,6 +800,8 @@ mod tests { ip: "1.1.1.3".into(), status: "Pending".into(), is_ready: true, + pod_type: None, + bootstrap_port: None, }; assert!(!not_running_pod.is_healthy()); } @@ -447,10 +813,45 @@ mod tests { ip: "1.2.3.4".into(), status: "Running".into(), is_ready: true, + pod_type: None, + bootstrap_port: None, }; assert_eq!(pod_info.worker_url(8080), "http://1.2.3.4:8080"); } + #[test] + fn test_pod_info_equality_with_pod_type() { + let pod1 = PodInfo { + name: "pod1".into(), + ip: "1.2.3.4".into(), + status: "Running".into(), + is_ready: true, + pod_type: Some(PodType::Prefill), + bootstrap_port: Some(8081), + }; + + let pod2 = PodInfo { + name: "pod1".into(), + ip: "1.2.3.4".into(), + status: "Running".into(), + is_ready: true, + pod_type: Some(PodType::Prefill), + bootstrap_port: Some(8081), + }; + + let pod3 = PodInfo { + name: "pod1".into(), + ip: "1.2.3.4".into(), + status: "Running".into(), + is_ready: true, + pod_type: Some(PodType::Decode), + bootstrap_port: None, + }; + + assert_eq!(pod1, pod2); + assert_ne!(pod1, pod3); + } + #[tokio::test] async fn test_handle_pod_event_add_unhealthy_pod() { let router = create_test_router(); @@ -460,6 +861,8 @@ mod tests { ip: "1.2.3.4".into(), status: "Pending".into(), is_ready: false, + pod_type: None, + bootstrap_port: None, }; let port = 8080u16; @@ -468,6 +871,7 @@ mod tests { Arc::clone(&tracked_pods), Arc::clone(&router), port, + false, // pd_mode = false ) .await; @@ -488,6 +892,8 @@ mod tests { ip: "1.2.3.4".into(), status: "Running".into(), is_ready: true, + pod_type: None, + bootstrap_port: None, }; let port = 8080u16; @@ -496,10 +902,221 @@ mod tests { Arc::clone(&tracked_pods), Arc::clone(&router), port, + false, // pd_mode = false ) .await; assert!(tracked_pods.lock().unwrap().is_empty()); assert!(router.get_worker_urls().read().unwrap().is_empty()); } + + #[tokio::test] + async fn test_handle_pd_pod_event_prefill_pod() { + let router = create_test_router(); + let tracked_pods = Arc::new(Mutex::new(HashSet::new())); + let pod_info = PodInfo { + name: "prefill-pod".into(), + ip: "1.2.3.4".into(), + status: "Running".into(), + is_ready: true, + pod_type: Some(PodType::Prefill), + bootstrap_port: Some(8081), + }; + let port = 8080u16; + + // This test validates the structure but won't actually add workers since + // we're using a regular router instead of PD router + handle_pod_event( + &pod_info, + Arc::clone(&tracked_pods), + Arc::clone(&router), + port, + false, // pd_mode = false, so it should fallback to regular handling + ) + .await; + + // Pod should not be tracked since router.add_worker will fail for non-running server + assert!(!tracked_pods.lock().unwrap().contains(&pod_info)); + } + + #[tokio::test] + async fn test_handle_pd_pod_event_decode_pod() { + let router = create_test_router(); + let tracked_pods = Arc::new(Mutex::new(HashSet::new())); + let pod_info = PodInfo { + name: "decode-pod".into(), + ip: "1.2.3.5".into(), + status: "Running".into(), + is_ready: true, + pod_type: Some(PodType::Decode), + bootstrap_port: None, + }; + let port = 8080u16; + + handle_pod_event( + &pod_info, + Arc::clone(&tracked_pods), + Arc::clone(&router), + port, + false, // pd_mode = false, so it should fallback to regular handling + ) + .await; + + // Pod should not be tracked since router.add_worker will fail for non-running server + assert!(!tracked_pods.lock().unwrap().contains(&pod_info)); + } + + #[tokio::test] + async fn test_handle_pd_pod_deletion_tracked_pod() { + let router = create_test_router(); + let tracked_pods = Arc::new(Mutex::new(HashSet::new())); + let pod_info = PodInfo { + name: "test-pod".into(), + ip: "1.2.3.4".into(), + status: "Running".into(), + is_ready: true, + pod_type: Some(PodType::Prefill), + bootstrap_port: Some(8081), + }; + + // Add pod to tracked set first + { + let mut tracked = tracked_pods.lock().unwrap(); + tracked.insert(pod_info.clone()); + } + + let port = 8080u16; + + handle_pod_deletion( + &pod_info, + Arc::clone(&tracked_pods), + Arc::clone(&router), + port, + false, // pd_mode = false + ) + .await; + + // Pod should be removed from tracking + assert!(!tracked_pods.lock().unwrap().contains(&pod_info)); + } + + #[tokio::test] + async fn test_handle_pd_pod_deletion_untracked_pod() { + let router = create_test_router(); + let tracked_pods = Arc::new(Mutex::new(HashSet::new())); + let pod_info = PodInfo { + name: "untracked-pod".into(), + ip: "1.2.3.4".into(), + status: "Running".into(), + is_ready: true, + pod_type: Some(PodType::Decode), + bootstrap_port: None, + }; + let port = 8080u16; + + // Don't add pod to tracked set + + handle_pod_deletion( + &pod_info, + Arc::clone(&tracked_pods), + Arc::clone(&router), + port, + true, // pd_mode = true + ) + .await; + + // Tracked set should remain empty + assert!(tracked_pods.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn test_unified_handler_regular_mode() { + let router = create_test_router(); + let tracked_pods = Arc::new(Mutex::new(HashSet::new())); + let pod_info = PodInfo { + name: "regular-pod".into(), + ip: "1.2.3.4".into(), + status: "Running".into(), + is_ready: true, + pod_type: Some(PodType::Regular), + bootstrap_port: None, + }; + let port = 8080u16; + + // Test that unified handler works for regular mode + handle_pod_event( + &pod_info, + Arc::clone(&tracked_pods), + Arc::clone(&router), + port, + false, // pd_mode = false + ) + .await; + + // Pod should not be tracked since router.add_worker will fail for non-running server + assert!(!tracked_pods.lock().unwrap().contains(&pod_info)); + } + + #[tokio::test] + async fn test_unified_handler_pd_mode_with_prefill() { + let router = create_test_router(); + let tracked_pods = Arc::new(Mutex::new(HashSet::new())); + let pod_info = PodInfo { + name: "prefill-pod".into(), + ip: "1.2.3.4".into(), + status: "Running".into(), + is_ready: true, + pod_type: Some(PodType::Prefill), + bootstrap_port: Some(8081), + }; + let port = 8080u16; + + // Test that unified handler works for PD mode with prefill + handle_pod_event( + &pod_info, + Arc::clone(&tracked_pods), + Arc::clone(&router), + port, + true, // pd_mode = true + ) + .await; + + // Pod should not be tracked since router.add_pd_worker will fail for regular router + assert!(!tracked_pods.lock().unwrap().contains(&pod_info)); + } + + #[tokio::test] + async fn test_unified_handler_deletion_with_pd_mode() { + let router = create_test_router(); + let tracked_pods = Arc::new(Mutex::new(HashSet::new())); + let pod_info = PodInfo { + name: "decode-pod".into(), + ip: "1.2.3.4".into(), + status: "Running".into(), + is_ready: true, + pod_type: Some(PodType::Decode), + bootstrap_port: None, + }; + + // Add pod to tracked set first + { + let mut tracked = tracked_pods.lock().unwrap(); + tracked.insert(pod_info.clone()); + } + + let port = 8080u16; + + // Test that unified handler works for deletion in PD mode + handle_pod_deletion( + &pod_info, + Arc::clone(&tracked_pods), + Arc::clone(&router), + port, + true, // pd_mode = true + ) + .await; + + // Pod should be removed from tracking + assert!(!tracked_pods.lock().unwrap().contains(&pod_info)); + } } diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 6d9019b431f6..5a1e65790355 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -5,7 +5,7 @@ //! - Phase 2: Bootstrap injection and request handling //! - Phase 3: Cache-aware selection (when implemented) //! -//! Note: PD mode is enabled via the pd_disaggregated flag, not as a policy type. +//! Note: PD mode is enabled via the pd_disaggregation flag, not as a policy type. //! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode. #[cfg(test)] @@ -90,7 +90,7 @@ mod test_pd_routing { #[test] fn test_pd_selection_policies() { // Test all PD selection policy variants - // Note: These policies are only used when pd_disaggregated=true + // Note: These policies are only used when pd_disaggregation=true let policies = vec![ PDSelectionPolicy::Random, PDSelectionPolicy::PowerOfTwo, @@ -122,7 +122,7 @@ mod test_pd_routing { #[test] fn test_pd_router_configuration() { // Test PrefillDecodeConfig creation with various policies - // This config is used when pd_disaggregated=true + // This config is used when pd_disaggregation=true let configs = vec![ PolicyConfig::PrefillDecodeConfig { selection_policy: PDSelectionPolicy::Random, @@ -878,7 +878,7 @@ mod test_pd_routing { #[test] fn test_policy_type_to_pd_selection_policy_mapping() { // Document the mapping from PolicyType to PDSelectionPolicy - // This mapping happens in lib.rs when pd_disaggregated=true + // This mapping happens in lib.rs when pd_disaggregation=true // PolicyType::Random -> PDSelectionPolicy::Random // PolicyType::PowerOfTwo -> PDSelectionPolicy::PowerOfTwo