Skip to content

Commit 647fcb3

Browse files
authored
Add support for backend-specific Python dependencies (#325)
Add `general.cuda.python-depends` and `general.xpu.python-depends`. Currently `nvidia-cutlass-dsl` is supported for CUDA and `onednn` for XPU.
1 parent 3eeb97b commit 647fcb3

File tree

16 files changed

+262
-78
lines changed

16 files changed

+262
-78
lines changed

build2cmake/Cargo.lock

Lines changed: 21 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

build2cmake/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ rand = "0.8"
2020
serde = { version = "1", features = ["derive"] }
2121
serde_json = "1"
2222
serde-value = "0.7"
23+
thiserror = "1"
2324
toml = "0.8"
2425

2526
[build-dependencies]

build2cmake/src/config/v3.rs

Lines changed: 122 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,75 @@ use std::{
33
fmt::Display,
44
path::PathBuf,
55
str::FromStr,
6+
sync::LazyLock,
67
};
78

89
use eyre::Result;
910
use itertools::Itertools;
1011
use serde::{Deserialize, Serialize};
12+
use thiserror::Error;
1113

14+
use super::{common::Dependency, v2};
1215
use crate::version::Version;
1316

14-
use super::{common::Dependency, v2};
17+
#[derive(Debug, Error)]
18+
enum DependencyError {
19+
#[error("No dependencies are defined for backend: {backend:?}")]
20+
Backend { backend: String },
21+
#[error("Unknown dependency `{dependency:?}` for backend `{backend:?}`")]
22+
Dependency { backend: String, dependency: String },
23+
#[error("Unknown dependency: `{dependency:?}`")]
24+
GeneralDependency { dependency: String },
25+
}
26+
27+
#[derive(Debug, Deserialize, Serialize)]
28+
#[serde(deny_unknown_fields)]
29+
struct PythonDependencies {
30+
general: HashMap<String, PythonDependency>,
31+
backends: HashMap<Backend, HashMap<String, PythonDependency>>,
32+
}
33+
34+
impl PythonDependencies {
35+
fn get_dependency(&self, dependency: &str) -> Result<&[String], DependencyError> {
36+
match self.general.get(dependency) {
37+
None => Err(DependencyError::GeneralDependency {
38+
dependency: dependency.to_string(),
39+
}),
40+
Some(dep) => Ok(&dep.python),
41+
}
42+
}
43+
44+
fn get_backend_dependency(
45+
&self,
46+
backend: Backend,
47+
dependency: &str,
48+
) -> Result<&[String], DependencyError> {
49+
let backend_deps = match self.backends.get(&backend) {
50+
None => {
51+
return Err(DependencyError::Backend {
52+
backend: backend.to_string(),
53+
})
54+
}
55+
Some(backend_deps) => backend_deps,
56+
};
57+
match backend_deps.get(dependency) {
58+
None => Err(DependencyError::Dependency {
59+
backend: backend.to_string(),
60+
dependency: dependency.to_string(),
61+
}),
62+
Some(dep) => Ok(&dep.python),
63+
}
64+
}
65+
}
66+
67+
#[derive(Debug, Deserialize, Serialize)]
68+
struct PythonDependency {
69+
nix: Vec<String>,
70+
python: Vec<String>,
71+
}
72+
73+
static PYTHON_DEPENDENCIES: LazyLock<PythonDependencies> =
74+
LazyLock::new(|| serde_json::from_str(include_str!("../python_dependencies.json")).unwrap());
1575

1676
#[derive(Debug, Deserialize, Serialize)]
1777
#[serde(deny_unknown_fields)]
@@ -44,44 +104,84 @@ pub struct General {
44104

45105
pub hub: Option<Hub>,
46106

47-
pub python_depends: Option<Vec<PythonDependency>>,
107+
pub python_depends: Option<Vec<String>>,
108+
109+
pub xpu: Option<XpuGeneral>,
48110
}
49111

50112
impl General {
51113
/// Name of the kernel as a Python extension.
52114
pub fn python_name(&self) -> String {
53115
self.name.replace("-", "_")
54116
}
117+
118+
pub fn python_depends(&self) -> Box<dyn Iterator<Item = Result<String>> + '_> {
119+
let general_python_deps = match self.python_depends.as_ref() {
120+
Some(deps) => deps,
121+
None => {
122+
return Box::new(std::iter::empty());
123+
}
124+
};
125+
126+
Box::new(general_python_deps.iter().flat_map(move |dep| {
127+
match PYTHON_DEPENDENCIES.get_dependency(dep) {
128+
Ok(deps) => deps.iter().map(|s| Ok(s.clone())).collect::<Vec<_>>(),
129+
Err(e) => vec![Err(e.into())],
130+
}
131+
}))
132+
}
133+
134+
pub fn backend_python_depends(
135+
&self,
136+
backend: Backend,
137+
) -> Box<dyn Iterator<Item = Result<String>> + '_> {
138+
let backend_python_deps = match backend {
139+
Backend::Cuda => self
140+
.cuda
141+
.as_ref()
142+
.and_then(|cuda| cuda.python_depends.as_ref()),
143+
Backend::Xpu => self
144+
.xpu
145+
.as_ref()
146+
.and_then(|xpu| xpu.python_depends.as_ref()),
147+
_ => None,
148+
};
149+
150+
let backend_python_deps = match backend_python_deps {
151+
Some(deps) => deps,
152+
None => {
153+
return Box::new(std::iter::empty());
154+
}
155+
};
156+
157+
Box::new(backend_python_deps.iter().flat_map(move |dep| {
158+
match PYTHON_DEPENDENCIES.get_backend_dependency(backend, dep) {
159+
Ok(deps) => deps.iter().map(|s| Ok(s.clone())).collect::<Vec<_>>(),
160+
Err(e) => vec![Err(e.into())],
161+
}
162+
}))
163+
}
55164
}
56165

