Skip to content

Commit c2bb0fa

Browse files
Aurelius84yuetian
andauthored
[feat] Support async_reset_consumption to reuse data (#25)
## Title Add reset_consumption Interface for Reusing Data in TransferQueue ## Description ### Overview This PR adds the `reset_consumption` interface to TransferQueue, allowing users to reset the consumption status of data samples without clearing the actual data. This is particularly useful for debugging scenarios where the same rollout data needs to be processed multiple times without regenerating it. ### Changes Made #### 1. **Core Controller Implementation** (`transfer_queue/controller.py`) - Added `reset_consumption()` method in `DataPartitionStatus` class to reset consumption status for specific tasks or all tasks - Added `reset_consumption()` method in `TransferQueueController` class to handle partition-level consumption reset - Added request handling for `RESET_CONSUMPTION` and `RESET_CONSUMPTION_RESPONSE` message types in the controller's request processing logic #### 2. **Client API** (`transfer_queue/client.py`) - Added `async_reset_consumption()` async method for resetting consumption status via async operations - Added `reset_consumption()` synchronous wrapper method for convenience - Both methods support resetting consumption for a specific task or all tasks in a partition #### 3. **Protocol Extensions** (`transfer_queue/utils/zmq_utils.py`) - Added `RESET_CONSUMPTION` and `RESET_CONSUMPTION_RESPONSE` request types to `ZMQRequestType` enum #### 4. **Comprehensive Testing** **In `tests/test_client.py`:** - Added `MockController` support for `RESET_CONSUMPTION` requests - Added `test_reset_consumption()` - Tests synchronous reset with specific task name - Added `test_reset_consumption_all_tasks()` - Tests synchronous reset for all tasks - Added `test_async_reset_consumption()` - Tests async reset with specific task name - Added `test_async_reset_consumption_all_tasks()` - Tests async reset for all tasks **In `tests/test_controller.py`:** - Added `test_controller_reset_consumption()` - Comprehensive integration test that verifies: - Consumption status before and after data consumption - Reset functionality for specific tasks - Reset functionality for all tasks (task_name=None) - Multi-task consumption scenarios ### Usage Examples **Synchronous API:** ```python from transfer_queue import TransferQueueClient client = TransferQueueClient(client_id="client_0", controller_info=controller_info) # Reset consumption for a specific task success = client.reset_consumption(partition_id="train_0", task_name="generate_sequences") # Reset consumption for all tasks in a partition success = client.reset_consumption(partition_id="train_0") ``` **Asynchronous API:** ```python import asyncio # Reset consumption for a specific task success = await client.async_reset_consumption(partition_id="train_0", task_name="generate_sequences") # Reset consumption for all tasks success = await client.async_reset_consumption(partition_id="train_0") ``` **Use Case Example:** ```python # Scenario: Re-train on the same rollout data without regeneration # Step 1: Fetch data and train batch_meta = client.get_meta( data_fields=["prompts", "attention_mask"], batch_size=32, partition_id="train_0", mode="fetch", task_name="training" ) batch_data = client.get_data(batch_meta) # ... perform training ... # Step 2: Reset consumption to allow re-training on same data client.reset_consumption(partition_id="train_0", task_name="training") # Step 3: Fetch and train again on the same data batch_meta = client.get_meta( data_fields=["prompts", "attention_mask"], batch_size=32, partition_id="train_0", mode="fetch", task_name="training" ) batch_data = client.get_data(batch_meta) # ... perform training again ... ``` Signed-off-by: yuetian <zhangliujie@xiaohongshu.com> Co-authored-by: yuetian <zhangliujie@xiaohongshu.com>
1 parent fbdb58e commit c2bb0fa

5 files changed

Lines changed: 352 additions & 0 deletions

File tree

tests/test_client.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,13 @@ def _handle_requests(self):
137137
elif request_msg.request_type == ZMQRequestType.SET_CUSTOM_META:
138138
response_body = {"message": "success"}
139139
response_type = ZMQRequestType.SET_CUSTOM_META_RESPONSE
140+
elif request_msg.request_type == ZMQRequestType.RESET_CONSUMPTION:
141+
# Mock reset consumption - always succeed
142+
response_body = {
143+
"success": True,
144+
"message": "Consumption reset successfully",
145+
}
146+
response_type = ZMQRequestType.RESET_CONSUMPTION_RESPONSE
140147
else:
141148
response_body = {"error": f"Unknown request type: {request_msg.request_type}"}
142149
response_type = ZMQRequestType.CLEAR_META_RESPONSE
@@ -531,6 +538,52 @@ def test_get_partition_list(client_setup):
531538
assert "test_partition" in partition_list
532539

533540

541+
def test_reset_consumption(client_setup):
542+
"""Test synchronous reset_consumption - resets consumption status for a partition"""
543+
client, _, _ = client_setup
544+
545+
# Test synchronous reset_consumption with task_name
546+
success = client.reset_consumption(partition_id="train_0", task_name="generate_sequences")
547+
assert success is True
548+
549+
print("✓ reset_consumption with task_name returns True")
550+
551+
552+
def test_reset_consumption_all_tasks(client_setup):
553+
"""Test synchronous reset_consumption without task_name (resets all tasks)"""
554+
client, _, _ = client_setup
555+
556+
# Test synchronous reset_consumption without task_name (reset all tasks)
557+
success = client.reset_consumption(partition_id="train_0")
558+
assert success is True
559+
560+
print("✓ reset_consumption without task_name (all tasks) returns True")
561+
562+
563+
@pytest.mark.asyncio
564+
async def test_async_reset_consumption(client_setup):
565+
"""Test async reset_consumption - resets consumption status for a partition"""
566+
client, _, _ = client_setup
567+
568+
# Test async_reset_consumption with task_name
569+
success = await client.async_reset_consumption(partition_id="train_0", task_name="generate_sequences")
570+
assert success is True
571+
572+
print("✓ async_reset_consumption with task_name returns True")
573+
574+
575+
@pytest.mark.asyncio
576+
async def test_async_reset_consumption_all_tasks(client_setup):
577+
"""Test async reset_consumption without task_name (resets all tasks)"""
578+
client, _, _ = client_setup
579+
580+
# Test async_reset_consumption without task_name (reset all tasks)
581+
success = await client.async_reset_consumption(partition_id="train_0")
582+
assert success is True
583+
584+
print("✓ async_reset_consumption without task_name (all tasks) returns True")
585+
586+
534587
@pytest.mark.asyncio
535588
async def test_async_check_consumption_status(client_setup):
536589
"""Test async consumption status checking"""

tests/test_controller.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,163 @@ def test_controller_with_single_partition(self, ray_setup):
198198
assert partition is None
199199
print("✓ Clear partition correct")
200200

201+
def test_controller_reset_consumption(self, ray_setup):
202+
"""Test reset_consumption functionality - allows data to be re-consumed"""
203+
gbs = 4
204+
num_n_samples = 2
205+
partition_id = "test_reset_consumption"
206+
207+
tq_controller = TransferQueueController.remote()
208+
209+
# Step 1: Create metadata in insert mode
210+
data_fields = ["prompt_ids", "attention_mask"]
211+
metadata = ray.get(
212+
tq_controller.get_metadata.remote(
213+
data_fields=data_fields,
214+
batch_size=gbs * num_n_samples,
215+
partition_id=partition_id,
216+
mode="insert",
217+
)
218+
)
219+
assert metadata.global_indexes == list(range(gbs * num_n_samples))
220+
221+
# Step 2: Update production status
222+
dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in metadata.global_indexes}
223+
shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,)} for k in metadata.global_indexes}
224+
success = ray.get(
225+
tq_controller.update_production_status.remote(
226+
partition_id=partition_id,
227+
global_indexes=metadata.global_indexes,
228+
field_names=metadata.field_names,
229+
dtypes=dtypes,
230+
shapes=shapes,
231+
)
232+
)
233+
assert success
234+
235+
# Step 3: Verify consumption status BEFORE consumption (should be all zeros)
236+
global_index, consumption_status = ray.get(
237+
tq_controller.get_consumption_status.remote(
238+
partition_id=partition_id,
239+
task_name="generate_sequences",
240+
)
241+
)
242+
expected_consumption_before = torch.zeros(gbs * num_n_samples, dtype=torch.int8)
243+
assert torch.equal(consumption_status, expected_consumption_before)
244+
print("✓ Consumption status before fetch is all zeros")
245+
246+
# Step 4: Fetch data (mark as consumed)
247+
gen_meta = ray.get(
248+
tq_controller.get_metadata.remote(
249+
data_fields=["prompt_ids"],
250+
batch_size=gbs * num_n_samples,
251+
partition_id=partition_id,
252+
mode="fetch",
253+
task_name="generate_sequences",
254+
)
255+
)
256+
assert gen_meta.global_indexes == list(range(gbs * num_n_samples))
257+
258+
# Step 5: Verify consumption status AFTER consumption (should be all ones)
259+
global_index, consumption_status = ray.get(
260+
tq_controller.get_consumption_status.remote(
261+
partition_id=partition_id,
262+
task_name="generate_sequences",
263+
)
264+
)
265+
expected_consumption_after = torch.ones(gbs * num_n_samples, dtype=torch.int8)
266+
assert torch.equal(consumption_status, expected_consumption_after)
267+
print("✓ Consumption status after fetch is all ones")
268+
269+
# Step 6: Reset consumption for specific task
270+
ray.get(
271+
tq_controller.reset_consumption.remote(
272+
partition_id=partition_id,
273+
task_name="generate_sequences",
274+
)
275+
)
276+
277+
# Step 7: Verify consumption status is reset (should be all zeros again)
278+
global_index, consumption_status = ray.get(
279+
tq_controller.get_consumption_status.remote(
280+
partition_id=partition_id,
281+
task_name="generate_sequences",
282+
)
283+
)
284+
expected_consumption_reset = torch.zeros(gbs * num_n_samples, dtype=torch.int8)
285+
assert torch.equal(consumption_status, expected_consumption_reset)
286+
print("✓ Consumption status after reset is all zeros")
287+
288+
# Step 8: Consume again and test reset all tasks
289+
gen_meta_2 = ray.get(
290+
tq_controller.get_metadata.remote(
291+
data_fields=["prompt_ids"],
292+
batch_size=gbs * num_n_samples,
293+
partition_id=partition_id,
294+
mode="fetch",
295+
task_name="generate_sequences",
296+
)
297+
)
298+
assert gen_meta_2.global_indexes == list(range(gbs * num_n_samples))
299+
300+
# Also consume with another task
301+
gen_meta_3 = ray.get(
302+
tq_controller.get_metadata.remote(
303+
data_fields=["attention_mask"],
304+
batch_size=gbs * num_n_samples,
305+
partition_id=partition_id,
306+
mode="fetch",
307+
task_name="another_task",
308+
)
309+
)
310+
assert gen_meta_3.global_indexes == list(range(gbs * num_n_samples))
311+
312+
# Verify both tasks have consumed
313+
_, consumption_status_task1 = ray.get(
314+
tq_controller.get_consumption_status.remote(
315+
partition_id=partition_id,
316+
task_name="generate_sequences",
317+
)
318+
)
319+
_, consumption_status_task2 = ray.get(
320+
tq_controller.get_consumption_status.remote(
321+
partition_id=partition_id,
322+
task_name="another_task",
323+
)
324+
)
325+
assert torch.equal(consumption_status_task1, torch.ones(gbs * num_n_samples, dtype=torch.int8))
326+
assert torch.equal(consumption_status_task2, torch.ones(gbs * num_n_samples, dtype=torch.int8))
327+
print("✓ Both tasks consumed successfully")
328+
329+
# Step 9: Reset all tasks (task_name=None)
330+
ray.get(
331+
tq_controller.reset_consumption.remote(
332+
partition_id=partition_id,
333+
task_name=None, # Reset all tasks
334+
)
335+
)
336+
337+
# Step 10: Verify all tasks are reset
338+
_, consumption_status_task1_reset = ray.get(
339+
tq_controller.get_consumption_status.remote(
340+
partition_id=partition_id,
341+
task_name="generate_sequences",
342+
)
343+
)
344+
_, consumption_status_task2_reset = ray.get(
345+
tq_controller.get_consumption_status.remote(
346+
partition_id=partition_id,
347+
task_name="another_task",
348+
)
349+
)
350+
assert torch.equal(consumption_status_task1_reset, torch.zeros(gbs * num_n_samples, dtype=torch.int8))
351+
assert torch.equal(consumption_status_task2_reset, torch.zeros(gbs * num_n_samples, dtype=torch.int8))
352+
print("✓ Reset all tasks successful - both tasks have zero consumption status")
353+
354+
# Clean up
355+
ray.get(tq_controller.clear_partition.remote(partition_id))
356+
print("✓ Reset consumption test completed successfully")
357+
201358
def test_controller_with_multi_partitions(self, ray_setup):
202359
gbs_1 = 8
203360
num_n_samples_1 = 4

