Skip to content

[Feat] Add SeqlenBalancedSampler and enhance StreamingDataset support#69

Closed
NINGBENZHE wants to merge 2 commits intoAscend:mainfrom
redai-infra:main
Closed

[Feat] Add SeqlenBalancedSampler and enhance StreamingDataset support#69
NINGBENZHE wants to merge 2 commits intoAscend:mainfrom
redai-infra:main

Conversation

@NINGBENZHE
Copy link
Copy Markdown
Contributor

feat] Add SeqlenBalancedSampler and improve StreamingDataset

  • Add SeqlenBalancedSampler based on Karmarkar-Karp algorithm to balance
    sequence lengths across DP ranks for GRPO training
  • Add streaming mode support for StreamingDataset via
    should_check_consumption_status parameter
  • Add polling_mode sampler cache lookup in controller to avoid redundant
    sampling when data is insufficient
  • Replace print() with logger.info() in controller
  • Downgrade 1D tensor warnings to info level in client and metadata
  • Add comprehensive unit tests for SeqlenBalancedSampler and KarmarkarKarp

NINGBENZHE and others added 2 commits March 31, 2026 12:33
- Add SeqlenBalancedSampler based on Karmarkar-Karp algorithm to balance
  sequence lengths across DP ranks for GRPO training
- Add streaming mode support for StreamingDataset via
  should_check_consumption_status parameter
- Add polling_mode sampler cache lookup in controller to avoid redundant
  sampling when data is insufficient
- Replace print() with logger.info() in controller
- Downgrade 1D tensor warnings to info level in client and metadata
- Add comprehensive unit tests for SeqlenBalancedSampler and KarmarkarKarp
…ced_samplers

[feat] Update and Add SeqlenBalancedSampler and improve StreamingDataset
@ascend-robot
Copy link
Copy Markdown

CLA Signature Guide

@NINGBENZHE , thanks for your pull request.

The following commit(s) are not associated with a signed Contributor License Agreement (CLA).

Commit Reason
[b9a325b [feat] Add SeqlenBalancedSample...](b9a325b) the email used in the commit is not linked to a signed CLA!
please verify that it matches the email you used when signing the CLA.
56cc9719 Merge pull request #1 from NINGB... the email used in the commit is not linked to a signed CLA!
please verify that it matches the email you used when signing the CLA.

To sign CLA, click here.

To check if your email is configured correctly, refer to the FAQs.

Once you've signed the CLA or updating your email, please comment /check-cla to revalidate CLA status.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new sequence-length balancing sampler for GRPO workloads and improves streaming/finite iteration behavior in StreamingDataset, along with controller-side polling optimizations and expanded sampler test coverage.

Changes:

  • Introduce SeqlenBalancedSampler (Karmarkar–Karp based) to balance token counts across DP ranks while preserving GRPO prompt-group integrity.
  • Extend StreamingDataset with an optional finite-dataset termination mode (should_check_consumption_status) and refactor client init to use transfer_queue.interface.
  • Update controller polling behavior to reuse sampler cache, pass partition context into samplers, and adjust logging severity/noise in a few places.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
transfer_queue/sampler/seqlen_balanced_sampler.py New sampler + KK partitioning utilities for DP seqlen balancing (with caching).
transfer_queue/sampler/init.py Export SeqlenBalancedSampler.
transfer_queue/init.py Re-export SeqlenBalancedSampler at package level.
transfer_queue/controller.py Polling-mode cache lookup, pass partition into sampler calls, and tweak sampling result validation.
transfer_queue/dataloader/streaming_dataset.py Add finite-dataset termination option and switch client init to interface.init/get_client.
transfer_queue/metadata.py Downgrade 1D tensor warning to info.
transfer_queue/client.py Downgrade 1D tensor warning to info.
tests/test_samplers.py Add extensive unit tests for SeqlenBalancedSampler and KK utilities; include sampler integration coverage.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

if self.batch_index <= len(self.buffer) - 1:
current_data = self.buffer[self.batch_index]
self.batch_index += 1
logger.info(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}")
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This per-microbatch logger.info runs on every yielded batch and can generate extremely high log volume in long-running streaming training jobs. Consider downgrading this to debug (or gating it behind a configurable log interval) to avoid performance/log-storage impact.

Suggested change
logger.info(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}")
logger.debug(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}")

Copilot uses AI. Check for mistakes.
Comment on lines +153 to +157
group_lengths = [sum(total_lengths[g * group_size : (g + 1) * group_size]) for g in range(num_groups)]
# Balance groups across DP ranks
balanced_group_partitions = get_seqlen_balanced_partitions(group_lengths, self.dp_size, equal_size=True)
# Expand group indices back to sample indices
partitions = []
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_seqlen_balanced_partitions(..., equal_size=True) asserts len(group_lengths) % dp_size == 0 (and len(group_lengths) >= dp_size). If the number of groups can’t be evenly split across ranks, this will raise at runtime. Consider validating those conditions before calling, or fall back to equal_size=False / sample-level balancing.

Copilot uses AI. Check for mistakes.
# Values: 0 = not produced, 1 = ready for consumption
TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1))

production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8)
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

production_status is initialized with a concrete torch.zeros(...) value at class definition time, so all DataPartitionStatus instances will share the same tensor object. Any in-place updates to one partition’s production_status can affect other partitions. Use a per-instance field(default_factory=...) to create a new tensor for each partition.

Suggested change
production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8)
production_status: Tensor = field(
default_factory=lambda: torch.zeros(
int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1)),
1,
dtype=torch.int8,
)
)

Copilot uses AI. Check for mistakes.

def get_seqlen_balanced_partitions(seqlen_list: list[int], k_partitions: int, equal_size: bool):
"""get order of seq lengths to make partitions balanced, this is
used in balacing sum of seqlength across dp ranks and microbatches
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring typo: “balacing” should be “balancing”.

Suggested change
used in balacing sum of seqlength across dp ranks and microbatches
used in balancing sum of seqlength across dp ranks and microbatches

Copilot uses AI. Check for mistakes.
Comment on lines 162 to 166
"""
client_id = uuid.uuid4().hex[:8]

# TODO: DEPRECATE in future
controller_config = self.config.get("controller", None)
if controller_config:
controller_info = controller_config.get("zmq_info", None)
else:
controller_info = self.config.get("controller_info", None)
if controller_info:
warnings.warn(
"Config entry `controller_info` will be deprecated in 0.1.7, please "
"use `controller.zmq_info` instead.",
category=DeprecationWarning,
stacklevel=2,
)

if not controller_info or not isinstance(controller_info, ZMQServerInfo):
raise ValueError("Invalid or missing controller.zmq_info in config")

backend_config = self.config.get("backend", None)
if not backend_config:
storage_backend = self.config.get("storage_backend", None)
backend_config = self.config
if storage_backend:
warnings.warn(
"Config entry `storage_backend` will be deprecated in 0.1.7, please "
"use `backend.storage_backend` instead.",
category=DeprecationWarning,
stacklevel=2,
)
else:
storage_backend = backend_config.get("storage_backend", None)
backend_config = self.config.backend[storage_backend]

if not storage_backend:
raise ValueError("Missing storage_backend in config")

self._tq_client = TransferQueueClient(client_id, controller_info)
self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=backend_config)

init(self.config)
self._tq_client = get_client()

Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_create_client() now uses init(self.config) / get_client(), but StreamingDataset.__init__ still documents config keys like controller.controller_info (and other deprecated entries). Consider updating that docstring to reflect the current initialization path and required config fields (e.g., controller.zmq_info, backend.storage_backend, etc.) so users don’t follow outdated guidance.

Copilot uses AI. Check for mistakes.
@NINGBENZHE
Copy link
Copy Markdown
Contributor Author

/check-cla

@ascend-robot
Copy link
Copy Markdown

CLA Signature Guide

@NINGBENZHE , thanks for your pull request.

The following commit(s) are not associated with a signed Contributor License Agreement (CLA).

Commit Reason
56cc9719 Merge pull request #1 from NINGB... the email used in the commit is not linked to a signed CLA!
please verify that it matches the email you used when signing the CLA.

To sign CLA, click here.

To check if your email is configured correctly, refer to the FAQs.

Once you've signed the CLA or updating your email, please comment /check-cla to revalidate CLA status.

@NINGBENZHE
Copy link
Copy Markdown
Contributor Author

/check-cla

@ascend-robot
Copy link
Copy Markdown

CLA Signature Guide

@NINGBENZHE , thanks for your pull request.

The following commit(s) are not associated with a signed Contributor License Agreement (CLA).

Commit Reason
56cc9719 Merge pull request #1 from NINGB... the email used in the commit is not linked to a signed CLA!
please verify that it matches the email you used when signing the CLA.

To sign CLA, click here.

To check if your email is configured correctly, refer to the FAQs.

Once you've signed the CLA or updating your email, please comment /check-cla to revalidate CLA status.

@NINGBENZHE NINGBENZHE closed this Apr 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants