Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
468 changes: 466 additions & 2 deletions tests/test_samplers.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions transfer_queue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .sampler import BaseSampler
from .sampler.grpo_group_n_sampler import GRPOGroupNSampler
from .sampler.rank_aware_sampler import RankAwareSampler
from .sampler.seqlen_balanced_sampler import SeqlenBalancedSampler
from .sampler.sequential_sampler import SequentialSampler

__all__ = (
Expand Down Expand Up @@ -76,6 +77,7 @@
"GRPOGroupNSampler",
"SequentialSampler",
"RankAwareSampler",
"SeqlenBalancedSampler",
]
)

Expand Down
2 changes: 1 addition & 1 deletion transfer_queue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ async def async_put(

for field_name, field_data in data.items():
if isinstance(field_data, torch.Tensor) and field_data.ndim == 1:
logger.warning(
logger.info(
f"[{self.client_id}]: Data field '{field_name}' is a tensor with only one dimension. "
f"You may receive 2D tensors in key-value based backend."
)
Expand Down
47 changes: 30 additions & 17 deletions transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
# Sample pre-allocation for StreamingDataLoader compatibility.
# By pre-allocating sample indices (typically global_batch_size), consumers can accurately
# determine consumption status even before producers have generated the samples.
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))


class PartitionIndexManager:
Expand Down Expand Up @@ -335,6 +334,7 @@ class DataPartitionStatus:

# Production status tensor - dynamically expandable
# Values: 0 = not produced, 1 = ready for consumption
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))

production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8)
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

production_status: Tensor = torch.zeros(...) is a mutable dataclass default, so all DataPartitionStatus instances will share the same underlying tensor unless it gets reassigned. Because methods like update_production_status mutate self.production_status in-place, this can cause cross-partition contamination. Use field(default_factory=lambda: torch.zeros(...)) (and similarly for any other tensor defaults) so each partition gets its own tensor instance.

Suggested change
production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8)
production_status: Tensor = field(
default_factory=lambda: torch.zeros(
DataPartitionStatus.TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8
)
)

Copilot uses AI. Check for mistakes.

Expand Down Expand Up @@ -1050,6 +1050,8 @@ def create_partition(self, partition_id: str) -> bool:
Returns:
True if partition was created successfully, False if it already exists
"""
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this variable changing during training?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it won’t change during training. However, the previous implementation sometimes missed the environment variable at initialization, and it wasn’t read at runtime.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's quite strange..


if partition_id in self.partitions:
logger.warning(f"Partition {partition_id} already exists")
return False
Expand Down Expand Up @@ -1313,38 +1315,49 @@ def get_metadata(

if len(ready_for_consume_indexes) < batch_size:
if self.polling_mode:
logger.debug(
f"[{self.controller_id}]: Not enough data for task {task_name} in partition {partition_id}."
f" Required: {batch_size}, Available: {len(ready_for_consume_indexes)}."
f" Returning None due to polling mode."
# Return cached result if available
if self.sampler.has_cached_result(partition_id, task_name, sampling_config):
break
else:
logger.debug(
f"[{self.controller_id}]: Not enough data for task {task_name} in "
f"partition {partition_id}. Required: {batch_size}, "
f"Available: {len(ready_for_consume_indexes)}."
f" Returning None due to polling mode."
)
return BatchMeta.empty()
else:
logger.warning(
f"[{self.controller_id}]: Insufficient data for task {task_name}. Required: {batch_size} "
f"samples with fields {data_fields} in partition {partition_id}, but only have "
f"{len(ready_for_consume_indexes)} samples meeting the criteria. "
f"Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
)
return BatchMeta.empty()
time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT:
raise TimeoutError(
f"Timeout while waiting for sufficient data for task {task_name}. "
f"Required: {batch_size}, Available: {len(ready_for_consume_indexes)}"
)
logger.warning(
f"[{self.controller_id}]: Insufficient data for task {task_name}. Required: {batch_size} "
f"samples with fields {data_fields} in partition {partition_id}, but only have "
f"{len(ready_for_consume_indexes)} samples meeting the criteria. "
f"Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
)
time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
else:
break

batch_global_indexes, consumed_indexes = self.sampler(
ready_for_consume_indexes,
batch_size,
partition=self._get_partition(partition_id),
**(sampling_config or {}),
**kwargs,
)

# Check if we got valid results from the sampler
if len(batch_global_indexes) != batch_size:
# Check if we got valid results from the sampler.
# Some samplers (e.g. SeqlenBalancedSampler) may return variable-size
# batches per DP rank, so we only check for empty results.
if len(batch_global_indexes) == 0:
if self.polling_mode:
return BatchMeta.empty()
raise RuntimeError(
f"Sampler returned insufficient samples. Please check the sampler logic. "
f"Sampler returned no samples. Please check the sampler logic. "
f"Expected: {batch_size}, before sampling: {len(ready_for_consume_indexes)}, "
f"after sampling: {len(batch_global_indexes)}"
)
Expand Down Expand Up @@ -1826,7 +1839,7 @@ def _process_request(self):
partition_id=params["partition_id"],
mode=params.get("mode", "fetch"),
task_name=params.get("task_name"),
sampling_config=params.get("sampling_config"),
sampling_config=params.get("sampling_config", {}),
)

response_msg = ZMQMessage.create(
Expand Down
106 changes: 48 additions & 58 deletions transfer_queue/dataloader/streaming_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@
import os
import time
import uuid
import warnings
from typing import Callable, Iterator

from omegaconf import DictConfig
from tensordict import TensorDict
from torch.utils.data import IterableDataset

from transfer_queue import TransferQueueClient
from transfer_queue.client import TransferQueueClient
from transfer_queue.metadata import BatchMeta
from transfer_queue.utils.zmq_utils import ZMQServerInfo

TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL = float(
os.environ.get("TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL", 1)
Expand Down Expand Up @@ -77,6 +75,7 @@ def __init__(
partition_id: str,
task_name: str,
dp_rank: int,
should_check_consumption_status: bool = False,
fetch_batch_fn: Callable | None = None,
process_batch_fn: Callable | None = None,
):
Expand All @@ -98,6 +97,14 @@ def __init__(
which samples have been consumed by which task.
dp_rank: The group ID of the current data group. All
ranks with the same dp_rank will receive identical samples.
should_check_consumption_status: Whether to check the consumption status of the
partition to decide when to stop iterating. Defaults to ``False``, which
means the iterator runs as an **infinite stream** — it will continuously
poll for new data and never exit on its own. This is the typical mode for
online/streaming training where producers keep feeding data indefinitely.
Set to ``True`` when the total number of samples is known in advance (i.e.
finite-dataset mode); the iterator will then stop once all samples in the
partition have been consumed.
fetch_batch_fn: Optional custom function to retrieve batch data.
If None, uses default_fetch_batch_fn function.
process_batch_fn: Optional custom function to post-process
Expand All @@ -123,6 +130,7 @@ def __init__(
self.partition_id = partition_id
self.task_name = task_name
self.dp_rank = dp_rank
self.should_check_consumption_status = should_check_consumption_status
self.fetch_batch_fn = fetch_batch_fn if fetch_batch_fn else default_fetch_batch_fn
self.process_batch_fn = process_batch_fn if process_batch_fn else chunk_batch_fn

Expand All @@ -149,65 +157,45 @@ def __init__(
super().__init__()

def _create_client(self):
"""Create and initialize a TransferQueue client.

