@@ -16,6 +16,8 @@ pub enum AcceleratorError {
1616 Utf8 ( #[ from] std:: string:: FromUtf8Error ) ,
1717 #[ error( transparent) ]
1818 ParseInt ( #[ from] std:: num:: ParseIntError ) ,
19+ #[ error( "Failed to parse NVIDIA device: {0}" ) ]
20+ Device ( String ) ,
1921 #[ error( "Unknown AMD GPU architecture: {0}" ) ]
2022 UnknownAmdGpuArchitecture ( String ) ,
2123}
@@ -57,7 +59,7 @@ impl Accelerator {
5759 /// 2. The `UV_AMD_GPU_ARCHITECTURE` environment variable.
5860 /// 3. `/sys/module/nvidia/version`, which contains the driver version (e.g., `550.144.03`).
5961 /// 4. `/proc/driver/nvidia/version`, which contains the driver version among other information.
60- /// 5. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`.
62+ /// 5. `nvidia-smi --query-gpu=index,uuid, driver_version --format=csv,noheader`.
6163 /// 6. `rocm_agent_enumerator`, which lists the AMD GPU architectures.
6264 /// 7. `/sys/bus/pci/devices`, filtering for the Intel GPU via PCI.
6365 pub fn detect ( ) -> Result < Option < Self > , AcceleratorError > {
@@ -121,14 +123,40 @@ impl Accelerator {
121123
122124 // Query `nvidia-smi`.
123125 if let Ok ( output) = std:: process:: Command :: new ( "nvidia-smi" )
124- . arg ( "--query-gpu=driver_version" )
126+ . arg ( "--query-gpu=index,uuid, driver_version" )
125127 . arg ( "--format=csv,noheader" )
126128 . output ( )
127129 {
128130 if output. status . success ( ) {
129- let driver_version = Version :: from_str ( & String :: from_utf8 ( output. stdout ) ?) ?;
130- debug ! ( "Detected CUDA driver version from `nvidia-smi`: {driver_version}" ) ;
131- return Ok ( Some ( Self :: Cuda { driver_version } ) ) ;
131+ let visible_devices = VisibleDevices :: from_env ( ) ?. unwrap_or ( VisibleDevices :: All ) ;
132+ let stdout = String :: from_utf8 ( output. stdout ) ?;
133+ for line in stdout. lines ( ) {
134+ let mut parts = line. split ( ',' ) ;
135+
136+ // Parse the GPU index.
137+ let index = parts. next ( ) . and_then ( |s| s. trim ( ) . parse :: < usize > ( ) . ok ( ) ) ;
138+
139+ // Parse the GPU UUID.
140+ let uuid = parts. next ( ) . map ( str:: trim) ;
141+
142+ // Determine if this GPU is visible based on the environment variable.
143+ if visible_devices. includes ( index, uuid) {
144+ if let Some ( driver_version) = parts. next ( ) {
145+ let driver_version = Version :: from_str ( driver_version. trim ( ) ) ?;
146+ debug ! (
147+ "Detected CUDA driver version from `nvidia-smi`: {driver_version}"
148+ ) ;
149+ return Ok ( Some ( Self :: Cuda { driver_version } ) ) ;
150+ }
151+ } else {
152+ debug ! ( "Skipping invisible GPU {index:?} with UUID: {uuid:?}" ) ;
153+ }
154+ }
155+ if let Some ( first_line) = stdout. lines ( ) . next ( ) {
156+ let driver_version = Version :: from_str ( first_line. trim ( ) ) ?;
157+ debug ! ( "Detected CUDA driver version from `nvidia-smi`: {driver_version}" ) ;
158+ return Ok ( Some ( Self :: Cuda { driver_version } ) ) ;
159+ }
132160 }
133161
134162 debug ! (
@@ -193,6 +221,68 @@ impl Accelerator {
193221 }
194222}
195223
224+ #[ derive( Debug , Clone , Eq , PartialEq ) ]
225+ enum VisibleDevices {
226+ /// All GPUs are visible.
227+ All ,
228+ /// No GPUs are visible.
229+ None ,
230+ /// Some GPUs are visible, specified by their indices and/or UUIDs.
231+ Some {
232+ uuids : Vec < String > ,
233+ indices : Vec < usize > ,
234+ } ,
235+ }
236+
237+ impl VisibleDevices {
238+ /// Read and parse the [`NVIDIA_VISIBLE_DEVICES`] environment variable.
239+ fn from_env ( ) -> Result < Option < Self > , AcceleratorError > {
240+ let Some ( nvidia_visible_devices) = std:: env:: var ( EnvVars :: NVIDIA_VISIBLE_DEVICES ) . ok ( )
241+ else {
242+ return Ok ( None ) ;
243+ } ;
244+ Self :: parse ( & nvidia_visible_devices)
245+ }
246+
247+ /// Parse the [`NVIDIA_VISIBLE_DEVICES`] environment variable.
248+ fn parse ( s : & str ) -> Result < Option < Self > , AcceleratorError > {
249+ if s. is_empty ( ) {
250+ Ok ( None )
251+ } else if s == "void" {
252+ Ok ( None )
253+ } else if s == "all" {
254+ Ok ( Some ( Self :: All ) )
255+ } else if s == "none" {
256+ Ok ( Some ( Self :: None ) )
257+ } else {
258+ let mut indices = Vec :: new ( ) ;
259+ let mut uuids = Vec :: new ( ) ;
260+ for device in s. split ( ',' ) {
261+ if device. starts_with ( "GPU-" ) {
262+ uuids. push ( device. to_string ( ) ) ;
263+ } else if let Ok ( index) = device. parse :: < usize > ( ) {
264+ indices. push ( index) ;
265+ } else {
266+ return Err ( AcceleratorError :: Device ( device. to_string ( ) ) ) ;
267+ }
268+ }
269+ Ok ( Some ( Self :: Some { uuids, indices } ) )
270+ }
271+ }
272+
273+ /// Return `true` if the given index or UUID is included in the visible devices.
274+ fn includes ( & self , index : Option < usize > , uuid : Option < & str > ) -> bool {
275+ match self {
276+ Self :: All => true ,
277+ Self :: None => false ,
278+ Self :: Some { uuids, indices } => {
279+ index. is_some_and ( |index| indices. contains ( & index) )
280+ || uuid. is_some_and ( |uuid| uuids. iter ( ) . any ( |value| value == uuid) )
281+ }
282+ }
283+ }
284+ }
285+
196286/// Parse the CUDA driver version from the content of `/sys/module/nvidia/version`.
197287fn parse_sys_module_nvidia_version ( content : & str ) -> Result < Version , AcceleratorError > {
198288 // Parse, e.g.:
@@ -304,4 +394,134 @@ mod tests {
304394 let result = parse_proc_driver_nvidia_version ( content) . unwrap ( ) ;
305395 assert_eq ! ( result, Some ( Version :: from_str( "375.74" ) . unwrap( ) ) ) ;
306396 }
397+
398+ #[ test]
399+ fn nvidia_smi_multi_gpu ( ) {
400+ // Test that we can parse nvidia-smi output with multiple GPUs (multiple lines)
401+ let single_gpu = "572.60\n " ;
402+ if let Some ( first_line) = single_gpu. lines ( ) . next ( ) {
403+ let version = Version :: from_str ( first_line. trim ( ) ) . unwrap ( ) ;
404+ assert_eq ! ( version, Version :: from_str( "572.60" ) . unwrap( ) ) ;
405+ }
406+
407+ let multi_gpu = "572.60\n 572.60\n " ;
408+ if let Some ( first_line) = multi_gpu. lines ( ) . next ( ) {
409+ let version = Version :: from_str ( first_line. trim ( ) ) . unwrap ( ) ;
410+ assert_eq ! ( version, Version :: from_str( "572.60" ) . unwrap( ) ) ;
411+ }
412+ }
413+
414+ #[ test]
415+ fn visible_devices_parse ( ) {
416+ assert_eq ! (
417+ VisibleDevices :: parse( "all" ) . unwrap( ) ,
418+ Some ( VisibleDevices :: All )
419+ ) ;
420+
421+ assert_eq ! (
422+ VisibleDevices :: parse( "none" ) . unwrap( ) ,
423+ Some ( VisibleDevices :: None )
424+ ) ;
425+
426+ assert_eq ! ( VisibleDevices :: parse( "void" ) . unwrap( ) , None ) ;
427+
428+ assert_eq ! ( VisibleDevices :: parse( "" ) . unwrap( ) , None ) ;
429+
430+ assert_eq ! (
431+ VisibleDevices :: parse( "0" ) . unwrap( ) ,
432+ Some ( VisibleDevices :: Some {
433+ uuids: vec![ ] ,
434+ indices: vec![ 0 ]
435+ } )
436+ ) ;
437+
438+ assert_eq ! (
439+ VisibleDevices :: parse( "0,1,2" ) . unwrap( ) ,
440+ Some ( VisibleDevices :: Some {
441+ uuids: vec![ ] ,
442+ indices: vec![ 0 , 1 , 2 ]
443+ } )
444+ ) ;
445+
446+ assert_eq ! (
447+ VisibleDevices :: parse( "GPU-12345678-abcd-efgh-ijkl-123456789abc" ) . unwrap( ) ,
448+ Some ( VisibleDevices :: Some {
449+ uuids: vec![ "GPU-12345678-abcd-efgh-ijkl-123456789abc" . to_string( ) ] ,
450+ indices: vec![ ]
451+ } )
452+ ) ;
453+
454+ assert_eq ! (
455+ VisibleDevices :: parse( "GPU-12345678,GPU-87654321" ) . unwrap( ) ,
456+ Some ( VisibleDevices :: Some {
457+ uuids: vec![ "GPU-12345678" . to_string( ) , "GPU-87654321" . to_string( ) ] ,
458+ indices: vec![ ]
459+ } )
460+ ) ;
461+
462+ assert_eq ! (
463+ VisibleDevices :: parse( "0,GPU-12345678,1,GPU-87654321,2" ) . unwrap( ) ,
464+ Some ( VisibleDevices :: Some {
465+ uuids: vec![ "GPU-12345678" . to_string( ) , "GPU-87654321" . to_string( ) ] ,
466+ indices: vec![ 0 , 1 , 2 ]
467+ } )
468+ ) ;
469+
470+ assert ! ( matches!(
471+ VisibleDevices :: parse( "invalid" ) . unwrap_err( ) ,
472+ AcceleratorError :: Device ( s) if s == "invalid"
473+ ) ) ;
474+
475+ assert ! ( matches!(
476+ VisibleDevices :: parse( "0,invalid,1" ) . unwrap_err( ) ,
477+ AcceleratorError :: Device ( s) if s == "invalid"
478+ ) ) ;
479+ }
480+
481+ #[ test]
482+ fn visible_devices_includes ( ) {
483+ let all = VisibleDevices :: All ;
484+ assert ! ( all. includes( Some ( 0 ) , None ) ) ;
485+ assert ! ( all. includes( None , Some ( "GPU-12345678" ) ) ) ;
486+ assert ! ( all. includes( Some ( 999 ) , Some ( "GPU-any" ) ) ) ;
487+
488+ let none = VisibleDevices :: None ;
489+ assert ! ( !none. includes( Some ( 0 ) , None ) ) ;
490+ assert ! ( !none. includes( None , Some ( "GPU-12345678" ) ) ) ;
491+ assert ! ( !none. includes( Some ( 999 ) , Some ( "GPU-any" ) ) ) ;
492+
493+ let some_indices = VisibleDevices :: Some {
494+ uuids : vec ! [ ] ,
495+ indices : vec ! [ 0 , 2 , 4 ] ,
496+ } ;
497+ assert ! ( some_indices. includes( Some ( 0 ) , None ) ) ;
498+ assert ! ( some_indices. includes( Some ( 2 ) , None ) ) ;
499+ assert ! ( some_indices. includes( Some ( 4 ) , None ) ) ;
500+ assert ! ( !some_indices. includes( Some ( 1 ) , None ) ) ;
501+ assert ! ( !some_indices. includes( Some ( 3 ) , None ) ) ;
502+ assert ! ( !some_indices. includes( None , Some ( "GPU-12345678" ) ) ) ;
503+
504+ let some_uuids = VisibleDevices :: Some {
505+ uuids : vec ! [ "GPU-12345678" . to_string( ) , "GPU-87654321" . to_string( ) ] ,
506+ indices : vec ! [ ] ,
507+ } ;
508+ assert ! ( some_uuids. includes( None , Some ( "GPU-12345678" ) ) ) ;
509+ assert ! ( some_uuids. includes( None , Some ( "GPU-87654321" ) ) ) ;
510+ assert ! ( !some_uuids. includes( None , Some ( "GPU-99999999" ) ) ) ;
511+ assert ! ( !some_uuids. includes( Some ( 0 ) , None ) ) ;
512+
513+ let some_mixed = VisibleDevices :: Some {
514+ uuids : vec ! [ "GPU-12345678" . to_string( ) ] ,
515+ indices : vec ! [ 0 , 1 ] ,
516+ } ;
517+ assert ! ( some_mixed. includes( Some ( 0 ) , None ) ) ;
518+ assert ! ( some_mixed. includes( Some ( 1 ) , None ) ) ;
519+ assert ! ( !some_mixed. includes( Some ( 2 ) , None ) ) ;
520+ assert ! ( some_mixed. includes( None , Some ( "GPU-12345678" ) ) ) ;
521+ assert ! ( !some_mixed. includes( None , Some ( "GPU-87654321" ) ) ) ;
522+
523+ assert ! ( some_mixed. includes( Some ( 0 ) , Some ( "GPU-99999999" ) ) ) ;
524+ assert ! ( some_mixed. includes( Some ( 99 ) , Some ( "GPU-12345678" ) ) ) ;
525+ assert ! ( !some_mixed. includes( Some ( 99 ) , Some ( "GPU-99999999" ) ) ) ;
526+ }
307527}
0 commit comments