[feat] Add SeqlenBalancedSampler and enhance StreamingDataset support#70
Conversation
CLA Signature PassNINGBENZHE, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
- Add SeqlenBalancedSampler based on Karmarkar-Karp algorithm to balance sequence lengths across DP ranks for GRPO training - Add streaming mode support for StreamingDataset via should_check_consumption_status parameter - Add polling_mode sampler cache lookup in controller to avoid redundant sampling when data is insufficient - Replace print() with logger.info() in controller - Downgrade 1D tensor warnings to info level in client and metadata - Add comprehensive unit tests for SeqlenBalancedSampler and KarmarkarKarp Signed-off-by: 宁本哲 <ningbenzhe@xiaohongshu.com>
5a52448 to
6e9f23c
Compare
CLA Signature PassNINGBENZHE, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
SeqlenBalancedSampler and enhance StreamingDataset support
|
|
||
| - **False (default — streaming mode)**: The iterator runs as an | ||
| infinite stream, continuously polling TransferQueue for new data. | ||
| It will block (with a 1-second sleep) when no data is available and |
There was a problem hiding this comment.
| It will block (with a 1-second sleep) when no data is available and | |
| It will sleep for `TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL` seconds (default=1) when no data is available and |
| Returns: | ||
| True if partition was created successfully, False if it already exists | ||
| """ | ||
| TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1)) |
There was a problem hiding this comment.
Will this variable changing during training?
There was a problem hiding this comment.
No, it won’t change during training. However, the previous implementation sometimes missed the environment variable at initialization, and it wasn’t read at runtime.
transfer_queue/controller.py
Outdated
| sampling_config = sampling_config or {} | ||
| states = self.sampler._states.get(partition_id, {}).get(task_name, {}) | ||
| dp_rank = sampling_config.get("dp_rank", None) | ||
| batch_index = sampling_config.get("batch_index", None) | ||
|
|
||
| # Return cached result if available | ||
| if dp_rank is not None and dp_rank in states and batch_index in states[dp_rank]: | ||
| break |
There was a problem hiding this comment.
These lines are coupled with seqlen_balanced_sampler.py. We can implement a new interface that process the sampler's state for each type of sampler
There was a problem hiding this comment.
This is the generic logic for caching. After the initial consumption, the ready_for_consume_indexes check can be skipped if a cache exists.
There was a problem hiding this comment.
I agree that the logic is general. Maybe we can add a check_caching_states interface in XXSampler
class XXSampler:
def check_caching_states(self):
pass| partition_id: Partition identifier. | ||
| **kwargs: Must include ``dp_rank``, ``batch_index``, and | ||
| ``partition`` (the ``DataPartitionStatus`` object from the | ||
| controller). |
There was a problem hiding this comment.
Do we need both partition and partition_id?
There was a problem hiding this comment.
The DataPartitionStatus object already has its own name
There was a problem hiding this comment.
This is for caching and retrieving total_lengths via custon_meta.
There was a problem hiding this comment.
I mean maybe we don't have to pass partition_id? Just get the ID from partition.partition_id
tests/test_samplers.py
Outdated
| def test_initialization_invalid_n_samples_per_prompt(self): | ||
| """Test that n_samples_per_prompt must be positive (inherited from GRPO).""" | ||
| with pytest.raises(ValueError) as exc_info: | ||
| SeqlenBalancedSampler(n_samples_per_prompt=0, dp_size=2) | ||
| assert "must be positive" in str(exc_info.value) |
There was a problem hiding this comment.
Remove if duplicated (test coverd by GRPONSampler)
tests/test_samplers.py
Outdated
| def test_initialization_default(self): | ||
| """Test SeqlenBalancedSampler default initialization.""" | ||
| sampler = SeqlenBalancedSampler() | ||
| assert isinstance(sampler, GRPOGroupNSampler) | ||
| assert isinstance(sampler, BaseSampler) | ||
| assert sampler.n_samples_per_prompt == 1 | ||
| assert sampler.dp_size == 1 | ||
| assert sampler._balanced_cache == {} | ||
|
|
||
| def test_initialization_custom(self): | ||
| """Test SeqlenBalancedSampler custom initialization.""" | ||
| sampler = SeqlenBalancedSampler(n_samples_per_prompt=4, dp_size=2) | ||
| assert sampler.n_samples_per_prompt == 4 | ||
| assert sampler.dp_size == 2 |
There was a problem hiding this comment.
Pull request overview
Adds a new sequence-length–balanced sampler for GRPO workflows and improves streaming consumption behavior, including controller-side polling optimizations and expanded test coverage.
Changes:
- Introduces
SeqlenBalancedSampler(Karmarkar–Karp based) and exports it from package/sampler modules. - Enhances
StreamingDatasetto support infinite-stream mode viashould_check_consumption_status, and updates controller polling to reuse sampler cache when data is insufficient. - Adjusts logging verbosity (1D tensor warnings → info) and adds extensive unit tests for the new sampler and balancing utilities.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| transfer_queue/sampler/seqlen_balanced_sampler.py | New sampler + Karmarkar–Karp utilities and caching logic for DP-rank balancing |
| transfer_queue/sampler/init.py | Re-exports SeqlenBalancedSampler |
| transfer_queue/init.py | Public API export for SeqlenBalancedSampler |
| transfer_queue/dataloader/streaming_dataset.py | Adds streaming vs finite-dataset iteration mode; refactors client init/calls |
| transfer_queue/controller.py | Polling-mode cache lookup to avoid redundant sampling; passes partition into sampler |
| transfer_queue/metadata.py | Downgrades 1D tensor warning to info |
| transfer_queue/client.py | Downgrades 1D tensor warning to info |
| tests/test_samplers.py | Adds comprehensive tests for SeqlenBalancedSampler and balancing helpers |
Comments suppressed due to low confidence (1)
transfer_queue/dataloader/streaming_dataset.py:222
- The empty-batch backoff is hardcoded to
time.sleep(1), butTQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL(env-configurable) is defined at the top of the file and is currently unused. Use the constant here (or remove it) so users can tune polling latency/CPU usage viaTQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVALas intended.
self.buffer.append((batch_data, batch_meta))
else:
time.sleep(1)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Balance groups across DP ranks | ||
| balanced_group_partitions = get_seqlen_balanced_partitions(group_lengths, self.dp_size, equal_size=True) |
There was a problem hiding this comment.
balanced_group_partitions = get_seqlen_balanced_partitions(..., equal_size=True) will raise an AssertionError when the number of prompt-groups (num_groups = global_batch_size / n_samples_per_prompt) is not divisible by dp_size (e.g., n_samples_per_prompt=4, dp_size=2, per-DP batch_size=6 → global_batch_size=12 → num_groups=3). Consider either validating early that per-DP batch_size is a multiple of n_samples_per_prompt (so num_groups % dp_size == 0), or calling the balancer with equal_size=False at the group level and documenting that per-rank batch sizes may differ in that case.
| # Balance groups across DP ranks | |
| balanced_group_partitions = get_seqlen_balanced_partitions(group_lengths, self.dp_size, equal_size=True) | |
| # Balance groups across DP ranks. When the number of groups is | |
| # divisible by dp_size, we can enforce equal-sized group | |
| # counts per rank. Otherwise, we relax the constraint to | |
| # avoid assertion errors in the balancer. | |
| equal_size_groups = num_groups % self.dp_size == 0 | |
| if not equal_size_groups: | |
| logger.warning( | |
| "SeqlenBalancedSampler: num_groups=%d is not divisible by dp_size=%d; " | |
| "falling back to equal_size=False at the group level. Per-rank group " | |
| "counts may differ.", | |
| num_groups, | |
| self.dp_size, | |
| ) | |
| balanced_group_partitions = get_seqlen_balanced_partitions( | |
| group_lengths, | |
| self.dp_size, | |
| equal_size=equal_size_groups, | |
| ) |
| assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" | ||
|
|
||
| def _check_and_sort_partitions(partitions): | ||
| assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}" | ||
| seen_idx = set() | ||
| sorted_partitions: list[list[int]] = [[] for _ in range(k_partitions)] | ||
| for i, partition in enumerate(partitions): | ||
| assert len(partition) > 0, f"the {i}-th partition is empty" | ||
| for idx in partition: | ||
| seen_idx.add(idx) | ||
| sorted_partitions[i] = sorted(partition) | ||
| assert seen_idx == set(range(len(seqlen_list))) | ||
| return sorted_partitions | ||
|
|
||
| partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size) |
There was a problem hiding this comment.
karmarkar_karp / get_seqlen_balanced_partitions use assert for input validation (e.g., len(seqlen_list) >= k_partitions, non-empty partitions). These checks are stripped when Python runs with optimizations (-O), which can turn invalid inputs into hard-to-debug downstream errors. Prefer raising ValueError (or a custom exception) with the same messages so validation is always enforced in production.
| assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" | |
| def _check_and_sort_partitions(partitions): | |
| assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}" | |
| seen_idx = set() | |
| sorted_partitions: list[list[int]] = [[] for _ in range(k_partitions)] | |
| for i, partition in enumerate(partitions): | |
| assert len(partition) > 0, f"the {i}-th partition is empty" | |
| for idx in partition: | |
| seen_idx.add(idx) | |
| sorted_partitions[i] = sorted(partition) | |
| assert seen_idx == set(range(len(seqlen_list))) | |
| return sorted_partitions | |
| partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size) | |
| if len(seqlen_list) < k_partitions: | |
| raise ValueError( | |
| f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" | |
| ) | |
| def _check_and_sort_partitions(partitions): | |
| if len(partitions) != k_partitions: | |
| raise ValueError(f"{len(partitions)} != {k_partitions}") | |
| seen_idx = set() | |
| sorted_partitions: list[list[int]] = [[] for _ in range(k_partitions)] | |
| for i, partition in enumerate(partitions): | |
| if len(partition) == 0: | |
| raise ValueError(f"the {i}-th partition is empty") | |
| for idx in partition: | |
| seen_idx.add(idx) | |
| sorted_partitions[i] = sorted(partition) | |
| expected_idx = set(range(len(seqlen_list))) | |
| if seen_idx != expected_idx: | |
| raise ValueError( | |
| f"Invalid partition indices: expected {expected_idx}, got {seen_idx}" | |
| ) | |
| return sorted_partitions | |
| partitions = karmarkar_karp( | |
| seqlen_list=seqlen_list, | |
| k_partitions=k_partitions, | |
| equal_size=equal_size, | |
| ) |
| batch_data, batch_meta = self.fetch_batch_fn( | ||
| self._tq_client, | ||
| self.data_fields, | ||
| self.batch_size, | ||
| self.partition_id, | ||
| self.task_name, | ||
| self.sampling_config, | ||
| self.batch_index, | ||
| tq_client=self._tq_client, | ||
| data_fields=self.data_fields, | ||
| batch_size=self.batch_size, | ||
| partition_id=self.partition_id, | ||
| task_name=self.task_name, | ||
| sampling_config=self.sampling_config, | ||
| batch_index=self.batch_index, | ||
| ) |
There was a problem hiding this comment.
fetch_batch_fn is now invoked with keyword arguments. This can be a breaking change for users who passed a custom fetch_batch_fn that accepts only positional parameters (the previous call-site was positional). To preserve backward compatibility, consider calling it positionally (or supporting both via a small adapter / try-except) while keeping the documented parameter order.
| # Values: 0 = not produced, 1 = ready for consumption | ||
| TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1)) | ||
|
|
||
| production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8) |
There was a problem hiding this comment.
production_status: Tensor = torch.zeros(...) is a mutable dataclass default, so all DataPartitionStatus instances will share the same underlying tensor unless it gets reassigned. Because methods like update_production_status mutate self.production_status in-place, this can cause cross-partition contamination. Use field(default_factory=lambda: torch.zeros(...)) (and similarly for any other tensor defaults) so each partition gets its own tensor instance.
| production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8) | |
| production_status: Tensor = field( | |
| default_factory=lambda: torch.zeros( | |
| DataPartitionStatus.TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8 | |
| ) | |
| ) |
CLA Signature PassNINGBENZHE, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
CLA Signature PassNINGBENZHE, thanks for your pull request. All authors of the commits have signed the CLA. 👍 |
🎯 Summary
This PR introduces the
SeqlenBalancedSamplerto optimize sequence length distribution across Data Parallel (DP) ranks during GRPO training. It also enhancesStreamingDatasetwith proper streaming mode support and refactors the controller's polling mechanism to improve efficiency when data is insufficient.✨ Key Features & Enhancements
1.
SeqlenBalancedSampler(Sequence-Length Balanced Sampling)GRPOGroupNSampler. It uses the Karmarkar-Karp largest differencing method to balance sequence lengths (total_lengths) across DP ranks, ensuring that each rank processes approximately the same total token count._balanced_cache) so that once global sampling and balancing are computed for a batch, subsequent DP ranks can quickly retrieve their assigned chunks.2.
StreamingDatasetImprovementsshould_check_consumption_statusparameter.False(Default): Operates in an infinite stream mode, continuously polling for new data (ideal for online/streaming pipelines).True: Operates in finite-dataset mode, terminating iteration only after all samples in the partition are fully consumed._create_client()to useinit()andget_client()fromtransfer_queue.interfaceinstead of manually setting upTransferQueueClient.3. Controller Optimizations
get_metadatato look up cached sampler states when operating inpolling_mode. Ifdp_rankandbatch_indexare cached, it immediately returns the data instead of failing or entering redundant wait loops whenready_for_consume_indexesare insufficient.SeqlenBalancedSampler.🛠️ Refactoring & Minor Fixes
logger.info()inclient.pyandmetadata.pyto reduce unnecessary noise.TQ_PRE_ALLOC_SAMPLE_NUMenvironment variable resolution into local method scopes where appropriate.🧪 Testing
SeqlenBalancedSamplercovering initialization, fallback behavior, balanced partitioning logic with mock custom meta, group level integrity, and caching mechanisms.karmarkar_karpandget_seqlen_balanced_partitionsfunctions (TestKarmarkarKarp).