Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ dependencies = {file = "requirements.txt"}
test = [
"pytest>=7.0.0",
"pytest-asyncio>=0.20.0",
"pytest-mock>=3.15.0",
"openyuanrong-datasystem>=0.6.3",
]

# If you need to mimic `package_dir={'': '.'}`:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_yuanrong_storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ def side_effect_mcreate(keys, sizes):
buffers = [MockBuffer(size) for size in sizes]
for b in buffers:
stored_raw_buffers.append(b.MutableData())
return 0, buffers
return buffers

storage_client._cpu_ds_client.mcreate.side_effect = side_effect_mcreate
storage_client._cpu_ds_client.get_buffers.return_value = (0, stored_raw_buffers)
storage_client._cpu_ds_client.get_buffers.return_value = stored_raw_buffers

storage_client.mset_zcopy(
["tensor_key", "string_key"], [torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32), "hello yuanrong"]
Expand Down
4 changes: 2 additions & 2 deletions transfer_queue/storage/clients/yuanrong_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def mset_zcopy(self, keys: list[str], objs: list[Any]):
"""
items_list = [[memoryview(b) for b in _encoder.encode(obj)] for obj in objs]
packed_sizes = [calc_packed_size(items) for items in items_list]
status, buffers = self._cpu_ds_client.mcreate(keys, packed_sizes)
buffers = self._cpu_ds_client.mcreate(keys, packed_sizes)
tasks = [(target.MutableData(), item) for target, item in zip(buffers, items_list, strict=False)]
with ThreadPoolExecutor(max_workers=DS_MAX_WORKERS) as executor:
list(executor.map(lambda p: pack_into(*p), tasks))
Expand All @@ -208,7 +208,7 @@ def mget_zcopy(self, keys: list[str]) -> list[Any]:
Returns:
list[Any]: List of deserialized objects corresponding to the input keys.
"""
status, buffers = self._cpu_ds_client.get_buffers(keys, timeout_ms=500)
buffers = self._cpu_ds_client.get_buffers(keys)
return [_decoder.decode(unpack_from(buffer)) if buffer is not None else None for buffer in buffers]

def _batch_put(self, keys: list[str], values: list[Any]):
Expand Down