Skip to content

[feat] Add SeqlenBalancedSampler and enhance StreamingDataset support#70

Merged
0oshowero0 merged 3 commits intoAscend:mainfrom
rednote-ai:feat/yuzhe/support_seqlen_balanced_samplers
Apr 2, 2026
Merged

[feat] Add SeqlenBalancedSampler and enhance StreamingDataset support#70
0oshowero0 merged 3 commits intoAscend:mainfrom
rednote-ai:feat/yuzhe/support_seqlen_balanced_samplers

Conversation

@NINGBENZHE
Copy link
Copy Markdown
Contributor

@NINGBENZHE NINGBENZHE commented Apr 1, 2026

🎯 Summary

This PR introduces the SeqlenBalancedSampler to optimize sequence length distribution across Data Parallel (DP) ranks during GRPO training. It also enhances StreamingDataset with proper streaming mode support and refactors the controller's polling mechanism to improve efficiency when data is insufficient.

✨ Key Features & Enhancements

1. SeqlenBalancedSampler (Sequence-Length Balanced Sampling)

  • Karmarkar-Karp Algorithm: Added a new sampler that extends GRPOGroupNSampler. It uses the Karmarkar-Karp largest differencing method to balance sequence lengths (total_lengths) across DP ranks, ensuring that each rank processes approximately the same total token count.
  • Group Integrity: Guarantees that complete prompt groups remain intact across ranks to fulfill pass@k metrics and GRPO advantage normalization requirements.
  • Assignment Caching: Implements state caching (_balanced_cache) so that once global sampling and balancing are computed for a batch, subsequent DP ranks can quickly retrieve their assigned chunks.

2. StreamingDataset Improvements

  • Finite vs. Infinite Stream: Introduced the should_check_consumption_status parameter.
    • False (Default): Operates in an infinite stream mode, continuously polling for new data (ideal for online/streaming pipelines).
    • True: Operates in finite-dataset mode, terminating iteration only after all samples in the partition are fully consumed.
  • Client Initialization Refactor: Refactored _create_client() to use init() and get_client() from transfer_queue.interface instead of manually setting up TransferQueueClient.

3. Controller Optimizations

  • Polling Mode Sampler Cache Lookup: Updated get_metadata to look up cached sampler states when operating in polling_mode. If dp_rank and batch_index are cached, it immediately returns the data instead of failing or entering redundant wait loops when ready_for_consume_indexes are insufficient.
  • Variable-size Batch Support: Updated the sampler length validation logic to accommodate variable-size batches returned by samplers like SeqlenBalancedSampler.

🛠️ Refactoring & Minor Fixes

  • Log Level Adjustments: Downgraded the 1D tensor shape warnings to logger.info() in client.py and metadata.py to reduce unnecessary noise.
  • Pre-allocation Scope: Moved TQ_PRE_ALLOC_SAMPLE_NUM environment variable resolution into local method scopes where appropriate.

🧪 Testing

  • Added comprehensive unit tests for SeqlenBalancedSampler covering initialization, fallback behavior, balanced partitioning logic with mock custom meta, group level integrity, and caching mechanisms.
  • Added explicit utility tests for the karmarkar_karp and get_seqlen_balanced_partitions functions (TestKarmarkarKarp).

@ascend-robot
Copy link
Copy Markdown

CLA Signature Pass

NINGBENZHE, thanks for your pull request. All authors of the commits have signed the CLA. 👍

- Add SeqlenBalancedSampler based on Karmarkar-Karp algorithm to balance
  sequence lengths across DP ranks for GRPO training
- Add streaming mode support for StreamingDataset via
  should_check_consumption_status parameter
- Add polling_mode sampler cache lookup in controller to avoid redundant
  sampling when data is insufficient
- Replace print() with logger.info() in controller
- Downgrade 1D tensor warnings to info level in client and metadata
- Add comprehensive unit tests for SeqlenBalancedSampler and KarmarkarKarp

Signed-off-by: 宁本哲 <ningbenzhe@xiaohongshu.com>
@NINGBENZHE NINGBENZHE force-pushed the feat/yuzhe/support_seqlen_balanced_samplers branch from 5a52448 to 6e9f23c Compare April 1, 2026 02:51
@ascend-robot
Copy link
Copy Markdown

CLA Signature Pass

NINGBENZHE, thanks for your pull request. All authors of the commits have signed the CLA. 👍

@0oshowero0 0oshowero0 changed the title [Feat] Add SeqlenBalancedSampler and enhance StreamingDataset support [feat] Add SeqlenBalancedSampler and enhance StreamingDataset support Apr 1, 2026

- **False (default — streaming mode)**: The iterator runs as an
infinite stream, continuously polling TransferQueue for new data.
It will block (with a 1-second sleep) when no data is available and
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
It will block (with a 1-second sleep) when no data is available and
It will sleep for `TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL` seconds (default=1) when no data is available and

Returns:
True if partition was created successfully, False if it already exists
"""
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this variable changing during training?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it won’t change during training. However, the previous implementation sometimes missed the environment variable at initialization, and it wasn’t read at runtime.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's quite strange..

Comment on lines +1318 to +1325
sampling_config = sampling_config or {}
states = self.sampler._states.get(partition_id, {}).get(task_name, {})
dp_rank = sampling_config.get("dp_rank", None)
batch_index = sampling_config.get("batch_index", None)

# Return cached result if available
if dp_rank is not None and dp_rank in states and batch_index in states[dp_rank]:
break
Copy link
Copy Markdown
Collaborator

@0oshowero0 0oshowero0 Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines are coupled with seqlen_balanced_sampler.py. We can implement a new interface that process the sampler's state for each type of sampler

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the generic logic for caching. After the initial consumption, the ready_for_consume_indexes check can be skipped if a cache exists.

Copy link
Copy Markdown
Collaborator

@0oshowero0 0oshowero0 Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the logic is general. Maybe we can add a check_caching_states interface in XXSampler

class XXSampler:
    def check_caching_states(self):
             pass

Comment on lines +82 to +85
partition_id: Partition identifier.
**kwargs: Must include ``dp_rank``, ``batch_index``, and
``partition`` (the ``DataPartitionStatus`` object from the
controller).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need both partition and partition_id?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DataPartitionStatus object already has its own name

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for caching and retrieving total_lengths via custon_meta.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean maybe we don't have to pass partition_id? Just get the ID from partition.partition_id

Comment on lines +726 to +730
def test_initialization_invalid_n_samples_per_prompt(self):
"""Test that n_samples_per_prompt must be positive (inherited from GRPO)."""
with pytest.raises(ValueError) as exc_info:
SeqlenBalancedSampler(n_samples_per_prompt=0, dp_size=2)
assert "must be positive" in str(exc_info.value)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove if duplicated (test coverd by GRPONSampler)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +701 to +714
def test_initialization_default(self):
"""Test SeqlenBalancedSampler default initialization."""
sampler = SeqlenBalancedSampler()
assert isinstance(sampler, GRPOGroupNSampler)
assert isinstance(sampler, BaseSampler)
assert sampler.n_samples_per_prompt == 1
assert sampler.dp_size == 1
assert sampler._balanced_cache == {}

def test_initialization_custom(self):
"""Test SeqlenBalancedSampler custom initialization."""
sampler = SeqlenBalancedSampler(n_samples_per_prompt=4, dp_size=2)
assert sampler.n_samples_per_prompt == 4
assert sampler.dp_size == 2
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not very necessary

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new sequence-length–balanced sampler for GRPO workflows and improves streaming consumption behavior, including controller-side polling optimizations and expanded test coverage.

Changes:

  • Introduces SeqlenBalancedSampler (Karmarkar–Karp based) and exports it from package/sampler modules.
  • Enhances StreamingDataset to support infinite-stream mode via should_check_consumption_status, and updates controller polling to reuse sampler cache when data is insufficient.
  • Adjusts logging verbosity (1D tensor warnings → info) and adds extensive unit tests for the new sampler and balancing utilities.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
transfer_queue/sampler/seqlen_balanced_sampler.py New sampler + Karmarkar–Karp utilities and caching logic for DP-rank balancing
transfer_queue/sampler/init.py Re-exports SeqlenBalancedSampler
transfer_queue/init.py Public API export for SeqlenBalancedSampler
transfer_queue/dataloader/streaming_dataset.py Adds streaming vs finite-dataset iteration mode; refactors client init/calls
transfer_queue/controller.py Polling-mode cache lookup to avoid redundant sampling; passes partition into sampler
transfer_queue/metadata.py Downgrades 1D tensor warning to info
transfer_queue/client.py Downgrades 1D tensor warning to info
tests/test_samplers.py Adds comprehensive tests for SeqlenBalancedSampler and balancing helpers
Comments suppressed due to low confidence (1)

transfer_queue/dataloader/streaming_dataset.py:222

  • The empty-batch backoff is hardcoded to time.sleep(1), but TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL (env-configurable) is defined at the top of the file and is currently unused. Use the constant here (or remove it) so users can tune polling latency/CPU usage via TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL as intended.
                        self.buffer.append((batch_data, batch_meta))
                    else:
                        time.sleep(1)


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +154 to +155
# Balance groups across DP ranks
balanced_group_partitions = get_seqlen_balanced_partitions(group_lengths, self.dp_size, equal_size=True)
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

balanced_group_partitions = get_seqlen_balanced_partitions(..., equal_size=True) will raise an AssertionError when the number of prompt-groups (num_groups = global_batch_size / n_samples_per_prompt) is not divisible by dp_size (e.g., n_samples_per_prompt=4, dp_size=2, per-DP batch_size=6 → global_batch_size=12 → num_groups=3). Consider either validating early that per-DP batch_size is a multiple of n_samples_per_prompt (so num_groups % dp_size == 0), or calling the balancer with equal_size=False at the group level and documenting that per-rank batch sizes may differ in that case.

Suggested change
# Balance groups across DP ranks
balanced_group_partitions = get_seqlen_balanced_partitions(group_lengths, self.dp_size, equal_size=True)
# Balance groups across DP ranks. When the number of groups is
# divisible by dp_size, we can enforce equal-sized group
# counts per rank. Otherwise, we relax the constraint to
# avoid assertion errors in the balancer.
equal_size_groups = num_groups % self.dp_size == 0
if not equal_size_groups:
logger.warning(
"SeqlenBalancedSampler: num_groups=%d is not divisible by dp_size=%d; "
"falling back to equal_size=False at the group level. Per-rank group "
"counts may differ.",
num_groups,
self.dp_size,
)
balanced_group_partitions = get_seqlen_balanced_partitions(
group_lengths,
self.dp_size,
equal_size=equal_size_groups,
)

Copilot uses AI. Check for mistakes.
Comment on lines +338 to +352
assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]"

def _check_and_sort_partitions(partitions):
assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}"
seen_idx = set()
sorted_partitions: list[list[int]] = [[] for _ in range(k_partitions)]
for i, partition in enumerate(partitions):
assert len(partition) > 0, f"the {i}-th partition is empty"
for idx in partition:
seen_idx.add(idx)
sorted_partitions[i] = sorted(partition)
assert seen_idx == set(range(len(seqlen_list)))
return sorted_partitions

partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size)
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

karmarkar_karp / get_seqlen_balanced_partitions use assert for input validation (e.g., len(seqlen_list) >= k_partitions, non-empty partitions). These checks are stripped when Python runs with optimizations (-O), which can turn invalid inputs into hard-to-debug downstream errors. Prefer raising ValueError (or a custom exception) with the same messages so validation is always enforced in production.

Suggested change
assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]"
def _check_and_sort_partitions(partitions):
assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}"
seen_idx = set()
sorted_partitions: list[list[int]] = [[] for _ in range(k_partitions)]
for i, partition in enumerate(partitions):
assert len(partition) > 0, f"the {i}-th partition is empty"
for idx in partition:
seen_idx.add(idx)
sorted_partitions[i] = sorted(partition)
assert seen_idx == set(range(len(seqlen_list)))
return sorted_partitions
partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size)
if len(seqlen_list) < k_partitions:
raise ValueError(
f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]"
)
def _check_and_sort_partitions(partitions):
if len(partitions) != k_partitions:
raise ValueError(f"{len(partitions)} != {k_partitions}")
seen_idx = set()
sorted_partitions: list[list[int]] = [[] for _ in range(k_partitions)]
for i, partition in enumerate(partitions):
if len(partition) == 0:
raise ValueError(f"the {i}-th partition is empty")
for idx in partition:
seen_idx.add(idx)
sorted_partitions[i] = sorted(partition)
expected_idx = set(range(len(seqlen_list)))
if seen_idx != expected_idx:
raise ValueError(
f"Invalid partition indices: expected {expected_idx}, got {seen_idx}"
)
return sorted_partitions
partitions = karmarkar_karp(
seqlen_list=seqlen_list,
k_partitions=k_partitions,
equal_size=equal_size,
)

Copilot uses AI. Check for mistakes.
Comment on lines 209 to 217
batch_data, batch_meta = self.fetch_batch_fn(
self._tq_client,
self.data_fields,
self.batch_size,
self.partition_id,
self.task_name,
self.sampling_config,
self.batch_index,
tq_client=self._tq_client,
data_fields=self.data_fields,
batch_size=self.batch_size,
partition_id=self.partition_id,
task_name=self.task_name,
sampling_config=self.sampling_config,
batch_index=self.batch_index,
)
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fetch_batch_fn is now invoked with keyword arguments. This can be a breaking change for users who passed a custom fetch_batch_fn that accepts only positional parameters (the previous call-site was positional). To preserve backward compatibility, consider calling it positionally (or supporting both via a small adapter / try-except) while keeping the documented parameter order.

Copilot uses AI. Check for mistakes.
# Values: 0 = not produced, 1 = ready for consumption
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))

production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8)
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

production_status: Tensor = torch.zeros(...) is a mutable dataclass default, so all DataPartitionStatus instances will share the same underlying tensor unless it gets reassigned. Because methods like update_production_status mutate self.production_status in-place, this can cause cross-partition contamination. Use field(default_factory=lambda: torch.zeros(...)) (and similarly for any other tensor defaults) so each partition gets its own tensor instance.

Suggested change
production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8)
production_status: Tensor = field(
default_factory=lambda: torch.zeros(
DataPartitionStatus.TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8
)
)

Copilot uses AI. Check for mistakes.
Signed-off-by: 宁本哲 <ningbenzhe@xiaohongshu.com>
@ascend-robot
Copy link
Copy Markdown

CLA Signature Pass

NINGBENZHE, thanks for your pull request. All authors of the commits have signed the CLA. 👍

Signed-off-by: 宁本哲 <ningbenzhe@xiaohongshu.com>
@ascend-robot
Copy link
Copy Markdown

CLA Signature Pass

NINGBENZHE, thanks for your pull request. All authors of the commits have signed the CLA. 👍

@0oshowero0 0oshowero0 merged commit f0047b9 into Ascend:main Apr 2, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants