Skip to content

Commit da2ad54

Browse files
[chore] Optimize pyproject and CI script (#13)
1. Optimize `pyproject.toml`, add `build`, `test`, `yuanrong` as optional dependencies. 2. Fix missing pre-commit checks due to the change of folder structure. 3. Separate pytest and build test in CI, simplifies the dependency installation. --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com> Signed-off-by: tianyi-ge <tianyig@outlook.com> Co-authored-by: tianyi-ge <tianyig@outlook.com>
1 parent b1d50d7 commit da2ad54

14 files changed

Lines changed: 98 additions & 63 deletions

File tree

.github/workflows/python-package.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,18 @@ jobs:
3131
- name: Install dependencies
3232
run: |
3333
python -m pip install --upgrade pip
34-
python -m pip install flake8 pytest build pytest_asyncio pytest-mock openyuanrong-datasystem
35-
python -m build --wheel
3634
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
37-
pip install dist/*.whl
35+
pip install -e ".[test,build,yuanrong]"
3836
- name: Lint with flake8
3937
run: |
4038
# stop the build if there are Python syntax errors or undefined names
4139
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
4240
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
4341
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
42+
- name: Build wheel and test installed distribution
43+
run: |
44+
python -m build --wheel
45+
pip install dist/*.whl --force-reinstall
4446
- name: Test with pytest
4547
run: |
4648
pytest

.github/workflows/sanity.yml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,6 @@ jobs:
3838
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
3939
with:
4040
python-version: ${{ matrix.python-version }}
41-
- name: Install dependencies
42-
run: |
43-
python -m pip install --upgrade pip
44-
python -m pip install build
45-
python -m build --wheel
46-
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
47-
pip install dist/*.whl
4841
- name: Run license test
4942
run: |
5043
python3 tests/sanity/check_license.py --directories .

pyproject.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,7 @@ filterwarnings = [
8787

8888
[[tool.mypy.overrides]]
8989
module = [
90-
"transfer_queue.data_system.*",
91-
"transfer_queue.utils.utils.*",
92-
"transfer_queue.utils.zmq_utils.*",
93-
"transfer_queue.utils.serial_utils.*",
90+
"transfer_queue.*",
9491
]
9592
ignore_errors = false
9693

@@ -108,9 +105,17 @@ version = {file = "transfer_queue/version/version"}
108105
dependencies = {file = "requirements.txt"}
109106

110107
[project.optional-dependencies]
108+
build = [
109+
"build"
110+
]
111111
test = [
112112
"pytest>=7.0.0",
113113
"pytest-asyncio>=0.20.0",
114+
"flake8",
115+
"pytest-mock",
116+
]
117+
yuanrong = [
118+
"openyuanrong-datasystem"
114119
]
115120

116121
# If you need to mimic `package_dir={'': '.'}`:

tests/test_controller.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_controller_with_single_partition(self, ray_setup):
7575
ProductionStatus.NOT_PRODUCED
7676
)
7777
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id))
78-
assert partition_index_range == set(range(gbs * num_n_samples))
78+
assert partition_index_range == list(range(gbs * num_n_samples))
7979

8080
print("✓ Initial get metadata correct")
8181

@@ -194,7 +194,7 @@ def test_controller_with_single_partition(self, ray_setup):
194194
ray.get(tq_controller.clear_partition.remote(partition_id))
195195
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
196196
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id))
197-
assert partition_index_range == set()
197+
assert partition_index_range == []
198198
assert partition is None
199199
print("✓ Clear partition correct")
200200

@@ -307,7 +307,7 @@ def test_controller_with_multi_partitions(self, ray_setup):
307307
[int(sample.fields.get("attention_mask").production_status) for sample in val_metadata.samples]
308308
) == int(ProductionStatus.NOT_PRODUCED)
309309
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2))
310-
assert partition_index_range == set(range(part1_index_range, part2_index_range + part1_index_range))
310+
assert partition_index_range == list(range(part1_index_range, part2_index_range + part1_index_range))
311311

312312
# Update production status
313313
dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in val_metadata.global_indexes}
@@ -359,11 +359,11 @@ def test_controller_with_multi_partitions(self, ray_setup):
359359

360360
assert not partition_index_range_1_after_clear
361361
assert partition_1_after_clear is None
362-
assert partition_index_range_1_after_clear == set()
362+
assert partition_index_range_1_after_clear == []
363363

364364
partition_2 = ray.get(tq_controller.get_partition_snapshot.remote(partition_id_2))
365365
partition_index_range_2 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_2))
366-
assert partition_index_range_2 == set([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
366+
assert partition_index_range_2 == [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]
367367
assert torch.all(
368368
partition_2.production_status[list(partition_index_range_2), : len(val_metadata.field_names)] == 1
369369
)
@@ -387,7 +387,7 @@ def test_controller_with_multi_partitions(self, ray_setup):
387387
[int(sample.fields.get("attention_mask").production_status) for sample in metadata_2.samples]
388388
) == int(ProductionStatus.NOT_PRODUCED)
389389
partition_index_range = ray.get(tq_controller.get_partition_index_range.remote(partition_id_3))
390-
assert partition_index_range == set(list(range(32)) + list(range(48, 80)))
390+
assert partition_index_range == list(range(32)) + list(range(48, 80))
391391
print("✓ Correctly assign partition_3")
392392

393393
def test_controller_clear_meta(self, ray_setup):

tests/test_kv_storage_manager.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@ def get_meta(data, global_indexes=None):
5757
@pytest.fixture
5858
def test_data():
5959
"""Fixture providing test configuration, data, and metadata."""
60-
cfg = {"client_name": "YuanrongStorageClient", "host": "127.0.0.1", "port": 31501, "device_id": 0}
60+
cfg = {
61+
"controller_info": MagicMock(),
62+
"client_name": "YuanrongStorageClient",
63+
"host": "127.0.0.1",
64+
"port": 31501,
65+
"device_id": 0,
66+
}
6167
global_indexes = [8, 9, 10]
6268

6369
data = TensorDict(
@@ -288,7 +294,7 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo
288294
mock_storage_client.put.return_value = mock_custom_meta
289295

290296
# Create manager with mocked dependencies
291-
config = {"client_name": "MockClient"}
297+
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
292298
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
293299
manager = KVStorageManager(config)
294300

@@ -338,7 +344,7 @@ def test_put_data_without_custom_meta(mock_notify, test_data_for_put_data):
338344
mock_storage_client.put.return_value = None
339345

340346
# Create manager with mocked dependencies
341-
config = {"client_name": "MockClient"}
347+
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
342348
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
343349
manager = KVStorageManager(config)
344350

@@ -361,7 +367,7 @@ def test_put_data_custom_meta_length_mismatch_raises_error(test_data_for_put_dat
361367
mock_storage_client.put.return_value = [{"key": "1"}, {"key": "2"}, {"key": "3"}]
362368

363369
# Create manager with mocked dependencies
364-
config = {"client_name": "MockClient"}
370+
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
365371
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
366372
manager = KVStorageManager(config)
367373

transfer_queue/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,7 @@ async def async_get_partition_list(
758758
)
759759

760760
try:
761+
assert socket is not None
761762
await socket.send_multipart(request_msg.serialize())
762763
response_serialized = await socket.recv_multipart()
763764
response_msg = ZMQMessage.deserialize(response_serialized)
@@ -1049,10 +1050,10 @@ def process_zmq_server_info(
10491050
>>> info_dict = process_zmq_server_info(handlers)"""
10501051
# Handle single handler object case
10511052
if not isinstance(handlers, dict):
1052-
return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[attr-defined]
1053+
return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined]
10531054
else:
10541055
# Handle dictionary case
10551056
server_info = {}
10561057
for name, handler in handlers.items():
1057-
server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[attr-defined]
1058+
server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined]
10581059
return server_info

transfer_queue/controller.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -188,17 +188,17 @@ def release_indexes(self, partition_id: str, indexes_to_release: list[int]):
188188
if not partition_indexes:
189189
self.partition_to_indexes.pop(partition_id, None)
190190

191-
def get_indexes_for_partition(self, partition_id) -> set[int]:
191+
def get_indexes_for_partition(self, partition_id) -> list[int]:
192192
"""
193193
Get all global_indexes for the specified partition.
194194
195195
Args:
196196
partition_id: Partition ID
197197
198198
Returns:
199-
set: Set of global_indexes for this partition
199+
list: List of global_indexes for this partition
200200
"""
201-
return self.partition_to_indexes.get(partition_id, set()).copy()
201+
return list(self.partition_to_indexes.get(partition_id, set()).copy())
202202

203203

204204
@dataclass
@@ -216,7 +216,7 @@ class DataPartitionStatus:
216216

217217
# Production status tensor - dynamically expandable
218218
# Values: 0 = not produced, 1 = ready for consumption
219-
production_status: Optional[Tensor] = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8)
219+
production_status: Tensor = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8)
220220

221221
# Consumption status per task - task_name -> consumption_tensor
222222
# Each tensor tracks which samples have been consumed by that task
@@ -251,16 +251,16 @@ def total_fields_num(self) -> int:
251251
@property
252252
def allocated_fields_num(self) -> int:
253253
"""Current number of allocated columns in the tensor."""
254-
return self.production_status.shape[1] if self.production_status is not None else 0
254+
return self.production_status.shape[1]
255255

256256
@property
257257
def allocated_samples_num(self) -> int:
258258
"""Current number of allocated rows in the tensor."""
259-
return self.production_status.shape[0] if self.production_status is not None else 0
259+
return self.production_status.shape[0]
260260

261261
# ==================== Dynamic Expansion Methods ====================
262262

263-
def ensure_samples_capacity(self, required_samples: int) -> bool:
263+
def ensure_samples_capacity(self, required_samples: int) -> None:
264264
"""
265265
Ensure the production status tensor has enough rows for the required samples.
266266
Dynamically expands if needed using unified minimum expansion size.
@@ -291,17 +291,14 @@ def ensure_samples_capacity(self, required_samples: int) -> bool:
291291
f"to {new_samples} samples (added {min_expansion} samples)"
292292
)
293293

294-
def ensure_fields_capacity(self, required_fields: int):
294+
def ensure_fields_capacity(self, required_fields: int) -> None:
295295
"""
296296
Ensure the production status tensor has enough columns for the required fields.
297297
Dynamically expands if needed using unified minimum expansion size.
298298
299299
Args:
300300
required_fields: Minimum number of fields needed
301301
"""
302-
if self.production_status is None:
303-
# Will be initialized when samples are added
304-
return
305302

306303
current_fields = self.production_status.shape[1]
307304
if required_fields > current_fields:
@@ -498,7 +495,9 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te
498495
return partition_global_index, consumption_status
499496

500497
# ==================== Production Status Interface ====================
501-
def get_production_status_for_fields(self, field_names: list[str], mask: bool = False) -> tuple[Tensor, Tensor]:
498+
def get_production_status_for_fields(
499+
self, field_names: list[str], mask: bool = False
500+
) -> tuple[Optional[Tensor], Optional[Tensor]]:
502501
"""
503502
Check if all samples for specified fields are fully produced and ready.
504503
@@ -511,13 +510,13 @@ def get_production_status_for_fields(self, field_names: list[str], mask: bool =
511510
- Partition global index tensor
512511
- Production status tensor for the specified task. 1 for ready, 0 for not ready.
513512
"""
514-
if self.production_status is None or field_names is None or len(field_names) == 0:
515-
return False
513+
if field_names is None or len(field_names) == 0:
514+
return None, None
516515

517516
# Check if all requested fields are registered
518517
for field_name in field_names:
519518
if field_name not in self.field_name_mapping:
520-
return False
519+
return None, None
521520

522521
# Create column mask for requested fields
523522
col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool)
@@ -548,8 +547,6 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]:
548547
Returns:
549548
List of sample indices that are ready for consumption
550549
"""
551-
if self.production_status is None:
552-
return []
553550

554551
# Check if all requested fields are registered
555552
for field_name in field_names:
@@ -837,15 +834,15 @@ def list_partitions(self) -> list[str]:
837834

838835
# ==================== Partition Index Management API ====================
839836

840-
def get_partition_index_range(self, partition: DataPartitionStatus) -> set:
837+
def get_partition_index_range(self, partition: DataPartitionStatus) -> list[int]:
841838
"""
842839
Get all indexes for a specific partition.
843840
844841
Args:
845842
partition: Partition identifier
846843
847844
Returns:
848-
Set of indexes allocated to the partition
845+
List of indexes allocated to the partition
849846
"""
850847
return self.index_manager.get_indexes_for_partition(partition)
851848

@@ -980,6 +977,9 @@ def get_metadata(
980977
if mode == "fetch":
981978
# Find ready samples within current data partition and package into BatchMeta when reading
982979

980+
if batch_size is None:
981+
raise ValueError("must provide batch_size in fetch mode")
982+
983983
start_time = time.time()
984984
while True:
985985
# ready_for_consume_indexes: samples where all required fields are produced

transfer_queue/dataloader/streaming_dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
170170
if self._tq_client is None:
171171
self._create_client()
172172

173+
assert self._tq_client is not None, "Failed to create TransferQueue client"
174+
173175
# TODO: need to consider async scenario where the samples in partition is dynamically increasing
174176
while not self._tq_client.check_consumption_status(self.task_name, self.partition_id):
175177
try:

transfer_queue/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def get_all_custom_meta(self) -> dict[int, dict[str, Any]]:
265265
"""Get the entire custom meta dictionary"""
266266
return copy.deepcopy(self._custom_meta)
267267

268-
def update_custom_meta(self, new_custom_meta: dict[int, dict[str, Any]] = None):
268+
def update_custom_meta(self, new_custom_meta: Optional[dict[int, dict[str, Any]]]):
269269
"""Update custom meta with a new dictionary"""
270270
if new_custom_meta:
271271
self._custom_meta.update(new_custom_meta)

transfer_queue/storage/clients/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ class TransferQueueStorageKVClient(ABC):
2323
Subclasses must implement the core methods: put, get, and clear.
2424
"""
2525

26+
def __init__(self, config: dict[str, Any]):
27+
"""
28+
Initialize the storage client with configuration.
29+
Args:
30+
config (dict[str, Any]): Configuration dictionary for the storage client.
31+
"""
32+
self.config = config
33+
2634
@abstractmethod
2735
def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
2836
"""

0 commit comments

Comments
 (0)