Skip to content

FSDP2+TP2 demo script does not work #621

@xxman-google

Description

@xxman-google

Describe the bug

Running SFT with the following config examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml gave an error.

Steps/Code to reproduce bug

Under main, run

uv run examples/run_sft.py --config examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-fsdp2tp2sp.v2.yaml

Expected behavior
Encountered the following error:

[Rank 1] Loading state dict from rank 0... [repeated 2x across cluster]
(DTensorPolicyWorker pid=343603) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::lm_policy-0-2:DTensorPolicyWorker.__init__() (pid=343603, ip=10.182.0.80, actor_id=87eda088ff87e9a825b334fc01000000, repr=DTensorPolicyWorker[rank=2]) [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 12x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/nemo-rl/nemo_rl/models/policy/dtensor_policy_worker.py", line 266, in __init__ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     torch.distributed.broadcast(buf, src=0) [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     return func(*args, **kwargs) [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)            ^^^^^^^^^^^^^^^^^^^^^ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 2714, in broadcast [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     work = group.broadcast([tensor], opts) [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     return disable_fn(*args, **kwargs) [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     return fn(*args, **kwargs) [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)            ^^^^^^^^^^^^^^^^^^^ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_api.py", line 344, in __torch_dispatch__ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     return DTensor._op_dispatcher.dispatch( [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 167, in dispatch [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     op_info = self.unwrap_to_op_info(op_call, args, kwargs) [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)   File "/app/ray_venvs/nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker/lib/python3.12/site-packages/torch/distributed/tensor/_dispatch.py", line 393, in unwrap_to_op_info [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)     assert compute_mesh is not None, ( [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603)            ^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 6x across cluster]
(DTensorPolicyWorker pid=343603) AssertionError: found no DeviceMesh from dtensor args for c10d.broadcast_.default! [repeated 6x across cluster]

Environment overview (please complete the following information)

Environment details

If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:

  • OS version
  • PyTorch version
  • Python version

Additional context

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions