-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[Feat]: Implement partial rollout #1826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,7 +37,7 @@ | |
| import torch.distributed | ||
| from omegaconf import DictConfig, OmegaConf | ||
| from tensordict import TensorDict | ||
| from vllm import LLM, SamplingParams | ||
| from vllm import LLM, RequestOutput, SamplingParams | ||
| from vllm.distributed import parallel_state as vllm_ps | ||
| from vllm.lora.request import LoRARequest | ||
| from vllm.worker.worker_base import WorkerWrapperBase | ||
|
|
@@ -239,12 +239,18 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: | |
| if batch_size != len(non_tensor_batch["raw_prompt_ids"]): | ||
| raise RuntimeError("vllm sharding manager is not work properly.") | ||
|
|
||
| raw_prompt_ids = non_tensor_batch.pop("raw_prompt_ids") | ||
| if "raw_response_ids" in non_tensor_batch: | ||
| raw_response_ids = non_tensor_batch.pop("raw_response_ids") | ||
| else: | ||
| raw_response_ids = np.fromiter(([] for _ in range(batch_size)), dtype=object) | ||
|
|
||
| if "multi_modal_data" in non_tensor_batch: | ||
| vllm_inputs = [] | ||
| for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")): | ||
| vllm_inputs.append({"prompt_token_ids": raw_prompt_ids, "multi_modal_data": multi_modal_data}) | ||
| else: | ||
| vllm_inputs = [{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")] | ||
| vllm_inputs = [{"prompt_token_ids": raw_prompt_ids_ + raw_response_ids_} for raw_prompt_ids_, raw_response_ids_ in zip(raw_prompt_ids, raw_response_ids)] | ||
|
|
||
| # ensure the type of `prompt_token_ids` passed to vllm is list[int] | ||
| # https://github.com/volcengine/verl/pull/772 | ||
|
|
@@ -273,6 +279,11 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: | |
| "temperature": self.config.val_kwargs.temperature, | ||
| "n": 1, # if validate, already repeat in ray_trainer | ||
| } | ||
| else: | ||
| kwargs = { | ||
| "n": 1, # also repeated in ray_trainer | ||
| "max_tokens": self.config.response_length // self.config.partial_rollout_max_split, | ||
| } | ||
|
|
||
| lora_requests = None | ||
| if self.lora_kwargs: | ||
|
|
@@ -283,7 +294,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: | |
|
|
||
| # users can customize different sampling_params at different run | ||
| with self.update_sampling_params(**kwargs): | ||
| outputs = self.inference_engine.generate( | ||
| outputs: list[RequestOutput] = self.inference_engine.generate( | ||
| prompts=vllm_inputs, # because we have already convert it to prompt token id | ||
| sampling_params=self.sampling_params, | ||
| lora_request=lora_requests, | ||
|
|
@@ -294,15 +305,22 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: | |
| # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) | ||
|
|
||
| response = [] | ||
| finished = [] | ||
| rollout_log_probs = [] | ||
| for output in outputs: | ||
| for sample_id in range(len(output.outputs)): | ||
| response_ids = output.outputs[sample_id].token_ids | ||
| response.append(response_ids) | ||
| filtered_response = [id if id < 151669 else 0 for id in response_ids] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this magic number? 151669 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like the vocab size of a certain model, but why do you need filter here? did you encounter model to generate token id larger than what's in the tokenizer?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Qwen3 will do that, and will cause vllm to raise an exception with Qwen's output is fed back to its input. I'm sorry that I have forgotten to mention it when writing the PR.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same happens to Qwen 2.5 Math too. |
||
| response.append(filtered_response) | ||
| finished.append(output.outputs[sample_id].finish_reason != "length") | ||
| curr_log_prob = [] | ||
| for i, logprob in enumerate(output.outputs[sample_id].logprobs): | ||
| curr_log_prob.append(logprob[response_ids[i]].logprob) | ||
| rollout_log_probs.append(curr_log_prob) | ||
| non_tensor_batch["finished"] = np.array(finished) | ||
| response = raw_response_ids + np.fromiter(response, dtype=object) | ||
| non_tensor_batch["raw_response_ids"] = response | ||
| non_tensor_batch["raw_prompt_ids"] = raw_prompt_ids | ||
|
|
||
| response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(idx.device) | ||
| rollout_log_probs = pad_2d_list_to_length(rollout_log_probs, -1, max_length=self.config.response_length).to(idx.device) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will running into issues here if filter_mask is either all True or all False:
RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([32]) and value.shape=torch.Size([0, 4096]).Basically select_idxs won't support empty selection
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. I'll test further as soon as the machines are vacant. I remembered encountering and fixing similar issues during early developing. I think during later training runs there were cases when all/no prompts are finished, and there was no RE.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm, turned out I was using an older version of verl (0.3). I think this issue has been fixed. thanks!