This method initializes the TransferQueueClient with the provided configuration
and storage backend, and sets up the storage manager for data retrieval.

Raises:
ValueError: If controller_info or storage_backend is missing or invalid.
"""Create and initialize a TransferQueue client directly from config.

This method creates a ``TransferQueueClient`` using the ZMQ address and
storage backend information already present in ``self.config``. It
intentionally does **not** call ``tq.init()`` because that relies on Ray
internally (``ray.get_actor`` / ``ray.get``), which is **unsafe in
forked subprocesses** spawned by PyTorch DataLoader (``num_workers > 0``).
Creating the client directly via ZMQ avoids this issue.
"""
client_id = uuid.uuid4().hex[:8]

# TODO: DEPRECATE in future
controller_config = self.config.get("controller", None)
if controller_config:
controller_info = controller_config.get("zmq_info", None)
else:
controller_info = self.config.get("controller_info", None)
if controller_info:
warnings.warn(
"Config entry `controller_info` will be deprecated in 0.1.7, please "
"use `controller.zmq_info` instead.",
category=DeprecationWarning,
stacklevel=2,
)

if not controller_info or not isinstance(controller_info, ZMQServerInfo):
raise ValueError("Invalid or missing controller.zmq_info in config")

backend_config = self.config.get("backend", None)
if not backend_config:
storage_backend = self.config.get("storage_backend", None)
backend_config = self.config
if storage_backend:
warnings.warn(
"Config entry `storage_backend` will be deprecated in 0.1.7, please "
"use `backend.storage_backend` instead.",
category=DeprecationWarning,
stacklevel=2,
)
else:
storage_backend = backend_config.get("storage_backend", None)
backend_config = self.config.backend[storage_backend]

if not storage_backend:
raise ValueError("Missing storage_backend in config")
client_id = f"StreamingDataset_{uuid.uuid4().hex[:8]}"

controller_info = self.config.controller.zmq_info
storage_backend = self.config.backend.storage_backend
backend_config = self.config.backend[storage_backend]

self._tq_client = TransferQueueClient(client_id, controller_info)
self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=backend_config)

