[rollout, data] fix: honor train_max_samples/val_max_samples in fully async rollouter#5359
Conversation
…ly async rollouter
There was a problem hiding this comment.
Code Review
The pull request successfully implements the plumbing for train_max_samples and val_max_samples in the FullyAsyncRollouter, bringing it closer to parity with the standard PPO trainer. However, there are a few issues regarding API consistency and potential runtime crashes when these new configuration options are used in conjunction with the use_trainer_do_validate flag.
| train_dataset = create_rl_dataset( | ||
| config.data.train_files, | ||
| config.data, | ||
| tokenizer, | ||
| processor, | ||
| max_samples=config.data.get("train_max_samples", -1), | ||
| ) |
There was a problem hiding this comment.
For consistency with the main PPO trainer path (verl/trainer/main_ppo.py) and to ensure the call matches the intended API usage, it is recommended to explicitly pass is_train=True when creating the training dataset.
train_dataset = create_rl_dataset(
config.data.train_files,
config.data,
tokenizer,
processor,
is_train=True,
max_samples=config.data.get("train_max_samples", -1),
)| val_dataset = create_rl_dataset( | ||
| config.data.val_files, | ||
| config.data, | ||
| tokenizer, | ||
| processor, | ||
| max_samples=config.data.get("val_max_samples", -1), | ||
| ) |
There was a problem hiding this comment.
The call to create_rl_dataset for the validation dataset is missing the is_train=False argument. While this parameter is currently unused in the default implementation of create_rl_dataset, it is part of the function signature and is explicitly passed in the main PPO trainer path. Including it ensures logical correctness and prevents potential issues if the dataset initialization logic is updated to rely on this flag in the future.
val_dataset = create_rl_dataset(
config.data.val_files,
config.data,
tokenizer,
processor,
is_train=False,
max_samples=config.data.get("val_max_samples", -1),
)| config.data, | ||
| tokenizer, | ||
| processor, | ||
| max_samples=config.data.get("val_max_samples", -1), |
There was a problem hiding this comment.
Setting val_max_samples introduces a risk of runtime crashes when async_training.use_trainer_do_validate is enabled. At line 123, val_dataset.split(total_gpus) is called, and the RLHFDataset.split implementation raises a ValueError if the dataset size is not exactly divisible by the number of splits. If a user sets a val_max_samples value that isn't a multiple of the total GPU count, the training will fail during initialization. It is recommended to either document this constraint or ensure the dataset size is adjusted to be divisible by total_gpus before splitting.
What does this PR do?
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
Design & Code Changes
verl/experimental/fully_async_policy/fully_async_rollouter.py: passmax_samples=when callingcreate_rl_dataset()for train/val.Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.