Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
489 changes: 487 additions & 2 deletions tests/test_samplers.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions transfer_queue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .sampler import BaseSampler
from .sampler.grpo_group_n_sampler import GRPOGroupNSampler
from .sampler.rank_aware_sampler import RankAwareSampler
from .sampler.seqlen_balanced_sampler import SeqlenBalancedSampler
from .sampler.sequential_sampler import SequentialSampler

__all__ = (
Expand Down Expand Up @@ -76,6 +77,7 @@
"GRPOGroupNSampler",
"SequentialSampler",
"RankAwareSampler",
"SeqlenBalancedSampler",
]
)

Expand Down
2 changes: 1 addition & 1 deletion transfer_queue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ async def async_put(

for field_name, field_data in data.items():
if isinstance(field_data, torch.Tensor) and field_data.ndim == 1:
logger.warning(
logger.info(
f"[{self.client_id}]: Data field '{field_name}' is a tensor with only one dimension. "
f"You may receive 2D tensors in key-value based backend."
)
Expand Down
52 changes: 35 additions & 17 deletions transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
# Sample pre-allocation for StreamingDataLoader compatibility.
# By pre-allocating sample indices (typically global_batch_size), consumers can accurately
# determine consumption status even before producers have generated the samples.
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))


class PartitionIndexManager:
Expand Down Expand Up @@ -335,6 +334,7 @@ class DataPartitionStatus:

# Production status tensor - dynamically expandable
# Values: 0 = not produced, 1 = ready for consumption
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))

production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8)
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

production_status is initialized with a concrete torch.zeros(...) value at class definition time, so all DataPartitionStatus instances will share the same tensor object. Any in-place updates to one partition’s production_status can affect other partitions. Use a per-instance field(default_factory=...) to create a new tensor for each partition.

Suggested change
production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8)
production_status: Tensor = field(
default_factory=lambda: torch.zeros(
int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1)),
1,
dtype=torch.int8,
)
)

Copilot uses AI. Check for mistakes.

Expand Down Expand Up @@ -1050,6 +1050,8 @@ def create_partition(self, partition_id: str) -> bool:
Returns:
True if partition was created successfully, False if it already exists
"""
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))

if partition_id in self.partitions:
logger.warning(f"Partition {partition_id} already exists")
return False
Expand Down Expand Up @@ -1313,38 +1315,54 @@ def get_metadata(

if len(ready_for_consume_indexes) < batch_size:
if self.polling_mode:
logger.debug(
f"[{self.controller_id}]: Not enough data for task {task_name} in partition {partition_id}."
f" Required: {batch_size}, Available: {len(ready_for_consume_indexes)}."
f" Returning None due to polling mode."
sampling_config = sampling_config or {}
states = self.sampler._states.get(partition_id, {}).get(task_name, {})
dp_rank = sampling_config.get("dp_rank", None)
batch_index = sampling_config.get("batch_index", None)

# Return cached result if available
if dp_rank is not None and dp_rank in states and batch_index in states[dp_rank]:
break
else:
logger.debug(
f"[{self.controller_id}]: Not enough data for task {task_name} in "
f"partition {partition_id}. Required: {batch_size}, "
f"Available: {len(ready_for_consume_indexes)}."
f" Returning None due to polling mode."
)
return BatchMeta.empty()
else:
logger.warning(
f"[{self.controller_id}]: Insufficient data for task {task_name}. Required: {batch_size} "
f"samples with fields {data_fields} in partition {partition_id}, but only have "
f"{len(ready_for_consume_indexes)} samples meeting the criteria. "
f"Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
)
return BatchMeta.empty()
time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT:
raise TimeoutError(
f"Timeout while waiting for sufficient data for task {task_name}. "
f"Required: {batch_size}, Available: {len(ready_for_consume_indexes)}"
)
logger.warning(
f"[{self.controller_id}]: Insufficient data for task {task_name}. Required: {batch_size} "
f"samples with fields {data_fields} in partition {partition_id}, but only have "
f"{len(ready_for_consume_indexes)} samples meeting the criteria. "
f"Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
)
time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
else:
break

batch_global_indexes, consumed_indexes = self.sampler(
ready_for_consume_indexes,
batch_size,
partition=self._get_partition(partition_id),
**(sampling_config or {}),
**kwargs,
)

# Check if we got valid results from the sampler
if len(batch_global_indexes) != batch_size:
# Check if we got valid results from the sampler.
# Some samplers (e.g. SeqlenBalancedSampler) may return variable-size
# batches per DP rank, so we only check for empty results.
if len(batch_global_indexes) == 0:
if self.polling_mode:
return BatchMeta.empty()
raise RuntimeError(
f"Sampler returned insufficient samples. Please check the sampler logic. "
f"Sampler returned no samples. Please check the sampler logic. "
f"Expected: {batch_size}, before sampling: {len(ready_for_consume_indexes)}, "
f"after sampling: {len(batch_global_indexes)}"
)
Expand Down Expand Up @@ -1826,7 +1844,7 @@ def _process_request(self):
partition_id=params["partition_id"],
mode=params.get("mode", "fetch"),
task_name=params.get("task_name"),
sampling_config=params.get("sampling_config"),
sampling_config=params.get("sampling_config", {}),
)

response_msg = ZMQMessage.create(
Expand Down
98 changes: 38 additions & 60 deletions transfer_queue/dataloader/streaming_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,14 @@
import logging
import os
import time
import uuid
import warnings
from typing import Callable, Iterator

from omegaconf import DictConfig
from tensordict import TensorDict
from torch.utils.data import IterableDataset

from transfer_queue import TransferQueueClient
from transfer_queue.interface import get_client, init
from transfer_queue.metadata import BatchMeta
from transfer_queue.utils.zmq_utils import ZMQServerInfo

TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL = float(
os.environ.get("TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL", 1)
Expand Down Expand Up @@ -77,6 +74,7 @@ def __init__(
partition_id: str,
task_name: str,
dp_rank: int,
should_check_consumption_status: bool = False,
fetch_batch_fn: Callable | None = None,
process_batch_fn: Callable | None = None,
):
Expand All @@ -98,6 +96,14 @@ def __init__(
which samples have been consumed by which task.
dp_rank: The group ID of the current data group. All
ranks with the same dp_rank will receive identical samples.
should_check_consumption_status: Whether to check the consumption status of the
partition to decide when to stop iterating. Defaults to ``False``, which
means the iterator runs as an **infinite stream** — it will continuously
poll for new data and never exit on its own. This is the typical mode for
online/streaming training where producers keep feeding data indefinitely.
Set to ``True`` when the total number of samples is known in advance (i.e.
finite-dataset mode); the iterator will then stop once all samples in the
partition have been consumed.
fetch_batch_fn: Optional custom function to retrieve batch data.
If None, uses default_fetch_batch_fn function.
process_batch_fn: Optional custom function to post-process
Expand All @@ -123,6 +129,7 @@ def __init__(
self.partition_id = partition_id
self.task_name = task_name
self.dp_rank = dp_rank
self.should_check_consumption_status = should_check_consumption_status
self.fetch_batch_fn = fetch_batch_fn if fetch_batch_fn else default_fetch_batch_fn
self.process_batch_fn = process_batch_fn if process_batch_fn else chunk_batch_fn

Expand Down Expand Up @@ -151,63 +158,32 @@ def __init__(
def _create_client(self):
"""Create and initialize a TransferQueue client.

This method initializes the TransferQueueClient with the provided configuration
and storage backend, and sets up the storage manager for data retrieval.

Raises:
ValueError: If controller_info or storage_backend is missing or invalid.
This method initializes the TransferQueueClient with the provided configuration.
"""
client_id = uuid.uuid4().hex[:8]

# TODO: DEPRECATE in future
controller_config = self.config.get("controller", None)
if controller_config:
controller_info = controller_config.get("zmq_info", None)
else:
controller_info = self.config.get("controller_info", None)
if controller_info:
warnings.warn(
"Config entry `controller_info` will be deprecated in 0.1.7, please "
"use `controller.zmq_info` instead.",
category=DeprecationWarning,
stacklevel=2,
)

if not controller_info or not isinstance(controller_info, ZMQServerInfo):
raise ValueError("Invalid or missing controller.zmq_info in config")

backend_config = self.config.get("backend", None)
if not backend_config:
storage_backend = self.config.get("storage_backend", None)
backend_config = self.config
if storage_backend:
warnings.warn(
"Config entry `storage_backend` will be deprecated in 0.1.7, please "
"use `backend.storage_backend` instead.",
category=DeprecationWarning,
stacklevel=2,
)
else:
storage_backend = backend_config.get("storage_backend", None)
backend_config = self.config.backend[storage_backend]

if not storage_backend:
raise ValueError("Missing storage_backend in config")

self._tq_client = TransferQueueClient(client_id, controller_info)
self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=backend_config)

init(self.config)
self._tq_client = get_client()

Comment on lines 162 to 166
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_create_client() now uses init(self.config) / get_client(), but StreamingDataset.__init__ still documents config keys like controller.controller_info (and other deprecated entries). Consider updating that docstring to reflect the current initialization path and required config fields (e.g., controller.zmq_info, backend.storage_backend, etc.) so users don’t follow outdated guidance.

Copilot uses AI. Check for mistakes.
def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
"""Iterate over the dataset, yielding batches of data.

The iteration behaviour depends on ``should_check_consumption_status``:

- **False (default — streaming mode)**: The iterator runs as an
infinite stream, continuously polling TransferQueue for new data.
It will block (with a 1-second sleep) when no data is available and
resume once new batches are produced. This is the standard mode for
online / streaming training pipelines where producers feed data
indefinitely.
- **True (finite-dataset mode)**: The iterator terminates once all
samples in the partition have been consumed (as reported by
``check_consumption_status``), *and* all buffered batches have been
yielded.

Yields:
Tuple[TensorDict, BatchMeta]: A tuple containing:
- TensorDict: Batch of data with the requested fields.
- BatchMeta: Corresponding metadata to interact with TransferQueue.
Note:
This iterator runs indefinitely until the data source is exhausted.
The caller should handle StopIteration when appropriate (e.g., when
all data has been consumed and no more data will be produced).
"""
if self._tq_client is None:
self._create_client()
Expand All @@ -218,24 +194,26 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
# TQ_PRE_ALLOC_SAMPLE_NUM to the required global_batch_size to make sure consumers can accurately
# determine consumption status even before producers have generated the samples.
while (
not self._tq_client.check_consumption_status(self.task_name, self.partition_id)
not self.should_check_consumption_status
or not self._tq_client.check_consumption_status(self.task_name, self.partition_id)
or self.batch_index <= len(self.buffer) - 1
):
try:
if self.batch_index <= len(self.buffer) - 1:
current_data = self.buffer[self.batch_index]
self.batch_index += 1
logger.info(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}")
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This per-microbatch logger.info runs on every yielded batch and can generate extremely high log volume in long-running streaming training jobs. Consider downgrading this to debug (or gating it behind a configurable log interval) to avoid performance/log-storage impact.

Suggested change
logger.info(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}")
logger.debug(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}")

Copilot uses AI. Check for mistakes.
yield from self.process_batch_fn(*current_data, micro_batch_size=self.micro_batch_size)

else:
batch_data, batch_meta = self.fetch_batch_fn(
self._tq_client,
self.data_fields,
self.batch_size,
self.partition_id,
self.task_name,
self.sampling_config,
self.batch_index,
tq_client=self._tq_client,
data_fields=self.data_fields,
batch_size=self.batch_size,
partition_id=self.partition_id,
task_name=self.task_name,
sampling_config=self.sampling_config,
batch_index=self.batch_index,
)
if batch_data is not None:
self.buffer.append((batch_data, batch_meta))
Expand Down
2 changes: 1 addition & 1 deletion transfer_queue/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]:
f"expected batch_size[0]={batch_size}, got value.shape[0]={value.shape[0]}"
)
if len(value.shape) == 1:
logger.warning(f"Receiving 1D tensor for field '{field_name}'. Unsqueeze the last dimension.")
logger.info(f"Receiving 1D tensor for field '{field_name}'. Unsqueeze the last dimension.")
value = value.unsqueeze(-1)
first_item = value[0]
else:
Expand Down
3 changes: 2 additions & 1 deletion transfer_queue/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .base import BaseSampler
from .grpo_group_n_sampler import GRPOGroupNSampler
from .rank_aware_sampler import RankAwareSampler
from .seqlen_balanced_sampler import SeqlenBalancedSampler
from .sequential_sampler import SequentialSampler

__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler", "RankAwareSampler"]
__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler", "RankAwareSampler", "SeqlenBalancedSampler"]
Loading
Loading