Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -2115,13 +2115,37 @@ OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4317
# Maximum concurrent health checks per worker (default: 10)
# MAX_CONCURRENT_HEALTH_CHECKS=10

# Enable automatic tools/prompts/resources refresh from the mcp servers during health checks (default: false)
# If the tools/prompts/resources in the mcp servers are not updated frequently, it is recommended to keep this disabled to reduce load on the servers
# AUTO_REFRESH_SERVERS=false
# -----------------------------------------------------------------------------
# Auto-Refresh / Polling (requires health checks above)
# -----------------------------------------------------------------------------
# Re-fetch tools, prompts, and resources from downstream MCP servers on each
# health-check cycle. Disabled by default to reduce load when tool lists are
# static. Set both variables below to enable polling-based auto-refresh.
#
AUTO_REFRESH_SERVERS=true
GATEWAY_AUTO_REFRESH_INTERVAL=300 # interval in seconds (minimum: 60)

# Default refresh interval in seconds for gateway tools/resources/prompts sync
# Minimum: 60 seconds
# GATEWAY_AUTO_REFRESH_INTERVAL=300
# =============================================================================
# Hot/Cold Server Classification
# =============================================================================
# Enable hot/cold server classification to stagger polling based on upstream
# MCP session usage. Classifies servers into hot (top 20% by recent usage) and
# cold (remaining 80%), with independent polling intervals for each group.
# Requires Redis for multi-worker coordination. Falls back gracefully in
# single-worker mode (make dev).
#
# Default: false (disabled)
# HOT_COLD_CLASSIFICATION_ENABLED=false

# DEPRECATED: Hot/cold server intervals are now auto-derived from GATEWAY_AUTO_REFRESH_INTERVAL
# Hot servers (top 20%): polled at 1x GATEWAY_AUTO_REFRESH_INTERVAL
# Cold servers (80%): polled at 3x GATEWAY_AUTO_REFRESH_INTERVAL
# HOT_SERVER_CHECK_INTERVAL=300 # REMOVED - auto-derived (1x base interval)
# COLD_SERVER_CHECK_INTERVAL=300 # REMOVED - auto-derived (3x base interval)

# DEPRECATED: Classification refresh now uses GATEWAY_AUTO_REFRESH_INTERVAL
# (No separate config needed - classification runs at same interval as gateway refresh)
# SERVER_CLASSIFICATION_REFRESH_INTERVAL=120 # REMOVED - use GATEWAY_AUTO_REFRESH_INTERVAL

# File lock name for gateway service leader election
# Used to coordinate multiple gateway instances when running in cluster mode
Expand Down
11 changes: 11 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,17 @@ services:
- GATEWAY_VALIDATION_TIMEOUT=5
# Max concurrent health checks per worker (default: 10)
- MAX_CONCURRENT_HEALTH_CHECKS=10
# Auto-refresh tools/resources/prompts from MCP servers during health checks
# SSE gateways use polling only; StreamableHTTP gateways with session pooling
# also receive push notifications via notifications/tools/list_changed
- AUTO_REFRESH_SERVERS=true
- GATEWAY_AUTO_REFRESH_INTERVAL=300
# Hot/cold server classification - Stagger polling based on upstream MCP session usage
# Classifies servers into hot (top 20% by recent usage) and cold (remaining 80%)
# Hot servers polled at GATEWAY_AUTO_REFRESH_INTERVAL (1x), cold at 3x (auto-derived)
- HOT_COLD_CLASSIFICATION_ENABLED=True
# HOT_SERVER_CHECK_INTERVAL and COLD_SERVER_CHECK_INTERVAL are auto-derived (no config needed)
# Classification refresh uses GATEWAY_AUTO_REFRESH_INTERVAL (no separate config needed)
# JWT Configuration - Choose ONE approach:
# Option 1: HMAC (Default - Simple deployments)
- JWT_ALGORITHM=HS256
Expand Down
79 changes: 76 additions & 3 deletions mcpgateway/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,7 +1650,7 @@ def parse_issuers(cls, v: Any) -> list[str]:

# Health Checks
# Interval in seconds between health checks (aligned with mcp_session_pool_health_check_interval)
health_check_interval: int = 60
health_check_interval: int = 300
# Timeout in seconds for each health check request
health_check_timeout: int = 5
# Per-check timeout (seconds) to bound total time of one gateway health check
Expand All @@ -1663,11 +1663,25 @@ def parse_issuers(cls, v: Any) -> list[str]:

