Skip to content

[rdt] Add CUDA IPC transport#59838

Merged
stephanie-wang merged 42 commits intoray-project:masterfrom
stephanie-wang:ipc
Jan 16, 2026
Merged

[rdt] Add CUDA IPC transport#59838
stephanie-wang merged 42 commits intoray-project:masterfrom
stephanie-wang:ipc

Conversation

@stephanie-wang
Copy link
Contributor

Adds a CUDA IPC transport to RDT. This relies on an internal torch function to serialize and deserialize a CUDA tensor across different processes. It may break if there are changes to torch.multiprocessing.reductions, but this seems to be the best stopgap solution.

One minor issue is that right now the receiver's buffers are allocated outside of the tensor transport manager. But ideally we should allow the tensor transport itself to allocate the receiver's buffers, since in this case we don't need to allocate any new buffers on the receiver. Will address in this a followup to update the tensor transport manager interface for recv_multiple_tensors.

avigyabb and others added 30 commits August 18, 2025 23:44
Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Signed-off-by: Qiaolin-Yu <liin1211@outlook.com>
@ray-gardener ray-gardener bot added the community-contribution Contributed by the community label Jan 5, 2026
return False

def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool:
return torch.cuda.is_available()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will check if cuda is available on the driver, not the actor. Either way, long term i'm not sure if this actor_has_tensor_transport should exist in the form it is now bc we can't have a ray.get on a .remote call

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good call, thanks. I guess I'll remove this for now and leave a TODO.

torch.cuda.current_stream().record_event(event)

device = gpu_object[0].device
ray_gpu_idx = ray.get_gpu_ids()[device.index]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a guarantee that the torch index will be the right index in the ray gpu ids list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ray will set the CUDA_VISIBLE_DEVICES to the assigned GPU IDs, so it should be the right index. I think the only case where it wouldn't be is if the user sets CUDA_VISIBLE_DEVICES themselves after the actor has been created. I'll add a note to the exception.

raise RuntimeError(
f"Expected CUDA IPC tensor reconstruction list_args[6] to be device ID, but got {list_args[6]}. Please file an issue at https://github.com/ray-project/ray/issues/new/choose."
)
list_args[6] = device.index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😭

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah :(


def double(self, data):
data.mul_(2)
torch.cuda.synchronize()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this synchronize, everything should work without it still?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah can probably remove it.

@tianyi-ge
Copy link
Contributor

Hi, I'm curious if ray can reuse nixl (ucx) to enable cuda_ipc? Is torch.multiprocessing.reductions better on performance?

@stephanie-wang
Copy link
Contributor Author

Hi, I'm curious if ray can reuse nixl (ucx) to enable cuda_ipc? Is torch.multiprocessing.reductions better on performance?

I think UCX will not support this behavior because the memory is actually shared, no copies.

Signed-off-by: Stephanie Wang <smwang@cs.washington.edu>
@stephanie-wang
Copy link
Contributor Author

Hi, I'm curious if ray can reuse nixl (ucx) to enable cuda_ipc? Is torch.multiprocessing.reductions better on performance?

I think UCX will not support this behavior because the memory is actually shared, no copies.

Hmm actually it does seem to support CUDA IPC but I'm not sure how it's exposed exactly since it is through shared memory.

@dayshah
Copy link
Contributor

dayshah commented Jan 13, 2026

#60076
allowing recv to create the tensors

Signed-off-by: Stephanie Wang <smwang@cs.washington.edu>
@stephanie-wang stephanie-wang added the go add ONLY when ready to merge, run all tests label Jan 14, 2026
Copy link
Contributor

@dayshah dayshah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs a couple updates from the latest merges

)

