Skip to content

Commit fa2ce29

Browse files
Copilotleofang
andcommitted
Add backend parameter to as_bytes() method with support for NVRTC, NVVM, and nvJitLink
Co-authored-by: leofang <5534781+leofang@users.noreply.github.com>
1 parent 523a246 commit fa2ce29

File tree

2 files changed

+244
-7
lines changed

2 files changed

+244
-7
lines changed

cuda_core/cuda/core/experimental/_program.py

Lines changed: 180 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,116 @@ def _process_define_macro(formatted_options, macro):
126126
raise RuntimeError(f"Expected define_macro {union_type}, list[{union_type}], got {macro}")
127127

128128

129+
def _format_options_for_backend(options_dict: dict, backend: str) -> list[str]:
130+
"""Format compilation options for a specific backend.
131+
132+
This helper function converts a dictionary of option names and values into
133+
properly formatted string options for the specified backend. Different backends
134+
(NVRTC, NVVM, nvJitLink) use slightly different option naming conventions and
135+
value formats.
136+
137+
Parameters
138+
----------
139+
options_dict : dict
140+
Dictionary mapping option names to their values. The keys should be
141+
generic option names (e.g., "arch", "debug", "ftz").
142+
backend : str
143+
The backend to format options for. Must be one of "NVRTC", "NVVM", or "nvJitLink".
144+
145+
Returns
146+
-------
147+
list[str]
148+
List of formatted option strings suitable for the specified backend.
149+
150+
Raises
151+
------
152+
ValueError
153+
If an unsupported backend is specified.
154+
155+
Notes
156+
-----
157+
- NVRTC uses `--` prefix and "true"/"false" for booleans
158+
- NVVM uses `-` prefix and "1"/"0" for booleans
159+
- nvJitLink uses `-` prefix and "true"/"false" for booleans
160+
"""
161+
if backend not in ("NVRTC", "NVVM", "nvJitLink"):
162+
raise ValueError(f"Unsupported backend '{backend}'. Must be one of: NVRTC, NVVM, nvJitLink")
163+
164+
formatted = []
165+
166+
for key, value in options_dict.items():
167+
if value is None:
168+
continue
169+
170+
if backend == "NVRTC":
171+
# NVRTC uses -- prefix
172+
if key == "arch":
173+
formatted.append(f"-arch={value}")
174+
elif key == "debug" and value:
175+
formatted.append("--device-debug")
176+
elif key == "lineinfo" and value:
177+
formatted.append("--generate-line-info")
178+
elif key == "max_register_count":
179+
formatted.append(f"--maxrregcount={value}")
180+
elif key in ("ftz", "prec_sqrt", "prec_div"):
181+
bool_val = "true" if value else "false"
182+
# NVRTC uses hyphens in option names
183+
option_name = key.replace("_", "-")
184+
formatted.append(f"--{option_name}={bool_val}")
185+
elif key == "fma":
186+
bool_val = "true" if value else "false"
187+
formatted.append(f"--fmad={bool_val}")
188+
elif key == "device_code_optimize" and value:
189+
formatted.append("--dopt=on")
190+
elif key == "use_fast_math" and value:
191+
formatted.append("--use_fast_math")
192+
elif key == "link_time_optimization" and value:
193+
formatted.append("--dlink-time-opt")
194+
# Add more NVRTC-specific options as needed
195+
196+
elif backend == "NVVM":
197+
# NVVM uses - prefix and 1/0 for booleans
198+
if key == "arch":
199+
# NVVM uses compute_ instead of sm_
200+
arch_val = value
201+
if arch_val.startswith("sm_"):
202+
arch_val = f"compute_{arch_val[3:]}"
203+
formatted.append(f"-arch={arch_val}")
204+
elif key == "debug" and value:
205+
formatted.append("-g")
206+
elif key == "device_code_optimize":
207+
if value is False:
208+
formatted.append("-opt=0")
209+
elif value is True:
210+
formatted.append("-opt=3")
211+
elif key in ("ftz", "prec_sqrt", "prec_div", "fma"):
212+
bool_val = "1" if value else "0"
213+
# NVVM uses hyphens in option names
214+
option_name = key.replace("_", "-")
215+
formatted.append(f"-{option_name}={bool_val}")
216+
217+
elif backend == "nvJitLink":
218+
# nvJitLink uses - prefix and true/false for booleans
219+
if key == "arch":
220+
formatted.append(f"-arch={value}")
221+
elif key == "debug" and value:
222+
formatted.append("-g")
223+
elif key == "lineinfo" and value:
224+
formatted.append("-lineinfo")
225+
elif key == "max_register_count":
226+
formatted.append(f"-maxrregcount={value}")
227+
elif key in ("ftz", "prec_sqrt", "prec_div", "fma"):
228+
bool_val = "true" if value else "false"
229+
# nvJitLink uses hyphens in option names
230+
option_name = key.replace("_", "-")
231+
formatted.append(f"-{option_name}={bool_val}")
232+
elif key == "link_time_optimization" and value:
233+
formatted.append("-lto")
234+
# Add more nvJitLink-specific options as needed
235+
236+
return formatted
237+
238+
129239
@dataclass
130240
class ProgramOptions:
131241
"""Customizable options for configuring `Program`.
@@ -422,28 +532,91 @@ def __post_init__(self):
422532
if self.numba_debug:
423533
self._formatted_options.append("--numba-debug")
424534

425-
def as_bytes(self) -> list[bytes]:
535+
def as_bytes(self, backend: str = "NVRTC") -> list[bytes]:
426536
"""Convert the formatted program options to a list of byte strings.
427537
428-
This method encodes each of the formatted options stored in this
429-
`ProgramOptions` instance into byte strings, suitable for passing
538+
This method encodes the options stored in this `ProgramOptions` instance
539+
into byte strings formatted for the specified backend, suitable for passing
430540
to C libraries that calls the underlying compiler library.
431541
542+
Parameters
543+
----------
544+
backend : str, optional
545+
The compiler backend to format options for. Must be one of:
546+
547+
- "NVRTC" (default): NVIDIA NVRTC compiler, supports all ProgramOptions
548+
- "NVVM": NVIDIA NVVM compiler, supports a subset of options
549+
- "nvJitLink": NVIDIA nvJitLink linker, supports a subset of options
550+
551+
Different backends use different option naming conventions and support
552+
different option subsets. This method will format and filter options
553+
appropriately for the chosen backend.
554+
432555
Returns
433556
-------
434557
list[bytes]
435558
A list of byte-encoded option strings. Each element represents
436559
a single compilation option in the format expected by the underlying compiler library.
437560
561+
Raises
562+
------
563+
ValueError
564+
If an unsupported backend is specified.
565+
438566
Examples
439567
--------
440568
>>> options = ProgramOptions(arch="sm_80", debug=True)
441-
>>> byte_options = options.as_bytes()
442-
>>> print(byte_options)
569+
>>> # Get options for NVRTC (default)
570+
>>> nvrtc_options = options.as_bytes()
571+
>>> print(nvrtc_options)
443572
[b'-arch=sm_80', b'--device-debug']
573+
>>>
574+
>>> # Get options for NVVM
575+
>>> nvvm_options = options.as_bytes("NVVM")
576+
>>> print(nvvm_options)
577+
[b'-arch=compute_80', b'-g']
578+
>>>
579+
>>> # Get options for nvJitLink
580+
>>> nvjitlink_options = options.as_bytes("nvJitLink")
581+
>>> print(nvjitlink_options)
582+
[b'-arch=sm_80', b'-g']
444583
"""
445-
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
446-
return list(o.encode() for o in self._formatted_options)
584+
if backend == "NVRTC":
585+
# For NVRTC, use the pre-formatted options (backward compatible)
586+
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
587+
return list(o.encode() for o in self._formatted_options)
588+
589+
elif backend in ("NVVM", "nvJitLink"):
590+
# For NVVM and nvJitLink, extract common options and format appropriately
591+
options_dict = {}
592+
593+
# Common options supported by multiple backends
594+
if self.arch is not None:
595+
options_dict["arch"] = self.arch
596+
if self.debug is not None:
597+
options_dict["debug"] = self.debug
598+
if self.lineinfo is not None and backend == "nvJitLink":
599+
options_dict["lineinfo"] = self.lineinfo
600+
if self.max_register_count is not None:
601+
options_dict["max_register_count"] = self.max_register_count
602+
if self.ftz is not None:
603+
options_dict["ftz"] = self.ftz
604+
if self.prec_sqrt is not None:
605+
options_dict["prec_sqrt"] = self.prec_sqrt
606+
if self.prec_div is not None:
607+
options_dict["prec_div"] = self.prec_div
608+
if self.fma is not None:
609+
options_dict["fma"] = self.fma
610+
if self.device_code_optimize is not None and backend == "NVVM":
611+
options_dict["device_code_optimize"] = self.device_code_optimize
612+
if self.link_time_optimization is not None and backend == "nvJitLink":
613+
options_dict["link_time_optimization"] = self.link_time_optimization
614+
615+
formatted_options = _format_options_for_backend(options_dict, backend)
616+
return list(o.encode() for o in formatted_options)
617+
618+
else:
619+
raise ValueError(f"Unsupported backend '{backend}'. Must be one of: NVRTC, NVVM, nvJitLink")
447620

448621
def __repr__(self):
449622
# __TODO__ improve this

cuda_core/tests/test_program.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,67 @@ def test_program_options_as_bytes_empty():
453453
assert all(isinstance(opt, bytes) for opt in byte_options)
454454
# The arch option should be present (automatically determined from current device)
455455
assert any(b"-arch=" in opt for opt in byte_options)
456+
457+
458+
def test_program_options_as_bytes_nvvm_backend():
459+
"""Test that ProgramOptions.as_bytes() formats options correctly for NVVM backend"""
460+
options = ProgramOptions(
461+
arch="sm_80",
462+
debug=True,
463+
ftz=True,
464+
prec_sqrt=False,
465+
prec_div=True,
466+
fma=False,
467+
device_code_optimize=True,
468+
)
469+
470+
byte_options = options.as_bytes("NVVM")
471+
472+
# Verify the return type
473+
assert isinstance(byte_options, list)
474+
assert all(isinstance(opt, bytes) for opt in byte_options)
475+
476+
# NVVM uses compute_ instead of sm_ and 1/0 for booleans, with hyphens in option names
477+
assert b"-arch=compute_80" in byte_options
478+
assert b"-g" in byte_options
479+
assert b"-ftz=1" in byte_options
480+
assert b"-prec-sqrt=0" in byte_options
481+
assert b"-prec-div=1" in byte_options
482+
assert b"-fma=0" in byte_options
483+
assert b"-opt=3" in byte_options
484+
485+
486+
def test_program_options_as_bytes_nvjitlink_backend():
487+
"""Test that ProgramOptions.as_bytes() formats options correctly for nvJitLink backend"""
488+
options = ProgramOptions(
489+
arch="sm_80",
490+
debug=True,
491+
lineinfo=True,
492+
max_register_count=32,
493+
ftz=False,
494+
prec_sqrt=True,
495+
link_time_optimization=True,
496+
)
497+
498+
byte_options = options.as_bytes("nvJitLink")
499+
500+
# Verify the return type
501+
assert isinstance(byte_options, list)
502+
assert all(isinstance(opt, bytes) for opt in byte_options)
503+
504+
# nvJitLink uses - prefix and true/false for booleans, with hyphens in option names
505+
assert b"-arch=sm_80" in byte_options
506+
assert b"-g" in byte_options
507+
assert b"-lineinfo" in byte_options
508+
assert b"-maxrregcount=32" in byte_options
509+
assert b"-ftz=false" in byte_options
510+
assert b"-prec-sqrt=true" in byte_options
511+
assert b"-lto" in byte_options
512+
513+
514+
def test_program_options_as_bytes_invalid_backend():
515+
"""Test that ProgramOptions.as_bytes() raises error for invalid backend"""
516+
options = ProgramOptions()
517+
518+
with pytest.raises(ValueError, match="Unsupported backend 'invalid'"):
519+
options.as_bytes("invalid")

0 commit comments

Comments
 (0)