-
Notifications
You must be signed in to change notification settings - Fork 17
[Feat] Add SeqlenBalancedSampler and enhance StreamingDataset support #69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -16,17 +16,14 @@ | |||||
| import logging | ||||||
| import os | ||||||
| import time | ||||||
| import uuid | ||||||
| import warnings | ||||||
| from typing import Callable, Iterator | ||||||
|
|
||||||
| from omegaconf import DictConfig | ||||||
| from tensordict import TensorDict | ||||||
| from torch.utils.data import IterableDataset | ||||||
|
|
||||||
| from transfer_queue import TransferQueueClient | ||||||
| from transfer_queue.interface import get_client, init | ||||||
| from transfer_queue.metadata import BatchMeta | ||||||
| from transfer_queue.utils.zmq_utils import ZMQServerInfo | ||||||
|
|
||||||
| TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL = float( | ||||||
| os.environ.get("TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL", 1) | ||||||
|
|
@@ -77,6 +74,7 @@ def __init__( | |||||
| partition_id: str, | ||||||
| task_name: str, | ||||||
| dp_rank: int, | ||||||
| should_check_consumption_status: bool = False, | ||||||
| fetch_batch_fn: Callable | None = None, | ||||||
| process_batch_fn: Callable | None = None, | ||||||
| ): | ||||||
|
|
@@ -98,6 +96,14 @@ def __init__( | |||||
| which samples have been consumed by which task. | ||||||
| dp_rank: The group ID of the current data group. All | ||||||
| ranks with the same dp_rank will receive identical samples. | ||||||
| should_check_consumption_status: Whether to check the consumption status of the | ||||||
| partition to decide when to stop iterating. Defaults to ``False``, which | ||||||
| means the iterator runs as an **infinite stream** — it will continuously | ||||||
| poll for new data and never exit on its own. This is the typical mode for | ||||||
| online/streaming training where producers keep feeding data indefinitely. | ||||||
| Set to ``True`` when the total number of samples is known in advance (i.e. | ||||||
| finite-dataset mode); the iterator will then stop once all samples in the | ||||||
| partition have been consumed. | ||||||
| fetch_batch_fn: Optional custom function to retrieve batch data. | ||||||
| If None, uses default_fetch_batch_fn function. | ||||||
| process_batch_fn: Optional custom function to post-process | ||||||
|
|
@@ -123,6 +129,7 @@ def __init__( | |||||
| self.partition_id = partition_id | ||||||
| self.task_name = task_name | ||||||
| self.dp_rank = dp_rank | ||||||
| self.should_check_consumption_status = should_check_consumption_status | ||||||
| self.fetch_batch_fn = fetch_batch_fn if fetch_batch_fn else default_fetch_batch_fn | ||||||
| self.process_batch_fn = process_batch_fn if process_batch_fn else chunk_batch_fn | ||||||
|
|
||||||
|
|
@@ -151,63 +158,32 @@ def __init__( | |||||
| def _create_client(self): | ||||||
| """Create and initialize a TransferQueue client. | ||||||
|
|
||||||
| This method initializes the TransferQueueClient with the provided configuration | ||||||
| and storage backend, and sets up the storage manager for data retrieval. | ||||||
|
|
||||||
| Raises: | ||||||
| ValueError: If controller_info or storage_backend is missing or invalid. | ||||||
| This method initializes the TransferQueueClient with the provided configuration. | ||||||
| """ | ||||||
| client_id = uuid.uuid4().hex[:8] | ||||||
|
|
||||||
| # TODO: DEPRECATE in future | ||||||
| controller_config = self.config.get("controller", None) | ||||||
| if controller_config: | ||||||
| controller_info = controller_config.get("zmq_info", None) | ||||||
| else: | ||||||
| controller_info = self.config.get("controller_info", None) | ||||||
| if controller_info: | ||||||
| warnings.warn( | ||||||
| "Config entry `controller_info` will be deprecated in 0.1.7, please " | ||||||
| "use `controller.zmq_info` instead.", | ||||||
| category=DeprecationWarning, | ||||||
| stacklevel=2, | ||||||
| ) | ||||||
|
|
||||||
| if not controller_info or not isinstance(controller_info, ZMQServerInfo): | ||||||
| raise ValueError("Invalid or missing controller.zmq_info in config") | ||||||
|
|
||||||
| backend_config = self.config.get("backend", None) | ||||||
| if not backend_config: | ||||||
| storage_backend = self.config.get("storage_backend", None) | ||||||
| backend_config = self.config | ||||||
| if storage_backend: | ||||||
| warnings.warn( | ||||||
| "Config entry `storage_backend` will be deprecated in 0.1.7, please " | ||||||
| "use `backend.storage_backend` instead.", | ||||||
| category=DeprecationWarning, | ||||||
| stacklevel=2, | ||||||
| ) | ||||||
| else: | ||||||
| storage_backend = backend_config.get("storage_backend", None) | ||||||
| backend_config = self.config.backend[storage_backend] | ||||||
|
|
||||||
| if not storage_backend: | ||||||
| raise ValueError("Missing storage_backend in config") | ||||||
|
|
||||||
| self._tq_client = TransferQueueClient(client_id, controller_info) | ||||||
| self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=backend_config) | ||||||
|
|
||||||
| init(self.config) | ||||||
| self._tq_client = get_client() | ||||||
|
|
||||||
|
Comment on lines
162
to
166
|
||||||
| def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: | ||||||
| """Iterate over the dataset, yielding batches of data. | ||||||
|
|
||||||
| The iteration behaviour depends on ``should_check_consumption_status``: | ||||||
|
|
||||||
| - **False (default — streaming mode)**: The iterator runs as an | ||||||
| infinite stream, continuously polling TransferQueue for new data. | ||||||
| It will block (with a 1-second sleep) when no data is available and | ||||||
| resume once new batches are produced. This is the standard mode for | ||||||
| online / streaming training pipelines where producers feed data | ||||||
| indefinitely. | ||||||
| - **True (finite-dataset mode)**: The iterator terminates once all | ||||||
| samples in the partition have been consumed (as reported by | ||||||
| ``check_consumption_status``), *and* all buffered batches have been | ||||||
| yielded. | ||||||
|
|
||||||
| Yields: | ||||||
| Tuple[TensorDict, BatchMeta]: A tuple containing: | ||||||
| - TensorDict: Batch of data with the requested fields. | ||||||
| - BatchMeta: Corresponding metadata to interact with TransferQueue. | ||||||
| Note: | ||||||
| This iterator runs indefinitely until the data source is exhausted. | ||||||
| The caller should handle StopIteration when appropriate (e.g., when | ||||||
| all data has been consumed and no more data will be produced). | ||||||
| """ | ||||||
| if self._tq_client is None: | ||||||
| self._create_client() | ||||||
|
|
@@ -218,24 +194,26 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]: | |||||
| # TQ_PRE_ALLOC_SAMPLE_NUM to the required global_batch_size to make sure consumers can accurately | ||||||
| # determine consumption status even before producers have generated the samples. | ||||||
| while ( | ||||||
| not self._tq_client.check_consumption_status(self.task_name, self.partition_id) | ||||||
| not self.should_check_consumption_status | ||||||
| or not self._tq_client.check_consumption_status(self.task_name, self.partition_id) | ||||||
| or self.batch_index <= len(self.buffer) - 1 | ||||||
| ): | ||||||
| try: | ||||||
| if self.batch_index <= len(self.buffer) - 1: | ||||||
| current_data = self.buffer[self.batch_index] | ||||||
| self.batch_index += 1 | ||||||
| logger.info(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}") | ||||||
|
||||||
| logger.info(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}") | |
| logger.debug(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
production_statusis initialized with a concretetorch.zeros(...)value at class definition time, so allDataPartitionStatusinstances will share the same tensor object. Any in-place updates to one partition’sproduction_statuscan affect other partitions. Use a per-instancefield(default_factory=...)to create a new tensor for each partition.