Skip to content

Commit 94f085a

Browse files
committed
Add support for backend-specific Python dependencies
Add `general.cuda.python-depends` and `general.xpu.python-depends`. Currently `nvidia-cutlass-dsl` is supported for CUDA and `onednn` for XPU.
1 parent d7aa270 commit 94f085a

File tree

16 files changed

+299
-78
lines changed

16 files changed

+299
-78
lines changed

build2cmake/Cargo.lock

Lines changed: 21 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

build2cmake/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ rand = "0.8"
2020
serde = { version = "1", features = ["derive"] }
2121
serde_json = "1"
2222
serde-value = "0.7"
23+
thiserror = "1"
2324
toml = "0.8"
2425

2526
[build-dependencies]

build2cmake/src/config/v3.rs

Lines changed: 146 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,77 @@ use std::{
33
fmt::Display,
44
path::PathBuf,
55
str::FromStr,
6+
sync::LazyLock,
67
};
78

89
use eyre::Result;
910
use itertools::Itertools;
1011
use serde::{Deserialize, Serialize};
12+
use thiserror::Error;
1113

14+
use super::{common::Dependency, v2};
1215
use crate::version::Version;
1316

14-
use super::{common::Dependency, v2};
17+
#[derive(Debug, Error)]
18+
enum DependencyError {
19+
#[error("No dependencies are defined: {backend:?}")]
20+
UnknownBackend { backend: String },
21+
#[error("Unknown dependency `{dependency:?}` for backend `{backend:?}`")]
22+
UnknownDependency { backend: String, dependency: String },
23+
#[error("Unknown dependency: `{dependency:?}`")]
24+
UnknownGeneralDependency { dependency: String },
25+
}
26+
27+
#[derive(Debug, Deserialize, Serialize)]
28+
#[serde(deny_unknown_fields)]
29+
struct PythonDependencies {
30+
general: HashMap<String, PythonDependency>,
31+
backends: HashMap<Backend, HashMap<String, PythonDependency>>,
32+
}
33+
34+
impl PythonDependencies {
35+
fn get_dependency(&self, dependency: &str) -> Result<&[String], DependencyError> {
36+
match self.general.get(dependency) {
37+
None => Err(DependencyError::UnknownGeneralDependency {
38+
dependency: dependency.to_string(),
39+
}),
40+
Some(dep) => Ok(&dep.python),
41+
}
42+
}
43+
44+
fn get_backend_dependency(
45+
&self,
46+
backend: Backend,
47+
dependency: &str,
48+
) -> Result<&[String], DependencyError> {
49+
let backend_deps = match self.backends.get(&backend) {
50+
None => {
51+
return Err(DependencyError::UnknownBackend {
52+
backend: backend.to_string(),
53+
})
54+
}
55+
Some(backend_deps) => backend_deps,
56+
};
57+
match backend_deps.get(dependency) {
58+
None => {
59+
return Err(DependencyError::UnknownDependency {
60+
backend: backend.to_string(),
61+
dependency: dependency.to_string(),
62+
})
63+
}
64+
Some(dep) => Ok(&dep.python),
65+
}
66+
}
67+
}
68+
69+
#[derive(Debug, Deserialize, Serialize)]
70+
struct PythonDependency {
71+
nix: Vec<String>,
72+
python: Vec<String>,
73+
}
74+
75+
static PYTHON_DEPENDENCIES: LazyLock<PythonDependencies> =
76+
LazyLock::new(|| serde_json::from_str(include_str!("../python_dependencies.json")).unwrap());
1577

1678
#[derive(Debug, Deserialize, Serialize)]
1779
#[serde(deny_unknown_fields)]
@@ -44,46 +106,114 @@ pub struct General {
44106

45107
pub hub: Option<Hub>,
46108

47-
pub python_depends: Option<Vec<PythonDependency>>,
109+
pub python_depends: Option<Vec<String>>,
110+
111+
pub xpu: Option<XpuGeneral>,
48112
}
49113

50114
impl General {
51115
/// Name of the kernel as a Python extension.
52116
pub fn python_name(&self) -> String {
53117
self.name.replace("-", "_")
54118
}
119+
120+
pub fn python_depends(&self) -> Box<dyn Iterator<Item = Result<String>> + '_> {
121+
let general_python_deps = match self.python_depends.as_ref() {
122+
Some(deps) => deps,
123+
None => {
124+
return Box::new(std::iter::empty());
125+
}
126+
};
127+
128+
Box::new(general_python_deps.iter().flat_map(move |dep| {
129+
match PYTHON_DEPENDENCIES.get_dependency(dep) {
130+
Ok(deps) => deps.iter().map(|s| Ok(s.clone())).collect::<Vec<_>>(),
131+
Err(e) => vec![Err(e.into())],
132+
}
133+
}))
134+
}
135+
136+
pub fn backend_python_depends(
137+
&self,
138+
backend: Backend,
139+
) -> Box<dyn Iterator<Item = Result<String>> + '_> {
140+
let backend_python_deps = match backend {
141+
Backend::Cuda => self
142+
.cuda
143+
.as_ref()
144+
.and_then(|cuda| cuda.python_depends.as_ref()),
145+
Backend::Xpu => self
146+
.xpu
147+
.as_ref()
148+
.and_then(|xpu| xpu.python_depends.as_ref()),
149+
_ => None,
150+
};
151+
152+
let backend_python_deps = match backend_python_deps {
153+
Some(deps) => deps,
154+
None => {
155+
return Box::new(std::iter::empty());
156+
}
157+
};
158+
159+
Box::new(backend_python_deps.iter().flat_map(move |dep| {
160+
match PYTHON_DEPENDENCIES.get_backend_dependency(backend, dep) {
161+
Ok(deps) => deps.iter().map(|s| Ok(s.clone())).collect::<Vec<_>>(),
162+
Err(e) => vec![Err(e.into())],
163+
}
164+
}))
165+
}
55166
}
56167

57168
#[derive(Debug, Deserialize, Serialize)]
58169
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
59170
pub struct CudaGeneral {
60171
pub minver: Option<Version>,
61172
pub maxver: Option<Version>,
173+
pub python_depends: Option<Vec<String>>,
62174
}
63175

64176
#[derive(Debug, Deserialize, Serialize)]
65177
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
66-
pub struct Hub {
67-
pub repo_id: Option<String>,
68-
pub branch: Option<String>,
178+
pub enum CudaPythonDependency {
179+
NvidiaCutlassDsl,
69180
}
70181

71-
#[derive(Clone, Debug, Deserialize, Serialize)]
182+
impl Display for CudaPythonDependency {
183+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184+
match self {
185+
CudaPythonDependency::NvidiaCutlassDsl => write!(f, "nvidia-cutlass-dsl"),
186+
}
187+
}
188+
}
189+
190+
#[derive(Debug, Deserialize, Serialize)]
72191
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
73-
pub enum PythonDependency {
74-
Einops,
75-
NvidiaCutlassDsl,
192+
pub struct XpuGeneral {
193+
pub python_depends: Option<Vec<String>>,
76194
}
77195

78-
impl Display for PythonDependency {
196+
#[derive(Debug, Deserialize, Serialize)]
197+
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
198+
pub enum XpuPythonDependency {
199+
Onednn,
200+
}
201+
202+
impl Display for XpuPythonDependency {
79203
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80204
match self {
81-
PythonDependency::Einops => write!(f, "einops"),
82-
PythonDependency::NvidiaCutlassDsl => write!(f, "nvidia-cutlass-dsl"),
205+
XpuPythonDependency::Onednn => write!(f, "onednn-devel"),
83206
}
84207
}
85208
}
86209

210+
#[derive(Debug, Deserialize, Serialize)]
211+
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
212+
pub struct Hub {
213+
pub repo_id: Option<String>,
214+
pub branch: Option<String>,
215+
}
216+
87217
#[derive(Debug, Deserialize, Clone, Serialize)]
88218
#[serde(deny_unknown_fields)]
89219
pub struct Torch {
@@ -215,7 +345,7 @@ impl Kernel {
215345
}
216346
}
217347

218-
#[derive(Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize)]
348+
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
219349
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
220350
pub enum Backend {
221351
Cpu,
@@ -290,6 +420,7 @@ impl General {
290420
Some(CudaGeneral {
291421
minver: general.cuda_minver,
292422
maxver: general.cuda_maxver,
423+
python_depends: None,
293424
})
294425
} else {
295426
None
@@ -300,9 +431,8 @@ impl General {
300431
backends,
301432
cuda,
302433
hub: general.hub.map(Into::into),
303-
python_depends: general
304-
.python_depends
305-
.map(|deps| deps.into_iter().map(Into::into).collect()),
434+
python_depends: None,
435+
xpu: None,
306436
}
307437
}
308438
}
@@ -316,15 +446,6 @@ impl From<v2::Hub> for Hub {
316446
}
317447
}
318448

