-
Notifications
You must be signed in to change notification settings - Fork 16
[feat] Add SeqlenBalancedSampler and enhance StreamingDataset support
#70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,7 +63,6 @@ | |
| # Sample pre-allocation for StreamingDataLoader compatibility. | ||
| # By pre-allocating sample indices (typically global_batch_size), consumers can accurately | ||
| # determine consumption status even before producers have generated the samples. | ||
| TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1)) | ||
|
|
||
|
|
||
| class PartitionIndexManager: | ||
|
|
@@ -335,6 +334,7 @@ class DataPartitionStatus: | |
|
|
||
| # Production status tensor - dynamically expandable | ||
| # Values: 0 = not produced, 1 = ready for consumption | ||
| TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1)) | ||
|
|
||
| production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8) | ||
|
|
||
|
|
@@ -1050,6 +1050,8 @@ def create_partition(self, partition_id: str) -> bool: | |
| Returns: | ||
| True if partition was created successfully, False if it already exists | ||
| """ | ||
| TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this variable changing during training?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, it won’t change during training. However, the previous implementation sometimes missed the environment variable at initialization, and it wasn’t read at runtime.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's quite strange.. |
||
|
|
||
| if partition_id in self.partitions: | ||
| logger.warning(f"Partition {partition_id} already exists") | ||
| return False | ||
|
|
@@ -1313,38 +1315,49 @@ def get_metadata( | |
|
|
||
| if len(ready_for_consume_indexes) < batch_size: | ||
| if self.polling_mode: | ||
| logger.debug( | ||
| f"[{self.controller_id}]: Not enough data for task {task_name} in partition {partition_id}." | ||
| f" Required: {batch_size}, Available: {len(ready_for_consume_indexes)}." | ||
| f" Returning None due to polling mode." | ||
| # Return cached result if available | ||
| if self.sampler.has_cached_result(partition_id, task_name, sampling_config): | ||
| break | ||
| else: | ||
| logger.debug( | ||
| f"[{self.controller_id}]: Not enough data for task {task_name} in " | ||
| f"partition {partition_id}. Required: {batch_size}, " | ||
| f"Available: {len(ready_for_consume_indexes)}." | ||
| f" Returning None due to polling mode." | ||
| ) | ||
| return BatchMeta.empty() | ||
| else: | ||
| logger.warning( | ||
| f"[{self.controller_id}]: Insufficient data for task {task_name}. Required: {batch_size} " | ||
| f"samples with fields {data_fields} in partition {partition_id}, but only have " | ||
| f"{len(ready_for_consume_indexes)} samples meeting the criteria. " | ||
| f"Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..." | ||
| ) | ||
| return BatchMeta.empty() | ||
| time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL) | ||
| if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT: | ||
| raise TimeoutError( | ||
| f"Timeout while waiting for sufficient data for task {task_name}. " | ||
| f"Required: {batch_size}, Available: {len(ready_for_consume_indexes)}" | ||
| ) | ||
| logger.warning( | ||
| f"[{self.controller_id}]: Insufficient data for task {task_name}. Required: {batch_size} " | ||
| f"samples with fields {data_fields} in partition {partition_id}, but only have " | ||
| f"{len(ready_for_consume_indexes)} samples meeting the criteria. " | ||
| f"Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..." | ||
| ) | ||
| time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL) | ||
| else: | ||
| break | ||
|
|
||
| batch_global_indexes, consumed_indexes = self.sampler( | ||
| ready_for_consume_indexes, | ||
| batch_size, | ||
| partition=self._get_partition(partition_id), | ||
| **(sampling_config or {}), | ||
| **kwargs, | ||
| ) | ||
|
|
||
| # Check if we got valid results from the sampler | ||
| if len(batch_global_indexes) != batch_size: | ||
| # Check if we got valid results from the sampler. | ||
| # Some samplers (e.g. SeqlenBalancedSampler) may return variable-size | ||
| # batches per DP rank, so we only check for empty results. | ||
| if len(batch_global_indexes) == 0: | ||
| if self.polling_mode: | ||
| return BatchMeta.empty() | ||
| raise RuntimeError( | ||
| f"Sampler returned insufficient samples. Please check the sampler logic. " | ||
| f"Sampler returned no samples. Please check the sampler logic. " | ||
| f"Expected: {batch_size}, before sampling: {len(ready_for_consume_indexes)}, " | ||
| f"after sampling: {len(batch_global_indexes)}" | ||
| ) | ||
|
|
@@ -1826,7 +1839,7 @@ def _process_request(self): | |
| partition_id=params["partition_id"], | ||
| mode=params.get("mode", "fetch"), | ||
| task_name=params.get("task_name"), | ||
| sampling_config=params.get("sampling_config"), | ||
| sampling_config=params.get("sampling_config", {}), | ||
| ) | ||
|
|
||
| response_msg = ZMQMessage.create( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,16 +17,14 @@ | |
| import os | ||
| import time | ||
| import uuid | ||
| import warnings | ||
| from typing import Callable, Iterator | ||
|
|
||
| from omegaconf import DictConfig | ||
| from tensordict import TensorDict | ||
| from torch.utils.data import IterableDataset | ||
|
|
||
| from transfer_queue import TransferQueueClient | ||
| from transfer_queue.client import TransferQueueClient | ||
| from transfer_queue.metadata import BatchMeta | ||
| from transfer_queue.utils.zmq_utils import ZMQServerInfo | ||
|
|
||
| TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL = float( | ||
| os.environ.get("TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL", 1) | ||
|
|
@@ -77,6 +75,7 @@ def __init__( | |
| partition_id: str, | ||
| task_name: str, | ||
| dp_rank: int, | ||
| should_check_consumption_status: bool = False, | ||
| fetch_batch_fn: Callable | None = None, | ||
| process_batch_fn: Callable | None = None, | ||
| ): | ||
|
|
@@ -98,6 +97,14 @@ def __init__( | |
| which samples have been consumed by which task. | ||
| dp_rank: The group ID of the current data group. All | ||
| ranks with the same dp_rank will receive identical samples. | ||
| should_check_consumption_status: Whether to check the consumption status of the | ||
| partition to decide when to stop iterating. Defaults to ``False``, which | ||
| means the iterator runs as an **infinite stream** — it will continuously | ||
| poll for new data and never exit on its own. This is the typical mode for | ||
| online/streaming training where producers keep feeding data indefinitely. | ||
| Set to ``True`` when the total number of samples is known in advance (i.e. | ||
| finite-dataset mode); the iterator will then stop once all samples in the | ||
| partition have been consumed. | ||
| fetch_batch_fn: Optional custom function to retrieve batch data. | ||
| If None, uses default_fetch_batch_fn function. | ||
| process_batch_fn: Optional custom function to post-process | ||
|
|
@@ -123,6 +130,7 @@ def __init__( | |
| self.partition_id = partition_id | ||
| self.task_name = task_name | ||
| self.dp_rank = dp_rank | ||
| self.should_check_consumption_status = should_check_consumption_status | ||
| self.fetch_batch_fn = fetch_batch_fn if fetch_batch_fn else default_fetch_batch_fn | ||
| self.process_batch_fn = process_batch_fn if process_batch_fn else chunk_batch_fn | ||
|
|
||
|
|
@@ -149,65 +157,45 @@ def __init__( | |
| super().__init__() | ||
|
|
||
| def _create_client(self): | ||
| """Create and initialize a TransferQueue client. | ||
|
|
||
| This method initializes the TransferQueueClient with the provided configuration | ||
| and storage backend, and sets up the storage manager for data retrieval. | ||
|
|
||
| Raises: | ||
| ValueError: If controller_info or storage_backend is missing or invalid. | ||
| """Create and initialize a TransferQueue client directly from config. | ||
|
|
||
| This method creates a ``TransferQueueClient`` using the ZMQ address and | ||
| storage backend information already present in ``self.config``. It | ||
| intentionally does **not** call ``tq.init()`` because that relies on Ray | ||
| internally (``ray.get_actor`` / ``ray.get``), which is **unsafe in | ||
| forked subprocesses** spawned by PyTorch DataLoader (``num_workers > 0``). | ||
| Creating the client directly via ZMQ avoids this issue. | ||
| """ | ||
| client_id = uuid.uuid4().hex[:8] | ||
|
|
||
| # TODO: DEPRECATE in future | ||
| controller_config = self.config.get("controller", None) | ||
| if controller_config: | ||
| controller_info = controller_config.get("zmq_info", None) | ||
| else: | ||
| controller_info = self.config.get("controller_info", None) | ||
| if controller_info: | ||
| warnings.warn( | ||
| "Config entry `controller_info` will be deprecated in 0.1.7, please " | ||
| "use `controller.zmq_info` instead.", | ||
| category=DeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
|
|
||
| if not controller_info or not isinstance(controller_info, ZMQServerInfo): | ||
| raise ValueError("Invalid or missing controller.zmq_info in config") | ||
|
|
||
| backend_config = self.config.get("backend", None) | ||
| if not backend_config: | ||
| storage_backend = self.config.get("storage_backend", None) | ||
| backend_config = self.config | ||
| if storage_backend: | ||
| warnings.warn( | ||
| "Config entry `storage_backend` will be deprecated in 0.1.7, please " | ||
| "use `backend.storage_backend` instead.", | ||
| category=DeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
| else: | ||
| storage_backend = backend_config.get("storage_backend", None) | ||
| backend_config = self.config.backend[storage_backend] | ||
|
|
||
| if not storage_backend: | ||
| raise ValueError("Missing storage_backend in config") | ||
| client_id = f"StreamingDataset_{uuid.uuid4().hex[:8]}" | ||
|
|
||
| controller_info = self.config.controller.zmq_info | ||
| storage_backend = self.config.backend.storage_backend | ||
| backend_config = self.config.backend[storage_backend] | ||
|
|
||
| self._tq_client = TransferQueueClient(client_id, controller_info) | ||
| self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=backend_config) | ||
|
|
||
| def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: | ||
| """Iterate over the dataset, yielding batches of data. | ||
|
|
||
| The iteration behaviour depends on ``should_check_consumption_status``: | ||
|
|
||
| - **False (default — streaming mode)**: The iterator runs as an | ||
| infinite stream, continuously polling TransferQueue for new data. | ||
| It will sleep for `TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL` seconds | ||
| (default=1) when no data is available and | ||
| resume once new batches are produced. This is the standard mode for | ||
| online / streaming training pipelines where producers feed data | ||
| indefinitely. | ||
| - **True (finite-dataset mode)**: The iterator terminates once all | ||
| samples in the partition have been consumed (as reported by | ||
| ``check_consumption_status``), *and* all buffered batches have been | ||
| yielded. | ||
|
|
||
| Yields: | ||
| Tuple[TensorDict, BatchMeta]: A tuple containing: | ||
| - TensorDict: Batch of data with the requested fields. | ||
| - BatchMeta: Corresponding metadata to interact with TransferQueue. | ||
| Note: | ||
| This iterator runs indefinitely until the data source is exhausted. | ||
| The caller should handle StopIteration when appropriate (e.g., when | ||
| all data has been consumed and no more data will be produced). | ||
| """ | ||
| if self._tq_client is None: | ||
| self._create_client() | ||
|
|
@@ -218,24 +206,26 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: | |
| # TQ_PRE_ALLOC_SAMPLE_NUM to the required global_batch_size to make sure consumers can accurately | ||
| # determine consumption status even before producers have generated the samples. | ||
| while ( | ||
| not self._tq_client.check_consumption_status(self.task_name, self.partition_id) | ||
| not self.should_check_consumption_status | ||
| or not self._tq_client.check_consumption_status(self.task_name, self.partition_id) | ||
| or self.batch_index <= len(self.buffer) - 1 | ||
| ): | ||
| try: | ||
| if self.batch_index <= len(self.buffer) - 1: | ||
| current_data = self.buffer[self.batch_index] | ||
| self.batch_index += 1 | ||
| logger.debug(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}") | ||
| yield from self.process_batch_fn(*current_data, micro_batch_size=self.micro_batch_size) | ||
|
|
||
| else: | ||
| batch_data, batch_meta = self.fetch_batch_fn( | ||
| self._tq_client, | ||
| self.data_fields, | ||
| self.batch_size, | ||
| self.partition_id, | ||
| self.task_name, | ||
| self.sampling_config, | ||
| self.batch_index, | ||
| tq_client=self._tq_client, | ||
| data_fields=self.data_fields, | ||
| batch_size=self.batch_size, | ||
| partition_id=self.partition_id, | ||
| task_name=self.task_name, | ||
| sampling_config=self.sampling_config, | ||
| batch_index=self.batch_index, | ||
| ) | ||
|
Comment on lines
221
to
229
|
||
| if batch_data is not None: | ||
| self.buffer.append((batch_data, batch_meta)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
production_status: Tensor = torch.zeros(...)is a mutable dataclass default, so allDataPartitionStatusinstances will share the same underlying tensor unless it gets reassigned. Because methods likeupdate_production_statusmutateself.production_statusin-place, this can cause cross-partition contamination. Usefield(default_factory=lambda: torch.zeros(...))(and similarly for any other tensor defaults) so each partition gets its own tensor instance.