Skip to content

Commit 22b0e2e

Browse files
committed
make it clear we need two CUDA contexts for retriving the stream's device
1 parent 3809a33 commit 22b0e2e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

cuda_core/cuda/core/experimental/_stream.pyx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,12 +327,12 @@ cdef class Stream:
327327
cdef int _get_device_and_context(self) except?-1:
328328
cdef cydriver.CUcontext curr_ctx
329329
if self._device_id == cydriver.CU_DEVICE_INVALID:
330-
# TODO: It is likely faster/safer to call cuCtxGetCurrent?
331-
from cuda.core.experimental._device import Device # avoid circular import
332-
curr_ctx = <cydriver.CUcontext><uintptr_t>(Device().context._handle)
333330
with nogil:
334-
# Get the stream context first
331+
# Get the current context
332+
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&curr_ctx))
333+
# Get the stream's context (self.ctx_handle is populated)
335334
self._get_context()
335+
# Get the stream's device (may require a context-switching dance)
336336
self._device_id = get_device_from_ctx(self._ctx_handle, curr_ctx)
337337
return 0
338338

0 commit comments

Comments
 (0)