Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,34 @@ def repeat(self, repeat_times=2, interleave=True):
meta_info=self.meta_info,
)

@staticmethod
def split(data_proto: "DataProto", filter_mask) -> tuple["DataProto", "DataProto"]:
"""
Split a DataProto into two based on a boolean mask.

Args:
data_proto: The DataProto to split
filter_mask: Boolean tensor/array where True values go to the first DataProto

Returns:
Tuple[DataProto, DataProto]: First DataProto with items where mask is True,
Second DataProto with items where mask is False
"""
# Convert to tensor if it's a list or numpy array
if isinstance(filter_mask, list):
filter_mask = torch.tensor(filter_mask, dtype=torch.bool)
elif isinstance(filter_mask, np.ndarray):
filter_mask = torch.from_numpy(filter_mask)

# Create inverse mask
inverse_mask = ~filter_mask

# Split into two DataProtos
first_proto = data_proto.select_idxs(filter_mask)
second_proto = data_proto.select_idxs(inverse_mask)
Comment on lines +783 to +784

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will running into issues here if filter_mask is either all True or all False:
RuntimeError: batch dimension mismatch, got self.batch_size=torch.Size([32]) and value.shape=torch.Size([0, 4096]).

Basically select_idxs won't support empty selection

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I'll test further as soon as the machines are vacant. I remembered encountering and fixing similar issues during early developing. I think during later training runs there were cases when all/no prompts are finished, and there was no RE.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, turned out I was using an older version of verl (0.3). I think this issue has been fixed. thanks!


return first_proto, second_proto