57166
#[derive(Debug, Deserialize, Serialize)]
58167
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
59168
pub struct CudaGeneral {
60169
pub minver: Option<Version>,
61170
pub maxver: Option<Version>,
171+
pub python_depends: Option<Vec<String>>,
62172
}
63173

64174
#[derive(Debug, Deserialize, Serialize)]
65175
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
66-
pub struct Hub {
67-
pub repo_id: Option<String>,
68-
pub branch: Option<String>,
176+
pub struct XpuGeneral {
177+
pub python_depends: Option<Vec<String>>,
69178
}
70179

71-
#[derive(Clone, Debug, Deserialize, Serialize)]
180+
#[derive(Debug, Deserialize, Serialize)]
72181
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
73-
pub enum PythonDependency {
74-
Einops,
75-
NvidiaCutlassDsl,
76-
}
77-
78-
impl Display for PythonDependency {
79-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80-
match self {
81-
PythonDependency::Einops => write!(f, "einops"),
82-
PythonDependency::NvidiaCutlassDsl => write!(f, "nvidia-cutlass-dsl"),
83-
}
84-
}
182+
pub struct Hub {
183+
pub repo_id: Option<String>,
184+
pub branch: Option<String>,
85185
}
86186

87187
#[derive(Debug, Deserialize, Clone, Serialize)]
@@ -215,7 +315,7 @@ impl Kernel {
215315
}
216316
}
217317

218-
#[derive(Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize)]
318+
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
219319
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
220320
pub enum Backend {
221321
Cpu,
@@ -290,6 +390,7 @@ impl General {
290390
Some(CudaGeneral {
291391
minver: general.cuda_minver,
292392
maxver: general.cuda_maxver,
393+
python_depends: None,
293394
})
294395
} else {
295396
None
@@ -300,9 +401,8 @@ impl General {
300401
backends,
301402
cuda,
302403
hub: general.hub.map(Into::into),
303-
python_depends: general
304-
.python_depends
305-
.map(|deps| deps.into_iter().map(Into::into).collect()),
404+
python_depends: None,
405+
xpu: None,
306406
}
307407
}
308408
}
@@ -316,15 +416,6 @@ impl From<v2::Hub> for Hub {
316416
}
317417
}
318418

