-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathserial_utils.py
More file actions
390 lines (312 loc) · 15.3 KB
/
serial_utils.py
File metadata and controls
390 lines (312 loc) · 15.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2025 The TransferQueue Team
# Copyright 2025 The vLLM project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This implementation is inspired by https://github.com/vllm-project/vllm/blob/main/vllm/v1/serial_utils.py
import logging
import os
import pickle
import warnings
from collections.abc import Sequence
from contextvars import ContextVar
from types import FunctionType
from typing import Any, TypeAlias
import cloudpickle
import numpy as np
import torch
import zmq
from msgspec import msgpack
from tensordict import TensorDictBase
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
CUSTOM_TYPE_TENSOR = 3 # For tensor with buffer reference
CUSTOM_TYPE_NESTED_TENSOR = 4 # For nested tensor (strided or jagged)
CUSTOM_TYPE_NUMPY = 5 # For numpy ndarray with buffer reference
# 0xC1 is permanently reserved (invalid) in msgpack spec — safe to use as pickle fallback sentinel.
_PICKLE_FALLBACK_SENTINEL = b"\xc1\xfe\xed"
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
# Ignore warnings about non-writable buffers from torch.frombuffer. Upper codes will ensure
# the tensors are writable to users.
warnings.filterwarnings(action="ignore", message=r"The given buffer is not writable*", category=UserWarning)
# ContextVar for thread/coroutine-safe buffer storage during serialization/deserialization
# This enables the global _encoder/_decoder instances to be safely used across threads
_encoder_aux_buffers: ContextVar[list[bytestr] | None] = ContextVar("encoder_aux_buffers", default=None)
_decoder_aux_buffers: ContextVar[Sequence[bytestr] | None] = ContextVar("decoder_aux_buffers", default=None)
class MsgpackEncoder:
"""Encoder with custom torch tensor and numpy array serialization.
This implementation uses ContextVar for thread-safe buffer storage,
allowing the global encoder instance to be safely used across multiple
threads and async coroutines.
"""
def __init__(self):
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
@property
def aux_buffers(self) -> list[bytestr]:
"""Get the current context's aux_buffers."""
buffers = _encoder_aux_buffers.get()
assert buffers is not None, "aux_buffers accessed outside of encode() context"
return buffers
def encode(self, obj: Any) -> Sequence[bytestr]:
"""Encode a given object to a byte array."""
bufs: list[bytestr] = [b""]
token = _encoder_aux_buffers.set(bufs)
try:
bufs[0] = self.encoder.encode(obj)
# This `bufs` list allows us to collect direct pointers to backing
# buffers of tensors and np arrays, and return them along with the
# top-level encoded buffer instead of copying their data into the
# new buffer.
return bufs
finally:
_encoder_aux_buffers.reset(token)
def enc_hook(self, obj: Any) -> Any:
"""Custom encoding hook for types msgspec doesn't natively support.
For zero-copy tensor serialization, we need to handle:
- torch.Tensor: Extract buffer, store metadata
- TensorDict: Convert to dict structure for recursive processing
- numpy.ndarray: Convert to tensor for unified handling
"""
if isinstance(obj, torch.Tensor):
return self._encode_tensor(obj)
# Handle TensorDict explicitly for recursive zero-copy
if isinstance(obj, TensorDictBase):
return self._encode_tensordict(obj)
# Numpy arrays: serialize natively unless the dtype contains Python objects.
if isinstance(obj, np.ndarray):
if obj.dtype.kind != "O" and not obj.dtype.hasobject:
try:
return self._encode_numpy(obj)
except (TypeError, RuntimeError, ValueError):
# Fallback to pickle for platforms that don't support the view
pass
# Only true object arrays (or structured dtypes with object fields) reach here
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
if isinstance(obj, FunctionType):
# cloudpickle for functions/methods
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
# Fallback to pickle for unknown types
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
def _encode_tensordict(self, obj: Any) -> dict:
"""Convert TensorDict to a dict structure for recursive msgpack processing.
This allows msgpack to recursively call enc_hook for each tensor inside,
enabling zero-copy serialization of nested tensors.
"""
# Convert to dict, preserving structure
# TensorDict.to_dict() returns nested dicts with tensors as leaves
data_dict = dict(obj.items())
# Return a marked dict that decoder will recognize
return {
"__tq_tensordict__": True,
"batch_size": list(obj.batch_size), # torch.Size -> list for msgpack
"data": data_dict,
}
def _encode_tensor(self, obj: torch.Tensor) -> msgpack.Ext:
"""Encode tensor with zero-copy buffer extraction (handles GPU, non-contiguous, nested)."""
assert len(self.aux_buffers) > 0
# Handle nested tensors (strided or jagged) via unbind
if obj.is_nested:
return self._encode_nested_tensor(obj)
return self._encode_regular_tensor(obj)
def _encode_nested_tensor(self, obj: torch.Tensor) -> msgpack.Ext:
"""Encode nested tensor by unbinding into sub-tensors for zero-copy."""
# Unbind nested tensor into list of regular tensors
sub_tensors = obj.unbind()
# Encode each sub-tensor with zero-copy
encoded_sub_tensors = []
for t in sub_tensors:
# Get tensor metadata (dtype, shape, buffer_idx)
meta = self._encode_regular_tensor_meta(t)
encoded_sub_tensors.append(meta)
# Pack: layout type + list of tensor metas
layout = "jagged" if obj.layout == torch.jagged else "strided"
nested_meta = {
"layout": layout,
"tensors": encoded_sub_tensors,
}
return msgpack.Ext(CUSTOM_TYPE_NESTED_TENSOR, pickle.dumps(nested_meta, protocol=pickle.HIGHEST_PROTOCOL))
def _encode_regular_tensor_meta(self, obj: torch.Tensor) -> tuple:
"""Encode a regular tensor and return its metadata tuple."""
# Handle non-contiguous tensors
if not obj.is_contiguous():
obj = obj.contiguous()
# Handle GPU tensors
if obj.device.type != "cpu":
obj = obj.cpu()
# Zero-copy buffer extraction via uint8 view
arr = obj.flatten().view(torch.uint8).numpy()
buf = memoryview(arr)
idx = len(self.aux_buffers)
self.aux_buffers.append(buf)
dtype = str(obj.dtype).removeprefix("torch.")
return (dtype, tuple(obj.shape), idx)
def _encode_regular_tensor(self, obj: torch.Tensor) -> msgpack.Ext:
"""Encode a regular (non-nested) tensor with zero-copy."""
# Handle non-contiguous tensors
if not obj.is_contiguous():
obj = obj.contiguous()
# Handle GPU tensors
if obj.device.type != "cpu":
obj = obj.cpu()
if obj.is_sparse:
# Sparse tensors fallback to pickle
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
# Note: view(uint8) is a byte-level view, NOT a value conversion.
arr = obj.flatten().view(torch.uint8).numpy()
buf = memoryview(arr)
idx = len(self.aux_buffers)
self.aux_buffers.append(buf)
# Pack tensor metadata as Ext type
dtype = str(obj.dtype).removeprefix("torch.")
meta = (dtype, tuple(obj.shape), idx)
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(meta, protocol=pickle.HIGHEST_PROTOCOL))
def _encode_numpy(self, obj: np.ndarray) -> msgpack.Ext:
"""Encode numpy array with zero-copy buffer extraction."""
# Ensure C-contiguous layout; no-op when already contiguous
if not obj.flags["C_CONTIGUOUS"]:
obj = np.ascontiguousarray(obj)
# Byte-level view as uint8 then ravel → 1-D C-contiguous raw-bytes array
buf = memoryview(obj.view(np.uint8).ravel())
idx = len(self.aux_buffers)
self.aux_buffers.append(buf)
meta = (str(obj.dtype), tuple(obj.shape), idx)
return msgpack.Ext(CUSTOM_TYPE_NUMPY, pickle.dumps(meta, protocol=pickle.HIGHEST_PROTOCOL))
class MsgpackDecoder:
"""Decoder with custom torch tensor and numpy array serialization.
This implementation uses ContextVar for thread-safe buffer storage,
allowing the global decoder instance to be safely used across multiple
threads and async coroutines.
"""
def __init__(self):
self.decoder = msgpack.Decoder(ext_hook=self.ext_hook)
@property
def aux_buffers(self) -> Sequence[bytestr]:
"""Get the current context's aux_buffers."""
buffers = _decoder_aux_buffers.get()
assert buffers is not None, "aux_buffers accessed outside of decode() context"
return buffers
def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any:
"""Decode a list of bytes."""
if isinstance(bufs, bytestr):
result = self.decoder.decode(bufs)
else:
token = _decoder_aux_buffers.set(bufs)
try:
result = self.decoder.decode(bufs[0]) # type: ignore[index]
finally:
_decoder_aux_buffers.reset(token)
# Post-process to reconstruct TensorDict from marked dicts
return self._reconstruct_special_types(result)
def _reconstruct_special_types(self, obj: Any) -> Any:
"""Recursively reconstruct special types (TensorDict) from their dict representation."""
if isinstance(obj, dict):
# Check if this is a TensorDict marker
if obj.get("__tq_tensordict__"):
return self._reconstruct_tensordict(obj)
# Recursively process dict values
return {k: self._reconstruct_special_types(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._reconstruct_special_types(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(self._reconstruct_special_types(item) for item in obj)
return obj
def _reconstruct_tensordict(self, obj: dict) -> Any:
"""Reconstruct TensorDict from marked dict structure."""
try:
from tensordict import TensorDict
batch_size = obj["batch_size"]
data = obj["data"]
# Recursively process nested data
processed_data = self._reconstruct_special_types(data)
return TensorDict(processed_data, batch_size=batch_size)
except ImportError:
# If tensordict not available, return as dict
return obj
def _decode_tensor(self, meta: tuple) -> torch.Tensor:
"""Decode tensor from (dtype, shape, buffer_idx) tuple."""
dtype, shape, idx = meta
buffer = self.aux_buffers[idx]
torch_dtype = getattr(torch, dtype)
if not buffer: # Handle empty tensors
return torch.empty(shape, dtype=torch_dtype)
# Create uint8 tensor from buffer, then view as original dtype and reshape
arr = torch.frombuffer(buffer, dtype=torch.uint8)
# Convert back to proper shape & type
return arr.view(torch_dtype).view(shape)
def _decode_nested_tensor(self, nested_meta: dict) -> torch.Tensor:
"""Decode nested tensor from serialized sub-tensors."""
layout = nested_meta["layout"]
tensor_metas = nested_meta["tensors"]
# Decode each sub-tensor
sub_tensors = [self._decode_tensor(meta) for meta in tensor_metas]
# Reconstruct nested tensor with appropriate layout
if layout == "jagged":
return torch.nested.as_nested_tensor(sub_tensors, layout=torch.jagged)
else: # strided
return torch.nested.as_nested_tensor(sub_tensors, layout=torch.strided)
def _decode_numpy(self, meta: tuple) -> np.ndarray:
"""Decode numpy array from (dtype_str, shape, buffer_idx) tuple."""
dtype_str, shape, idx = meta
buffer = self.aux_buffers[idx]
np_dtype = np.dtype(dtype_str)
if not buffer: # empty array
return np.empty(shape, dtype=np_dtype)
# Reconstruct from raw bytes: uint8 view → reinterpret as original dtype
arr = np.frombuffer(buffer, dtype=np.uint8)
return arr.view(np_dtype).reshape(shape)
def ext_hook(self, code: int, data: memoryview) -> Any:
"""Custom decoding hook for types msgspec doesn't natively support.
For zero-copy tensor serialization, we need to handle:
- torch.Tensor: Extract buffer, store metadata
- TensorDict: Convert to dict structure for recursive processing
- numpy.ndarray: Convert to tensor for unified handling
"""
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)
if code == CUSTOM_TYPE_TENSOR:
meta = pickle.loads(data)
return self._decode_tensor(meta)
if code == CUSTOM_TYPE_NESTED_TENSOR:
nested_meta = pickle.loads(data)
return self._decode_nested_tensor(nested_meta)
if code == CUSTOM_TYPE_NUMPY:
meta = pickle.loads(data)
return self._decode_numpy(meta)
raise NotImplementedError(f"Extension type code {code} is not supported")
_encoder = MsgpackEncoder()
_decoder = MsgpackDecoder()
def encode(obj: Any) -> list[bytestr]:
"""Encode an object via msgpack zero-copy; falls back to pickle on failure.
The pickle path is a normal degradation path (e.g. body contains torch.dtype
objects). Use this as the single entry point for all ZMQ message serialization.
"""
try:
return list(_encoder.encode(obj))
except (TypeError, ValueError) as e:
logger.debug(
"encode: msgpack failed (%s), falling back to pickle.",
type(e).__name__,
)
return [_PICKLE_FALLBACK_SENTINEL, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)]
def decode(frames: list) -> Any:
"""Decode frames produced by encode.
Transparently handles both the msgpack zero-copy path and the pickle
fallback path based on the leading sentinel frame.
"""
if len(frames) >= 2 and frames[0] == _PICKLE_FALLBACK_SENTINEL:
return pickle.loads(frames[1])
return _decoder.decode(frames)