def unfold_column_chunks(self, n_split: int, split_keys: Optional[List[str]] = None):
"""Split along the second dim into `n_split`, unfold it to the first dim (batch dim)
Useful in passing grouped tensors that doesn't want to be shuffled in dataset.
Expand Down
2 changes: 2 additions & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ actor_rollout_ref:
top_p: 1
prompt_length: ${data.max_prompt_length} # for xperf_gpt
response_length: ${data.max_response_length}
partial_rollout_max_split: ${algorithm.partial_rollout_max_split}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
Expand Down Expand Up @@ -335,6 +336,7 @@ algorithm:
norm_adv_by_std_in_grpo: True
use_kl_in_reward: False
kl_penalty: kl # how to estimate kl divergence
partial_rollout_max_split: 1 # max rounds of rollout before the prompt is forced finished, 1 means no partial rollout
kl_ctrl:
type: fixed
kl_coef: 0.001
Expand Down
5 changes: 5 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ actor_rollout_ref:
# typically the same as data max response length
response_length: ${data.max_response_length}

partial_rollout_max_split: ${algorithm.partial_rollout_max_split}

# for vllm rollout
# Rollout model parameters type. Align with actor model's FSDP/Megatron type.
dtype: bfloat16
Expand Down Expand Up @@ -814,6 +816,9 @@ algorithm:
# How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full"
kl_penalty: kl

# max rounds of rollout before the prompt is forced finished, 1 means no partial rollout
partial_rollout_max_split: 1

# KL control configuration
kl_ctrl:

Expand Down
85 changes: 79 additions & 6 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ def __init__(
else:
raise NotImplementedError

self.enable_partial_rollout: bool = self.config.algorithm.partial_rollout_max_split > 1

self._validate_config()
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)

Expand Down Expand Up @@ -463,6 +465,9 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
assert config.actor_rollout_ref.rollout.multi_turn.tool_config_path is not None, "tool_config_path must be set when enabling multi_turn with tool, due to no role-playing support"
assert config.algorithm.adv_estimator in [AdvantageEstimator.GRPO], "only GRPO is tested for multi-turn with tool"

# check partial rollout config
assert config.data.max_response_length % config.algorithm.partial_rollout_max_split == 0, "max_response_length must be divisible by partial_rollout_max_split"

print("[validate_config] All configuration checks passed successfully!")

def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):
Expand Down Expand Up @@ -921,6 +926,9 @@ def fit(self):
self.global_steps += 1
last_val_metrics = None

partial_batch: Optional[DataProto] = None # samples whose rollout is not finished yet
staged_batch: Optional[DataProto] = None # samples whose rollout has been finished but not yet trained on

for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
do_profile = self.global_steps in self.config.trainer.profile_steps if self.config.trainer.profile_steps is not None else False
Expand All @@ -937,9 +945,17 @@ def fit(self):
timing_raw = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)

batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch.non_tensor_batch["age"] = np.ones(len(batch.batch), dtype=int)
batch.non_tensor_batch["raw_response_ids"] = np.fromiter(([] for _ in range(len(batch.batch))), dtype=object)

batch = DataProto.concat([partial_batch, batch]) if partial_batch is not None else batch

# pop those keys for generation
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids", "raw_response_ids"]
if "multi_modal_data" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("multi_modal_data")
if "raw_prompt" in batch.non_tensor_batch:
Expand All @@ -965,6 +981,61 @@ def fit(self):
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)

with marked_timer("filter", timing_raw):
batch = batch.union(gen_batch_output)

finished_mask = batch.non_tensor_batch.pop("finished")
finished_mask = (batch.non_tensor_batch["age"] == self.config.algorithm.partial_rollout_max_split) | finished_mask
staged_out, partial_batch = DataProto.split(batch, finished_mask)
staged_out.non_tensor_batch.pop("raw_prompt_ids")
staged_out.non_tensor_batch.pop("raw_response_ids")

partial_batch.non_tensor_batch["age"] += 1

if len(partial_batch.batch) > 0:
for key in ("input_ids", "attention_mask", "position_ids"):
tmp = partial_batch.batch.pop(key, None)
partial_batch.batch[key] = tmp[:, : self.config.data.max_prompt_length]

for key in ("prompts", "responses", "rollout_log_probs"):
# we don't support rollout_log_probs in this feature branch yet
partial_batch.batch.pop(key)
else:
partial_batch = None

# note that we no longer ensure the order of samples in staged_batch
staged_batch = DataProto.concat([staged_batch, staged_out]) if staged_batch is not None else staged_out

# prompts whose number of finished rollout is enough can be trained on
# while filtering, we ensure sample number is divisible by n_gpus_per_node and as large as possible
can_train_mask = np.zeros(len(staged_batch.batch), dtype=bool)
id2count = defaultdict(int)
required_rollouts = self.config.actor_rollout_ref.rollout.n
divisor = self.config.actor_rollout_ref.actor.ppo_mini_batch_size * required_rollouts

for uid in staged_batch.non_tensor_batch["uid"]:
id2count[uid] += 1
assert not id2count or max(id2count.values()) <= required_rollouts, "max number of responses exceeds rollout n"

complete_uids = [uid for uid, count in id2count.items() if count == required_rollouts]

total_complete_samples = len(complete_uids) * required_rollouts
max_usable_groups = (total_complete_samples // divisor) * divisor // required_rollouts
can_train_count = max_usable_groups * required_rollouts

if can_train_count == 0:
print(f"{total_complete_samples=}, no complete uid groups available. Keep generating...")
continue

selected_uids = set(complete_uids[:max_usable_groups])

for i, uid in enumerate(staged_batch.non_tensor_batch["uid"]):
if uid in selected_uids:
can_train_mask[i] = True

batch, staged_batch = DataProto.split(staged_batch, can_train_mask)
staged_batch.non_tensor_batch["age"] += 1

if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with marked_timer("gen_max", timing_raw, color="purple"):
gen_baseline_batch = deepcopy(gen_batch)
Expand All @@ -981,11 +1052,6 @@ def fit(self):

del gen_baseline_batch, gen_baseline_output

batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)

batch.batch["response_mask"] = compute_response_mask(batch)
# Balance the number of valid tokens across DP ranks.
# NOTE: This usually changes the order of data in the `batch`,
Expand Down Expand Up @@ -1143,6 +1209,13 @@ def fit(self):
"training/epoch": epoch,
}
)
if self.enable_partial_rollout:
metrics.update(
{
"training/can_train_count": can_train_count,
"training/total_complete_samples": total_complete_samples,
}
)
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
Expand Down
26 changes: 22 additions & 4 deletions verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import torch.distributed
from omegaconf import DictConfig, OmegaConf
from tensordict import TensorDict
from vllm import LLM, SamplingParams
from vllm import LLM, RequestOutput, SamplingParams
from vllm.distributed import parallel_state as vllm_ps
from vllm.lora.request import LoRARequest
from vllm.worker.worker_base import WorkerWrapperBase
Expand Down Expand Up @@ -239,12 +239,18 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
if batch_size != len(non_tensor_batch["raw_prompt_ids"]):
raise RuntimeError("vllm sharding manager is not work properly.")

raw_prompt_ids = non_tensor_batch.pop("raw_prompt_ids")
if "raw_response_ids" in non_tensor_batch:
raw_response_ids = non_tensor_batch.pop("raw_response_ids")
else:
raw_response_ids = np.fromiter(([] for _ in range(batch_size)), dtype=object)

if "multi_modal_data" in non_tensor_batch:
vllm_inputs = []
for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")):
vllm_inputs.append({"prompt_token_ids": raw_prompt_ids, "multi_modal_data": multi_modal_data})
else:
vllm_inputs = [{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")]
vllm_inputs = [{"prompt_token_ids": raw_prompt_ids_ + raw_response_ids_} for raw_prompt_ids_, raw_response_ids_ in zip(raw_prompt_ids, raw_response_ids)]

# ensure the type of `prompt_token_ids` passed to vllm is list[int]
# https://github.com/volcengine/verl/pull/772
Expand Down Expand Up @@ -273,6 +279,11 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
"temperature": self.config.val_kwargs.temperature,
"n": 1, # if validate, already repeat in ray_trainer
}
else:
kwargs = {
"n": 1, # also repeated in ray_trainer
"max_tokens": self.config.response_length // self.config.partial_rollout_max_split,
}

lora_requests = None
if self.lora_kwargs:
Expand All @@ -283,7 +294,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:

# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
outputs = self.inference_engine.generate(
outputs: list[RequestOutput] = self.inference_engine.generate(
prompts=vllm_inputs, # because we have already convert it to prompt token id
sampling_params=self.sampling_params,
lora_request=lora_requests,
Expand All @@ -294,15 +305,22 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
# if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)

response = []
finished = []
rollout_log_probs = []
for output in outputs:
for sample_id in range(len(output.outputs)):
response_ids = output.outputs[sample_id].token_ids
response.append(response_ids)
filtered_response = [id if id < 151669 else 0 for id in response_ids]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this magic number? 151669

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like the vocab size of a certain model, but why do you need filter here? did you encounter model to generate token id larger than what's in the tokenizer?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Qwen3 will do that, and will cause vllm to raise an exception with Qwen's output is fed back to its input. I'm sorry that I have forgotten to mention it when writing the PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same happens to Qwen 2.5 Math too.

response.append(filtered_response)
finished.append(output.outputs[sample_id].finish_reason != "length")
curr_log_prob = []
for i, logprob in enumerate(output.outputs[sample_id].logprobs):
curr_log_prob.append(logprob[response_ids[i]].logprob)
rollout_log_probs.append(curr_log_prob)
non_tensor_batch["finished"] = np.array(finished)
response = raw_response_ids + np.fromiter(response, dtype=object)
non_tensor_batch["raw_response_ids"] = response
non_tensor_batch["raw_prompt_ids"] = raw_prompt_ids

response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(idx.device)
rollout_log_probs = pad_2d_list_to_length(rollout_log_probs, -1, max_length=self.config.response_length).to(idx.device)
Expand Down