@staticmethod
def get_tensor_transport_metadata(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function isn't needed anymore

Comment on lines +148 to +149
tensor_transport_metadata: CudaIpcTransportMetadata,
communicator_metadata: CudaIpcCommunicatorMetadata,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be the parent type when overriding


@staticmethod
def recv_multiple_tensors(
tensors,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensors isn't a param anymore, have to return the tensors now

Signed-off-by: Stephanie Wang <smwang@cs.washington.edu>
Signed-off-by: Stephanie Wang <smwang@cs.washington.edu>
tensors.append(tensor)
return tensors

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's inconsistent for these to be static while the parent class's which this is overriding are not static. Same for all other statics that don't match parent here

Copy link
Contributor

@dayshah dayshah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚢

Signed-off-by: Stephanie Wang <smwang@cs.washington.edu>
Signed-off-by: Stephanie Wang <smwang@cs.washington.edu>

@property
def tensor_transport_backend(self) -> str:
return "CUDA_IPC"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Property decorator inconsistent with parent method interface

Low Severity

The tensor_transport_backend method uses a @property decorator, but the parent class TensorTransportManager declares it as a regular abstract method. All other implementations (NixlTensorTransport, CollectiveTensorTransport) define it as a regular method. This inconsistency means any polymorphic code calling transport.tensor_transport_backend() as a method would fail for CudaIpcTransport with a TypeError since calling a property returns a string, not a callable.

Fix in Cursor Fix in Web

@stephanie-wang stephanie-wang merged commit 5ee47bd into ray-project:master Jan 16, 2026
5 of 6 checks passed
limarkdcunha pushed a commit to limarkdcunha/ray that referenced this pull request Jan 18, 2026
Adds a CUDA IPC transport to RDT. This relies on an internal torch
function to serialize and deserialize a CUDA tensor across different
processes. It may break if there are changes to
[torch.multiprocessing.reductions](https://github.com/pytorch/pytorch/blob/1495b35d29512f303ab37780760c5e692158514b/torch/multiprocessing/reductions.py),
but this seems to be the best stopgap solution.

---------

Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Signed-off-by: Qiaolin-Yu <liin1211@outlook.com>
Signed-off-by: Stephanie Wang <smwang@cs.washington.edu>
Co-authored-by: Avi Basnet <avigyabb@stanford.edu>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
Signed-off-by: Limark Dcunha <limarkdcunha@gmail.com>
jeffery4011 pushed a commit to jeffery4011/ray that referenced this pull request Jan 20, 2026
Adds a CUDA IPC transport to RDT. This relies on an internal torch
function to serialize and deserialize a CUDA tensor across different
processes. It may break if there are changes to
[torch.multiprocessing.reductions](https://github.com/pytorch/pytorch/blob/1495b35d29512f303ab37780760c5e692158514b/torch/multiprocessing/reductions.py),
but this seems to be the best stopgap solution.

---------

Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Signed-off-by: Qiaolin-Yu <liin1211@outlook.com>
Signed-off-by: Stephanie Wang <smwang@cs.washington.edu>
Co-authored-by: Avi Basnet <avigyabb@stanford.edu>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
Signed-off-by: jeffery4011 <jefferyshen1015@gmail.com>
ryanaoleary pushed a commit to ryanaoleary/ray that referenced this pull request Feb 3, 2026
Adds a CUDA IPC transport to RDT. This relies on an internal torch
function to serialize and deserialize a CUDA tensor across different
processes. It may break if there are changes to
[torch.multiprocessing.reductions](https://github.com/pytorch/pytorch/blob/1495b35d29512f303ab37780760c5e692158514b/torch/multiprocessing/reductions.py),
but this seems to be the best stopgap solution.

---------

Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Signed-off-by: Qiaolin-Yu <liin1211@outlook.com>
Signed-off-by: Stephanie Wang <smwang@cs.washington.edu>
Co-authored-by: Avi Basnet <avigyabb@stanford.edu>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
peterxcli pushed a commit to peterxcli/ray that referenced this pull request Feb 25, 2026
Adds a CUDA IPC transport to RDT. This relies on an internal torch
function to serialize and deserialize a CUDA tensor across different
processes. It may break if there are changes to
[torch.multiprocessing.reductions](https://github.com/pytorch/pytorch/blob/1495b35d29512f303ab37780760c5e692158514b/torch/multiprocessing/reductions.py),
but this seems to be the best stopgap solution.

---------

Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Signed-off-by: Qiaolin-Yu <liin1211@outlook.com>
Signed-off-by: Stephanie Wang <smwang@cs.washington.edu>
Co-authored-by: Avi Basnet <avigyabb@stanford.edu>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
Signed-off-by: peterxcli <peterxcli@gmail.com>
peterxcli pushed a commit to peterxcli/ray that referenced this pull request Feb 25, 2026
Adds a CUDA IPC transport to RDT. This relies on an internal torch
function to serialize and deserialize a CUDA tensor across different
processes. It may break if there are changes to
[torch.multiprocessing.reductions](https://github.com/pytorch/pytorch/blob/1495b35d29512f303ab37780760c5e692158514b/torch/multiprocessing/reductions.py),
but this seems to be the best stopgap solution.

---------

Signed-off-by: Avi Basnet <avigyabb@stanford.edu>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Signed-off-by: Qiaolin-Yu <liin1211@outlook.com>
Signed-off-by: Stephanie Wang <smwang@cs.washington.edu>
Co-authored-by: Avi Basnet <avigyabb@stanford.edu>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-3-214.us-east-2.compute.internal>
Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
Signed-off-by: peterxcli <peterxcli@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution Contributed by the community core Issues that should be addressed in Ray Core data Ray Data-related issues go add ONLY when ready to merge, run all tests gpu GPU related issues train Ray Train Related Issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants