diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 235c9b07..026bdcd0 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -17,9 +17,7 @@ import logging import os import threading -from functools import wraps -from typing import Any, Callable, Optional -from uuid import uuid4 +from typing import Any, Optional import torch import zmq @@ -38,8 +36,7 @@ ZMQMessage, ZMQRequestType, ZMQServerInfo, - create_zmq_socket, - format_zmq_address, + dynamic_zmq_socket, ) logger = logging.getLogger(__name__) @@ -53,6 +50,13 @@ TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8)) +# Pre-bound decorator for controller socket operations. +_controller_socket = dynamic_zmq_socket( + "request_handle_socket", + owner_id_attr="client_id", + server_attr="_controller", +) + class AsyncTransferQueueClient: """Asynchronous client for interacting with TransferQueue controller and storage systems. @@ -99,63 +103,8 @@ def initialize_storage_manager( manager_type, controller_info=self._controller, config=config ) - # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong. - @staticmethod - def dynamic_socket(socket_name: str): - """Decorator to auto-manage ZMQ sockets for Controller/Storage servers. - - Handles socket lifecycle: create -> connect -> inject -> close. - - Args: - socket_name: Port name from server config to use for ZMQ connection (e.g., "data_req_port") - - Decorated Function Requirements: - 1. Must be an async class method (needs `self`) - 2. `self` must have: - - `_controller`: Server registry - - `client_id`: Unique client ID for socket identity - 3. Receives ZMQ socket via `socket` keyword argument (injected by decorator) - """ - - def decorator(func: Callable): - @wraps(func) - async def wrapper(self, *args, **kwargs): - server_info = self._controller - if not server_info: - raise RuntimeError("No controller registered") - - context = zmq.asyncio.Context() - address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name)) - identity = f"{self.client_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(context, zmq.DEALER, identity=identity, ip=server_info.ip) - - try: - sock.connect(address) - logger.debug( - f"[{self.client_id}]: Connected to Controller {server_info.id} at {address} " - f"with identity {identity.decode()}" - ) - - kwargs["socket"] = sock - return await func(self, *args, **kwargs) - except Exception as e: - logger.error(f"[{self.client_id}]: Error in socket operation with Controller {server_info.id}: {e}") - raise - finally: - try: - if not sock.closed: - sock.close(linger=-1) - except Exception as e: - logger.warning(f"[{self.client_id}]: Error closing socket to Controller {server_info.id}: {e}") - - context.term() - - return wrapper - - return decorator - # ==================== Basic API ==================== - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_get_meta( self, data_fields: list[str], @@ -245,7 +194,7 @@ async def async_get_meta( f"{response_msg.body.get('message', 'Unknown error')}" ) - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_set_custom_meta( self, metadata: BatchMeta, @@ -534,7 +483,7 @@ async def async_clear_samples(self, metadata: BatchMeta): except Exception as e: raise RuntimeError(f"Error in clear_samples operation: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None): """Clear metadata in the controller. @@ -560,7 +509,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None): if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE: raise RuntimeError("Failed to clear samples metadata in controller.") - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta: """Get metadata required for the whole partition from controller. @@ -590,7 +539,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta return response_msg.body["metadata"] - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def _clear_partition_in_controller(self, partition_id, socket=None): """Clear the whole partition in the controller. @@ -617,7 +566,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None): raise RuntimeError(f"Failed to clear partition {partition_id} in controller.") # ==================== Status Query API ==================== - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_get_consumption_status( self, task_name: str, @@ -680,7 +629,7 @@ async def async_get_consumption_status( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in get_consumption_status: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_get_production_status( self, data_fields: list[str], @@ -812,7 +761,7 @@ async def async_check_production_status( return False return torch.all(production_status == 1).item() - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_reset_consumption( self, partition_id: str, @@ -874,7 +823,7 @@ async def async_reset_consumption( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_get_partition_list( self, socket: Optional[zmq.asyncio.Socket] = None, @@ -920,7 +869,7 @@ async def async_get_partition_list( raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e # ==================== KV Interface API ==================== - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_kv_retrieve_meta( self, keys: list[str] | str, @@ -986,7 +935,7 @@ async def async_kv_retrieve_meta( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_kv_retrieve_keys( self, global_indexes: list[int] | int, @@ -1049,7 +998,7 @@ async def async_kv_retrieve_keys( except Exception as e: raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_indexes: {str(e)}") from e - @dynamic_socket(socket_name="request_handle_socket") + @_controller_socket async def async_kv_list( self, partition_id: Optional[str] = None, diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 27e173c7..0eb21057 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -19,10 +19,8 @@ import warnings from collections import defaultdict from collections.abc import Mapping -from functools import wraps from operator import itemgetter -from typing import Any, Callable, NamedTuple -from uuid import uuid4 +from typing import Any, NamedTuple import torch import zmq @@ -36,8 +34,7 @@ ZMQMessage, ZMQRequestType, ZMQServerInfo, - create_zmq_socket, - format_zmq_address, + dynamic_zmq_socket, ) logger = logging.getLogger(__name__) @@ -51,6 +48,15 @@ TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds +# Pre-bound decorator for storage-unit socket operations. +_storage_unit_socket = dynamic_zmq_socket( + "put_get_socket", + owner_id_attr="storage_manager_id", + server_attr="storage_unit_infos", + target_kwarg="target_storage_unit", + timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT, +) + class RoutingGroup(NamedTuple): """Routing result for a single storage unit.""" @@ -114,78 +120,6 @@ def _register_servers(self, server_infos: "ZMQServerInfo | dict[Any, ZMQServerIn return server_infos_transform - # TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong. - @staticmethod - def dynamic_storage_manager_socket(socket_name: str, timeout: int): - """Decorator to auto-manage ZMQ sockets for Controller/Storage servers (create -> connect -> inject -> close). - - Args: - socket_name (str): Port name (from server config) to use for ZMQ connection (e.g., "data_req_port"). - timeout (float): Timeout in seconds for ZMQ connection (in seconds). - - Decorated Function Rules: - 1. Must be an async class method (needs `self`). - 2. `self` requires: - - `storage_unit_infos: storage unit infos (ZMQServerInfo | dict[Any, ZMQServerInfo]). - 3. Specify target server via: - - `target_storage_unit` arg. - 4. Receives ZMQ socket via `socket` keyword arg (injected by decorator). - """ - - def decorator(func: Callable): - @wraps(func) - async def wrapper(self, *args, **kwargs): - server_key = kwargs.get("target_storage_unit") - if server_key is None: - for arg in args: - if isinstance(arg, str) and arg in self.storage_unit_infos.keys(): - server_key = arg - break - - server_info = self.storage_unit_infos.get(server_key) - - if not server_info: - raise RuntimeError(f"Server {server_key} not found in registered servers") - - context = zmq.asyncio.Context() - address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name)) - identity = f"{self.storage_manager_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() - sock = create_zmq_socket(context, zmq.DEALER, server_info.ip, identity) - - try: - sock.connect(address) - # Timeouts to avoid indefinite await on recv/send - sock.setsockopt(zmq.RCVTIMEO, timeout * 1000) - sock.setsockopt(zmq.SNDTIMEO, timeout * 1000) - logger.debug( - f"[{self.storage_manager_id}]: Connected to StorageUnit {server_info.id} at {address} " - f"with identity {identity.decode()}" - ) - - kwargs["socket"] = sock - return await func(self, *args, **kwargs) - except Exception as e: - logger.error( - f"[{self.storage_manager_id}]: Error in socket operation with " - f"StorageUnit {server_info.id} at {address}: " - f"{type(e).__name__}: {e}" - ) - raise - finally: - try: - if not sock.closed: - sock.close(linger=-1) - except Exception as e: - logger.warning( - f"[{self.storage_manager_id}]: Error closing socket to StorageUnit {server_info.id}: {e}" - ) - - context.term() - - return wrapper - - return decorator - def _group_by_hash(self, global_indexes: list[int]) -> dict[str, RoutingGroup]: """Group samples by global_idx % num_su, return {storage_id: RoutingGroup}. @@ -335,7 +269,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: field_schema, ) - @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) + @_storage_unit_socket async def _put_to_single_storage_unit( self, global_indexes: list[int], @@ -456,7 +390,7 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: return TensorDict(tensor_data, batch_size=len(metadata)) - @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) + @_storage_unit_socket async def _get_from_single_storage_unit( self, global_indexes: list[int], @@ -528,7 +462,7 @@ async def clear_data(self, metadata: BatchMeta) -> None: if isinstance(result, Exception): logger.error(f"[{self.storage_manager_id}]: Error in clear operation task {i}: {result}") - @dynamic_storage_manager_socket(socket_name="put_get_socket", timeout=TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT) + @_storage_unit_socket async def _clear_single_storage_unit(self, global_indexes, target_storage_unit=None, socket=None): try: request_msg = ZMQMessage.create( diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 8afbb480..17d749c6 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -17,13 +17,16 @@ import os import socket import time +from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Optional, TypeAlias +from functools import wraps +from typing import Any, Callable, Optional, TypeAlias from uuid import uuid4 import psutil import ray import zmq +import zmq.asyncio from ray.util import get_node_ip_address from transfer_queue.utils.enum_utils import ExplicitEnum, TransferQueueRole @@ -301,6 +304,95 @@ def create_zmq_socket( return socket +def dynamic_zmq_socket( + socket_name: str, + *, + owner_id_attr: str, + server_attr: str, + target_kwarg: Optional[str] = None, + timeout: Optional[int] = None, +): + """Create a reusable async decorator for request sockets. + + This decorator encapsulates the common socket lifecycle used by both + client-side and storage-manager-side request paths: + create context/socket -> connect -> inject socket -> close/term. + + Args: + socket_name: Socket port key in ``ZMQServerInfo.ports``. + owner_id_attr: Attribute name on ``self`` used in identity/log prefix + (e.g., ``client_id`` or ``storage_manager_id``). + server_attr: Attribute name on ``self`` that stores server info. + - ``ZMQServerInfo`` for single-target calls. + - ``Mapping[str, ZMQServerInfo]`` for multi-target calls. + target_kwarg: Optional kwarg name that provides target server id when + ``server_attr`` is a mapping. + timeout: Optional timeout (seconds) for both send/recv operations. + """ + + def decorator(func: Callable): + @wraps(func) + async def wrapper(self, *args, **kwargs): + owner_id = getattr(self, owner_id_attr, None) + if owner_id is None: + raise RuntimeError(f"Missing owner id attribute: {owner_id_attr}") + + server_obj = getattr(self, server_attr, None) + if server_obj is None: + raise RuntimeError(f"Missing server registry attribute: {server_attr}") + + target_name: Optional[str] = None + if target_kwarg is not None: + target_name = kwargs.get(target_kwarg) + if target_name is None: + for arg in args: + if isinstance(arg, str): + target_name = arg + break + + if isinstance(server_obj, ZMQServerInfo): + if target_name is not None and target_name != server_obj.id: + raise RuntimeError( + f"Target mismatch: target '{target_name}' does not match registered server '{server_obj.id}'" + ) + server_info = server_obj + elif isinstance(server_obj, Mapping): + if target_name is None: + raise RuntimeError(f"Missing target server identifier via '{target_kwarg}'") + server_info = server_obj.get(target_name) + if server_info is None: + raise RuntimeError(f"Server '{target_name}' not found in registered servers") + else: + raise RuntimeError(f"Unsupported server registry type for '{server_attr}': {type(server_obj).__name__}") + + port = server_info.ports.get(socket_name) + if port is None: + raise RuntimeError(f"Socket '{socket_name}' not configured for server '{server_info.id}'") + + context = zmq.asyncio.Context() + address = format_zmq_address(server_info.ip, port) + identity = f"{owner_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode() + sock = create_zmq_socket(context, zmq.DEALER, server_info.ip, identity=identity) + + try: + sock.connect(address) + if timeout is not None: + sock.setsockopt(zmq.RCVTIMEO, timeout * 1000) + sock.setsockopt(zmq.SNDTIMEO, timeout * 1000) + kwargs["socket"] = sock + return await func(self, *args, **kwargs) + finally: + try: + if not sock.closed: + sock.close(linger=-1) + finally: + context.term() + + return wrapper + + return decorator + + def process_zmq_server_info( handlers: dict[Any, Any] | Any, ): # noqa: UP007