1010 grid_grad , grid_grad_backward )
1111from .utils import fake_decorator
1212try :
13- from torch .cuda . amp import custom_fwd , custom_bwd
13+ from torch .amp import custom_fwd , custom_bwd
1414except (ModuleNotFoundError , ImportError ):
15- custom_fwd = custom_bwd = fake_decorator
15+ try :
16+ from torch .cuda .amp import (
17+ custom_fwd as _custom_fwd_cuda ,
18+ custom_bwd as _custom_bwd_cuda
19+ )
20+ except (ModuleNotFoundError , ImportError ):
21+ _custom_fwd_cuda = _custom_bwd_cuda = fake_decorator
22+
23+ try :
24+ from torch .cpu .amp import (
25+ custom_fwd as _custom_fwd_cpu ,
26+ custom_bwd as _custom_bwd_cpu
27+ )
28+ except (ModuleNotFoundError , ImportError ):
29+ _custom_fwd_cpu = _custom_bwd_cpu = fake_decorator
30+
31+ def custom_fwd (fwd = None , * , device_type , cast_inputs = None ):
32+ if device_type == 'cuda' :
33+ decorator = _custom_fwd_cuda (cast_inputs = cast_inputs )
34+ return decorator (fwd ) if fwd else decorator
35+ if device_type == 'cpu' :
36+ decorator = _custom_fwd_cpu (cast_inputs = cast_inputs )
37+ return decorator (fwd ) if fwd else decorator
38+ return fake_decorator (fwd ) if fwd else decorator
39+
40+ def custom_bwd (bwd = None , * , device_type ):
41+ if device_type == 'cuda' :
42+ decorator = _custom_bwd_cuda
43+ return decorator (bwd ) if bwd else decorator
44+ if device_type == 'cpu' :
45+ decorator = _custom_bwd_cpu
46+ return decorator (bwd ) if bwd else decorator
47+ return fake_decorator (bwd ) if bwd else decorator
1648
1749
1850def make_list (x ):
@@ -125,7 +157,7 @@ def inter_to_nitorch(inter, as_type='str'):
125157class GridPull (torch .autograd .Function ):
126158
127159 @staticmethod
128- @custom_fwd (cast_inputs = torch .float32 )
160+ @custom_fwd (device_type = 'cuda' , cast_inputs = torch .float32 )
129161 def forward (ctx , input , grid , interpolation , bound , extrapolate ):
130162
131163 bound = bound_to_nitorch (make_list (bound ), as_type = 'int' )
@@ -143,7 +175,7 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate):
143175 return output
144176
145177 @staticmethod
146- @custom_bwd
178+ @custom_bwd ( device_type = 'cuda' )
147179 def backward (ctx , grad ):
148180 var = ctx .saved_tensors
149181 opt = ctx .opt
@@ -155,7 +187,7 @@ def backward(ctx, grad):
155187class GridPush (torch .autograd .Function ):
156188
157189 @staticmethod
158- @custom_fwd (cast_inputs = torch .float32 )
190+ @custom_fwd (device_type = 'cuda' , cast_inputs = torch .float32 )
159191 def forward (ctx , input , grid , shape , interpolation , bound , extrapolate ):
160192
161193 bound = bound_to_nitorch (make_list (bound ), as_type = 'int' )
@@ -173,7 +205,7 @@ def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):
173205 return output
174206
175207 @staticmethod
176- @custom_bwd
208+ @custom_bwd ( device_type = 'cuda' )
177209 def backward (ctx , grad ):
178210 var = ctx .saved_tensors
179211 opt = ctx .opt
@@ -185,7 +217,7 @@ def backward(ctx, grad):
185217class GridCount (torch .autograd .Function ):
186218
187219 @staticmethod
188- @custom_fwd (cast_inputs = torch .float32 )
220+ @custom_fwd (device_type = 'cuda' , cast_inputs = torch .float32 )
189221 def forward (ctx , grid , shape , interpolation , bound , extrapolate ):
190222
191223 bound = bound_to_nitorch (make_list (bound ), as_type = 'int' )
@@ -203,7 +235,7 @@ def forward(ctx, grid, shape, interpolation, bound, extrapolate):
203235 return output
204236
205237 @staticmethod
206- @custom_bwd
238+ @custom_bwd ( device_type = 'cuda' )
207239 def backward (ctx , grad ):
208240 var = ctx .saved_tensors
209241 opt = ctx .opt
@@ -216,7 +248,7 @@ def backward(ctx, grad):
216248class GridGrad (torch .autograd .Function ):
217249
218250 @staticmethod
219- @custom_fwd (cast_inputs = torch .float32 )
251+ @custom_fwd (device_type = 'cuda' , cast_inputs = torch .float32 )
220252 def forward (ctx , input , grid , interpolation , bound , extrapolate ):
221253
222254 bound = bound_to_nitorch (make_list (bound ), as_type = 'int' )
@@ -234,7 +266,7 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate):
234266 return output
235267
236268 @staticmethod
237- @custom_bwd
269+ @custom_bwd ( device_type = 'cuda' )
238270 def backward (ctx , grad ):
239271 var = ctx .saved_tensors
240272 opt = ctx .opt
@@ -248,7 +280,7 @@ def backward(ctx, grad):
248280class SplineCoeff (torch .autograd .Function ):
249281
250282 @staticmethod
251- @custom_fwd
283+ @custom_fwd ( device_type = 'cuda' )
252284 def forward (ctx , input , bound , interpolation , dim , inplace ):
253285
254286 bound = bound_to_nitorch (make_list (bound )[0 ], as_type = 'int' )
@@ -265,7 +297,7 @@ def forward(ctx, input, bound, interpolation, dim, inplace):
265297 return output
266298
267299 @staticmethod
268- @custom_bwd
300+ @custom_bwd ( device_type = 'cuda' )
269301 def backward (ctx , grad ):
270302 # symmetric filter -> backward == forward
271303 # (I don't know if I can write into grad, so inplace=False to be safe)
@@ -276,7 +308,7 @@ def backward(ctx, grad):
276308class SplineCoeffND (torch .autograd .Function ):
277309
278310 @staticmethod
279- @custom_fwd
311+ @custom_fwd ( device_type = 'cuda' )
280312 def forward (ctx , input , bound , interpolation , dim , inplace ):
281313
282314 bound = bound_to_nitorch (make_list (bound ), as_type = 'int' )
@@ -293,7 +325,7 @@ def forward(ctx, input, bound, interpolation, dim, inplace):
293325 return output
294326
295327 @staticmethod
296- @custom_bwd
328+ @custom_bwd ( device_type = 'cuda' )
297329 def backward (ctx , grad ):
298330 # symmetric filter -> backward == forward
299331 # (I don't know if I can write into grad, so inplace=False to be safe)
0 commit comments