diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 3612b379a7a..7f8d32fe4ad 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -441,6 +441,9 @@ actor_rollout_ref: # number of responses (i.e. num sample times). > 1 for grpo n: 1 + # Whether to wake up inference engine in multi-stage. (Wake up model weights first, then resume kv cache) + multi_stage_wake_up: false + # Extra inference engine arguments (vllm, sglang). engine_kwargs: diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 6ca7c15c411..203e56cd84e 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -478,6 +478,7 @@ def _build_rollout(self, trust_remote_code=False): full_params="hf" in self.config.rollout.load_format, device_mesh=rollout_device_mesh, offload_param=self._is_offload_param, + multi_stage_wake_up=self.config.rollout.multi_stage_wake_up, ) log_gpu_memory_usage("After building sharding manager", logger=logger) diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 73b6b43bd2e..728bfd262ec 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -127,21 +127,27 @@ def __init__(self, **kwargs): # default to use dummy load format, which need to reload weights in first time self._need_reload = True - async def release_memory_occupation(self): + async def release_memory_occupation(self, tags: Optional[list[str]] = None): """Release GPU occupation temporarily.""" - obj = ReleaseMemoryOccupationReqInput() + if tags is None: + obj = ReleaseMemoryOccupationReqInput() + else: + obj = ReleaseMemoryOccupationReqInput(tags=tags) return await self.tokenizer_manager.release_memory_occupation(obj, None) - async def resume_memory_occupation(self): + async def resume_memory_occupation(self, tags: Optional[list[str]] = None): """Resume GPU occupation.""" - # because __init__ is a sync method, it can not call the async release_memory_occupation # have to move release_memory_occupation from __init__ to here + # For multi-stage awake, we run release weight and kv_cache when we resume weights for the first time. if self._need_reload: await self.release_memory_occupation() self._need_reload = False - obj = ResumeMemoryOccupationReqInput() + if tags is None: + obj = ResumeMemoryOccupationReqInput() + else: + obj = ResumeMemoryOccupationReqInput(tags=tags) return await self.tokenizer_manager.resume_memory_occupation(obj, None) async def update_weights_from_tensor( diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index ff72e382ad7..7a6050a9080 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -59,12 +59,14 @@ def __init__( full_params: bool = False, device_mesh: DeviceMesh = None, offload_param: bool = False, + multi_stage_wake_up: bool = False, ): self.module = module self.inference_engine = inference_engine self.model_config = model_config self.device_mesh = device_mesh self.offload_param = offload_param + self.multi_stage_wake_up = multi_stage_wake_up # Full params self.full_params = full_params @@ -95,7 +97,17 @@ def __init__( def __enter__(self): self.timing = {} with simple_timer("reshard", self.timing): + loop = asyncio.get_event_loop() + + if self.device_mesh["infer_tp"].get_local_rank() == 0: + if self.multi_stage_wake_up: + loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["weights"])) + log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger) + else: + loop.run_until_complete(self.inference_engine.resume_memory_occupation()) + log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger) get_torch_device().empty_cache() + log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) if self.offload_param: load_fsdp_model_to_gpu(self.module) @@ -105,7 +117,6 @@ def __enter__(self): params = {k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()} params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) # Copy, not share memory - loop = asyncio.get_event_loop() loop.run_until_complete(self.update_weights(params)) log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) @@ -115,6 +126,10 @@ def __enter__(self): get_torch_device().empty_cache() log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) + if self.multi_stage_wake_up: + loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["kv_cache"])) + log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger) + # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: self.torch_random_states = get_torch_device().get_rng_state() @@ -138,9 +153,6 @@ def __exit__(self, exc_type, exc_value, traceback): get_torch_device().set_rng_state(self.torch_random_states) async def update_weights(self, params): - if self.device_mesh["infer_tp"].get_local_rank() == 0: - await self.inference_engine.resume_memory_occupation() - # Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update named_tensors = [(k, v) for k, v in params.items()] load_format = None