def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
"""Iterate over the dataset, yielding batches of data.

The iteration behaviour depends on ``should_check_consumption_status``:

- **False (default — streaming mode)**: The iterator runs as an
infinite stream, continuously polling TransferQueue for new data.
It will sleep for `TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL` seconds
(default=1) when no data is available and
resume once new batches are produced. This is the standard mode for
online / streaming training pipelines where producers feed data
indefinitely.
- **True (finite-dataset mode)**: The iterator terminates once all
samples in the partition have been consumed (as reported by
``check_consumption_status``), *and* all buffered batches have been
yielded.

Yields:
Tuple[TensorDict, BatchMeta]: A tuple containing:
- TensorDict: Batch of data with the requested fields.
- BatchMeta: Corresponding metadata to interact with TransferQueue.
Note:
This iterator runs indefinitely until the data source is exhausted.
The caller should handle StopIteration when appropriate (e.g., when
all data has been consumed and no more data will be produced).
"""
if self._tq_client is None:
self._create_client()
Expand All @@ -218,24 +206,26 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
# TQ_PRE_ALLOC_SAMPLE_NUM to the required global_batch_size to make sure consumers can accurately
# determine consumption status even before producers have generated the samples.
while (
not self._tq_client.check_consumption_status(self.task_name, self.partition_id)
not self.should_check_consumption_status
or not self._tq_client.check_consumption_status(self.task_name, self.partition_id)
or self.batch_index <= len(self.buffer) - 1
):
try:
if self.batch_index <= len(self.buffer) - 1:
current_data = self.buffer[self.batch_index]
self.batch_index += 1
logger.debug(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}")
yield from self.process_batch_fn(*current_data, micro_batch_size=self.micro_batch_size)

else:
batch_data, batch_meta = self.fetch_batch_fn(
self._tq_client,
self.data_fields,
self.batch_size,
self.partition_id,
self.task_name,
self.sampling_config,
self.batch_index,
tq_client=self._tq_client,
data_fields=self.data_fields,
batch_size=self.batch_size,
partition_id=self.partition_id,
task_name=self.task_name,
sampling_config=self.sampling_config,
batch_index=self.batch_index,
)
Comment on lines 221 to 229
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fetch_batch_fn is now invoked with keyword arguments. This can be a breaking change for users who passed a custom fetch_batch_fn that accepts only positional parameters (the previous call-site was positional). To preserve backward compatibility, consider calling it positionally (or supporting both via a small adapter / try-except) while keeping the documented parameter order.

Copilot uses AI. Check for mistakes.
if batch_data is not None:
self.buffer.append((batch_data, batch_meta))
Expand Down
9 changes: 6 additions & 3 deletions transfer_queue/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _init_from_existing() -> bool:


# ==================== Initialization API ====================
def init(conf: Optional[DictConfig] = None) -> None:
def init(conf: Optional[DictConfig] = None) -> Optional[DictConfig]:
"""Initialize the TransferQueue system.

This function sets up the TransferQueue controller, distributed storage, and client.
Expand All @@ -234,6 +234,8 @@ def init(conf: Optional[DictConfig] = None) -> None:
the default config from 'config.yaml'. This is only used for first-time
initializing. When connecting to an existing controller, this parameter
is ignored.
Returns:
The merged configuration dictionary.

Raises:
ValueError: If config is not valid or required configuration keys are missing.
Expand All @@ -251,7 +253,7 @@ def init(conf: Optional[DictConfig] = None) -> None:
>>> data = tq.get_data(metadata)
"""
if _init_from_existing():
return
return conf

# First-time initialize TransferQueue
logger.info("No TransferQueueController found. Starting first-time initialization...")
Expand Down Expand Up @@ -289,7 +291,7 @@ def init(conf: Optional[DictConfig] = None) -> None:
except ValueError:
logger.info("Some other rank has initialized TransferQueueController. Try to connect to existing controller.")
_init_from_existing()
return
return final_conf

controller_zmq_info = process_zmq_server_info(_TRANSFER_QUEUE_CONTROLLER)
final_conf.controller.zmq_info = controller_zmq_info
Expand All @@ -303,6 +305,7 @@ def init(conf: Optional[DictConfig] = None) -> None:

# create client
_maybe_create_transferqueue_client(final_conf)
return final_conf


def close():
Expand Down
2 changes: 1 addition & 1 deletion transfer_queue/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]:
f"expected batch_size[0]={batch_size}, got value.shape[0]={value.shape[0]}"
)
if len(value.shape) == 1:
logger.warning(f"Receiving 1D tensor for field '{field_name}'. Unsqueeze the last dimension.")
logger.info(f"Receiving 1D tensor for field '{field_name}'. Unsqueeze the last dimension.")
value = value.unsqueeze(-1)
first_item = value[0]
else:
Expand Down
3 changes: 2 additions & 1 deletion transfer_queue/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .base import BaseSampler
from .grpo_group_n_sampler import GRPOGroupNSampler
from .rank_aware_sampler import RankAwareSampler
from .seqlen_balanced_sampler import SeqlenBalancedSampler
from .sequential_sampler import SequentialSampler

__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler", "RankAwareSampler"]
__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler", "RankAwareSampler", "SeqlenBalancedSampler"]
Loading
Loading