Skip to content

Commit bcb803f

Browse files
authored
Add __weakref__ support to core API classes (#1533)
* Add __weakref__ support to core API classes Enable weak referencing for Cython cdef classes: - Stream, Event, Context (cdef classes via pxd) - Buffer, LaunchConfig (cdef classes via pxd) Enable weak referencing for Python classes with __slots__: - Device, ObjectCode (added __weakref__ to __slots__) Note: Kernel already had __weakref__ in its __slots__. Memory resource classes inherit __weakref__ from _MemPool. Add test_weakref.py to verify all core API classes are weak-referenceable. * Fix test_weakref kernel fixture PTX version incompatibility Change object_code fixture to compile to cubin instead of ptx to avoid CUDA_ERROR_UNSUPPORTED_PTX_VERSION when the toolkit version is newer than the driver version on test machines.
1 parent d57a310 commit bcb803f

File tree

8 files changed

+82
-2
lines changed

8 files changed

+82
-2
lines changed

cuda_core/cuda/core/_context.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ cdef class Context:
1414
cdef:
1515
ContextHandle _h_context
1616
int _device_id
17+
object __weakref__
1718

1819
@staticmethod
1920
cdef Context _from_handle(type cls, ContextHandle h_context, int device_id)

cuda_core/cuda/core/_device.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,7 @@ class Device:
955955
Default value of `None` return the currently used device.
956956

957957
"""
958-
__slots__ = ("_device_id", "_memory_resource", "_has_inited", "_properties", "_uuid", "_context")
958+
__slots__ = ("_device_id", "_memory_resource", "_has_inited", "_properties", "_uuid", "_context", "__weakref__")
959959

960960
def __new__(cls, device_id: Device | int | None = None):
961961
if isinstance(device_id, Device):

cuda_core/cuda/core/_event.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ cdef class Event:
1616
bint _ipc_enabled
1717
object _ipc_descriptor
1818
int _device_id
19+
object __weakref__
1920

2021
@staticmethod
2122
cdef Event _init(type cls, int device_id, ContextHandle h_context, options, bint is_free)

cuda_core/cuda/core/_launch_config.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ cdef class LaunchConfig:
1717
public bint cooperative_launch
1818

1919
vector[cydriver.CUlaunchAttribute] _attrs
20+
object __weakref__
2021

2122
cdef cydriver.CUlaunchConfig _to_native_launch_config(self)
2223

cuda_core/cuda/core/_memory/_buffer.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ cdef class Buffer:
2323
object _owner
2424
_MemAttrs _mem_attrs
2525
bint _mem_attrs_inited
26+
object __weakref__
2627

2728

2829
cdef class MemoryResource:

cuda_core/cuda/core/_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ class ObjectCode:
546546
:class:`~cuda.core.Program`
547547
"""
548548

549-
__slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map", "_name")
549+
__slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map", "_name", "__weakref__")
550550
_supported_code_type = ("cubin", "ptx", "ltoir", "fatbin", "object", "library")
551551

552552
def __new__(self, *args, **kwargs):

cuda_core/cuda/core/_stream.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ cdef class Stream:
1313
int _device_id
1414
int _nonblocking
1515
int _priority
16+
object __weakref__
1617

1718
@staticmethod
1819
cdef Stream _from_handle(type cls, StreamHandle h_stream)

cuda_core/tests/test_weakref.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import weakref
6+
7+
import pytest
8+
from cuda.core import Device
9+
10+
11+
@pytest.fixture(scope="module")
12+
def device():
13+
dev = Device()
14+
dev.set_current()
15+
return dev
16+
17+
18+
@pytest.fixture
19+
def stream(device):
20+
return device.create_stream()
21+
22+
23+
@pytest.fixture
24+
def event(device):
25+
return device.create_event()
26+
27+
28+
@pytest.fixture
29+
def context(device):
30+
return device.context
31+
32+
33+
@pytest.fixture
34+
def buffer(device):
35+
return device.allocate(1024)
36+
37+
38+
@pytest.fixture
39+
def launch_config():
40+
from cuda.core import LaunchConfig
41+
42+
return LaunchConfig(grid=(1,), block=(1,))
43+
44+
45+
@pytest.fixture
46+
def object_code():
47+
from cuda.core import Program
48+
49+
prog = Program('extern "C" __global__ void test_kernel() {}', "c++")
50+
return prog.compile("cubin")
51+
52+
53+
@pytest.fixture
54+
def kernel(object_code):
55+
return object_code.get_kernel("test_kernel")
56+
57+
58+
WEAK_REFERENCEABLE = [
59+
"device",
60+
"stream",
61+
"event",
62+
"context",
63+
"buffer",
64+
"launch_config",
65+
"object_code",
66+
"kernel",
67+
]
68+
69+
70+
@pytest.mark.parametrize("fixture_name", WEAK_REFERENCEABLE)
71+
def test_weakref(fixture_name, request):
72+
"""Core API classes should be weak-referenceable."""
73+
obj = request.getfixturevalue(fixture_name)
74+
ref = weakref.ref(obj)
75+
assert ref() is obj

0 commit comments

Comments
 (0)