@@ -126,6 +126,116 @@ def _process_define_macro(formatted_options, macro):
126126 raise RuntimeError (f"Expected define_macro { union_type } , list[{ union_type } ], got { macro } " )
127127
128128
129+ def _format_options_for_backend (options_dict : dict , backend : str ) -> list [str ]:
130+ """Format compilation options for a specific backend.
131+
132+ This helper function converts a dictionary of option names and values into
133+ properly formatted string options for the specified backend. Different backends
134+ (NVRTC, NVVM, nvJitLink) use slightly different option naming conventions and
135+ value formats.
136+
137+ Parameters
138+ ----------
139+ options_dict : dict
140+ Dictionary mapping option names to their values. The keys should be
141+ generic option names (e.g., "arch", "debug", "ftz").
142+ backend : str
143+ The backend to format options for. Must be one of "NVRTC", "NVVM", or "nvJitLink".
144+
145+ Returns
146+ -------
147+ list[str]
148+ List of formatted option strings suitable for the specified backend.
149+
150+ Raises
151+ ------
152+ ValueError
153+ If an unsupported backend is specified.
154+
155+ Notes
156+ -----
157+ - NVRTC uses `--` prefix and "true"/"false" for booleans
158+ - NVVM uses `-` prefix and "1"/"0" for booleans
159+ - nvJitLink uses `-` prefix and "true"/"false" for booleans
160+ """
161+ if backend not in ("NVRTC" , "NVVM" , "nvJitLink" ):
162+ raise ValueError (f"Unsupported backend '{ backend } '. Must be one of: NVRTC, NVVM, nvJitLink" )
163+
164+ formatted = []
165+
166+ for key , value in options_dict .items ():
167+ if value is None :
168+ continue
169+
170+ if backend == "NVRTC" :
171+ # NVRTC uses -- prefix
172+ if key == "arch" :
173+ formatted .append (f"-arch={ value } " )
174+ elif key == "debug" and value :
175+ formatted .append ("--device-debug" )
176+ elif key == "lineinfo" and value :
177+ formatted .append ("--generate-line-info" )
178+ elif key == "max_register_count" :
179+ formatted .append (f"--maxrregcount={ value } " )
180+ elif key in ("ftz" , "prec_sqrt" , "prec_div" ):
181+ bool_val = "true" if value else "false"
182+ # NVRTC uses hyphens in option names
183+ option_name = key .replace ("_" , "-" )
184+ formatted .append (f"--{ option_name } ={ bool_val } " )
185+ elif key == "fma" :
186+ bool_val = "true" if value else "false"
187+ formatted .append (f"--fmad={ bool_val } " )
188+ elif key == "device_code_optimize" and value :
189+ formatted .append ("--dopt=on" )
190+ elif key == "use_fast_math" and value :
191+ formatted .append ("--use_fast_math" )
192+ elif key == "link_time_optimization" and value :
193+ formatted .append ("--dlink-time-opt" )
194+ # Add more NVRTC-specific options as needed
195+
196+ elif backend == "NVVM" :
197+ # NVVM uses - prefix and 1/0 for booleans
198+ if key == "arch" :
199+ # NVVM uses compute_ instead of sm_
200+ arch_val = value
201+ if arch_val .startswith ("sm_" ):
202+ arch_val = f"compute_{ arch_val [3 :]} "
203+ formatted .append (f"-arch={ arch_val } " )
204+ elif key == "debug" and value :
205+ formatted .append ("-g" )
206+ elif key == "device_code_optimize" :
207+ if value is False :
208+ formatted .append ("-opt=0" )
209+ elif value is True :
210+ formatted .append ("-opt=3" )
211+ elif key in ("ftz" , "prec_sqrt" , "prec_div" , "fma" ):
212+ bool_val = "1" if value else "0"
213+ # NVVM uses hyphens in option names
214+ option_name = key .replace ("_" , "-" )
215+ formatted .append (f"-{ option_name } ={ bool_val } " )
216+
217+ elif backend == "nvJitLink" :
218+ # nvJitLink uses - prefix and true/false for booleans
219+ if key == "arch" :
220+ formatted .append (f"-arch={ value } " )
221+ elif key == "debug" and value :
222+ formatted .append ("-g" )
223+ elif key == "lineinfo" and value :
224+ formatted .append ("-lineinfo" )
225+ elif key == "max_register_count" :
226+ formatted .append (f"-maxrregcount={ value } " )
227+ elif key in ("ftz" , "prec_sqrt" , "prec_div" , "fma" ):
228+ bool_val = "true" if value else "false"
229+ # nvJitLink uses hyphens in option names
230+ option_name = key .replace ("_" , "-" )
231+ formatted .append (f"-{ option_name } ={ bool_val } " )
232+ elif key == "link_time_optimization" and value :
233+ formatted .append ("-lto" )
234+ # Add more nvJitLink-specific options as needed
235+
236+ return formatted
237+
238+
129239@dataclass
130240class ProgramOptions :
131241 """Customizable options for configuring `Program`.
@@ -422,28 +532,91 @@ def __post_init__(self):
422532 if self .numba_debug :
423533 self ._formatted_options .append ("--numba-debug" )
424534
425- def as_bytes (self ) -> list [bytes ]:
535+ def as_bytes (self , backend : str = "NVRTC" ) -> list [bytes ]:
426536 """Convert the formatted program options to a list of byte strings.
427537
428- This method encodes each of the formatted options stored in this
429- `ProgramOptions` instance into byte strings, suitable for passing
538+ This method encodes the options stored in this `ProgramOptions` instance
539+ into byte strings formatted for the specified backend , suitable for passing
430540 to C libraries that calls the underlying compiler library.
431541
542+ Parameters
543+ ----------
544+ backend : str, optional
545+ The compiler backend to format options for. Must be one of:
546+
547+ - "NVRTC" (default): NVIDIA NVRTC compiler, supports all ProgramOptions
548+ - "NVVM": NVIDIA NVVM compiler, supports a subset of options
549+ - "nvJitLink": NVIDIA nvJitLink linker, supports a subset of options
550+
551+ Different backends use different option naming conventions and support
552+ different option subsets. This method will format and filter options
553+ appropriately for the chosen backend.
554+
432555 Returns
433556 -------
434557 list[bytes]
435558 A list of byte-encoded option strings. Each element represents
436559 a single compilation option in the format expected by the underlying compiler library.
437560
561+ Raises
562+ ------
563+ ValueError
564+ If an unsupported backend is specified.
565+
438566 Examples
439567 --------
440568 >>> options = ProgramOptions(arch="sm_80", debug=True)
441- >>> byte_options = options.as_bytes()
442- >>> print(byte_options)
569+ >>> # Get options for NVRTC (default)
570+ >>> nvrtc_options = options.as_bytes()
571+ >>> print(nvrtc_options)
443572 [b'-arch=sm_80', b'--device-debug']
573+ >>>
574+ >>> # Get options for NVVM
575+ >>> nvvm_options = options.as_bytes("NVVM")
576+ >>> print(nvvm_options)
577+ [b'-arch=compute_80', b'-g']
578+ >>>
579+ >>> # Get options for nvJitLink
580+ >>> nvjitlink_options = options.as_bytes("nvJitLink")
581+ >>> print(nvjitlink_options)
582+ [b'-arch=sm_80', b'-g']
444583 """
445- # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
446- return list (o .encode () for o in self ._formatted_options )
584+ if backend == "NVRTC" :
585+ # For NVRTC, use the pre-formatted options (backward compatible)
586+ # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
587+ return list (o .encode () for o in self ._formatted_options )
588+
589+ elif backend in ("NVVM" , "nvJitLink" ):
590+ # For NVVM and nvJitLink, extract common options and format appropriately
591+ options_dict = {}
592+
593+ # Common options supported by multiple backends
594+ if self .arch is not None :
595+ options_dict ["arch" ] = self .arch
596+ if self .debug is not None :
597+ options_dict ["debug" ] = self .debug
598+ if self .lineinfo is not None and backend == "nvJitLink" :
599+ options_dict ["lineinfo" ] = self .lineinfo
600+ if self .max_register_count is not None :
601+ options_dict ["max_register_count" ] = self .max_register_count
602+ if self .ftz is not None :
603+ options_dict ["ftz" ] = self .ftz
604+ if self .prec_sqrt is not None :
605+ options_dict ["prec_sqrt" ] = self .prec_sqrt
606+ if self .prec_div is not None :
607+ options_dict ["prec_div" ] = self .prec_div
608+ if self .fma is not None :
609+ options_dict ["fma" ] = self .fma
610+ if self .device_code_optimize is not None and backend == "NVVM" :
611+ options_dict ["device_code_optimize" ] = self .device_code_optimize
612+ if self .link_time_optimization is not None and backend == "nvJitLink" :
613+ options_dict ["link_time_optimization" ] = self .link_time_optimization
614+
615+ formatted_options = _format_options_for_backend (options_dict , backend )
616+ return list (o .encode () for o in formatted_options )
617+
618+ else :
619+ raise ValueError (f"Unsupported backend '{ backend } '. Must be one of: NVRTC, NVVM, nvJitLink" )
447620
448621 def __repr__ (self ):
449622 # __TODO__ improve this
0 commit comments