Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \
model.target_modules=all-linear \
model.strategy=fsdp \
ulysses_sequence_parallel_size=2 \
use_remove_padding=true
use_remove_padding=true \
trainer.device=npu
1 change: 0 additions & 1 deletion examples/split_placement/main_ppo_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def main_task(config):
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
device_name=config.trainer.device,
)
trainer.init_workers()
trainer.fit()
Expand Down
1 change: 0 additions & 1 deletion recipe/dapo/main_dapo.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def run(self, config):
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
device_name=config.trainer.device,
)
trainer.init_workers()
trainer.fit()
Expand Down
1 change: 0 additions & 1 deletion recipe/prime/main_prime.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def main_task(config, compute_score=None):
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
device_name=config.trainer.device,
)
trainer.init_workers()
trainer.fit()
Expand Down
1 change: 0 additions & 1 deletion recipe/spin/main_spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def run(self, config):
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
device_name=config.trainer.device,
)
trainer.init_workers()
trainer.fit_dpo()
Expand Down
6 changes: 3 additions & 3 deletions recipe/spin/spin_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def __init__(
val_dataset: Optional[Dataset] = None,
collate_fn=None,
train_sampler: Optional[Sampler] = None,
device_name="cuda",
device_name=None,
):
# assert get_torch_device().is_available(), 'cuda must be available on driver'

Expand All @@ -391,7 +391,7 @@ def __init__(
self.ray_worker_group_cls = ray_worker_group_cls
self.validation_generations_logger = ValidationGenerationsLogger()
self.async_rollout_mode = False
self.device_name = device_name
self.device_name = device_name if device_name else self.config.trainer.device

# define in-reward KL control
# kl loss control currently not suppoorted
Expand Down Expand Up @@ -807,13 +807,13 @@ def init_workers(self):
wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
wg_kwargs["device_name"] = self.device_name

for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(
resource_pool=resource_pool,
ray_cls_with_init=worker_dict_cls,
device_name=self.device_name,
**wg_kwargs,
)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
Expand Down
1 change: 0 additions & 1 deletion recipe/sppo/main_sppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def run(self, config):
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
device_name=config.trainer.device,
)
trainer.init_workers()
trainer.fit()
Expand Down
4 changes: 2 additions & 2 deletions recipe/sppo/sppo_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
val_dataset: Optional[Dataset] = None,
collate_fn=None,
train_sampler: Optional[Sampler] = None,
device_name="cuda",
device_name=None,
):
self.tokenizer = tokenizer
self.processor = processor
Expand All @@ -115,7 +115,7 @@ def __init__(
self.use_rm = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls
self.validation_generations_logger = ValidationGenerationsLogger()
self.device_name = device_name
self.device_name = device_name if device_name else self.config.trainer.device

# define in-reward KL control
# kl loss control currently not supported
Expand Down
3 changes: 2 additions & 1 deletion tests/special_npu/run_qwen2_5_05b_sft_peft_sp2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ torchrun --standalone --nnodes=1 --nproc_per_node=8 \
model.target_modules=all-linear \
model.strategy=fsdp \
ulysses_sequence_parallel_size=2 \
use_remove_padding=true
use_remove_padding=true \
trainer.device=npu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Adding trainer.device=npu directly to the script makes the script specific to NPU devices. It would be more flexible to allow the user to specify the device via an environment variable or command-line argument. This way, the same script can be used for different devices without modification.


rm -rf ./outputs ./save_ckpts
3 changes: 1 addition & 2 deletions verl/single_controller/ray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ def __init__(
worker_names=None,
worker_handles: list[ray.actor.ActorHandle] = None,
ray_wait_register_center_timeout: int = 300,
device_name="cuda",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@as12138 I agree it's a good refactor to move device_name to the kwargs from a positional args. Could you please add this API change to this issue: #2528, so that any custom extensions of single controller can have a source of API changes upon next verl version.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course, I have already added it.

**kwargs,
) -> None:
"""Initialize a RayWorkerGroup.
Expand All @@ -294,7 +293,7 @@ def __init__(
# if a WorkerGroup is spawned from Colocate WorkerGroup, this indicates which sub-class is binded to
# this WorkerGroup.
self.sub_cls_name = ""
self.device_name = device_name
self.device_name = kwargs.get("device_name", "cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The default value for device_name is hardcoded to cuda. If the intention is to support other devices like npu, this should be configurable or dynamically determined based on the available hardware. Otherwise, it will cause the program to crash if cuda is not available.

Suggested change
self.device_name = kwargs.get("device_name", "cuda")
self.device_name = kwargs.get("device_name", self.config.trainer.device)

self.profile_steps = kwargs.get("profile_steps", None)
self.worker_nsight_options = kwargs.get("worker_nsight_options", None)
if self.worker_nsight_options is not None and self.worker_nsight_options["capture-range-end"] is None:
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/sft_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ trainer:
nnodes: 1
n_gpus_per_node: 8
max_ckpt_to_keep: null # TODO
device: cuda
2 changes: 1 addition & 1 deletion verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
# TODO: add checkpoint manager
if self.device_mesh.get_rank() == 0:
print(self.config)
self.device_name = get_device_name()
self.device_name = self.config.trainer.device

def _normalize_config_bsz(self):
dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0)
Expand Down
1 change: 0 additions & 1 deletion verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def run(self, config):
val_dataset=val_dataset,
collate_fn=collate_fn,
train_sampler=train_sampler,
device_name=config.trainer.device,
)
# Initialize the workers of the trainer.
trainer.init_workers()
Expand Down
8 changes: 4 additions & 4 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def __init__(
val_dataset: Optional[Dataset] = None,
collate_fn=None,
train_sampler: Optional[Sampler] = None,
device_name="cuda",
device_name=None,
):
"""
Initialize distributed PPO trainer with Ray backend.
Expand All @@ -334,7 +334,7 @@ def __init__(
val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None.
collate_fn: Function to collate data samples into batches.
train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None.
device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to "cuda".
device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None.
"""

# Store the tokenizer for text processing
Expand All @@ -355,7 +355,7 @@ def __init__(
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
self.use_rm = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls
self.device_name = device_name
self.device_name = device_name if device_name else self.config.trainer.device
self.validation_generations_logger = ValidationGenerationsLogger()

# if ref_in_actor is True, the reference policy will be actor without lora applied
Expand Down Expand Up @@ -895,13 +895,13 @@ def init_workers(self):
wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
OmegaConf.select(self.config.trainer, "worker_nsight_options")
)
wg_kwargs["device_name"] = self.device_name

for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(
resource_pool=resource_pool,
ray_cls_with_init=worker_dict_cls,
device_name=self.device_name,
**wg_kwargs,
)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
Expand Down
Loading