Skip to content

Commit b9a325b

Browse files
committed
[feat] Add SeqlenBalancedSampler and improve StreamingDataset
- Add SeqlenBalancedSampler based on Karmarkar-Karp algorithm to balance sequence lengths across DP ranks for GRPO training - Add streaming mode support for StreamingDataset via should_check_consumption_status parameter - Add polling_mode sampler cache lookup in controller to avoid redundant sampling when data is insufficient - Replace print() with logger.info() in controller - Downgrade 1D tensor warnings to info level in client and metadata - Add comprehensive unit tests for SeqlenBalancedSampler and KarmarkarKarp
1 parent 4e9ae22 commit b9a325b

8 files changed

Lines changed: 902 additions & 82 deletions

File tree

tests/test_samplers.py

Lines changed: 487 additions & 2 deletions
Large diffs are not rendered by default.

transfer_queue/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from .sampler import BaseSampler
3939
from .sampler.grpo_group_n_sampler import GRPOGroupNSampler
4040
from .sampler.rank_aware_sampler import RankAwareSampler
41+
from .sampler.seqlen_balanced_sampler import SeqlenBalancedSampler
4142
from .sampler.sequential_sampler import SequentialSampler
4243

4344
__all__ = (
@@ -76,6 +77,7 @@
7677
"GRPOGroupNSampler",
7778
"SequentialSampler",
7879
"RankAwareSampler",
80+
"SeqlenBalancedSampler",
7981
]
8082
)
8183

transfer_queue/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ async def async_put(
389389

390390
for field_name, field_data in data.items():
391391
if isinstance(field_data, torch.Tensor) and field_data.ndim == 1:
392-
logger.warning(
392+
logger.info(
393393
f"[{self.client_id}]: Data field '{field_name}' is a tensor with only one dimension. "
394394
f"You may receive 2D tensors in key-value based backend."
395395
)

transfer_queue/controller.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363
# Sample pre-allocation for StreamingDataLoader compatibility.
6464
# By pre-allocating sample indices (typically global_batch_size), consumers can accurately
6565
# determine consumption status even before producers have generated the samples.
66-
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))
6766

6867

6968
class PartitionIndexManager:
@@ -335,6 +334,7 @@ class DataPartitionStatus:
335334

336335
# Production status tensor - dynamically expandable
337336
# Values: 0 = not produced, 1 = ready for consumption
337+
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))
338338

339339
production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8)
340340

@@ -1050,6 +1050,8 @@ def create_partition(self, partition_id: str) -> bool:
10501050
Returns:
10511051
True if partition was created successfully, False if it already exists
10521052
"""
1053+
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))
1054+
10531055
if partition_id in self.partitions:
10541056
logger.warning(f"Partition {partition_id} already exists")
10551057
return False
@@ -1313,38 +1315,54 @@ def get_metadata(
13131315

13141316
if len(ready_for_consume_indexes) < batch_size:
13151317
if self.polling_mode:
1316-
logger.debug(
1317-
f"[{self.controller_id}]: Not enough data for task {task_name} in partition {partition_id}."
1318-
f" Required: {batch_size}, Available: {len(ready_for_consume_indexes)}."
1319-
f" Returning None due to polling mode."
1318+
sampling_config = sampling_config or {}
1319+
states = self.sampler._states.get(partition_id, {}).get(task_name, {})
1320+
dp_rank = sampling_config.get("dp_rank", None)
1321+
batch_index = sampling_config.get("batch_index", None)
1322+
1323+
# Return cached result if available
1324+
if dp_rank is not None and dp_rank in states and batch_index in states[dp_rank]:
1325+
break
1326+
else:
1327+
logger.debug(
1328+
f"[{self.controller_id}]: Not enough data for task {task_name} in "
1329+
f"partition {partition_id}. Required: {batch_size}, "
1330+
f"Available: {len(ready_for_consume_indexes)}."
1331+
f" Returning None due to polling mode."
1332+
)
1333+
return BatchMeta.empty()
1334+
else:
1335+
logger.warning(
1336+
f"[{self.controller_id}]: Insufficient data for task {task_name}. Required: {batch_size} "
1337+
f"samples with fields {data_fields} in partition {partition_id}, but only have "
1338+
f"{len(ready_for_consume_indexes)} samples meeting the criteria. "
1339+
f"Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
13201340
)
1321-
return BatchMeta.empty()
1341+
time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
13221342
if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT:
13231343
raise TimeoutError(
13241344
f"Timeout while waiting for sufficient data for task {task_name}. "
13251345
f"Required: {batch_size}, Available: {len(ready_for_consume_indexes)}"
13261346
)
1327-
logger.warning(
1328-
f"[{self.controller_id}]: Insufficient data for task {task_name}. Required: {batch_size} "
1329-
f"samples with fields {data_fields} in partition {partition_id}, but only have "
1330-
f"{len(ready_for_consume_indexes)} samples meeting the criteria. "
1331-
f"Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
1332-
)
1333-
time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
13341347
else:
13351348
break
13361349

13371350
batch_global_indexes, consumed_indexes = self.sampler(
13381351
ready_for_consume_indexes,
13391352
batch_size,
1353+
partition=self._get_partition(partition_id),
13401354
**(sampling_config or {}),
13411355
**kwargs,
13421356
)
13431357

1344-
# Check if we got valid results from the sampler
1345-
if len(batch_global_indexes) != batch_size:
1358+
# Check if we got valid results from the sampler.
1359+
# Some samplers (e.g. SeqlenBalancedSampler) may return variable-size
1360+
# batches per DP rank, so we only check for empty results.
1361+
if len(batch_global_indexes) == 0:
1362+
if self.polling_mode:
1363+
return BatchMeta.empty()
13461364
raise RuntimeError(
1347-
f"Sampler returned insufficient samples. Please check the sampler logic. "
1365+
f"Sampler returned no samples. Please check the sampler logic. "
13481366
f"Expected: {batch_size}, before sampling: {len(ready_for_consume_indexes)}, "
13491367
f"after sampling: {len(batch_global_indexes)}"
13501368
)
@@ -1826,7 +1844,7 @@ def _process_request(self):
18261844
partition_id=params["partition_id"],
18271845
mode=params.get("mode", "fetch"),
18281846
task_name=params.get("task_name"),
1829-
sampling_config=params.get("sampling_config"),
1847+
sampling_config=params.get("sampling_config", {}),
18301848
)
18311849

18321850
response_msg = ZMQMessage.create(

transfer_queue/dataloader/streaming_dataset.py

Lines changed: 38 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,14 @@
1616
import logging
1717
import os
1818
import time
19-
import uuid
20-
import warnings
2119
from typing import Callable, Iterator
2220

2321
from omegaconf import DictConfig
2422
from tensordict import TensorDict
2523
from torch.utils.data import IterableDataset
2624

27-
from transfer_queue import TransferQueueClient
25+
from transfer_queue.interface import get_client, init
2826
from transfer_queue.metadata import BatchMeta
29-
from transfer_queue.utils.zmq_utils import ZMQServerInfo
3027

3128
TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL = float(
3229
os.environ.get("TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL", 1)
@@ -77,6 +74,7 @@ def __init__(
7774
partition_id: str,
7875
task_name: str,
7976
dp_rank: int,
77+
should_check_consumption_status: bool = False,
8078
fetch_batch_fn: Callable | None = None,
8179
process_batch_fn: Callable | None = None,
8280
):
@@ -98,6 +96,14 @@ def __init__(
9896
which samples have been consumed by which task.
9997
dp_rank: The group ID of the current data group. All
10098
ranks with the same dp_rank will receive identical samples.
99+
should_check_consumption_status: Whether to check the consumption status of the
100+
partition to decide when to stop iterating. Defaults to ``False``, which
101+
means the iterator runs as an **infinite stream** — it will continuously
102+
poll for new data and never exit on its own. This is the typical mode for
103+
online/streaming training where producers keep feeding data indefinitely.
104+
Set to ``True`` when the total number of samples is known in advance (i.e.
105+
finite-dataset mode); the iterator will then stop once all samples in the
106+
partition have been consumed.
101107
fetch_batch_fn: Optional custom function to retrieve batch data.
102108
If None, uses default_fetch_batch_fn function.
103109
process_batch_fn: Optional custom function to post-process
@@ -123,6 +129,7 @@ def __init__(
123129
self.partition_id = partition_id
124130
self.task_name = task_name
125131
self.dp_rank = dp_rank
132+
self.should_check_consumption_status = should_check_consumption_status
126133
self.fetch_batch_fn = fetch_batch_fn if fetch_batch_fn else default_fetch_batch_fn
127134
self.process_batch_fn = process_batch_fn if process_batch_fn else chunk_batch_fn
128135

@@ -151,63 +158,32 @@ def __init__(
151158
def _create_client(self):
152159
"""Create and initialize a TransferQueue client.
153160
154-
This method initializes the TransferQueueClient with the provided configuration
155-
and storage backend, and sets up the storage manager for data retrieval.
156-
157-
Raises:
158-
ValueError: If controller_info or storage_backend is missing or invalid.
161+
This method initializes the TransferQueueClient with the provided configuration.
159162
"""
160-
client_id = uuid.uuid4().hex[:8]
161-
162-
# TODO: DEPRECATE in future
163-
controller_config = self.config.get("controller", None)
164-
if controller_config:
165-
controller_info = controller_config.get("zmq_info", None)
166-
else:
167-
controller_info = self.config.get("controller_info", None)
168-
if controller_info:
169-
warnings.warn(
170-
"Config entry `controller_info` will be deprecated in 0.1.7, please "
171-
"use `controller.zmq_info` instead.",
172-
category=DeprecationWarning,
173-
stacklevel=2,
174-
)
175-
176-
if not controller_info or not isinstance(controller_info, ZMQServerInfo):
177-
raise ValueError("Invalid or missing controller.zmq_info in config")
178-
179-
backend_config = self.config.get("backend", None)
180-
if not backend_config:
181-
storage_backend = self.config.get("storage_backend", None)
182-
backend_config = self.config
183-
if storage_backend:
184-
warnings.warn(
185-
"Config entry `storage_backend` will be deprecated in 0.1.7, please "
186-
"use `backend.storage_backend` instead.",
187-
category=DeprecationWarning,
188-
stacklevel=2,
189-
)
190-
else:
191-
storage_backend = backend_config.get("storage_backend", None)
192-
backend_config = self.config.backend[storage_backend]
193-
194-
if not storage_backend:
195-
raise ValueError("Missing storage_backend in config")
196-
197-
self._tq_client = TransferQueueClient(client_id, controller_info)
198-
self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=backend_config)
163+
164+
init(self.config)
165+
self._tq_client = get_client()
199166

200167
def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
201168
"""Iterate over the dataset, yielding batches of data.
202169
170+
The iteration behaviour depends on ``should_check_consumption_status``:
171+
172+
- **False (default — streaming mode)**: The iterator runs as an
173+
infinite stream, continuously polling TransferQueue for new data.
174+
It will block (with a 1-second sleep) when no data is available and
175+
resume once new batches are produced. This is the standard mode for
176+
online / streaming training pipelines where producers feed data
177+
indefinitely.
178+
- **True (finite-dataset mode)**: The iterator terminates once all
179+
samples in the partition have been consumed (as reported by
180+
``check_consumption_status``), *and* all buffered batches have been
181+
yielded.
182+
203183
Yields:
204184
Tuple[TensorDict, BatchMeta]: A tuple containing:
205185
- TensorDict: Batch of data with the requested fields.
206186
- BatchMeta: Corresponding metadata to interact with TransferQueue.
207-
Note:
208-
This iterator runs indefinitely until the data source is exhausted.
209-
The caller should handle StopIteration when appropriate (e.g., when
210-
all data has been consumed and no more data will be produced).
211187
"""
212188
if self._tq_client is None:
213189
self._create_client()
@@ -218,24 +194,26 @@ def __iter__(self) -> Iterator[tuple[TensorDict, BatchMeta]]:
218194
# TQ_PRE_ALLOC_SAMPLE_NUM to the required global_batch_size to make sure consumers can accurately
219195
# determine consumption status even before producers have generated the samples.
220196
while (
221-
not self._tq_client.check_consumption_status(self.task_name, self.partition_id)
197+
not self.should_check_consumption_status
198+
or not self._tq_client.check_consumption_status(self.task_name, self.partition_id)
222199
or self.batch_index <= len(self.buffer) - 1
223200
):
224201
try:
225202
if self.batch_index <= len(self.buffer) - 1:
226203
current_data = self.buffer[self.batch_index]
227204
self.batch_index += 1
205+
logger.info(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}")
228206
yield from self.process_batch_fn(*current_data, micro_batch_size=self.micro_batch_size)
229207

230208
else:
231209
batch_data, batch_meta = self.fetch_batch_fn(
232-
self._tq_client,
233-
self.data_fields,
234-
self.batch_size,
235-
self.partition_id,
236-
self.task_name,
237-
self.sampling_config,
238-
self.batch_index,
210+
tq_client=self._tq_client,
211+
data_fields=self.data_fields,
212+
batch_size=self.batch_size,
213+
partition_id=self.partition_id,
214+
task_name=self.task_name,
215+
sampling_config=self.sampling_config,
216+
batch_index=self.batch_index,
239217
)
240218
if batch_data is not None:
241219
self.buffer.append((batch_data, batch_meta))

transfer_queue/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def extract_field_schema(data: TensorDict) -> dict[str, dict[str, Any]]:
169169
f"expected batch_size[0]={batch_size}, got value.shape[0]={value.shape[0]}"
170170
)
171171
if len(value.shape) == 1:
172-
logger.warning(f"Receiving 1D tensor for field '{field_name}'. Unsqueeze the last dimension.")
172+
logger.info(f"Receiving 1D tensor for field '{field_name}'. Unsqueeze the last dimension.")
173173
value = value.unsqueeze(-1)
174174
first_item = value[0]
175175
else:

transfer_queue/sampler/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .base import BaseSampler
1717
from .grpo_group_n_sampler import GRPOGroupNSampler
1818
from .rank_aware_sampler import RankAwareSampler
19+
from .seqlen_balanced_sampler import SeqlenBalancedSampler
1920
from .sequential_sampler import SequentialSampler
2021

21-
__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler", "RankAwareSampler"]
22+
__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler", "RankAwareSampler", "SeqlenBalancedSampler"]

0 commit comments

Comments
 (0)