319-
impl From<v2::PythonDependency> for PythonDependency {
320-
fn from(dep: v2::PythonDependency) -> Self {
321-
match dep {
322-
v2::PythonDependency::Einops => PythonDependency::Einops,
323-
v2::PythonDependency::NvidiaCutlassDsl => PythonDependency::NvidiaCutlassDsl,
324-
}
325-
}
326-
}
327-
328419
impl From<v2::Torch> for Torch {
329420
fn from(torch: v2::Torch) -> Self {
330421
Self {

build2cmake/src/main.rs

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ fn generate_torch(
172172
};
173173

174174
let file_set = if build.is_noarch() {
175-
write_torch_ext_noarch(&env, &build, target_dir.clone(), ops_id)?
175+
write_torch_ext_noarch(&env, backend, &build, target_dir.clone(), ops_id)?
176176
} else {
177177
match backend {
178178
Backend::Cpu => write_torch_ext_cpu(&env, &build, target_dir.clone(), ops_id)?,
@@ -375,13 +375,11 @@ fn get_generated_files(
375375
) -> Result<Vec<PathBuf>> {
376376
let mut all_set = FileSet::new();
377377

378-
if build.is_noarch() {
379-
let set = write_torch_ext_noarch(env, build, target_dir.clone(), ops_id.clone())?;
380-
381-
all_set.extend(set);
382-
} else {
383-
for backend in &build.general.backends {
384-
let set = match backend {
378+
for backend in &build.general.backends {
379+
let set = if build.is_noarch() {
380+
write_torch_ext_noarch(env, *backend, build, target_dir.clone(), ops_id.clone())?
381+
} else {
382+
match backend {
385383
Backend::Cpu => {
386384
write_torch_ext_cpu(env, build, target_dir.clone(), ops_id.clone())?
387385
}
@@ -394,10 +392,9 @@ fn get_generated_files(
394392
Backend::Xpu => {
395393
write_torch_ext_xpu(env, build, target_dir.clone(), ops_id.clone())?
396394
}
397-
};
398-
399-
all_set.extend(set);
400-
}
395+
}
396+
};
397+
all_set.extend(set);
401398
}
402399

403400
Ok(all_set.into_names())
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"general": {
3+
"einops": {
4+
"nix": ["einops"],
5+
"python": ["einops"]
6+
}
7+
},
8+
"backends": {
9+
"cpu": {},
10+
"cuda": {
11+
"nvidia-cutlass-dsl": {
12+
"nix": ["nvidia-cutlass-dsl"],
13+
"python": ["nvidia-cutlass-dsl"]
14+
}
15+
},
16+
"metal": {},
17+
"rocm": {},
18+
"xpu": {
19+
"onednn": {
20+
"nix": [],
21+
"python": ["onednn-devel"]
22+
}
23+
}
24+
}
25+
}

build2cmake/src/torch/common.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,23 @@ use eyre::{Context, Result};
22
use itertools::Itertools;
33
use minijinja::{context, Environment};
44

5-
use crate::{config::General, FileSet};
5+
use crate::config::{Backend, General};
6+
use crate::FileSet;
67

78
pub fn write_pyproject_toml(
89
env: &Environment,
10+
backend: Backend,
911
general: &General,
1012
file_set: &mut FileSet,
1113
) -> Result<()> {
1214
let writer = file_set.entry("pyproject.toml");
1315

14-
let python_dependencies = general
15-
.python_depends
16-
.as_ref()
17-
.unwrap_or(&vec![])
18-
.iter()
19-
.map(|d| format!("\"{d}\""))
20-
.join(", ");
16+
let python_dependencies = itertools::process_results(
17+
general
18+
.python_depends()
19+
.chain(general.backend_python_depends(backend)),
20+
|iter| iter.map(|d| format!("\"{d}\"")).join(", "),
21+
)?;
2122

2223
env.get_template("pyproject.toml")
2324
.wrap_err("Cannot get pyproject.toml template")?

build2cmake/src/torch/cpu.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use minijinja::{context, Environment};
66

77
use super::{common::write_pyproject_toml, kernel_ops_identifier};
88
use crate::{
9-
config::{Build, Kernel, Torch},
9+
config::{Backend, Build, Kernel, Torch},
1010
fileset::FileSet,
1111
version::Version,
1212
};
@@ -48,7 +48,7 @@ pub fn write_torch_ext_cpu(
4848

4949
write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?;
5050

51-
write_pyproject_toml(env, &build.general, &mut file_set)?;
51+
write_pyproject_toml(env, Backend::Cpu, &build.general, &mut file_set)?;
5252

5353
write_torch_registration_macros(&mut file_set)?;
5454

build2cmake/src/torch/cuda.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ pub fn write_torch_ext_cuda(
6161

6262
write_ops_py(env, &build.general.python_name(), &ops_name, &mut file_set)?;
6363

64-
write_pyproject_toml(env, &build.general, &mut file_set)?;
64+
write_pyproject_toml(env, backend, &build.general, &mut file_set)?;
6565

6666
write_torch_registration_macros(&mut file_set)?;
6767

0 commit comments

Comments
 (0)