319-
impl From<v2::PythonDependency> for PythonDependency {
320-
fn from(dep: v2::PythonDependency) -> Self {
321-
match dep {
322-
v2::PythonDependency::Einops => PythonDependency::Einops,
323-
v2::PythonDependency::NvidiaCutlassDsl => PythonDependency::NvidiaCutlassDsl,
324-
}
325-
}
326-
}
327-
328449
impl From<v2::Torch> for Torch {
329450
fn from(torch: v2::Torch) -> Self {
330451
Self {

build2cmake/src/main.rs

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ fn generate_torch(
172172
};
173173

174174
let file_set = if build.is_noarch() {
175-
write_torch_ext_noarch(&env, &build, target_dir.clone(), ops_id)?
175+
write_torch_ext_noarch(&env, backend, &build, target_dir.clone(), ops_id)?
176176
} else {
177177
match backend {
178178
Backend::Cpu => write_torch_ext_cpu(&env, &build, target_dir.clone(), ops_id)?,
@@ -376,26 +376,30 @@ fn get_generated_files(
376376
let mut all_set = FileSet::new();
377377

378378
if build.is_noarch() {
379-
let set = write_torch_ext_noarch(env, build, target_dir.clone(), ops_id.clone())?;
380-
381-
all_set.extend(set);
382379
} else {
383380
for backend in &build.general.backends {
384-
let set = match backend {
385-
Backend::Cpu => {
386-
write_torch_ext_cpu(env, build, target_dir.clone(), ops_id.clone())?
387-
}
388-
Backend::Cuda | Backend::Rocm => {
389-
write_torch_ext_cuda(env, *backend, build, target_dir.clone(), ops_id.clone())?
390-
}
391-
Backend::Metal => {
392-
write_torch_ext_metal(env, build, target_dir.clone(), ops_id.clone())?
393-
}
394-
Backend::Xpu => {
395-
write_torch_ext_xpu(env, build, target_dir.clone(), ops_id.clone())?
381+
let set = if build.is_noarch() {
382+
write_torch_ext_noarch(env, *backend, build, target_dir.clone(), ops_id.clone())?
383+
} else {
384+
match backend {
385+
Backend::Cpu => {
386+
write_torch_ext_cpu(env, build, target_dir.clone(), ops_id.clone())?
387+
}
388+
Backend::Cuda | Backend::Rocm => write_torch_ext_cuda(
389+
env,
390+
*backend,
391+
build,
392+
target_dir.clone(),
393+
ops_id.clone(),
394+
)?,
395+
Backend::Metal => {
396+
write_torch_ext_metal(env, build, target_dir.clone(), ops_id.clone())?
397+
}
398+
Backend::Xpu => {
399+
write_torch_ext_xpu(env, build, target_dir.clone(), ops_id.clone())?
400+
}
396401
}
397402
};
398-
399403
all_set.extend(set);
400404
}
401405
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"general": {
3+
"einops": {
4+
"nix": ["einops"],
5+
"python": ["einops"]
6+
}
7+
},
8+
"backends": {
9+
"cpu": {},
10+
"cuda": {
11+
"nvidia-cutlass-dsl": {
12+
"nix": ["nvidia-cutlass-dsl"],
13+
"python": ["nvidia-cutlass-dsl"]
14+
}
15+
},
16+
"metal": {},
17+
"rocm": {},
18+
"xpu": {
19+
"onednn": {
20+
"nix": [],
21+
"python": ["onednn-devel"]
22+
}
23+
}
24+
}
25+
}

0 commit comments

Comments
 (0)