transfer_queue/client.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,63 @@ async def async_check_consumption_status(
768768
return False
769769
return torch.all(consumption_status == 1).item()
770770

771+
@dynamic_socket(socket_name="request_handle_socket")
772+
async def async_reset_consumption(
773+
self,
774+
partition_id: str,
775+
task_name: Optional[str] = None,
776+
socket: Optional[zmq.asyncio.Socket] = None,
777+
) -> bool:
778+
"""Reset consumption status for a partition, allowing data to be re-consumed.
779+
This is useful for debugging scenarios where the same rollout data needs to be
780+
trained multiple times without regenerating the data.
781+
Args:
782+
partition_id: Partition id to reset consumption status for
783+
task_name: Name of the task to reset. If None, resets all tasks.
784+
socket: ZMQ async socket for message transmission (injected by decorator)
785+
Returns:
786+
bool: True if reset was successful, False otherwise
787+
Raises:
788+
RuntimeError: If communication fails or controller returns error response
789+
Example:
790+
>>> # Reset consumption for train task to re-train on same data
791+
>>> success = asyncio.run(client.async_reset_consumption(
792+
... partition_id="train_0",
793+
... task_name="train"
794+
... ))
795+
>>> print(f"Reset successful: {success}")
796+
"""
797+
assert socket is not None
798+
body = {"partition_id": partition_id}
799+
if task_name is not None:
800+
body["task_name"] = task_name
801+
request_msg = ZMQMessage.create(
802+
request_type=ZMQRequestType.RESET_CONSUMPTION,
803+
sender_id=self.client_id,
804+
receiver_id=self._controller.id,
805+
body=body,
806+
)
807+
try:
808+
await socket.send_multipart(request_msg.serialize())
809+
response_serialized = await socket.recv_multipart()
810+
response_msg = ZMQMessage.deserialize(response_serialized)
811+
logger.debug(
812+
f"[{self.client_id}]: Client reset consumption response: {response_msg} "
813+
f"from controller {self._controller.id}"
814+
)
815+
if response_msg.request_type == ZMQRequestType.RESET_CONSUMPTION_RESPONSE:
816+
success = response_msg.body.get("success", False)
817+
if not success:
818+
logger.warning(f"[{self.client_id}]: Reset consumption failed: {response_msg.body.get('message')}")
819+
return success
820+
else:
821+
raise RuntimeError(
822+
f"[{self.client_id}]: Failed to reset consumption from controller {self._controller.id}: "
823+
f"{response_msg.body.get('message', 'Unknown error')}"
824+
)
825+
except Exception as e:
826+
raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e
827+
771828
async def async_check_production_status(
772829
self,
773830
data_fields: list[str],
@@ -917,6 +974,7 @@ def wrapper(*args, **kwargs):
917974
self._check_production_status = _make_sync(self.async_check_production_status)
918975
self._get_partition_list = _make_sync(self.async_get_partition_list)
919976
self._set_custom_meta = _make_sync(self.async_set_custom_meta)
977+
self._reset_consumption = _make_sync(self.async_reset_consumption)
920978

921979
def put(
922980
self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None
@@ -1138,6 +1196,18 @@ def get_consumption_status(
11381196
"""
11391197
return self._get_consumption_status(task_name, partition_id)
11401198

1199+
def reset_consumption(self, partition_id: str, task_name: Optional[str] = None) -> bool:
1200+
"""Synchronously reset consumption status for a partition.
1201+
This allows the same data to be re-consumed, useful for debugging scenarios
1202+
where the same rollout data needs to be trained multiple times.
1203+
Args:
1204+
partition_id: Partition id to reset consumption status for
1205+
task_name: Name of the task to reset. If None, resets all tasks.
1206+
Returns:
1207+
bool: True if reset was successful, False otherwise
1208+
"""
1209+
return self._reset_consumption(partition_id, task_name)
1210+
11411211
def check_production_status(self, data_fields: list[str], partition_id: str) -> bool:
11421212
"""Synchronously check if all samples for a partition are ready (produced) for consumption.
11431213

0 commit comments

Comments
 (0)