diff --git a/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh b/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh index 45e427f39a1..4b2ed18d0fd 100644 --- a/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh +++ b/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh @@ -32,4 +32,5 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ model.target_modules=all-linear \ model.strategy=fsdp \ ulysses_sequence_parallel_size=2 \ - use_remove_padding=true + use_remove_padding=true \ + trainer.device=npu diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py index c438e7a13f5..64be39da78a 100644 --- a/examples/split_placement/main_ppo_split.py +++ b/examples/split_placement/main_ppo_split.py @@ -203,7 +203,6 @@ def main_task(config): ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn, - device_name=config.trainer.device, ) trainer.init_workers() trainer.fit() diff --git a/recipe/dapo/main_dapo.py b/recipe/dapo/main_dapo.py index 268591b8d79..bb0b32fc4af 100644 --- a/recipe/dapo/main_dapo.py +++ b/recipe/dapo/main_dapo.py @@ -161,7 +161,6 @@ def run(self, config): ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn, - device_name=config.trainer.device, ) trainer.init_workers() trainer.fit() diff --git a/recipe/prime/main_prime.py b/recipe/prime/main_prime.py index 6bf7f5e45a1..882248cb389 100644 --- a/recipe/prime/main_prime.py +++ b/recipe/prime/main_prime.py @@ -140,7 +140,6 @@ def main_task(config, compute_score=None): ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn, - device_name=config.trainer.device, ) trainer.init_workers() trainer.fit() diff --git a/recipe/spin/main_spin.py b/recipe/spin/main_spin.py index 9a879ee77c3..a38b8f860a2 100644 --- a/recipe/spin/main_spin.py +++ b/recipe/spin/main_spin.py @@ -149,7 +149,6 @@ def run(self, config): ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn, - device_name=config.trainer.device, ) trainer.init_workers() trainer.fit_dpo() diff --git a/recipe/spin/spin_trainer.py b/recipe/spin/spin_trainer.py index fa435dbdd19..43789218f57 100644 --- a/recipe/spin/spin_trainer.py +++ b/recipe/spin/spin_trainer.py @@ -368,7 +368,7 @@ def __init__( val_dataset: Optional[Dataset] = None, collate_fn=None, train_sampler: Optional[Sampler] = None, - device_name="cuda", + device_name=None, ): # assert get_torch_device().is_available(), 'cuda must be available on driver' @@ -391,7 +391,7 @@ def __init__( self.ray_worker_group_cls = ray_worker_group_cls self.validation_generations_logger = ValidationGenerationsLogger() self.async_rollout_mode = False - self.device_name = device_name + self.device_name = device_name if device_name else self.config.trainer.device # define in-reward KL control # kl loss control currently not suppoorted @@ -807,13 +807,13 @@ def init_workers(self): wg_kwargs = {} # Setting up kwargs for RayWorkerGroup if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + wg_kwargs["device_name"] = self.device_name for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) wg_dict = self.ray_worker_group_cls( resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, - device_name=self.device_name, **wg_kwargs, ) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) diff --git a/recipe/sppo/main_sppo.py b/recipe/sppo/main_sppo.py index d99f4f2dc8c..9739009bc1e 100644 --- a/recipe/sppo/main_sppo.py +++ b/recipe/sppo/main_sppo.py @@ -144,7 +144,6 @@ def run(self, config): ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn, - device_name=config.trainer.device, ) trainer.init_workers() trainer.fit() diff --git a/recipe/sppo/sppo_ray_trainer.py b/recipe/sppo/sppo_ray_trainer.py index 15e2f9c4085..0725d293e2b 100644 --- a/recipe/sppo/sppo_ray_trainer.py +++ b/recipe/sppo/sppo_ray_trainer.py @@ -95,7 +95,7 @@ def __init__( val_dataset: Optional[Dataset] = None, collate_fn=None, train_sampler: Optional[Sampler] = None, - device_name="cuda", + device_name=None, ): self.tokenizer = tokenizer self.processor = processor @@ -115,7 +115,7 @@ def __init__( self.use_rm = Role.RewardModel in role_worker_mapping self.ray_worker_group_cls = ray_worker_group_cls self.validation_generations_logger = ValidationGenerationsLogger() - self.device_name = device_name + self.device_name = device_name if device_name else self.config.trainer.device # define in-reward KL control # kl loss control currently not supported diff --git a/tests/special_npu/run_qwen2_5_05b_sft_peft_sp2.sh b/tests/special_npu/run_qwen2_5_05b_sft_peft_sp2.sh index 1bb8fc4cdbc..f69a1105772 100644 --- a/tests/special_npu/run_qwen2_5_05b_sft_peft_sp2.sh +++ b/tests/special_npu/run_qwen2_5_05b_sft_peft_sp2.sh @@ -24,6 +24,7 @@ torchrun --standalone --nnodes=1 --nproc_per_node=8 \ model.target_modules=all-linear \ model.strategy=fsdp \ ulysses_sequence_parallel_size=2 \ - use_remove_padding=true + use_remove_padding=true \ + trainer.device=npu rm -rf ./outputs ./save_ckpts diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index 6c9495d6103..b692206beee 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -270,7 +270,6 @@ def __init__( worker_names=None, worker_handles: list[ray.actor.ActorHandle] = None, ray_wait_register_center_timeout: int = 300, - device_name="cuda", **kwargs, ) -> None: """Initialize a RayWorkerGroup. @@ -294,7 +293,7 @@ def __init__( # if a WorkerGroup is spawned from Colocate WorkerGroup, this indicates which sub-class is binded to # this WorkerGroup. self.sub_cls_name = "" - self.device_name = device_name + self.device_name = kwargs.get("device_name", "cuda") self.profile_steps = kwargs.get("profile_steps", None) self.worker_nsight_options = kwargs.get("worker_nsight_options", None) if self.worker_nsight_options is not None and self.worker_nsight_options["capture-range-end"] is None: diff --git a/verl/trainer/config/sft_trainer.yaml b/verl/trainer/config/sft_trainer.yaml index c3af1a48fad..b31fb402eee 100644 --- a/verl/trainer/config/sft_trainer.yaml +++ b/verl/trainer/config/sft_trainer.yaml @@ -65,3 +65,4 @@ trainer: nnodes: 1 n_gpus_per_node: 8 max_ckpt_to_keep: null # TODO + device: cuda diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index 531ebab6276..76f73653423 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -119,7 +119,7 @@ def __init__( # TODO: add checkpoint manager if self.device_mesh.get_rank() == 0: print(self.config) - self.device_name = get_device_name() + self.device_name = self.config.trainer.device def _normalize_config_bsz(self): dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 2a0b21dedba..3b3dc04a602 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -231,7 +231,6 @@ def run(self, config): val_dataset=val_dataset, collate_fn=collate_fn, train_sampler=train_sampler, - device_name=config.trainer.device, ) # Initialize the workers of the trainer. trainer.init_workers() diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 4f1de884d3a..9f81cfcd487 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -315,7 +315,7 @@ def __init__( val_dataset: Optional[Dataset] = None, collate_fn=None, train_sampler: Optional[Sampler] = None, - device_name="cuda", + device_name=None, ): """ Initialize distributed PPO trainer with Ray backend. @@ -334,7 +334,7 @@ def __init__( val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. collate_fn: Function to collate data samples into batches. train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. - device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to "cuda". + device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None. """ # Store the tokenizer for text processing @@ -355,7 +355,7 @@ def __init__( self.use_reference_policy = Role.RefPolicy in role_worker_mapping self.use_rm = Role.RewardModel in role_worker_mapping self.ray_worker_group_cls = ray_worker_group_cls - self.device_name = device_name + self.device_name = device_name if device_name else self.config.trainer.device self.validation_generations_logger = ValidationGenerationsLogger() # if ref_in_actor is True, the reference policy will be actor without lora applied @@ -895,13 +895,13 @@ def init_workers(self): wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( OmegaConf.select(self.config.trainer, "worker_nsight_options") ) + wg_kwargs["device_name"] = self.device_name for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) wg_dict = self.ray_worker_group_cls( resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, - device_name=self.device_name, **wg_kwargs, ) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())