Skip to content

Commit d4f5a2d

Browse files
committed
Respect multi-GPU outputs in nvidia-smi
1 parent 93630a8 commit d4f5a2d

File tree

2 files changed

+229
-5
lines changed

2 files changed

+229
-5
lines changed

crates/uv-static/src/env_vars.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,4 +833,8 @@ impl EnvVars {
833833

834834
/// Disable Hugging Face authentication, even if `HF_TOKEN` is set.
835835
pub const UV_NO_HF_TOKEN: &'static str = "UV_NO_HF_TOKEN";
836+
837+
/// The visible devices for NVIDIA GPUs, to respect when querying `nvidia-smi` to detect GPU
838+
/// drivers.
839+
pub const NVIDIA_VISIBLE_DEVICES: &'static str = "NVIDIA_VISIBLE_DEVICES";
836840
}

crates/uv-torch/src/accelerator.rs

Lines changed: 225 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ pub enum AcceleratorError {
1616
Utf8(#[from] std::string::FromUtf8Error),
1717
#[error(transparent)]
1818
ParseInt(#[from] std::num::ParseIntError),
19+
#[error("Failed to parse NVIDIA device: {0}")]
20+
Device(String),
1921
#[error("Unknown AMD GPU architecture: {0}")]
2022
UnknownAmdGpuArchitecture(String),
2123
}
@@ -57,7 +59,7 @@ impl Accelerator {
5759
/// 2. The `UV_AMD_GPU_ARCHITECTURE` environment variable.
5860
/// 3. `/sys/module/nvidia/version`, which contains the driver version (e.g., `550.144.03`).
5961
/// 4. `/proc/driver/nvidia/version`, which contains the driver version among other information.
60-
/// 5. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`.
62+
/// 5. `nvidia-smi --query-gpu=index,uuid,driver_version --format=csv,noheader`.
6163
/// 6. `rocm_agent_enumerator`, which lists the AMD GPU architectures.
6264
/// 7. `/sys/bus/pci/devices`, filtering for the Intel GPU via PCI.
6365
pub fn detect() -> Result<Option<Self>, AcceleratorError> {
@@ -121,14 +123,40 @@ impl Accelerator {
121123

122124
// Query `nvidia-smi`.
123125
if let Ok(output) = std::process::Command::new("nvidia-smi")
124-
.arg("--query-gpu=driver_version")
126+
.arg("--query-gpu=index,uuid,driver_version")
125127
.arg("--format=csv,noheader")
126128
.output()
127129
{
128130
if output.status.success() {
129-
let driver_version = Version::from_str(&String::from_utf8(output.stdout)?)?;
130-
debug!("Detected CUDA driver version from `nvidia-smi`: {driver_version}");
131-
return Ok(Some(Self::Cuda { driver_version }));
131+
let visible_devices = VisibleDevices::from_env()?.unwrap_or(VisibleDevices::All);
132+
let stdout = String::from_utf8(output.stdout)?;
133+
for line in stdout.lines() {
134+
let mut parts = line.split(',');
135+
136+
// Parse the GPU index.
137+
let index = parts.next().and_then(|s| s.trim().parse::<usize>().ok());
138+
139+
// Parse the GPU UUID.
140+
let uuid = parts.next().map(str::trim);
141+
142+
// Determine if this GPU is visible based on the environment variable.
143+
if visible_devices.includes(index, uuid) {
144+
if let Some(driver_version) = parts.next() {
145+
let driver_version = Version::from_str(driver_version.trim())?;
146+
debug!(
147+
"Detected CUDA driver version from `nvidia-smi`: {driver_version}"
148+
);
149+
return Ok(Some(Self::Cuda { driver_version }));
150+
}
151+
} else {
152+
debug!("Skipping invisible GPU {index:?} with UUID: {uuid:?}");
153+
}
154+
}
155+
if let Some(first_line) = stdout.lines().next() {
156+
let driver_version = Version::from_str(first_line.trim())?;
157+
debug!("Detected CUDA driver version from `nvidia-smi`: {driver_version}");
158+
return Ok(Some(Self::Cuda { driver_version }));
159+
}
132160
}
133161

134162
debug!(
@@ -193,6 +221,68 @@ impl Accelerator {
193221
}
194222
}
195223

224+
#[derive(Debug, Clone, Eq, PartialEq)]
225+
enum VisibleDevices {
226+
/// All GPUs are visible.
227+
All,
228+
/// No GPUs are visible.
229+
None,
230+
/// Some GPUs are visible, specified by their indices and/or UUIDs.
231+
Some {
232+
uuids: Vec<String>,
233+
indices: Vec<usize>,
234+
},
235+
}
236+
237+
impl VisibleDevices {
238+
/// Read and parse the [`NVIDIA_VISIBLE_DEVICES`] environment variable.
239+
fn from_env() -> Result<Option<Self>, AcceleratorError> {
240+
let Some(nvidia_visible_devices) = std::env::var(EnvVars::NVIDIA_VISIBLE_DEVICES).ok()
241+
else {
242+
return Ok(None);
243+
};
244+
Self::parse(&nvidia_visible_devices)
245+
}
246+
247+
/// Parse the [`NVIDIA_VISIBLE_DEVICES`] environment variable.
248+
fn parse(s: &str) -> Result<Option<Self>, AcceleratorError> {
249+
if s.is_empty() {
250+
Ok(None)
251+
} else if s == "void" {
252+
Ok(None)
253+
} else if s == "all" {
254+
Ok(Some(Self::All))
255+
} else if s == "none" {
256+
Ok(Some(Self::None))
257+
} else {
258+
let mut indices = Vec::new();
259+
let mut uuids = Vec::new();
260+
for device in s.split(',') {
261+
if device.starts_with("GPU-") {
262+
uuids.push(device.to_string());
263+
} else if let Ok(index) = device.parse::<usize>() {
264+
indices.push(index);
265+
} else {
266+
return Err(AcceleratorError::Device(device.to_string()));
267+
}
268+
}
269+
Ok(Some(Self::Some { uuids, indices }))
270+
}
271+
}
272+
273+
/// Return `true` if the given index or UUID is included in the visible devices.
274+
fn includes(&self, index: Option<usize>, uuid: Option<&str>) -> bool {
275+
match self {
276+
Self::All => true,
277+
Self::None => false,
278+
Self::Some { uuids, indices } => {
279+
index.is_some_and(|index| indices.contains(&index))
280+
|| uuid.is_some_and(|uuid| uuids.iter().any(|value| value == uuid))
281+
}
282+
}
283+
}
284+
}
285+
196286
/// Parse the CUDA driver version from the content of `/sys/module/nvidia/version`.
197287
fn parse_sys_module_nvidia_version(content: &str) -> Result<Version, AcceleratorError> {
198288
// Parse, e.g.:
@@ -304,4 +394,134 @@ mod tests {
304394
let result = parse_proc_driver_nvidia_version(content).unwrap();
305395
assert_eq!(result, Some(Version::from_str("375.74").unwrap()));
306396
}
397+
398+
#[test]
399+
fn nvidia_smi_multi_gpu() {
400+
// Test that we can parse nvidia-smi output with multiple GPUs (multiple lines)
401+
let single_gpu = "572.60\n";
402+
if let Some(first_line) = single_gpu.lines().next() {
403+
let version = Version::from_str(first_line.trim()).unwrap();
404+
assert_eq!(version, Version::from_str("572.60").unwrap());
405+
}
406+
407+
let multi_gpu = "572.60\n572.60\n";
408+
if let Some(first_line) = multi_gpu.lines().next() {
409+
let version = Version::from_str(first_line.trim()).unwrap();
410+
assert_eq!(version, Version::from_str("572.60").unwrap());
411+
}
412+
}
413+
414+
#[test]
415+
fn visible_devices_parse() {
416+
assert_eq!(
417+
VisibleDevices::parse("all").unwrap(),
418+
Some(VisibleDevices::All)
419+
);
420+
421+
assert_eq!(
422+
VisibleDevices::parse("none").unwrap(),
423+
Some(VisibleDevices::None)
424+
);
425+
426+
assert_eq!(VisibleDevices::parse("void").unwrap(), None);
427+
428+
assert_eq!(VisibleDevices::parse("").unwrap(), None);
429+
430+
assert_eq!(
431+
VisibleDevices::parse("0").unwrap(),
432+
Some(VisibleDevices::Some {
433+
uuids: vec![],
434+
indices: vec![0]
435+
})
436+
);
437+
438+
assert_eq!(
439+
VisibleDevices::parse("0,1,2").unwrap(),
440+
Some(VisibleDevices::Some {
441+
uuids: vec![],
442+
indices: vec![0, 1, 2]
443+
})
444+
);
445+
446+
assert_eq!(
447+
VisibleDevices::parse("GPU-12345678-abcd-efgh-ijkl-123456789abc").unwrap(),
448+
Some(VisibleDevices::Some {
449+
uuids: vec!["GPU-12345678-abcd-efgh-ijkl-123456789abc".to_string()],
450+
indices: vec![]
451+
})
452+
);
453+
454+
assert_eq!(
455+
VisibleDevices::parse("GPU-12345678,GPU-87654321").unwrap(),
456+
Some(VisibleDevices::Some {
457+
uuids: vec!["GPU-12345678".to_string(), "GPU-87654321".to_string()],
458+
indices: vec![]
459+
})
460+
);
461+
462+
assert_eq!(
463+
VisibleDevices::parse("0,GPU-12345678,1,GPU-87654321,2").unwrap(),
464+
Some(VisibleDevices::Some {
465+
uuids: vec!["GPU-12345678".to_string(), "GPU-87654321".to_string()],
466+
indices: vec![0, 1, 2]
467+
})
468+
);
469+
470+
assert!(matches!(
471+
VisibleDevices::parse("invalid").unwrap_err(),
472+
AcceleratorError::Device(s) if s == "invalid"
473+
));
474+
475+
assert!(matches!(
476+
VisibleDevices::parse("0,invalid,1").unwrap_err(),
477+
AcceleratorError::Device(s) if s == "invalid"
478+
));
479+
}
480+
481+
#[test]
482+
fn visible_devices_includes() {
483+
let all = VisibleDevices::All;
484+
assert!(all.includes(Some(0), None));
485+
assert!(all.includes(None, Some("GPU-12345678")));
486+
assert!(all.includes(Some(999), Some("GPU-any")));
487+
488+
let none = VisibleDevices::None;
489+
assert!(!none.includes(Some(0), None));
490+
assert!(!none.includes(None, Some("GPU-12345678")));
491+
assert!(!none.includes(Some(999), Some("GPU-any")));
492+
493+
let some_indices = VisibleDevices::Some {
494+
uuids: vec![],
495+
indices: vec![0, 2, 4],
496+
};
497+
assert!(some_indices.includes(Some(0), None));
498+
assert!(some_indices.includes(Some(2), None));
499+
assert!(some_indices.includes(Some(4), None));
500+
assert!(!some_indices.includes(Some(1), None));
501+
assert!(!some_indices.includes(Some(3), None));
502+
assert!(!some_indices.includes(None, Some("GPU-12345678")));
503+
504+
let some_uuids = VisibleDevices::Some {
505+
uuids: vec!["GPU-12345678".to_string(), "GPU-87654321".to_string()],
506+
indices: vec![],
507+
};
508+
assert!(some_uuids.includes(None, Some("GPU-12345678")));
509+
assert!(some_uuids.includes(None, Some("GPU-87654321")));
510+
assert!(!some_uuids.includes(None, Some("GPU-99999999")));
511+
assert!(!some_uuids.includes(Some(0), None));
512+
513+
let some_mixed = VisibleDevices::Some {
514+
uuids: vec!["GPU-12345678".to_string()],
515+
indices: vec![0, 1],
516+
};
517+
assert!(some_mixed.includes(Some(0), None));
518+
assert!(some_mixed.includes(Some(1), None));
519+
assert!(!some_mixed.includes(Some(2), None));
520+
assert!(some_mixed.includes(None, Some("GPU-12345678")));
521+
assert!(!some_mixed.includes(None, Some("GPU-87654321")));
522+
523+
assert!(some_mixed.includes(Some(0), Some("GPU-99999999")));
524+
assert!(some_mixed.includes(Some(99), Some("GPU-12345678")));
525+
assert!(!some_mixed.includes(Some(99), Some("GPU-99999999")));
526+
}
307527
}

0 commit comments

Comments
 (0)