Skip to content

Commit 9e799e4

Browse files
committed
ensure CUDA_PATH is honored by the build backend
1 parent 61617cf commit 9e799e4

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

cuda_core/build_hooks.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
get_requires_for_build_sdist = _build_meta.get_requires_for_build_sdist
2525

2626

27-
@functools.cache
2827
def _get_proper_cuda_bindings_major_version() -> str:
2928
# for local development (with/without build isolation)
3029
try:
@@ -72,10 +71,21 @@ def strip_prefix_suffix(filename):
7271
return filename[len(root_path) : -4]
7372

7473
module_names = (strip_prefix_suffix(f) for f in ext_files)
74+
75+
@functools.cache
76+
def get_cuda_paths():
77+
CUDA_PATH = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME", None))
78+
if not CUDA_PATH:
79+
raise RuntimeError("Environment variable CUDA_PATH or CUDA_HOME is not set")
80+
CUDA_PATH = CUDA_PATH.split(os.pathsep)
81+
print("CUDA paths:", CUDA_PATH)
82+
return CUDA_PATH
83+
7584
ext_modules = tuple(
7685
Extension(
7786
f"cuda.core.experimental.{mod.replace(os.path.sep, '.')}",
7887
sources=[f"cuda/core/experimental/{mod}.pyx"],
88+
include_dirs=list(os.path.join(root, "include") for root in get_cuda_paths()),
7989
language="c++",
8090
)
8191
for mod in module_names

0 commit comments

Comments
 (0)