diff --git a/pyproject.toml b/pyproject.toml index f6243e0..87a0a97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,8 @@ dependencies = {file = "requirements.txt"} test = [ "pytest>=7.0.0", "pytest-asyncio>=0.20.0", + "pytest-mock>=3.15.0", + "openyuanrong-datasystem>=0.6.3", ] # If you need to mimic `package_dir={'': '.'}`: diff --git a/tests/test_yuanrong_storage_manager.py b/tests/test_yuanrong_storage_manager.py index 4c38b7c..9433d6a 100644 --- a/tests/test_yuanrong_storage_manager.py +++ b/tests/test_yuanrong_storage_manager.py @@ -78,10 +78,10 @@ def side_effect_mcreate(keys, sizes): buffers = [MockBuffer(size) for size in sizes] for b in buffers: stored_raw_buffers.append(b.MutableData()) - return 0, buffers + return buffers storage_client._cpu_ds_client.mcreate.side_effect = side_effect_mcreate - storage_client._cpu_ds_client.get_buffers.return_value = (0, stored_raw_buffers) + storage_client._cpu_ds_client.get_buffers.return_value = stored_raw_buffers storage_client.mset_zcopy( ["tensor_key", "string_key"], [torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32), "hello yuanrong"] diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index c233472..8ff124a 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -193,7 +193,7 @@ def mset_zcopy(self, keys: list[str], objs: list[Any]): """ items_list = [[memoryview(b) for b in _encoder.encode(obj)] for obj in objs] packed_sizes = [calc_packed_size(items) for items in items_list] - status, buffers = self._cpu_ds_client.mcreate(keys, packed_sizes) + buffers = self._cpu_ds_client.mcreate(keys, packed_sizes) tasks = [(target.MutableData(), item) for target, item in zip(buffers, items_list, strict=False)] with ThreadPoolExecutor(max_workers=DS_MAX_WORKERS) as executor: list(executor.map(lambda p: pack_into(*p), tasks)) @@ -208,7 +208,7 @@ def mget_zcopy(self, keys: list[str]) -> list[Any]: Returns: list[Any]: List of deserialized objects corresponding to the input keys. """ - status, buffers = self._cpu_ds_client.get_buffers(keys, timeout_ms=500) + buffers = self._cpu_ds_client.get_buffers(keys) return [_decoder.decode(unpack_from(buffer)) if buffer is not None else None for buffer in buffers] def _batch_put(self, keys: list[str], values: list[Any]):