Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tests/e2e/run_deepseek_megatron_parallelism.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ python3 -m verl.trainer.main_ppo --config-path=config \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \
actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \
actor_rollout_ref.actor.megatron.context_parallel_size=2 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
Expand All @@ -30,13 +31,15 @@ python3 -m verl.trainer.main_ppo --config-path=config \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \
actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \
actor_rollout_ref.ref.megatron.context_parallel_size=2 \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \
critic.optim.lr=2e-5 \
critic.model.path=$HOME/models/deepseek-ai/deepseek-coder-1.3b-instruct \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size_per_gpu=4 \
critic.megatron.pipeline_model_parallel_size=2 \
critic.megatron.virtual_pipeline_model_parallel_size=2 \
critic.megatron.context_parallel_size=2 \
critic.megatron.tensor_model_parallel_size=2 \
algorithm.use_kl_in_reward=True \
algorithm.kl_penalty=kl \
Expand Down
3 changes: 3 additions & 0 deletions tests/e2e/run_qwen_megatron_parallelism.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \
actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \
actor_rollout_ref.actor.megatron.context_parallel_size=2 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
Expand All @@ -33,13 +34,15 @@ python3 -m verl.trainer.main_ppo --config-path=config \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \
actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \
actor_rollout_ref.ref.megatron.context_parallel_size=2 \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \
critic.optim.lr=2e-5 \
critic.model.path=$HOME/models/Qwen/Qwen2.5-0.5B \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size_per_gpu=4 \
critic.megatron.pipeline_model_parallel_size=2 \
critic.megatron.virtual_pipeline_model_parallel_size=2 \
critic.megatron.context_parallel_size=2 \
critic.megatron.tensor_model_parallel_size=2 \
algorithm.use_kl_in_reward=True \
algorithm.kl_penalty=kl \
Expand Down
13 changes: 2 additions & 11 deletions tests/ray/detached_worker/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from megatron.core import parallel_state as mpu
from megatron.core.models.gpt.gpt_model import ModelType
from megatron.core import tensor_parallel
from verl.utils.megatron_utils import get_model, init_megatron_optim_config, init_model_parallel_config
from verl.utils.megatron_utils import get_model, init_megatron_optim_config, mcore_model_parallel_config
from verl.utils.megatron.optimizer import get_megatron_optimizer

from transformers import LlamaConfig
Expand Down Expand Up @@ -78,16 +78,7 @@ def init_model(self):
num_attention_heads=16,
num_key_value_heads=16)

megatron_config = OmegaConf.create({
'sequence_parallel_enabled': True,
'param_dtype': 'bf16',
'pipeline_model_parallel_rank': mpu.get_pipeline_model_parallel_rank(),
'pipeline_model_parallel_size': mpu.get_pipeline_model_parallel_world_size(),
'virtual_pipeline_model_parallel_rank': mpu.get_virtual_pipeline_model_parallel_rank(),
'virtual_pipeline_model_parallel_size': mpu.get_virtual_pipeline_model_parallel_world_size()
})

megatron_config = init_model_parallel_config(megatron_config)
megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16)
self.megatron_config = megatron_config

def megatron_actor_model_provider(pre_process, post_process):
Expand Down
84 changes: 70 additions & 14 deletions verl/models/mcore/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ def gptmodel_forward(model,
batch_size, seq_len = attention_mask.shape[:2]
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process)
input_ids_rmpad = input_ids_rmpad.contiguous()
output = model(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids,
packed_seq_params=packed_seq_params)
output = postprocess_packed_seqs(output,
output_orig = model(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids,
packed_seq_params=packed_seq_params)

output = postprocess_packed_seqs(output_orig,
packed_seq_params,
attention_mask,
batch_size,
Expand Down Expand Up @@ -67,12 +68,21 @@ def preprocess_packed_seqs(input_ids: torch.Tensor,
pre_process: bool = True) -> tuple[torch.Tensor, PackedSeqParams]:
"""
Preprocess packed sequences
CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 gets second and second last chunks, and so on), this is for load balancing with causal masking.
See https://github.com/NVIDIA/TransformerEngine/issues/1368
"""
batch_size = input_ids.shape[0]

seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
tp_size = mpu.get_tensor_model_parallel_world_size()
pad_size = (tp_size - seqlens_in_batch % tp_size) % tp_size
cp_size = mpu.get_context_parallel_world_size()
cp_rank = mpu.get_context_parallel_rank()
if cp_size > 1:
align_size = tp_size * cp_size * 2
else:
align_size = tp_size

pad_size = (align_size - seqlens_in_batch % align_size) % align_size
seqlens_in_batch_padded = seqlens_in_batch + pad_size
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)
Expand All @@ -81,12 +91,28 @@ def preprocess_packed_seqs(input_ids: torch.Tensor,
max_seqlen_in_batch = seqlens_in_batch_padded.max().item()

shape = list(input_ids.shape[1:])
shape[0] = seqlens_in_batch_padded.sum().item()
shape[0] = seqlens_in_batch_padded.sum().item() // cp_size
if pre_process:
input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)
for i in range(batch_size):
seqlen = seqlens_in_batch[i]
input_ids_rmpad[cu_seqlens_padded[i]:cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]]
if cp_size <= 1:
seqlen = seqlens_in_batch[i]
input_ids_rmpad[cu_seqlens_padded[i]:cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]]
continue
seqlen = seqlens_in_batch_padded[i] // cp_size
half_seqlen = seqlen // 2
start_idx = cu_seqlens_padded[i] // cp_size
# split to 2 chunks
d = input_ids[i, attention_mask[i]]
input_ids_rmpad[start_idx:start_idx + half_seqlen] = d[half_seqlen * cp_rank:half_seqlen * (cp_rank + 1)]

