@@ -3,15 +3,77 @@ use std::{
33 fmt:: Display ,
44 path:: PathBuf ,
55 str:: FromStr ,
6+ sync:: LazyLock ,
67} ;
78
89use eyre:: Result ;
910use itertools:: Itertools ;
1011use serde:: { Deserialize , Serialize } ;
12+ use thiserror:: Error ;
1113
14+ use super :: { common:: Dependency , v2} ;
1215use crate :: version:: Version ;
1316
14- use super :: { common:: Dependency , v2} ;
17+ #[ derive( Debug , Error ) ]
18+ enum DependencyError {
19+ #[ error( "No dependencies are defined: {backend:?}" ) ]
20+ UnknownBackend { backend : String } ,
21+ #[ error( "Unknown dependency `{dependency:?}` for backend `{backend:?}`" ) ]
22+ UnknownDependency { backend : String , dependency : String } ,
23+ #[ error( "Unknown dependency: `{dependency:?}`" ) ]
24+ UnknownGeneralDependency { dependency : String } ,
25+ }
26+
27+ #[ derive( Debug , Deserialize , Serialize ) ]
28+ #[ serde( deny_unknown_fields) ]
29+ struct PythonDependencies {
30+ general : HashMap < String , PythonDependency > ,
31+ backends : HashMap < Backend , HashMap < String , PythonDependency > > ,
32+ }
33+
34+ impl PythonDependencies {
35+ fn get_dependency ( & self , dependency : & str ) -> Result < & [ String ] , DependencyError > {
36+ match self . general . get ( dependency) {
37+ None => Err ( DependencyError :: UnknownGeneralDependency {
38+ dependency : dependency. to_string ( ) ,
39+ } ) ,
40+ Some ( dep) => Ok ( & dep. python ) ,
41+ }
42+ }
43+
44+ fn get_backend_dependency (
45+ & self ,
46+ backend : Backend ,
47+ dependency : & str ,
48+ ) -> Result < & [ String ] , DependencyError > {
49+ let backend_deps = match self . backends . get ( & backend) {
50+ None => {
51+ return Err ( DependencyError :: UnknownBackend {
52+ backend : backend. to_string ( ) ,
53+ } )
54+ }
55+ Some ( backend_deps) => backend_deps,
56+ } ;
57+ match backend_deps. get ( dependency) {
58+ None => {
59+ return Err ( DependencyError :: UnknownDependency {
60+ backend : backend. to_string ( ) ,
61+ dependency : dependency. to_string ( ) ,
62+ } )
63+ }
64+ Some ( dep) => Ok ( & dep. python ) ,
65+ }
66+ }
67+ }
68+
69+ #[ derive( Debug , Deserialize , Serialize ) ]
70+ struct PythonDependency {
71+ nix : Vec < String > ,
72+ python : Vec < String > ,
73+ }
74+
75+ static PYTHON_DEPENDENCIES : LazyLock < PythonDependencies > =
76+ LazyLock :: new ( || serde_json:: from_str ( include_str ! ( "../python_dependencies.json" ) ) . unwrap ( ) ) ;
1577
1678#[ derive( Debug , Deserialize , Serialize ) ]
1779#[ serde( deny_unknown_fields) ]
@@ -44,46 +106,114 @@ pub struct General {
44106
45107 pub hub : Option < Hub > ,
46108
47- pub python_depends : Option < Vec < PythonDependency > > ,
109+ pub python_depends : Option < Vec < String > > ,
110+
111+ pub xpu : Option < XpuGeneral > ,
48112}
49113
50114impl General {
51115 /// Name of the kernel as a Python extension.
52116 pub fn python_name ( & self ) -> String {
53117 self . name . replace ( "-" , "_" )
54118 }
119+
120+ pub fn python_depends ( & self ) -> Box < dyn Iterator < Item = Result < String > > + ' _ > {
121+ let general_python_deps = match self . python_depends . as_ref ( ) {
122+ Some ( deps) => deps,
123+ None => {
124+ return Box :: new ( std:: iter:: empty ( ) ) ;
125+ }
126+ } ;
127+
128+ Box :: new ( general_python_deps. iter ( ) . flat_map ( move |dep| {
129+ match PYTHON_DEPENDENCIES . get_dependency ( dep) {
130+ Ok ( deps) => deps. iter ( ) . map ( |s| Ok ( s. clone ( ) ) ) . collect :: < Vec < _ > > ( ) ,
131+ Err ( e) => vec ! [ Err ( e. into( ) ) ] ,
132+ }
133+ } ) )
134+ }
135+
136+ pub fn backend_python_depends (
137+ & self ,
138+ backend : Backend ,
139+ ) -> Box < dyn Iterator < Item = Result < String > > + ' _ > {
140+ let backend_python_deps = match backend {
141+ Backend :: Cuda => self
142+ . cuda
143+ . as_ref ( )
144+ . and_then ( |cuda| cuda. python_depends . as_ref ( ) ) ,
145+ Backend :: Xpu => self
146+ . xpu
147+ . as_ref ( )
148+ . and_then ( |xpu| xpu. python_depends . as_ref ( ) ) ,
149+ _ => None ,
150+ } ;
151+
152+ let backend_python_deps = match backend_python_deps {
153+ Some ( deps) => deps,
154+ None => {
155+ return Box :: new ( std:: iter:: empty ( ) ) ;
156+ }
157+ } ;
158+
159+ Box :: new ( backend_python_deps. iter ( ) . flat_map ( move |dep| {
160+ match PYTHON_DEPENDENCIES . get_backend_dependency ( backend, dep) {
161+ Ok ( deps) => deps. iter ( ) . map ( |s| Ok ( s. clone ( ) ) ) . collect :: < Vec < _ > > ( ) ,
162+ Err ( e) => vec ! [ Err ( e. into( ) ) ] ,
163+ }
164+ } ) )
165+ }
55166}
56167
57168#[ derive( Debug , Deserialize , Serialize ) ]
58169#[ serde( deny_unknown_fields, rename_all = "kebab-case" ) ]
59170pub struct CudaGeneral {
60171 pub minver : Option < Version > ,
61172 pub maxver : Option < Version > ,
173+ pub python_depends : Option < Vec < String > > ,
62174}
63175
64176#[ derive( Debug , Deserialize , Serialize ) ]
65177#[ serde( deny_unknown_fields, rename_all = "kebab-case" ) ]
66- pub struct Hub {
67- pub repo_id : Option < String > ,
68- pub branch : Option < String > ,
178+ pub enum CudaPythonDependency {
179+ NvidiaCutlassDsl ,
69180}
70181
71- #[ derive( Clone , Debug , Deserialize , Serialize ) ]
182+ impl Display for CudaPythonDependency {
183+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
184+ match self {
185+ CudaPythonDependency :: NvidiaCutlassDsl => write ! ( f, "nvidia-cutlass-dsl" ) ,
186+ }
187+ }
188+ }
189+
190+ #[ derive( Debug , Deserialize , Serialize ) ]
72191#[ serde( deny_unknown_fields, rename_all = "kebab-case" ) ]
73- pub enum PythonDependency {
74- Einops ,
75- NvidiaCutlassDsl ,
192+ pub struct XpuGeneral {
193+ pub python_depends : Option < Vec < String > > ,
76194}
77195
78- impl Display for PythonDependency {
196+ #[ derive( Debug , Deserialize , Serialize ) ]
197+ #[ serde( deny_unknown_fields, rename_all = "kebab-case" ) ]
198+ pub enum XpuPythonDependency {
199+ Onednn ,
200+ }
201+
202+ impl Display for XpuPythonDependency {
79203 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
80204 match self {
81- PythonDependency :: Einops => write ! ( f, "einops" ) ,
82- PythonDependency :: NvidiaCutlassDsl => write ! ( f, "nvidia-cutlass-dsl" ) ,
205+ XpuPythonDependency :: Onednn => write ! ( f, "onednn-devel" ) ,
83206 }
84207 }
85208}
86209
210+ #[ derive( Debug , Deserialize , Serialize ) ]
211+ #[ serde( deny_unknown_fields, rename_all = "kebab-case" ) ]
212+ pub struct Hub {
213+ pub repo_id : Option < String > ,
214+ pub branch : Option < String > ,
215+ }
216+
87217#[ derive( Debug , Deserialize , Clone , Serialize ) ]
88218#[ serde( deny_unknown_fields) ]
89219pub struct Torch {
@@ -215,7 +345,7 @@ impl Kernel {
215345 }
216346}
217347
218- #[ derive( Clone , Copy , Debug , Deserialize , Eq , Ord , PartialEq , PartialOrd , Serialize ) ]
348+ #[ derive( Clone , Copy , Debug , Deserialize , Eq , Hash , Ord , PartialEq , PartialOrd , Serialize ) ]
219349#[ serde( deny_unknown_fields, rename_all = "kebab-case" ) ]
220350pub enum Backend {
221351 Cpu ,
@@ -290,6 +420,7 @@ impl General {
290420 Some ( CudaGeneral {
291421 minver : general. cuda_minver ,
292422 maxver : general. cuda_maxver ,
423+ python_depends : None ,
293424 } )
294425 } else {
295426 None
@@ -300,9 +431,8 @@ impl General {
300431 backends,
301432 cuda,
302433 hub : general. hub . map ( Into :: into) ,
303- python_depends : general
304- . python_depends
305- . map ( |deps| deps. into_iter ( ) . map ( Into :: into) . collect ( ) ) ,
434+ python_depends : None ,
435+ xpu : None ,
306436 }
307437 }
308438}
@@ -316,15 +446,6 @@ impl From<v2::Hub> for Hub {
316446 }
317447}
318448
319- impl From < v2:: PythonDependency > for PythonDependency {
320- fn from ( dep : v2:: PythonDependency ) -> Self {
321- match dep {
322- v2:: PythonDependency :: Einops => PythonDependency :: Einops ,
323- v2:: PythonDependency :: NvidiaCutlassDsl => PythonDependency :: NvidiaCutlassDsl ,
324- }
325- }
326- }
327-
328449impl From < v2:: Torch > for Torch {
329450 fn from ( torch : v2:: Torch ) -> Self {
330451 Self {
0 commit comments