[rollout] feat: enhancing global load balancer (issue #5442)#6059
[rollout] feat: enhancing global load balancer (issue #5442)#6059EricWuji wants to merge 15 commits intoverl-project:mainfrom
Conversation
…affinity This commit introduces several enhancements to the and for improved load balancing. Key changes include: - Added pluggable load balancing strategies: , , , and Please ask your administrator.. - Implemented group-level affinity for routing requests based on , allowing for better session management. - Updated configuration options in to support new load balancing features, including and . Additionally, tests have been updated to reflect these changes, ensuring robust functionality across the agent loop. This update aims to optimize server resource utilization and enhance the overall performance of the agent loop system.
There was a problem hiding this comment.
Code Review
This pull request introduces pluggable load-balancing strategies for the GlobalRequestLoadBalancer in the agent loop, including support for least-requests, least-KV-cache, weighted round-robin, and random routing. It also adds group-level sticky routing and Prometheus-based metric scraping for KV cache usage. While the architectural changes improve flexibility, the current implementation of metric scraping is synchronous and blocking within the Ray actor, which could create a significant performance bottleneck. Additionally, the use of threading locks in a single-threaded Ray actor is currently redundant, and the error logging for failed metric fetches lacks throttling, potentially leading to log flooding.
| def _refresh_kv_metrics_if_needed(self) -> None: | ||
| if self._strategy_name != "least_kv_cache": | ||
| return | ||
| if self._metric_name is None: | ||
| return | ||
| now = time.monotonic() | ||
| if now - self._last_metrics_ts < self._kv_refresh_interval: | ||
| return | ||
| self._last_metrics_ts = now | ||
| for sid in self._server_actor_ids: | ||
| url = build_metrics_url(sid, self._metrics_path) | ||
| try: | ||
| text = fetch_prometheus_text(url, self._fetch_timeout) | ||
| val = parse_prometheus_metric_value(text, self._metric_name) | ||
| self._kv_usage[sid] = val | ||
| except Exception as e: | ||
| logger.warning("Failed to refresh KV metrics for %s from %s: %s", sid, url, e) | ||
| self._kv_usage[sid] = None |
There was a problem hiding this comment.
The _refresh_kv_metrics_if_needed method performs synchronous, blocking HTTP requests to every rollout replica within the acquire_server call. In a Ray actor, this blocks the entire actor from processing any other requests (including release_server or other acquire_server calls) for the duration of all HTTP calls combined. If any replica is slow or unresponsive, the global load balancer becomes a critical bottleneck for the entire rollout process.
Consider moving the metrics collection to a background thread that periodically updates self._kv_usage asynchronously, or refactor the actor to be async and use asyncio.gather with an asynchronous HTTP client.
|
|
||
| def acquire_server(self, request_id: str) -> str: | ||
| self._group_id_to_server: LRUCache = LRUCache(maxsize=max_cache_size) | ||
| self._lock = threading.Lock() |
There was a problem hiding this comment.
Ray actors are single-threaded by default, processing tasks sequentially. The use of threading.Lock is redundant in this context and adds unnecessary overhead unless the actor explicitly spawns background threads that access shared state. If the metrics collection is moved to a background thread (as suggested in the critical finding), this lock would then become necessary; otherwise, it should be removed.
| val = parse_prometheus_metric_value(text, self._metric_name) | ||
| self._kv_usage[sid] = val | ||
| except Exception as e: | ||
| logger.warning("Failed to refresh KV metrics for %s from %s: %s", sid, url, e) |
There was a problem hiding this comment.
Logging a warning on every failed metric refresh attempt can lead to log flooding, especially since this is called frequently (every 2 seconds by default) and for every replica. If a replica is permanently down or misconfigured, the logs will be overwhelmed. Consider using a throttled logger or only logging the failure once per replica until it recovers.
…lobalRequestLoadBalancer
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a pluggable load balancing framework for the GlobalRequestLoadBalancer, supporting strategies such as least_requests, least_kv_cache, weighted_rr, and random. It also implements group-level sticky routing and infrastructure for scraping Prometheus metrics to inform routing decisions. Feedback highlights a potential AttributeError caused by the removal of extra_configs in the rollout configuration and suggests an improvement to the metrics scraper's logging logic to prevent log flooding during connectivity issues.
| class DiffusionSamplingConfig(SamplingConfig): | ||
| noise_level: float = 0.0 | ||
| num_inference_steps: int = 40 | ||
| seed: int = 42 |
There was a problem hiding this comment.
The removal of extra_configs from DiffusionSamplingConfig and DiffusionRolloutConfig (at line 385) will cause an AttributeError in verl/experimental/agent_loop/diffusion_agent_loop.py. That file still references this field in several places (e.g., lines 122, 155, and 165). If the intention was to replace extra_configs with specific fields like noise_level, then diffusion_agent_loop.py must be updated to use those fields. Otherwise, please restore the extra_configs field to maintain backward compatibility for diffusion models.
| except Exception as e: | ||
| self._mark_metrics_failed(sid) | ||
| logger.warning("Failed to refresh KV metrics for %s from %s: %s", sid, url, e) |
There was a problem hiding this comment.
The background metrics scraper logs a warning on every failure. If a server is unreachable or the metrics endpoint is misconfigured, this will flood the logs every refresh_interval_s. It is recommended to use the return value of _mark_metrics_failed(sid) to ensure the warning is only logged once per failure cycle.
| except Exception as e: | |
| self._mark_metrics_failed(sid) | |
| logger.warning("Failed to refresh KV metrics for %s from %s: %s", sid, url, e) | |
| except Exception as e: | |
| if self._mark_metrics_failed(sid): | |
| logger.warning("Failed to refresh KV metrics for %s from %s: %s", sid, url, e) |
|
@EricWuji Could you do some comparison about different load balance strategy? |
|
@wuxibin89 Thanks for your reply. I've updated the PR description with comparison sections. You can find them at the end of the description. |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a pluggable load-balancing framework for the GlobalRequestLoadBalancer, supporting strategies such as least requests, least KV cache usage, weighted round-robin, and random selection. It also implements group-level sticky routing to enhance prefix caching efficiency across worker groups. The review feedback highlights several critical areas for improvement: the Prometheus metric parser needs to be more robust to handle optional timestamps, the background metrics collection should be parallelized and wrapped in exception handling to prevent silent failures, and the weighted round-robin implementation should pre-compute mappings to avoid unnecessary overhead during request routing.
| def _last_float_on_line(line: str) -> Optional[float]: | ||
| matches = _FLOAT_RE.findall(line) | ||
| if not matches: | ||
| return None | ||
| try: | ||
| return float(matches[-1]) | ||
| except ValueError: | ||
| return None |
There was a problem hiding this comment.
The _last_float_on_line function is unreliable because the Prometheus text exposition format allows an optional timestamp after the value. Using findall and taking the last match will incorrectly return the timestamp (a large integer) as the metric value. Additionally, labels in Prometheus can contain spaces, which can break simple splitting. The value should be extracted as the first token following the metric name or the closing brace of the labels.
def _last_float_on_line(line: str) -> Optional[float]:
# Prometheus format: metric_name{labels} value [timestamp]
if '}' in line:
rest = line[line.rfind('}')+1:]
else:
parts = line.split(None, 1)
rest = parts[1] if len(parts) > 1 else ""
tokens = rest.split()
if tokens:
try:
return float(tokens[0])
except ValueError:
pass
return None| for sid in self._server_actor_ids: | ||
| url = build_metrics_url(sid, self._metrics_path) | ||
| try: | ||
| text = fetch_prometheus_text(url, self._fetch_timeout_s) | ||
| val = parse_prometheus_metric_value(text, metric_name) | ||
| if self._mark_metrics_success(sid, val): | ||
| logger.info("KV cache metrics recovered for replica %s", sid) | ||
| except Exception as e: | ||
| if self._mark_metrics_failed(sid): | ||
| logger.warning("Failed to refresh KV metrics for %s from %s: %s", sid, url, e) |
There was a problem hiding this comment.
Metrics for each replica are fetched sequentially in a blocking loop. In large-scale deployments, slow or unresponsive replicas can cause the background refresh thread to stall, leading to stale metrics for all servers. Consider using a thread pool or asynchronous requests to fetch metrics from all replicas in parallel.
| inflight: dict[str, int], | ||
| ) -> str: | ||
| # Selection ignores inflight; weights define share across replicas. | ||
| pos = {sid: i for i, sid in enumerate(self._ids)} |
| while not self._metrics_stop.is_set(): | ||
| self._refresh_metrics_blocking() | ||
| if self._metrics_stop.wait(timeout=self._refresh_interval_s): | ||
| break | ||
|
|
There was a problem hiding this comment.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request implements a pluggable load balancing framework for the agent loop, introducing strategies such as least requests, least KV cache usage, weighted round-robin, and random selection. It also adds support for group-based sticky sessions to facilitate GRPO rollouts and includes Prometheus metric scraping for dynamic load assessment. The review feedback identifies a logic error in the weighted round-robin implementation, a discrepancy between the group sticky implementation and its documentation, an efficiency concern regarding thread pool instantiation, and a mismatch in the fallback behavior for missing metrics.
| self._inflight[self._ids[i]] += self._weights[i] | ||
| best_i = max(indices, key=lambda i: (self._inflight[self._ids[i]], -i)) | ||
| self._inflight[self._ids[best_i]] -= total | ||
| return self._ids[best_i] |
There was a problem hiding this comment.
The WeightedRoundRobinStrategy incorrectly uses self._inflight as the accumulator for the smooth weighted round-robin algorithm. However, self._inflight is also used by the GlobalRequestLoadBalancer to track the actual number of active requests (via update_inflight and release_server).
This leads to two critical issues:
- The WRR state is corrupted by real-time request counts.
- Since WRR uses
max()to pick the server with the highest accumulated weight, andGlobalRequestLoadBalancerincrements_inflightfor active requests, this strategy will actually prefer servers that are more heavily loaded, which is the opposite of load balancing.
You should use the self._current attribute (which is initialized but unused) for the WRR algorithm state.
| self._inflight[self._ids[i]] += self._weights[i] | |
| best_i = max(indices, key=lambda i: (self._inflight[self._ids[i]], -i)) | |
| self._inflight[self._ids[best_i]] -= total | |
| return self._ids[best_i] | |
| for i in indices: | |
| self._current[i] += self._weights[i] | |
| best_i = max(indices, key=lambda i: (self._current[i], -i)) | |
| self._current[best_i] -= total |
| def acquire_server(self, request_id: str, request_group_id: str | None = None) -> str: | ||
| if request_group_id and request_group_id in self._request_group_to_server: | ||
| server_id = self._request_group_to_server[request_group_id] | ||
| self.update_inflight(server_id, 1) | ||
| return server_id | ||
| server_id = self.strategy.pick_server(list(self._server_actor_ids)) | ||
| if request_group_id: | ||
| self._request_group_to_server[request_group_id] = server_id | ||
| self.update_inflight(server_id, 1) | ||
| return server_id |
There was a problem hiding this comment.
The implementation of GroupStickyLoadBalancer does not match the documentation in docs/advance/agent_loop.rst. The documentation states that routing should consult the request_id LRU first, then the request_group_id LRU.
Currently, GroupStickyLoadBalancer only tracks request_group_id. This means that if request_group_id is not provided (or if a request needs stickiness outside of a group context), it will fall back to a fresh pick from the strategy, losing request_id stickiness. To match the intended design, this class should maintain both LRU caches.
| with ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="verl-lb-kv") as pool: | ||
| list(pool.map(worker, ids)) |
| def key(sid: str) -> tuple[float, int, str]: | ||
| kv = usage.get(sid) | ||
| if kv is None: | ||
| return (float("inf"), self._inflight[sid], sid) | ||
| return (kv, self._inflight[sid], sid) |
There was a problem hiding this comment.
The fallback logic for failed metrics in LeastKVCacheStrategy does not match the documentation. The documentation states that it should fall back to 'least in-flight behavior', but the implementation uses float('inf') as the primary key for servers with missing metrics.
This causes servers with failed scrapes to be avoided even if they have zero in-flight requests and other servers are heavily loaded. To implement a true fallback, you should use a neutral value for missing metrics or adjust the sorting key so that in-flight counts are compared fairly across all servers when metrics are missing.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a pluggable and configurable global load balancing system for the agent-loop rollout. It adds support for multiple strategies—including least-requests, least-KV-cache (via Prometheus metrics scraping), weighted round-robin, and random—alongside sticky session modes for both individual requests and GRPO rollout groups. The changes include refactoring the load balancer into dedicated modules, updating configuration schemas, and adding comprehensive tests. Feedback was provided to ensure deterministic tie-breaking in the least-KV-cache strategy's fallback logic.
|
|
||
| has_none = any(usage.get(sid) is None for sid in server_ids) | ||
| if has_none: # fallback to least in-flight | ||
| return min(server_ids, key=lambda sid: self._inflight[sid]) |
There was a problem hiding this comment.
For consistency with the LeastRequestsStrategy and the other branch of pick_server in this class, the fallback logic should include the server ID as a tie-breaker. This ensures deterministic behavior when multiple servers have the same in-flight request count.
| return min(server_ids, key=lambda sid: self._inflight[sid]) | |
| return min(server_ids, key=lambda sid: (self._inflight[sid], sid)) |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a pluggable global load balancing system for the agent loop, supporting strategies such as least requests, least KV cache usage, weighted round-robin, and random selection. It also implements a "group sticky" mode to route related rollout requests to the same server replica. The changes include new strategy implementations, Prometheus metric parsing utilities, and updated configuration schemas. Feedback suggests refining the weight fallback logic in the round-robin strategy and enhancing HTTP error handling during metric scraping.
| w = [1.0] * n | ||
| else: | ||
| # if weight is not found, use 1.0 as default | ||
| w = [resolve_load_balance_weight(weights, sid) or 1.0 for sid in self._ids] |
There was a problem hiding this comment.
The or 1.0 fallback here will incorrectly override a weight of 0.0 to 1.0. While RolloutConfig validation currently prevents non-positive weights, this strategy class should ideally respect a zero weight if it were to be used to drain a server, or at least handle it explicitly. If 0.0 is intended to be invalid, it's better to rely on the config validation or raise an error here rather than silently changing it to 1.0 which might be confusing if the user explicitly set a very low weight.
| w = [resolve_load_balance_weight(weights, sid) or 1.0 for sid in self._ids] | |
| w = [resolve_load_balance_weight(weights, sid) if resolve_load_balance_weight(weights, sid) is not None else 1.0 for sid in self._ids] |
| def fetch_prometheus_text(url: str, timeout_s: float = 2.0) -> str: | ||
| """HTTP GET returning Prometheus text exposition body.""" | ||
| req = urllib.request.Request(url, method="GET") | ||
| with urllib.request.urlopen(req, timeout=timeout_s) as resp: | ||
| return resp.read().decode("utf-8", errors="replace") |
There was a problem hiding this comment.
The fetch_prometheus_text function uses urllib.request.urlopen without handling potential HTTP errors (e.g., 404, 500). While the caller LeastKVCacheStrategy._refresh_one_replica catches all exceptions, it would be more robust to handle urllib.error.HTTPError specifically to provide better error messages or handle specific status codes.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a pluggable load balancing framework for the agent loop, supporting multiple strategies such as least requests, least KV cache usage (via Prometheus scraping), weighted round-robin, and random selection. It also adds a "group sticky" session mode to ensure that repeated rollouts for the same sample are routed to the same server instance, which is particularly useful for GRPO. The changes include new configuration parameters, comprehensive unit tests, and updated documentation. Review feedback suggests improving the encapsulation between the load balancer and its strategies by avoiding direct access to internal attributes like _inflight, and notes that the WeightedRoundRobinStrategy contains redundant state that is not utilized in its routing logic.
|
|
||
| def release_server(self, server_id: str) -> None: | ||
| """Release a server after a request completes, decrementing strategy in-flight counts.""" | ||
| inflight = getattr(self.strategy, "_inflight", None) |
There was a problem hiding this comment.
Accessing _inflight via getattr on the strategy object breaks encapsulation and makes the plugin system fragile. If a user registers a custom strategy that doesn't use this specific attribute name for tracking load, release_server will silently fail to update the state.
Consider defining an explicit interface in the LoadBalanceStrategy base class (e.g., notify_acquire and notify_release methods) to handle in-flight count updates, or move the in-flight tracking logic entirely into the GlobalRequestLoadBalancer base class if it's intended to be a core feature for all strategies.
| ] | ||
| self._weights = w | ||
| self._current = [0.0] * n | ||
| self._inflight: dict[str, int] = {sid: 0 for sid in self._ids} |
There was a problem hiding this comment.
The WeightedRoundRobinStrategy initializes self._inflight but never uses it in pick_server. While GlobalRequestLoadBalancer will still update these counts via release_server, the WRR strategy remains purely based on weights and doesn't account for actual server load. If the intention was to implement a 'Weighted Least Requests' strategy, pick_server should be updated to factor in self._inflight. If pure WRR is intended, the _inflight attribute is redundant and potentially misleading.
Summary
This PR is Task 2 (Phase 2) for the routing roadmap ([#5442](#5442)), building on the global
GlobalRequestLoadBalancerintroduced in Phase 1 ([#5399](#5399)).Phase 2 adds pluggable load-balance strategies, optional group-level sticky routing, and rollout-config wiring (strategy name, weights, random seed, KV-cache–aware routing hooks) so routing policy can be selected and tuned without forking core agent-loop code.
What does this PR do?
load_balance.py):least_requests,least_kv_cache,weighted_rr,random, withcreate_load_balance_strategy/register_load_balance_strategyfor extension.GlobalRequestLoadBalancerto delegate new-request picks to the selected strategy while preserving request-level stickiness; adds optionalgroup_sticky_routingandgroup_idonacquire_serverfor affinity across related requests.AsyncLLMServerManager(and related agent-loop paths) to pass worker / group identity into the load balancer so group affinity works end-to-end.load_balance_strategy,load_balance_weights,load_balance_random_seed,group_sticky_routing,num_load_balance_groups, plus KV cache metrics config for strategies that use server-side usage (e.g.least_kv_cache).tests/experimental/agent_loop/test_basic_agent_loop.pyfor new strategies and group/request stickiness behavior.Related: [#5442](#5442), Phase 1 [#5399](#5399).
Checklist Before Starting
https://github.com/verl-project/verl/issues?q=is%3Apr+load+balance+agent+loop(#5399)API and Usage Example
Rollout / Hydra overrides (illustrative):
Programmatic registration (custom strategy plugin):
Design & Code Changes
verl/experimental/agent_loop/load_balance.pycreate_load_balance_strategy.verl/experimental/agent_loop/agent_loop.pyGlobalRequestLoadBalancerstrategy + group sticky;build_global_load_balancer_remote_kwargsfromRolloutConfig;AsyncLLMServerManagerpassesgroup_id.verl/trainer/config/rollout/rollout.yaml+ worker configtests/experimental/agent_loop/test_basic_agent_loop.pyHigh level: Phase 1 centralized routing in one Ray actor; Phase 2 makes the policy behind
acquire_serverconfigurable and adds group affinity without duplicating LB logic per worker.Load balance strategy comparison
least_requestsleast_kv_cacheinflightordering.http://{replica_addr}{metrics_path},inflight list (if we failed to get the metric)metric_name,metrics_path,refresh_interval_sweighted_rrload_balance_weights(length must match#servers_address)randomserver_idsload_balance_random_seedSticky Strategy Comparison