# Auto-refresh tools/resources/prompts from gateways during health checks
# When enabled, tools/resources/prompts are fetched and synced with DB during health checks
auto_refresh_servers: bool = Field(default=False, description="Enable automatic tool/resource/prompt refresh during gateway health checks")
auto_refresh_servers: bool = Field(default=True, description="Enable automatic tool/resource/prompt refresh during gateway health checks")

# Per-gateway refresh configuration (used when auto_refresh_servers is True)
# Gateways can override this with their own refresh_interval_seconds
gateway_auto_refresh_interval: int = Field(default=300, ge=60, description="Default refresh interval in seconds for gateway tools/resources/prompts sync (minimum 60 seconds)")
gateway_auto_refresh_interval: int = Field(
default=300, ge=1, description="Default refresh interval in seconds for gateway tools/resources/prompts sync (minimum 60 seconds recommended for production)"
)

# Staggered polling configuration (default behavior - replaces naive "fire all at once" health checks)
# Gateways are assigned deterministic offsets using index-based linear distribution: offset = (i/N) × interval
# This spreads load uniformly: 2000 gateways @ 600s interval = one poll every 0.3s (no spikes)
# tick_interval and tolerance are auto-derived from gateway_auto_refresh_interval (see @property methods below)
staggered_polling_enabled: bool = Field(default=True, description="Enable staggered polling (default: True). Set to False only for emergency rollback to naive polling.")

# Hot/Cold Server Classification (staggered polling optimization)
# Classify servers by usage (hot = active sessions, cold = inactive) for optimized polling
# Poll intervals auto-derived: hot = gateway_auto_refresh_interval (1x), cold = 3x
# Classification refresh uses gateway_auto_refresh_interval (no separate config needed)
hot_cold_classification_enabled: bool = Field(default=True, description="Enable hot/cold server classification for staggered polling (requires Redis for multi-worker)")

# Validation Gateway URL
gateway_validation_timeout: int = 5 # seconds
Expand Down Expand Up @@ -1879,6 +1893,65 @@ def custom_well_known_files(self) -> Dict[str, str]:
logger.error(f"Invalid JSON in WELL_KNOWN_CUSTOM_FILES: {self.well_known_custom_files}")
return {}

@property
def staggered_polling_tick_interval(self) -> float:
"""Auto-scale tick interval based on gateway_auto_refresh_interval.

Formula: max(1.0, min(30.0, interval / 20))

This scales the polling loop wake frequency as a percentage of the main interval,
providing finer granularity for short intervals while capping at 30s for long intervals.

Examples:
- 60s interval → 3.0s tick (5% of interval)
- 300s interval → 15.0s tick (5% of interval)
- 600s interval → 30.0s tick (capped at max)
- 3600s interval → 30.0s tick (capped at max)

Returns:
float: Tick interval in seconds (1.0 to 30.0)
"""
interval = self.gateway_auto_refresh_interval
return max(1.0, min(30.0, interval / 20))

@property
def staggered_polling_tolerance(self) -> float:
"""Tolerance window scales with tick interval.

Tolerance determines the ±window for "gateway is due for polling".
Wider tolerance = more gateways polled per tick (batch efficiency vs timing precision).

Returns:
float: Tolerance window in seconds (same as tick_interval)
"""
return self.staggered_polling_tick_interval

@property
def hot_server_check_interval(self) -> float:
"""Hot server polling interval (auto-derived from gateway_auto_refresh_interval).

Hot servers (top 20% by usage) are polled at the same rate as gateway tool refresh.

Returns:
float: Hot server check interval in seconds (equals gateway_auto_refresh_interval)
"""
return float(self.gateway_auto_refresh_interval)

@property
def cold_server_check_interval(self) -> float:
"""Cold server polling interval (auto-derived from gateway_auto_refresh_interval).

Cold servers (remaining 80%) are polled at 3x the gateway refresh rate to save resources.

Examples:
- gateway_auto_refresh_interval=300s → cold=900s (15 minutes)
- gateway_auto_refresh_interval=60s → cold=180s (3 minutes)

Returns:
float: Cold server check interval in seconds (3x gateway_auto_refresh_interval)
"""
return float(self.gateway_auto_refresh_interval * 3)

