@@ -3,15 +3,75 @@ use std::{
33 fmt:: Display ,
44 path:: PathBuf ,
55 str:: FromStr ,
6+ sync:: LazyLock ,
67} ;
78
89use eyre:: Result ;
910use itertools:: Itertools ;
1011use serde:: { Deserialize , Serialize } ;
12+ use thiserror:: Error ;
1113
14+ use super :: { common:: Dependency , v2} ;
1215use crate :: version:: Version ;
1316
14- use super :: { common:: Dependency , v2} ;
17+ #[ derive( Debug , Error ) ]
18+ enum DependencyError {
19+ #[ error( "No dependencies are defined for backend: {backend:?}" ) ]
20+ Backend { backend : String } ,
21+ #[ error( "Unknown dependency `{dependency:?}` for backend `{backend:?}`" ) ]
22+ Dependency { backend : String , dependency : String } ,
23+ #[ error( "Unknown dependency: `{dependency:?}`" ) ]
24+ GeneralDependency { dependency : String } ,
25+ }
26+
27+ #[ derive( Debug , Deserialize , Serialize ) ]
28+ #[ serde( deny_unknown_fields) ]
29+ struct PythonDependencies {
30+ general : HashMap < String , PythonDependency > ,
31+ backends : HashMap < Backend , HashMap < String , PythonDependency > > ,
32+ }
33+
34+ impl PythonDependencies {
35+ fn get_dependency ( & self , dependency : & str ) -> Result < & [ String ] , DependencyError > {
36+ match self . general . get ( dependency) {
37+ None => Err ( DependencyError :: GeneralDependency {
38+ dependency : dependency. to_string ( ) ,
39+ } ) ,
40+ Some ( dep) => Ok ( & dep. python ) ,
41+ }
42+ }
43+
44+ fn get_backend_dependency (
45+ & self ,
46+ backend : Backend ,
47+ dependency : & str ,
48+ ) -> Result < & [ String ] , DependencyError > {
49+ let backend_deps = match self . backends . get ( & backend) {
50+ None => {
51+ return Err ( DependencyError :: Backend {
52+ backend : backend. to_string ( ) ,
53+ } )
54+ }
55+ Some ( backend_deps) => backend_deps,
56+ } ;
57+ match backend_deps. get ( dependency) {
58+ None => Err ( DependencyError :: Dependency {
59+ backend : backend. to_string ( ) ,
60+ dependency : dependency. to_string ( ) ,
61+ } ) ,
62+ Some ( dep) => Ok ( & dep. python ) ,
63+ }
64+ }
65+ }
66+
67+ #[ derive( Debug , Deserialize , Serialize ) ]
68+ struct PythonDependency {
69+ nix : Vec < String > ,
70+ python : Vec < String > ,
71+ }
72+
73+ static PYTHON_DEPENDENCIES : LazyLock < PythonDependencies > =
74+ LazyLock :: new ( || serde_json:: from_str ( include_str ! ( "../python_dependencies.json" ) ) . unwrap ( ) ) ;
1575
1676#[ derive( Debug , Deserialize , Serialize ) ]
1777#[ serde( deny_unknown_fields) ]
@@ -44,44 +104,84 @@ pub struct General {
44104
45105 pub hub : Option < Hub > ,
46106
47- pub python_depends : Option < Vec < PythonDependency > > ,
107+ pub python_depends : Option < Vec < String > > ,
108+
109+ pub xpu : Option < XpuGeneral > ,
48110}
49111
50112impl General {
51113 /// Name of the kernel as a Python extension.
52114 pub fn python_name ( & self ) -> String {
53115 self . name . replace ( "-" , "_" )
54116 }
117+
118+ pub fn python_depends ( & self ) -> Box < dyn Iterator < Item = Result < String > > + ' _ > {
119+ let general_python_deps = match self . python_depends . as_ref ( ) {
120+ Some ( deps) => deps,
121+ None => {
122+ return Box :: new ( std:: iter:: empty ( ) ) ;
123+ }
124+ } ;
125+
126+ Box :: new ( general_python_deps. iter ( ) . flat_map ( move |dep| {
127+ match PYTHON_DEPENDENCIES . get_dependency ( dep) {
128+ Ok ( deps) => deps. iter ( ) . map ( |s| Ok ( s. clone ( ) ) ) . collect :: < Vec < _ > > ( ) ,
129+ Err ( e) => vec ! [ Err ( e. into( ) ) ] ,
130+ }
131+ } ) )
132+ }
133+
134+ pub fn backend_python_depends (
135+ & self ,
136+ backend : Backend ,
137+ ) -> Box < dyn Iterator < Item = Result < String > > + ' _ > {
138+ let backend_python_deps = match backend {
139+ Backend :: Cuda => self
140+ . cuda
141+ . as_ref ( )
142+ . and_then ( |cuda| cuda. python_depends . as_ref ( ) ) ,
143+ Backend :: Xpu => self
144+ . xpu
145+ . as_ref ( )
146+ . and_then ( |xpu| xpu. python_depends . as_ref ( ) ) ,
147+ _ => None ,
148+ } ;
149+
150+ let backend_python_deps = match backend_python_deps {
151+ Some ( deps) => deps,
152+ None => {
153+ return Box :: new ( std:: iter:: empty ( ) ) ;
154+ }
155+ } ;
156+
157+ Box :: new ( backend_python_deps. iter ( ) . flat_map ( move |dep| {
158+ match PYTHON_DEPENDENCIES . get_backend_dependency ( backend, dep) {
159+ Ok ( deps) => deps. iter ( ) . map ( |s| Ok ( s. clone ( ) ) ) . collect :: < Vec < _ > > ( ) ,
160+ Err ( e) => vec ! [ Err ( e. into( ) ) ] ,
161+ }
162+ } ) )
163+ }
55164}
56165
57166#[ derive( Debug , Deserialize , Serialize ) ]
58167#[ serde( deny_unknown_fields, rename_all = "kebab-case" ) ]
59168pub struct CudaGeneral {
60169 pub minver : Option < Version > ,
61170 pub maxver : Option < Version > ,
171+ pub python_depends : Option < Vec < String > > ,
62172}
63173
64174#[ derive( Debug , Deserialize , Serialize ) ]
65175#[ serde( deny_unknown_fields, rename_all = "kebab-case" ) ]
66- pub struct Hub {
67- pub repo_id : Option < String > ,
68- pub branch : Option < String > ,
176+ pub struct XpuGeneral {
177+ pub python_depends : Option < Vec < String > > ,
69178}
70179
71- #[ derive( Clone , Debug , Deserialize , Serialize ) ]
180+ #[ derive( Debug , Deserialize , Serialize ) ]
72181#[ serde( deny_unknown_fields, rename_all = "kebab-case" ) ]
73- pub enum PythonDependency {
74- Einops ,
75- NvidiaCutlassDsl ,
76- }
77-
78- impl Display for PythonDependency {
79- fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
80- match self {
81- PythonDependency :: Einops => write ! ( f, "einops" ) ,
82- PythonDependency :: NvidiaCutlassDsl => write ! ( f, "nvidia-cutlass-dsl" ) ,
83- }
84- }
182+ pub struct Hub {
183+ pub repo_id : Option < String > ,
184+ pub branch : Option < String > ,
85185}
86186
87187#[ derive( Debug , Deserialize , Clone , Serialize ) ]
@@ -215,7 +315,7 @@ impl Kernel {
215315 }
216316}
217317
218- #[ derive( Clone , Copy , Debug , Deserialize , Eq , Ord , PartialEq , PartialOrd , Serialize ) ]
318+ #[ derive( Clone , Copy , Debug , Deserialize , Eq , Hash , Ord , PartialEq , PartialOrd , Serialize ) ]
219319#[ serde( deny_unknown_fields, rename_all = "kebab-case" ) ]
220320pub enum Backend {
221321 Cpu ,
@@ -290,6 +390,7 @@ impl General {
290390 Some ( CudaGeneral {
291391 minver : general. cuda_minver ,
292392 maxver : general. cuda_maxver ,
393+ python_depends : None ,
293394 } )
294395 } else {
295396 None
@@ -300,9 +401,8 @@ impl General {
300401 backends,
301402 cuda,
302403 hub : general. hub . map ( Into :: into) ,
303- python_depends : general
304- . python_depends
305- . map ( |deps| deps. into_iter ( ) . map ( Into :: into) . collect ( ) ) ,
404+ python_depends : None ,
405+ xpu : None ,
306406 }
307407 }
308408}
@@ -316,15 +416,6 @@ impl From<v2::Hub> for Hub {
316416 }
317417}
318418
319- impl From < v2:: PythonDependency > for PythonDependency {
320- fn from ( dep : v2:: PythonDependency ) -> Self {
321- match dep {
322- v2:: PythonDependency :: Einops => PythonDependency :: Einops ,
323- v2:: PythonDependency :: NvidiaCutlassDsl => PythonDependency :: NvidiaCutlassDsl ,
324- }
325- }
326- }
327-
328419impl From < v2:: Torch > for Torch {
329420 fn from ( torch : v2:: Torch ) -> Self {
330421 Self {
0 commit comments