|
53 | 53 | ) |
54 | 54 | from nemo_rl.models.huggingface.common import ModelFlag |
55 | 55 | from nemo_rl.models.policy.utils import is_vllm_v1_engine_enabled |
| 56 | +from nemo_rl.utils.nsys import wrap_with_nvtx_name |
56 | 57 |
|
57 | 58 |
|
58 | 59 | class VllmSpecificArgs(TypedDict): |
@@ -448,6 +449,7 @@ def _build_sampling_params( |
448 | 449 | include_stop_str_in_output=True, |
449 | 450 | ) |
450 | 451 |
|
| 452 | + @wrap_with_nvtx_name("vllm_genertion_worker/generate") |
451 | 453 | def generate( |
452 | 454 | self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False |
453 | 455 | ) -> BatchedDataDict[GenerationOutputSpec]: |
@@ -811,6 +813,7 @@ async def process_single_sample(sample_idx): |
811 | 813 | await asyncio.gather(*sample_tasks, return_exceptions=True) |
812 | 814 | raise e |
813 | 815 |
|
| 816 | + @wrap_with_nvtx_name("vllm_genertion_worker/generate_text") |
814 | 817 | def generate_text( |
815 | 818 | self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False |
816 | 819 | ) -> BatchedDataDict[GenerationOutputSpec]: |
@@ -1045,6 +1048,7 @@ async def prepare_refit_info_async(self, state_dict_info: dict[str, Any]) -> Non |
1045 | 1048 | """Async version of prepare_refit_info.""" |
1046 | 1049 | await self.llm.collective_rpc("prepare_refit_info", args=(state_dict_info,)) |
1047 | 1050 |
|
| 1051 | + @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_ipc_handles") |
1048 | 1052 | def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: |
1049 | 1053 | """Update weights from IPC handles by delegating to the vLLM Worker implementation. |
1050 | 1054 |
|
@@ -1156,6 +1160,7 @@ async def update_weights_from_ipc_handles_async( |
1156 | 1160 | traceback.print_exc() |
1157 | 1161 | return False |
1158 | 1162 |
|
| 1163 | + @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_collective") |
1159 | 1164 | def update_weights_from_collective(self) -> bool: |
1160 | 1165 | """Update the model weights from collective communication.""" |
1161 | 1166 | try: |
|
0 commit comments