1616import logging
1717import os
1818import time
19- import uuid
20- import warnings
2119from typing import Callable , Iterator
2220
2321from omegaconf import DictConfig
2422from tensordict import TensorDict
2523from torch .utils .data import IterableDataset
2624
27- from transfer_queue import TransferQueueClient
25+ from transfer_queue . interface import get_client , init
2826from transfer_queue .metadata import BatchMeta
29- from transfer_queue .utils .zmq_utils import ZMQServerInfo
3027
3128TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL = float (
3229 os .environ .get ("TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL" , 1 )
@@ -77,6 +74,7 @@ def __init__(
7774 partition_id : str ,
7875 task_name : str ,
7976 dp_rank : int ,
77+ should_check_consumption_status : bool = False ,
8078 fetch_batch_fn : Callable | None = None ,
8179 process_batch_fn : Callable | None = None ,
8280 ):
@@ -98,6 +96,14 @@ def __init__(
9896 which samples have been consumed by which task.
9997 dp_rank: The group ID of the current data group. All
10098 ranks with the same dp_rank will receive identical samples.
99+ should_check_consumption_status: Whether to check the consumption status of the
100+ partition to decide when to stop iterating. Defaults to ``False``, which
101+ means the iterator runs as an **infinite stream** — it will continuously
102+ poll for new data and never exit on its own. This is the typical mode for
103+ online/streaming training where producers keep feeding data indefinitely.
104+ Set to ``True`` when the total number of samples is known in advance (i.e.
105+ finite-dataset mode); the iterator will then stop once all samples in the
106+ partition have been consumed.
101107 fetch_batch_fn: Optional custom function to retrieve batch data.
102108 If None, uses default_fetch_batch_fn function.
103109 process_batch_fn: Optional custom function to post-process
@@ -123,6 +129,7 @@ def __init__(
123129 self .partition_id = partition_id
124130 self .task_name = task_name
125131 self .dp_rank = dp_rank
132+ self .should_check_consumption_status = should_check_consumption_status
126133 self .fetch_batch_fn = fetch_batch_fn if fetch_batch_fn else default_fetch_batch_fn
127134 self .process_batch_fn = process_batch_fn if process_batch_fn else chunk_batch_fn
128135
@@ -151,63 +158,32 @@ def __init__(
151158 def _create_client (self ):
152159 """Create and initialize a TransferQueue client.
153160
154- This method initializes the TransferQueueClient with the provided configuration
155- and storage backend, and sets up the storage manager for data retrieval.
156-
157- Raises:
158- ValueError: If controller_info or storage_backend is missing or invalid.
161+ This method initializes the TransferQueueClient with the provided configuration.
159162 """
160- client_id = uuid .uuid4 ().hex [:8 ]
161-
162- # TODO: DEPRECATE in future
163- controller_config = self .config .get ("controller" , None )
164- if controller_config :
165- controller_info = controller_config .get ("zmq_info" , None )
166- else :
167- controller_info = self .config .get ("controller_info" , None )
168- if controller_info :
169- warnings .warn (
170- "Config entry `controller_info` will be deprecated in 0.1.7, please "
171- "use `controller.zmq_info` instead." ,
172- category = DeprecationWarning ,
173- stacklevel = 2 ,
174- )
175-
176- if not controller_info or not isinstance (controller_info , ZMQServerInfo ):
177- raise ValueError ("Invalid or missing controller.zmq_info in config" )
178-
179- backend_config = self .config .get ("backend" , None )
180- if not backend_config :
181- storage_backend = self .config .get ("storage_backend" , None )
182- backend_config = self .config
183- if storage_backend :
184- warnings .warn (
185- "Config entry `storage_backend` will be deprecated in 0.1.7, please "
186- "use `backend.storage_backend` instead." ,
187- category = DeprecationWarning ,
188- stacklevel = 2 ,
189- )
190- else :
191- storage_backend = backend_config .get ("storage_backend" , None )
192- backend_config = self .config .backend [storage_backend ]
193-
194- if not storage_backend :
195- raise ValueError ("Missing storage_backend in config" )
196-
197- self ._tq_client = TransferQueueClient (client_id , controller_info )
198- self ._tq_client .initialize_storage_manager (manager_type = storage_backend , config = backend_config )
163+
164+ init (self .config )
165+ self ._tq_client = get_client ()
199166
200167 def __iter__ (self ) -> Iterator [tuple [TensorDict , BatchMeta ]]:
201168 """Iterate over the dataset, yielding batches of data.
202169
170+ The iteration behaviour depends on ``should_check_consumption_status``:
171+
172+ - **False (default — streaming mode)**: The iterator runs as an
173+ infinite stream, continuously polling TransferQueue for new data.
174+ It will block (with a 1-second sleep) when no data is available and
175+ resume once new batches are produced. This is the standard mode for
176+ online / streaming training pipelines where producers feed data
177+ indefinitely.
178+ - **True (finite-dataset mode)**: The iterator terminates once all
179+ samples in the partition have been consumed (as reported by
180+ ``check_consumption_status``), *and* all buffered batches have been
181+ yielded.
182+
203183 Yields:
204184 Tuple[TensorDict, BatchMeta]: A tuple containing:
205185 - TensorDict: Batch of data with the requested fields.
206186 - BatchMeta: Corresponding metadata to interact with TransferQueue.
207- Note:
208- This iterator runs indefinitely until the data source is exhausted.
209- The caller should handle StopIteration when appropriate (e.g., when
210- all data has been consumed and no more data will be produced).
211187 """
212188 if self ._tq_client is None :
213189 self ._create_client ()
@@ -218,24 +194,26 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
218194 # TQ_PRE_ALLOC_SAMPLE_NUM to the required global_batch_size to make sure consumers can accurately
219195 # determine consumption status even before producers have generated the samples.
220196 while (
221- not self ._tq_client .check_consumption_status (self .task_name , self .partition_id )
197+ not self .should_check_consumption_status
198+ or not self ._tq_client .check_consumption_status (self .task_name , self .partition_id )
222199 or self .batch_index <= len (self .buffer ) - 1
223200 ):
224201 try :
225202 if self .batch_index <= len (self .buffer ) - 1 :
226203 current_data = self .buffer [self .batch_index ]
227204 self .batch_index += 1
205+ logger .info (f"StreamDataloader current batch index is { self .batch_index } /{ len (self .buffer )} " )
228206 yield from self .process_batch_fn (* current_data , micro_batch_size = self .micro_batch_size )
229207
230208 else :
231209 batch_data , batch_meta = self .fetch_batch_fn (
232- self ._tq_client ,
233- self .data_fields ,
234- self .batch_size ,
235- self .partition_id ,
236- self .task_name ,
237- self .sampling_config ,
238- self .batch_index ,
210+ tq_client = self ._tq_client ,
211+ data_fields = self .data_fields ,
212+ batch_size = self .batch_size ,
213+ partition_id = self .partition_id ,
214+ task_name = self .task_name ,
215+ sampling_config = self .sampling_config ,
216+ batch_index = self .batch_index ,
239217 )
240218 if batch_data is not None :
241219 self .buffer .append ((batch_data , batch_meta ))
0 commit comments