diff --git a/sgl-model-gateway/e2e_test/infra/constants.py b/sgl-model-gateway/e2e_test/infra/constants.py index f7f2c6fc3c34..86fffe166511 100644 --- a/sgl-model-gateway/e2e_test/infra/constants.py +++ b/sgl-model-gateway/e2e_test/infra/constants.py @@ -61,7 +61,9 @@ class Runtime(str, Enum): # Model loading configuration INITIAL_GRACE_PERIOD = 30 # Wait before first health check (model loading time) -LAUNCH_STAGGER_DELAY = 5 # Delay between launching multiple workers +LAUNCH_STAGGER_DELAY = ( + 10 # Delay between launching multiple workers (avoid I/O contention) +) # Retry configuration MAX_RETRY_ATTEMPTS = ( diff --git a/sgl-model-gateway/e2e_test/infra/model_pool.py b/sgl-model-gateway/e2e_test/infra/model_pool.py index c1ef702b43c4..d5d083ec536a 100644 --- a/sgl-model-gateway/e2e_test/infra/model_pool.py +++ b/sgl-model-gateway/e2e_test/infra/model_pool.py @@ -4,6 +4,7 @@ import logging import os +import signal import subprocess import threading import time @@ -242,19 +243,39 @@ def _grpc_health_check(self, timeout: float = 5.0) -> bool: return False def terminate(self, timeout: float = 10.0) -> None: - """Terminate the model server process.""" + """Terminate the model server process and all child processes. + + Since workers are started with start_new_session=True, they run in their + own process group. We must kill the entire process group to ensure child + processes (e.g., TP workers) are also terminated and GPU memory is freed. + """ if self.process.poll() is not None: return # Already terminated - logger.info("Terminating %s (PID %d)", self.key, self.process.pid) + pid = self.process.pid + logger.info("Terminating %s (PID %d)", self.key, pid) + + # Try graceful shutdown of the entire process group first + try: + pgid = os.getpgid(pid) + os.killpg(pgid, signal.SIGTERM) + except (ProcessLookupError, OSError) as e: + logger.debug("Could not send SIGTERM to process group: %s", e) + # Fall back to terminating just the main process + self.process.terminate() - # Try graceful shutdown first - self.process.terminate() try: self.process.wait(timeout=timeout) except subprocess.TimeoutExpired: - logger.warning("%s did not terminate, killing", self.key) - self.process.kill() + logger.warning("%s did not terminate, killing process group", self.key) + # Force kill the entire process group + try: + pgid = os.getpgid(pid) + os.killpg(pgid, signal.SIGKILL) + except (ProcessLookupError, OSError) as e: + logger.debug("Could not send SIGKILL to process group: %s", e) + self.process.kill() + try: self.process.wait(timeout=5) # Brief timeout after kill except subprocess.TimeoutExpired: