[Feat] Add SeqlenBalancedSampler and enhance StreamingDataset support#69
[Feat] Add SeqlenBalancedSampler and enhance StreamingDataset support#69NINGBENZHE wants to merge 2 commits intoAscend:mainfrom
Conversation
- Add SeqlenBalancedSampler based on Karmarkar-Karp algorithm to balance sequence lengths across DP ranks for GRPO training - Add streaming mode support for StreamingDataset via should_check_consumption_status parameter - Add polling_mode sampler cache lookup in controller to avoid redundant sampling when data is insufficient - Replace print() with logger.info() in controller - Downgrade 1D tensor warnings to info level in client and metadata - Add comprehensive unit tests for SeqlenBalancedSampler and KarmarkarKarp
…ced_samplers [feat] Update and Add SeqlenBalancedSampler and improve StreamingDataset
CLA Signature Guide@NINGBENZHE , thanks for your pull request. The following commit(s) are not associated with a signed Contributor License Agreement (CLA).
To sign CLA, click here. To check if your email is configured correctly, refer to the FAQs. Once you've signed the CLA or updating your email, please comment |
There was a problem hiding this comment.
Pull request overview
Adds a new sequence-length balancing sampler for GRPO workloads and improves streaming/finite iteration behavior in StreamingDataset, along with controller-side polling optimizations and expanded sampler test coverage.
Changes:
- Introduce
SeqlenBalancedSampler(Karmarkar–Karp based) to balance token counts across DP ranks while preserving GRPO prompt-group integrity. - Extend
StreamingDatasetwith an optional finite-dataset termination mode (should_check_consumption_status) and refactor client init to usetransfer_queue.interface. - Update controller polling behavior to reuse sampler cache, pass partition context into samplers, and adjust logging severity/noise in a few places.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| transfer_queue/sampler/seqlen_balanced_sampler.py | New sampler + KK partitioning utilities for DP seqlen balancing (with caching). |
| transfer_queue/sampler/init.py | Export SeqlenBalancedSampler. |
| transfer_queue/init.py | Re-export SeqlenBalancedSampler at package level. |
| transfer_queue/controller.py | Polling-mode cache lookup, pass partition into sampler calls, and tweak sampling result validation. |
| transfer_queue/dataloader/streaming_dataset.py | Add finite-dataset termination option and switch client init to interface.init/get_client. |
| transfer_queue/metadata.py | Downgrade 1D tensor warning to info. |
| transfer_queue/client.py | Downgrade 1D tensor warning to info. |
| tests/test_samplers.py | Add extensive unit tests for SeqlenBalancedSampler and KK utilities; include sampler integration coverage. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if self.batch_index <= len(self.buffer) - 1: | ||
| current_data = self.buffer[self.batch_index] | ||
| self.batch_index += 1 | ||
| logger.info(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}") |
There was a problem hiding this comment.
This per-microbatch logger.info runs on every yielded batch and can generate extremely high log volume in long-running streaming training jobs. Consider downgrading this to debug (or gating it behind a configurable log interval) to avoid performance/log-storage impact.
| logger.info(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}") | |
| logger.debug(f"StreamDataloader current batch index is {self.batch_index}/{len(self.buffer)}") |
| group_lengths = [sum(total_lengths[g * group_size : (g + 1) * group_size]) for g in range(num_groups)] | ||
| # Balance groups across DP ranks | ||
| balanced_group_partitions = get_seqlen_balanced_partitions(group_lengths, self.dp_size, equal_size=True) | ||
| # Expand group indices back to sample indices | ||
| partitions = [] |
There was a problem hiding this comment.
get_seqlen_balanced_partitions(..., equal_size=True) asserts len(group_lengths) % dp_size == 0 (and len(group_lengths) >= dp_size). If the number of groups can’t be evenly split across ranks, this will raise at runtime. Consider validating those conditions before calling, or fall back to equal_size=False / sample-level balancing.
| # Values: 0 = not produced, 1 = ready for consumption | ||
| TQ_PRE_ALLOC_SAMPLE_NUM = int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1)) | ||
|
|
||
| production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8) |
There was a problem hiding this comment.
production_status is initialized with a concrete torch.zeros(...) value at class definition time, so all DataPartitionStatus instances will share the same tensor object. Any in-place updates to one partition’s production_status can affect other partitions. Use a per-instance field(default_factory=...) to create a new tensor for each partition.
| production_status: Tensor = torch.zeros(TQ_PRE_ALLOC_SAMPLE_NUM, 1, dtype=torch.int8) | |
| production_status: Tensor = field( | |
| default_factory=lambda: torch.zeros( | |
| int(os.environ.get("TQ_PRE_ALLOC_SAMPLE_NUM", 1)), | |
| 1, | |
| dtype=torch.int8, | |
| ) | |
| ) |
|
|
||
| def get_seqlen_balanced_partitions(seqlen_list: list[int], k_partitions: int, equal_size: bool): | ||
| """get order of seq lengths to make partitions balanced, this is | ||
| used in balacing sum of seqlength across dp ranks and microbatches |
There was a problem hiding this comment.
Docstring typo: “balacing” should be “balancing”.
| used in balacing sum of seqlength across dp ranks and microbatches | |
| used in balancing sum of seqlength across dp ranks and microbatches |
| """ | ||
| client_id = uuid.uuid4().hex[:8] | ||
|
|
||
| # TODO: DEPRECATE in future | ||
| controller_config = self.config.get("controller", None) | ||
| if controller_config: | ||
| controller_info = controller_config.get("zmq_info", None) | ||
| else: | ||
| controller_info = self.config.get("controller_info", None) | ||
| if controller_info: | ||
| warnings.warn( | ||
| "Config entry `controller_info` will be deprecated in 0.1.7, please " | ||
| "use `controller.zmq_info` instead.", | ||
| category=DeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
|
|
||
| if not controller_info or not isinstance(controller_info, ZMQServerInfo): | ||
| raise ValueError("Invalid or missing controller.zmq_info in config") | ||
|
|
||
| backend_config = self.config.get("backend", None) | ||
| if not backend_config: | ||
| storage_backend = self.config.get("storage_backend", None) | ||
| backend_config = self.config | ||
| if storage_backend: | ||
| warnings.warn( | ||
| "Config entry `storage_backend` will be deprecated in 0.1.7, please " | ||
| "use `backend.storage_backend` instead.", | ||
| category=DeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
| else: | ||
| storage_backend = backend_config.get("storage_backend", None) | ||
| backend_config = self.config.backend[storage_backend] | ||
|
|
||
| if not storage_backend: | ||
| raise ValueError("Missing storage_backend in config") | ||
|
|
||
| self._tq_client = TransferQueueClient(client_id, controller_info) | ||
| self._tq_client.initialize_storage_manager(manager_type=storage_backend, config=backend_config) | ||
|
|
||
| init(self.config) | ||
| self._tq_client = get_client() | ||
|
|
There was a problem hiding this comment.
_create_client() now uses init(self.config) / get_client(), but StreamingDataset.__init__ still documents config keys like controller.controller_info (and other deprecated entries). Consider updating that docstring to reflect the current initialization path and required config fields (e.g., controller.zmq_info, backend.storage_backend, etc.) so users don’t follow outdated guidance.
|
/check-cla |
CLA Signature Guide@NINGBENZHE , thanks for your pull request. The following commit(s) are not associated with a signed Contributor License Agreement (CLA).
To sign CLA, click here. To check if your email is configured correctly, refer to the FAQs. Once you've signed the CLA or updating your email, please comment |
|
/check-cla |
CLA Signature Guide@NINGBENZHE , thanks for your pull request. The following commit(s) are not associated with a signed Contributor License Agreement (CLA).
To sign CLA, click here. To check if your email is configured correctly, refer to the FAQs. Once you've signed the CLA or updating your email, please comment |
feat] Add SeqlenBalancedSampler and improve StreamingDataset
sequence lengths across DP ranks for GRPO training
should_check_consumption_status parameter
sampling when data is insufficient