Skip to content

Commit 3ffdcf3

Browse files
committed
fix: better CUDA detection
1 parent 612e4b4 commit 3ffdcf3

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

crates/cudart-sys/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ name = "era_cudart_sys"
1111
description = "Raw CUDA bindings for ZKsync"
1212

1313
[dependencies]
14-
serde_json = "1.0"
14+
regex-lite = "0.1"
1515

1616
[build-dependencies]
17-
serde_json = "1.0"
17+
regex-lite = "0.1"

crates/cudart-sys/src/utils.rs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ use std::path::{Path, PathBuf};
44
pub fn get_cuda_path() -> Option<&'static Path> {
55
#[cfg(target_os = "linux")]
66
{
7-
let path = Path::new("/usr/local/cuda");
8-
if path.exists() {
9-
Some(path)
10-
} else {
11-
None
7+
for path_name in [option_env!("CUDA_PATH"), Some("/usr/local/cuda")].iter().flatten() {
8+
let path = Path::new(path_name);
9+
if path.exists() {
10+
println!("CUDA installation found at `{}`", path.display());
11+
return Some(path)
12+
}
1213
}
14+
println!("CUDA installation path not found")
1315
}
1416
#[cfg(target_os = "windows")]
1517
{
@@ -42,12 +44,24 @@ pub fn get_cuda_lib_path() -> Option<PathBuf> {
4244

4345
pub fn get_cuda_version() -> Option<String> {
4446
if let Some(version) = option_env!("CUDA_VERSION") {
45-
Some(version.to_string())
47+
println!("CUDA version defined in CUDA_VERSION as `{}`", version);
48+
version.to_string()
4649
} else if let Some(path) = get_cuda_path() {
47-
let file = File::open(path.join("version.json")).expect("CUDA Toolkit should be installed");
48-
let reader = std::io::BufReader::new(file);
49-
let value: serde_json::Value = serde_json::from_reader(reader).unwrap();
50-
Some(value["cuda"]["version"].as_str().unwrap().to_string())
50+
println!("inferring CUDA version from nvcc output...");
51+
let re = regex_lite::Regex::new(r"V(?<version>\d{2}\.\d+\.\d+)").unwrap();
52+
let nvcc_out = std::process::Command::new("nvcc")
53+
.arg("--version")
54+
.output()
55+
.expect("failed to start `nvcc`");
56+
let nvcc_str = std::str::from_utf8(&nvcc_out.stdout).expect("`nvcc` output is not UTF8");
57+
let captures = re.captures(&nvcc_str).unwrap();
58+
let version = captures
59+
.get(0)
60+
.expect("unable to find nvcc version in the form VMM.mm.pp in the output of `nvcc --version`:\n{nvcc_str}")
61+
.as_str()
62+
.to_string();
63+
println!("CUDA version inferred to be `{version}`.");
64+
Some(version)
5165
} else {
5266
None
5367
}

0 commit comments

Comments
 (0)