@field_validator("well_known_security_txt_enabled", mode="after")
@classmethod
def _auto_enable_security_txt(cls, v: Any, info: ValidationInfo) -> bool:
Expand Down
161 changes: 161 additions & 0 deletions mcpgateway/services/gateway_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ def __init__(self) -> None:
self._active_gateways: Set[str] = set() # Track active gateway URLs
self._stream_response = None
self._pending_responses = {}
# Hot/cold server classification service (initialized in initialize())
self._classification_service: Optional[Any] = None
# Prefer using the globally-initialized singletons from the service modules
# so events propagate via their initialized EventService/Redis clients.
# Import lazily and fall back to creating local instances when the module-level
Expand Down Expand Up @@ -592,6 +594,15 @@ async def initialize(self) -> None:
# Always create the health check task in filelock mode; leader check is handled inside.
self._health_check_task = asyncio.create_task(self._run_health_checks(user_email))

# Initialize hot/cold classification service (if enabled)
if settings.hot_cold_classification_enabled:
# First-Party
from mcpgateway.services.server_classification_service import ServerClassificationService

self._classification_service = ServerClassificationService(redis_client=self._redis_client)
await self._classification_service.start()
logger.info("Hot/cold classification service initialized")

async def shutdown(self) -> None:
"""Shutdown the service.

Expand All @@ -615,6 +626,11 @@ async def shutdown(self) -> None:
except asyncio.CancelledError:
pass

# Stop classification service
if self._classification_service:
await self._classification_service.stop()
logger.info("Classification service stopped")

# Cancel leader heartbeat task if running
if getattr(self, "_leader_heartbeat_task", None):
self._leader_heartbeat_task.cancel()
Expand Down Expand Up @@ -3333,6 +3349,17 @@ async def _check_single_gateway_health(self, gateway: DbGateway, user_email: Opt
# Sanitize URL for logging/telemetry (redacts sensitive query params)
gateway_url_sanitized = sanitize_url_for_logging(gateway_url, auth_query_params_decrypted)

# Hot/cold classification: Check if this server should be health-checked now
if self._classification_service:
try:
should_check = await self._classification_service.should_poll_server(gateway_url, "health")
if not should_check:
logger.debug(f"Skipping health check for {SecurityValidator.sanitize_log_message(gateway_name)}: " f"not yet due based on hot/cold classification")
return
except Exception as e:
# Fail open: proceed with health check if classification check fails
logger.warning(f"Classification check failed for {gateway_name}, proceeding with health check (fail-open): {e}")

# Create span for individual gateway health check
with create_span(
"gateway.health_check",
Expand Down Expand Up @@ -3529,7 +3556,22 @@ def get_httpx_client_factory(
logger.warning(f"Failed to update last_seen for gateway {gateway_name}: {update_error}")

# Auto-refresh tools/resources/prompts if enabled
should_auto_refresh = False
if settings.auto_refresh_servers:
# Hot/cold classification: Check if this server should have tools refreshed now
if self._classification_service:
try:
should_auto_refresh = await self._classification_service.should_poll_server(gateway_url, "tools")
if not should_auto_refresh:
logger.debug(f"Skipping auto-refresh for {SecurityValidator.sanitize_log_message(gateway_name)}: " f"not yet due based on hot/cold classification")
except Exception as e:
# Fail open: proceed with auto-refresh if classification check fails
logger.warning(f"Classification check failed for {gateway_name}, proceeding with auto-refresh (fail-open): {e}")
should_auto_refresh = True
else:
should_auto_refresh = True

if should_auto_refresh:
try:
# Throttling: Check if refresh is needed based on last_refresh_at
refresh_needed = True
Expand Down Expand Up @@ -3930,6 +3972,125 @@ async def _run_leader_heartbeat(self) -> None:
logger.warning(f"Leader heartbeat error: {e}")
# Continue trying - the main health check loop will handle leadership loss

def _calculate_gateway_poll_offset(self, gateway_id: uuid.UUID) -> float:
"""Calculate deterministic poll offset for a gateway within health check interval.

