Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,6 @@ batch_meta = client.get_meta(
batch_size=8,
partition_id="train_0",
task_name="generate_sequences",
sampling_config={"n_samples_per_prompt": 4} # Put the required sampling parameters here
)
```

Expand Down
1 change: 0 additions & 1 deletion recipe/simple_use_case/async_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def _initialize_data_system(self):
# self.data_system_controller = TransferQueueController.remote(sampler=grpo_sampler)

# Then use sampling_config in get_meta calls:
# sampling_config={"n_samples_per_prompt": 4}
self.data_system_controller = TransferQueueController.remote()
logger.info("TransferQueueController has been created.")

Expand Down
1 change: 0 additions & 1 deletion recipe/simple_use_case/sync_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def initialize_data_system(config):
# data_system_controller = TransferQueueController.remote(sampler=grpo_sampler)

# Then use sampling_config in get_meta calls:
# sampling_config={"n_samples_per_prompt": 4}
data_system_controller = TransferQueueController.remote()
logger.info("TransferQueueController has been created.")

Expand Down
415 changes: 173 additions & 242 deletions tests/test_samplers.py

Large diffs are not rendered by default.

5 changes: 1 addition & 4 deletions transfer_queue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ async def async_get_meta(
- 'insert': Internal usage - should not be used by users
task_name: Optional task name associated with the request
sampling_config: Optional sampling configuration for custom samplers.
For GRPOGroupNSampler, should include "n_samples_per_prompt": int
socket: ZMQ async socket for message transmission (injected by decorator)

Returns:
Expand All @@ -206,7 +205,6 @@ async def async_get_meta(
... partition_id="train_0",
... mode="fetch",
... task_name="generate_sequences",
... sampling_config={"n_samples_per_prompt": 4}
... ))
>>> print(batch_meta.is_ready) # True if all samples ready
>>>
Expand Down Expand Up @@ -698,7 +696,7 @@ async def async_check_consumption_status(
partition_id=partition_id,
)

if consumption_status is None:
if consumption_status is None or consumption_status.numel() == 0:
return False
return torch.all(consumption_status == 1).item()

Expand Down Expand Up @@ -883,7 +881,6 @@ def get_meta(
partition_id: Target data partition id
task_name: Optional task name associated with the request
sampling_config: Optional sampling configuration for custom samplers.
For GRPOGroupNSampler, should include "n_samples_per_prompt": int

Returns:
BatchMeta: Batch metadata containing data location information
Expand Down
10 changes: 6 additions & 4 deletions transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ def __init__(
- If a BaseSampler subclass is provided, it will be instantiated
- Defaults to SequentialSampler for simple sequential sampling
- Example: sampler=GRPOGroupNSampler() (instance)
- Example: sampler=GRPOGroupNSampler (class)
- Example: sampler=SequentialSampler (class)
polling_mode: Whether to use polling mode for TransferQueue controller.
- If False, the controller will raise an error when no enough data is available.
- If True, the controller will return an empty BatchMeta when no enough data is available.
Expand Down Expand Up @@ -1015,12 +1015,12 @@ def get_metadata(
Raises:
TimeoutError: If waiting for sufficient data times out in fetch mode
"""
if partition_id not in self.partitions:
self.create_partition(partition_id)

if mode == "insert":
partition = self._get_partition(partition_id)
if partition_id not in self.partitions:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe for other modes when user doesn't specify partition_id?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I've checked—when partition_id is not provided, it defaults to returning an empty value, which should not cause any issues. I plan to rely on the pipeline to ensure no additional problems are introduced.

self.create_partition(partition_id)

partition = self._get_partition(partition_id)
if data_fields:
# This is called during put_data call without providing metadata.
# try to use pre-allocated global index first
Expand Down Expand Up @@ -1083,6 +1083,7 @@ def get_metadata(
ready_for_consume_indexes,
batch_size,
**(sampling_config or {}),
**kwargs,
)

# Check if we got valid results from the sampler
Expand Down Expand Up @@ -1240,6 +1241,7 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True):
partition.clear_data(global_indexes_range, clear_consumption)
self.index_manager.release_partition(partition_id)
self.partitions.pop(partition_id)
self.sampler.clear_cache(partition_id)

def clear_meta(
self,
Expand Down
30 changes: 30 additions & 0 deletions transfer_queue/dataloader/streaming_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
parameter in PyTorch DataLoader is set to None because batching is managed
by the StreamingDataset in coordination with RankAwareSampler.
"""
self.dataset: StreamingDataset = dataset

if collate_fn is None:
# use identical collate function to directly return the self-defined
Expand All @@ -137,3 +138,32 @@ def __init__(
persistent_workers=persistent_workers,
pin_memory_device=pin_memory_device,
)

def reset(self):
"""Reset the dataset iterator to the beginning.

Clears the buffer and resets the batch index for a fresh iteration.
"""
self.dataset.reset()

def step(self, partition_id):
"""Switch to a new partition and reset the dataset state.

This method clears the buffer, resets the batch index, and updates the partition_id
to fetch data from a different partition (e.g., switching from "train" to "val").

Args:
partition_id: The new partition ID to switch to.
"""
self.dataset.step(partition_id)

def get_buffer(self):
"""Get the current buffer from the underlying dataset.

Returns the batch buffer maintained by StreamingDataset, which stores
pre-fetched batches for efficient data access.

Returns:
list: Buffer containing pre-fetched (TensorDict, BatchMeta) tuples.
"""
return self.dataset.buffer
Loading