Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions transfer_queue/storage/managers/simple_backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from transfer_queue.storage.managers.base import TransferQueueStorageManager
from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory
from transfer_queue.storage.simple_backend import StorageMetaGroup
from transfer_queue.utils.utils import get_env_bool
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket

logger = logging.getLogger(__name__)
Expand All @@ -44,6 +45,8 @@
TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_RECV_TIMEOUT", 200)) # seconds
TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_MANAGER_SEND_TIMEOUT", 200)) # seconds

TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False)


@TransferQueueStorageManagerFactory.register("AsyncSimpleStorageManager")
class AsyncSimpleStorageManager(TransferQueueStorageManager):
Expand Down Expand Up @@ -236,7 +239,7 @@ async def _put_to_single_storage_unit(
"""

request_msg = ZMQMessage.create(
request_type=ZMQRequestType.PUT_DATA,
request_type=ZMQRequestType.PUT_DATA, # type: ignore[arg-type]
sender_id=self.storage_manager_id,
receiver_id=target_storage_unit,
body={"local_indexes": local_indexes, "data": storage_data},
Expand Down Expand Up @@ -331,7 +334,7 @@ async def _get_from_single_storage_unit(
fields = storage_meta_group.get_field_names()

request_msg = ZMQMessage.create(
request_type=ZMQRequestType.GET_DATA,
request_type=ZMQRequestType.GET_DATA, # type: ignore[arg-type]
sender_id=self.storage_manager_id,
receiver_id=target_storage_unit,
body={"local_indexes": local_indexes, "fields": fields},
Expand Down Expand Up @@ -452,6 +455,10 @@ def _filter_storage_data(storage_meta_group: StorageMetaGroup, data: TensorDict)
result = (result,)
results[fname] = list(result)

if not TQ_ZERO_COPY_SERIALIZATION:
# Explicitly copy tensor slices to prevent pickling the whole tensor for every storage unit.
# The tensors may still be contiguous, so we cannot use .contiguous() to trigger copy from parent tensors.
results[fname] = [item.clone() if isinstance(item, torch.Tensor) else item for item in results[fname]]
return results


Expand Down
52 changes: 31 additions & 21 deletions transfer_queue/utils/zmq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import logging
import os
import pickle
import socket
import time
from dataclasses import dataclass
Expand All @@ -28,6 +29,7 @@
from transfer_queue.utils.utils import (
ExplicitEnum,
TransferQueueRole,
get_env_bool,
)

logger = logging.getLogger(__name__)
Expand All @@ -42,6 +44,8 @@

bytestr: TypeAlias = bytes | bytearray | memoryview

TQ_ZERO_COPY_SERIALIZATION = get_env_bool("TQ_ZERO_COPY_SERIALIZATION", default=False)


class ZMQRequestType(ExplicitEnum):
"""
Expand Down Expand Up @@ -155,36 +159,42 @@ def create(

def serialize(self) -> list:
"""
Serialize message using unified MsgpackEncoder.
Returns: list[bytestr] - [msgpack_header, *tensor_buffers]
Serialize message using unified MsgpackEncoder or pickle.
Returns: list[bytestr] - [msgpack_header, *tensor_buffers] or [bytes]
"""
msg_dict = {
"request_type": self.request_type.value, # Enum -> str for msgpack
"sender_id": self.sender_id,
"receiver_id": self.receiver_id,
"request_id": self.request_id,
"timestamp": self.timestamp,
"body": self.body,
}
return list(_encoder.encode(msg_dict))
if TQ_ZERO_COPY_SERIALIZATION:
msg_dict = {
"request_type": self.request_type.value, # Enum -> str for msgpack
"sender_id": self.sender_id,
"receiver_id": self.receiver_id,
"request_id": self.request_id,
"timestamp": self.timestamp,
"body": self.body,
}
return list(_encoder.encode(msg_dict))
else:
return [pickle.dumps(self)]
Comment on lines +165 to +176
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR reintroduces the TQ_ZERO_COPY_SERIALIZATION environment variable switch with a default value of False, switching the serialization from msgpack to pickle. However, there are no tests in this PR to verify that the pickle serialization path works correctly.

Given that the PR description mentions "minor bugs in current implementation" and that this is a fallback mechanism, it's critical to have test coverage for:

  1. Serialization and deserialization using pickle (when TQ_ZERO_COPY_SERIALIZATION=False)
  2. The tensor cloning behavior in _filter_storage_data when using pickle mode
  3. End-to-end scenarios with pickle serialization

The existing tests in test_serial_utils_on_cpu.py test ZMQMessage serialization, but they don't appear to test with TQ_ZERO_COPY_SERIALIZATION=False. Consider adding test cases that explicitly set this environment variable to ensure both code paths are tested.

Copilot uses AI. Check for mistakes.

@classmethod
def deserialize(cls, frames: list) -> "ZMQMessage":
"""
Deserialize message using unified MsgpackDecoder.
Deserialize message using unified MsgpackDecoder or pickle.
"""
if not frames:
raise ValueError("Empty frames received")

msg_dict = _decoder.decode(frames)
return cls(
request_type=ZMQRequestType(msg_dict["request_type"]),
sender_id=msg_dict["sender_id"],
receiver_id=msg_dict["receiver_id"],
body=msg_dict["body"],
request_id=msg_dict["request_id"],
timestamp=msg_dict["timestamp"],
)
if TQ_ZERO_COPY_SERIALIZATION:
msg_dict = _decoder.decode(frames)
return cls(
request_type=ZMQRequestType(msg_dict["request_type"]),
sender_id=msg_dict["sender_id"],
receiver_id=msg_dict["receiver_id"],
body=msg_dict["body"],
request_id=msg_dict["request_id"],
timestamp=msg_dict["timestamp"],
)
else:
return pickle.loads(frames[0])
Comment on lines 156 to +197
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The serialization and deserialization paths now depend on a module-level constant TQ_ZERO_COPY_SERIALIZATION that is evaluated at import time. This creates a potential issue if the environment variable is changed after the modules are loaded, or if different processes have different values for this variable.

If one process serializes with TQ_ZERO_COPY_SERIALIZATION=True and another deserializes with TQ_ZERO_COPY_SERIALIZATION=False (or vice versa), the deserialization will fail. Consider one of the following approaches:

  1. Include a serialization format marker in the serialized data (e.g., a magic byte or header) so that the deserializer can automatically detect which format was used.

  2. Ensure processes are always started with consistent environment variable settings and document this requirement clearly.

  3. Make both code paths compatible by attempting to detect the format during deserialization (e.g., try msgpack first, fall back to pickle on error).

Copilot uses AI. Check for mistakes.


def get_free_port() -> str:
Expand Down