Skip to content

Commit bafe1e8

Browse files
committed
Fix nvtx name for vllm internal executor and add nvtx for vllm worker
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
1 parent 790d63f commit bafe1e8

2 files changed

Lines changed: 14 additions & 3 deletions

File tree

nemo_rl/models/generation/vllm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
)
5454
from nemo_rl.models.huggingface.common import ModelFlag
5555
from nemo_rl.models.policy.utils import is_vllm_v1_engine_enabled
56+
from nemo_rl.utils.nsys import wrap_with_nvtx_name
5657

5758

5859
class VllmSpecificArgs(TypedDict):
@@ -448,6 +449,7 @@ def _build_sampling_params(
448449
include_stop_str_in_output=True,
449450
)
450451

452+
@wrap_with_nvtx_name("vllm_genertion_worker/generate")
451453
def generate(
452454
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
453455
) -> BatchedDataDict[GenerationOutputSpec]:
@@ -811,6 +813,7 @@ async def process_single_sample(sample_idx):
811813
await asyncio.gather(*sample_tasks, return_exceptions=True)
812814
raise e
813815

816+
@wrap_with_nvtx_name("vllm_genertion_worker/generate_text")
814817
def generate_text(
815818
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
816819
) -> BatchedDataDict[GenerationOutputSpec]:
@@ -1045,6 +1048,7 @@ async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> Non
10451048
"""Async version of prepare_refit_info."""
10461049
await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,))
10471050

1051+
@wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_ipc_handles")
10481052
def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool:
10491053
"""Update weights from IPC handles by delegating to the vLLM Worker implementation.
10501054
@@ -1156,6 +1160,7 @@ async def update_weights_from_ipc_handles_async(
11561160
traceback.print_exc()
11571161
return False
11581162

1163+
@wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_collective")
11591164
def update_weights_from_collective(self) -> bool:
11601165
"""Update the model weights from collective communication."""
11611166
try:

nemo_rl/models/generation/vllm_backend.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def prepare_refit_info(
6868
"""
6969
self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored
7070

71-
@wrap_with_nvtx_name("update_weights_from_global_ipc_handles")
71+
@wrap_with_nvtx_name(
72+
"vllm_internal_worker_extension/update_weights_from_global_ipc_handles"
73+
)
7274
def update_weights_from_global_ipc_handles(self, global_device_ipc_handles):
7375
"""Update weights from global IPC handles.
7476
@@ -82,7 +84,9 @@ def update_weights_from_global_ipc_handles(self, global_device_ipc_handles):
8284
local_device_ipc_handles = global_device_ipc_handles[device_uuid]
8385
return self.update_weights_from_local_ipc_handles(local_device_ipc_handles)
8486

85-
@wrap_with_nvtx_name("update_weights_from_local_ipc_handles")
87+
@wrap_with_nvtx_name(
88+
"vllm_internal_worker_extension/update_weights_from_local_ipc_handles"
89+
)
8690
def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
8791
"""Update weights from local IPC handles.
8892
@@ -159,7 +163,9 @@ def update_weights_from_local_ipc_handles(self, local_device_ipc_handles):
159163
)
160164
return False
161165

162-
@wrap_with_nvtx_name("update_weights_from_collective")
166+
@wrap_with_nvtx_name(
167+
"vllm_internal_worker_extension/update_weights_from_collective"
168+
)
163169
def update_weights_from_collective(self) -> bool:
164170
"""Update the model weights from collective communication."""
165171
assert self.state_dict_info is not None, (

0 commit comments

Comments
 (0)