Uses stable hash function to assign each gateway a fixed offset, ensuring the same
gateway always polls at the same relative time within the interval. This spreads
gateway health checks uniformly across the health check interval.

Args:
gateway_id: UUID of the gateway to calculate offset for

Returns:
float: Offset in seconds between 0 and health_check_interval (exclusive)

Examples:
>>> service = GatewayService()
>>> service._health_check_interval = 60.0
>>> gw_id = uuid.UUID('12345678-1234-5678-1234-567812345678')
>>> offset = service._calculate_gateway_poll_offset(gw_id)
>>> 0 <= offset < 60.0
True
>>> # Same gateway always gets same offset (deterministic)
>>> offset == service._calculate_gateway_poll_offset(gw_id)
True
>>> # Different gateway gets different offset
>>> gw_id2 = uuid.UUID('87654321-4321-8765-4321-876543218765')
>>> offset2 = service._calculate_gateway_poll_offset(gw_id2)
>>> offset != offset2
True
"""
# Use UUID integer (stable 128-bit, no PYTHONHASHSEED dependency) for deterministic offset
gateway_hash = gateway_id.int
offset = gateway_hash % int(self._health_check_interval)
return float(offset)

def _should_poll_gateway_now(self, gateway: "DbGateway", current_time: float) -> bool:
"""Check if a gateway should be polled at the current time.

In naive mode (staggered_polling_enabled=False): always returns True.
In staggered mode: checks if current time is within tolerance of gateway's scheduled offset.

Args:
gateway: Gateway database object containing gateway_id
current_time: Current timestamp (seconds since epoch)

Returns:
bool: True if gateway should be polled now, False otherwise

Examples:
>>> from unittest.mock import Mock
>>> service = GatewayService()
>>> service._health_check_interval = 60.0
>>> mock_gateway = Mock()
>>> mock_gateway.gateway_id = uuid.UUID('12345678-1234-5678-1234-567812345678')
>>> # Naive mode: always True
>>> settings.staggered_polling_enabled = False
>>> service._should_poll_gateway_now(mock_gateway, time.time())
True
>>> # Staggered mode: check if within tolerance window
>>> settings.staggered_polling_enabled = True
>>> settings.staggered_polling_tolerance = 5.0
>>> # Calculate when gateway should poll
>>> offset = service._calculate_gateway_poll_offset(mock_gateway.gateway_id)
>>> interval = service._health_check_interval
>>> epoch_aligned = int(time.time() / interval) * interval
>>> next_poll_time = epoch_aligned + offset
>>> # Should poll if current time is within tolerance of next_poll_time
>>> current = next_poll_time # Exactly at poll time
>>> service._should_poll_gateway_now(mock_gateway, current)
True
"""
# Naive mode: poll all gateways on every cycle
if not settings.staggered_polling_enabled:
return True

# Calculate gateway's offset within the interval
offset = self._calculate_gateway_poll_offset(gateway.gateway_id)

# Align current time to interval boundaries and add offset
interval = self._health_check_interval
current_cycle_start = int(current_time / interval) * interval
next_poll_time = current_cycle_start + offset

# Check if we're within tolerance of the scheduled poll time
time_diff = abs(next_poll_time - current_time)
return time_diff <= settings.staggered_polling_tolerance

async def _check_is_leader(self) -> bool:
"""Check if this instance is the current leader for health checks.

Consolidates leader election logic across Redis, FileLock, and single-worker modes.
Returns True if this instance should run health checks, False otherwise.

Returns:
bool: True if leader, False otherwise

Examples:
>>> service = GatewayService()
>>> service._redis_client = None
>>> # Single-worker mode: always leader
>>> import asyncio
>>> asyncio.run(service._check_is_leader())
True
"""
if self._redis_client and settings.cache_type == "redis":
# Redis-based leader check (async, decode_responses=True returns strings)
current_leader = await self._redis_client.get(self._leader_key)
return current_leader == self._instance_id

if settings.cache_type == "none":
# Single-worker mode: always leader
return True

# FileLock mode: try to acquire lock with timeout=0 (non-blocking)
try:
self._file_lock.acquire(timeout=0)
return True
except Timeout:
return False

async def _run_health_checks(self, user_email: str) -> None:
"""Run health checks periodically,
Uses Redis or FileLock - for multiple workers.
Expand Down
Loading
Loading