remain_start = seqlens_in_batch_padded[i] - half_seqlen * (cp_rank + 1)
remain_end = seqlens_in_batch_padded[i] - half_seqlen * cp_rank
remain_end = min(remain_end, d.shape[0])
remain_len = remain_end - remain_start
if remain_len > 0:
input_ids_rmpad[start_idx + half_seqlen:start_idx + half_seqlen +
remain_len] = d[remain_start:remain_end]

packed_seq_params = PackedSeqParams(qkv_format='thd',
cu_seqlens_q=cu_seqlens_padded,
Expand Down Expand Up @@ -114,11 +140,40 @@ def postprocess_packed_seqs(output: torch.Tensor,
return output
shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim
output_new = torch.zeros(shape, dtype=output.dtype, device=output.device)

cp_size = mpu.get_context_parallel_world_size()
# all gather output across context parallel group
if cp_size > 1:
# output shape: [1, packed_len, hidden_dim]
# need to gather across cp group and concatenate in sequence dimension
output_list = [torch.empty_like(output) for _ in range(cp_size)]
torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())
output_list[mpu.get_context_parallel_rank()] = output
else:
output_list = [output]
for i in range(batch_size):
s = attention_mask[i].sum().item()
output_new[i,
attention_mask[i]] = output[0][packed_seq_params.
cu_seqlens_q_padded[i]:packed_seq_params.cu_seqlens_q_padded[i] + s]
if cp_size <= 1:
s = attention_mask[i].sum().item()
output_new[i,
attention_mask[i]] = output[0][packed_seq_params.
cu_seqlens_q_padded[i]:packed_seq_params.cu_seqlens_q_padded[i] +
s]
continue
s_len_padded_chunk = (packed_seq_params.cu_seqlens_q_padded[i + 1] -
packed_seq_params.cu_seqlens_q_padded[i]) // cp_size
half_seqlen = s_len_padded_chunk // 2
s_len = attention_mask[i].sum().item()
s_len_padded = s_len_padded_chunk * cp_size
tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)
for j in range(cp_size):
o = output_list[j][0]
# split to 2 chunks
packed_start_idx = packed_seq_params.cu_seqlens_q_padded[i] // cp_size
o0, o1 = o[packed_start_idx:packed_start_idx +
half_seqlen], o[packed_start_idx + half_seqlen:packed_start_idx + s_len_padded_chunk]
tmp[j * half_seqlen:(j + 1) * half_seqlen] = o0
tmp[s_len_padded - (j + 1) * half_seqlen:s_len_padded - j * half_seqlen] = o1
output_new[i, attention_mask[i]] = tmp[:s_len]

return output_new

Expand All @@ -134,7 +189,8 @@ def remove_left_padding(input_ids: torch.Tensor,
"""
assert attention_mask.ndim == 2
assert position_ids.ndim == 2

cp_size = mpu.get_context_parallel_world_size()
assert cp_size == 1, 'Context parallel size without seq_pack is not supported'
batch_size = input_ids.shape[0]
shape = list(input_ids.shape) # batch_size, seq_len,...
seq_lens = attention_mask.sum(dim=1)
Expand Down
Loading