diff --git a/docs/workers/sglang_worker.rst b/docs/workers/sglang_worker.rst index 13e0066503f..e42b2004358 100644 --- a/docs/workers/sglang_worker.rst +++ b/docs/workers/sglang_worker.rst @@ -37,6 +37,7 @@ We use Qwen/Qwen2-7B-Instruct on the gsm8k dataset for a simple test. .. code-block:: bash + export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK=True PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ @@ -70,6 +71,51 @@ We use Qwen/Qwen2-7B-Instruct on the gsm8k dataset for a simple test. trainer.test_freq=10 \ trainer.total_epochs=15 2>&1 | tee verl_demo.log +Why export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +1. ``verl`` initializes a ``SGLangRollout`` module during rollout, which is used to evaluate/generate samples. + +2. ``SGLangRollout`` will initialize ``VerlEngine``, and further initialize a ``torch.distributed.DeviceMesh``, used to support Tensor Parallel (TP). + +3. ``DeviceMesh.init()`` internally checks the free GPU memory of all participating devices. If the difference is too large (more than ~10%), it directly reports an error to avoid initialization failures or deadlocks. + +Why might there be inconsistent GPU memory? +""""""""""""""""""""""""""""""""""""""""""" + +**1. Ray Distributed Actor loads the model at different times** + +``verl`` uses Ray-based multi-process, multi-GPU concurrent training. Each ``WorkerDict`` may be called at different times: + +.. code-block:: python + + self.rollout = SGLangRollout(...) + +Different workers initialize the model at different times → different memory usage. + +**2. Delayed initialization causes memory bias** + +Some workers start model loading/inference (e.g., ``generate_sequences()``, ``compute_log_prob()``) earlier than others. +Early workers already use up GPU memory → late workers still have empty memory → memory difference appears. + +**3. SGLang's TP init uses "all-device broadcast", but there's no uniform release timing** + +Although ``SGLangRollout`` may only involve subset of GPUs, its ``VerlEngine`` initialization calls ``torch.distributed.init_process_group()`` and broadcasts weights, so: + +- Non-rollout GPUs also join the communication. +- Later on, ``DeviceMesh`` init will fail due to "inconsistent memory". + +**4. Different FSDP/TP loading behaviors also lead to mismatch** + +If using: + +.. code-block:: bash + + actor.fsdp_config.param_offload=True + ref.fsdp_config.param_offload=True + +Then some workers keep params on CPU while others already sharded to GPU → leads to asymmetric memory layout. + Using SGLang as the Inference Backend for PPO Training Across Multiple Machines ------------------------------------------------------------------------------ SGLang also supports running verl's RAY-based cross-machine inference in IPv4 and IPv6 scenarios. In the script below, we use TP=16 for cross-machine inference. Suppose we have two interconnected machines: node0 with IP 10.94.16.4 and node1 with IP 10.94.16.5.