Skip to content

Commit bab3b86

Browse files
authored
FIX: deprecated import + expose affine_grid (#21)
Fixes #19
1 parent a4d5f53 commit bab3b86

File tree

4 files changed

+75
-29
lines changed

4 files changed

+75
-29
lines changed

.github/workflows/test-and-publish.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ jobs:
5050
pytorch-version: "1.11"
5151
- python-version: "3.11"
5252
pytorch-version: "2.0"
53+
- python-version: "3.12"
54+
pytorch-version: "2.4"
5355
steps:
5456
- uses: actions/checkout@v3
5557
- uses: ./.github/actions/test

interpol/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from .api import *
2-
from .resize import *
3-
from .restrict import *
4-
from . import backend
1+
from .api import * # noqa: F401, F403
2+
from .resize import * # noqa: F401, F403
3+
from .restrict import * # noqa: F401, F403
4+
from . import backend # noqa: F401
55

66
from . import _version
77
__version__ = _version.get_versions()['version']

interpol/api.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
11
"""High level interpolation API"""
22

3-
__all__ = ['grid_pull', 'grid_push', 'grid_count', 'grid_grad',
4-
'spline_coeff', 'spline_coeff_nd',
5-
'identity_grid', 'add_identity_grid', 'add_identity_grid_']
3+
__all__ = [
4+
'pull',
5+
'push',
6+
'count',
7+
'grid_pull',
8+
'grid_push',
9+
'grid_count',
10+
'grid_grad',
11+
'spline_coeff',
12+
'spline_coeff_nd',
13+
'identity_grid',
14+
'add_identity_grid',
15+
'add_identity_grid_',
16+
'affine_grid',
17+
]
618

719
import torch
820
from .utils import expanded_shape, matvec
@@ -44,7 +56,7 @@
4456
https://en.wikipedia.org/wiki/Discrete_sine_transform"""
4557

4658
_doc_bound_coeff = \
47-
"""`bound` can be an int, a string or a BoundType.
59+
"""`bound` can be an int, a string or a BoundType.
4860
Possible values are:
4961
- 'replicate' or 'nearest' : a a a | a b c d | d d d
5062
- 'dct1' or 'mirror' : d c b | a b c d | c b a
@@ -61,7 +73,7 @@
6173
- `dct2` corresponds to mirroring about the edge of the first/last voxel
6274
See https://en.wikipedia.org/wiki/Discrete_cosine_transform
6375
https://en.wikipedia.org/wiki/Discrete_sine_transform
64-
76+
6577
/!\ Only 'dct1', 'dct2' and 'dft' are implemented for interpolation
6678
orders >= 6."""
6779

@@ -143,11 +155,11 @@ def grid_pull(input, grid, interpolation='linear', bound='zero',
143155
{interpolation}
144156
145157
{bound}
146-
147-
If the input dtype is not a floating point type, the input image is
148-
assumed to contain labels. Then, unique labels are extracted
149-
and resampled individually, making them soft labels. Finally,
150-
the label map is reconstructed from the individual soft labels by
158+
159+
If the input dtype is not a floating point type, the input image is
160+
assumed to contain labels. Then, unique labels are extracted
161+
and resampled individually, making them soft labels. Finally,
162+
the label map is reconstructed from the individual soft labels by
151163
assigning the label with maximum soft value.
152164
153165
Parameters
@@ -290,7 +302,7 @@ def grid_count(grid, shape=None, interpolation='linear', bound='zero',
290302
def grid_grad(input, grid, interpolation='linear', bound='zero',
291303
extrapolate=False, prefilter=False):
292304
"""Sample spatial gradients of an image with respect to a deformation field.
293-
305+
294306
Notes
295307
-----
296308
{interpolation}

interpol/autograd.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,41 @@
1010
grid_grad, grid_grad_backward)
1111
from .utils import fake_decorator
1212
try:
13-
from torch.cuda.amp import custom_fwd, custom_bwd
13+
from torch.amp import custom_fwd, custom_bwd
1414
except (ModuleNotFoundError, ImportError):
15-
custom_fwd = custom_bwd = fake_decorator
15+
try:
16+
from torch.cuda.amp import (
17+
custom_fwd as _custom_fwd_cuda,
18+
custom_bwd as _custom_bwd_cuda
19+
)
20+
except (ModuleNotFoundError, ImportError):
21+
_custom_fwd_cuda = _custom_bwd_cuda = fake_decorator
22+
23+
try:
24+
from torch.cpu.amp import (
25+
custom_fwd as _custom_fwd_cpu,
26+
custom_bwd as _custom_bwd_cpu
27+
)
28+
except (ModuleNotFoundError, ImportError):
29+
_custom_fwd_cpu = _custom_bwd_cpu = fake_decorator
30+
31+
def custom_fwd(fwd=None, *, device_type, cast_inputs=None):
32+
if device_type == 'cuda':
33+
decorator = _custom_fwd_cuda(cast_inputs=cast_inputs)
34+
return decorator(fwd) if fwd else decorator
35+
if device_type == 'cpu':
36+
decorator = _custom_fwd_cpu(cast_inputs=cast_inputs)
37+
return decorator(fwd) if fwd else decorator
38+
return fake_decorator(fwd) if fwd else decorator
39+
40+
def custom_bwd(bwd=None, *, device_type):
41+
if device_type == 'cuda':
42+
decorator = _custom_bwd_cuda
43+
return decorator(bwd) if bwd else decorator
44+
if device_type == 'cpu':
45+
decorator = _custom_bwd_cpu
46+
return decorator(bwd) if bwd else decorator
47+
return fake_decorator(bwd) if bwd else decorator
1648

1749

1850
def make_list(x):
@@ -125,7 +157,7 @@ def inter_to_nitorch(inter, as_type='str'):
125157
class GridPull(torch.autograd.Function):
126158

127159
@staticmethod
128-
@custom_fwd(cast_inputs=torch.float32)
160+
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
129161
def forward(ctx, input, grid, interpolation, bound, extrapolate):
130162

131163
bound = bound_to_nitorch(make_list(bound), as_type='int')
@@ -143,7 +175,7 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate):
143175
return output
144176

145177
@staticmethod
146-
@custom_bwd
178+
@custom_bwd(device_type='cuda')
147179
def backward(ctx, grad):
148180
var = ctx.saved_tensors
149181
opt = ctx.opt
@@ -155,7 +187,7 @@ def backward(ctx, grad):
155187
class GridPush(torch.autograd.Function):
156188

157189
@staticmethod
158-
@custom_fwd(cast_inputs=torch.float32)
190+
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
159191
def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):
160192

161193
bound = bound_to_nitorch(make_list(bound), as_type='int')
@@ -173,7 +205,7 @@ def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):
173205
return output
174206

175207
@staticmethod
176-
@custom_bwd
208+
@custom_bwd(device_type='cuda')
177209
def backward(ctx, grad):
178210
var = ctx.saved_tensors
179211
opt = ctx.opt
@@ -185,7 +217,7 @@ def backward(ctx, grad):
185217
class GridCount(torch.autograd.Function):
186218

187219
@staticmethod
188-
@custom_fwd(cast_inputs=torch.float32)
220+
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
189221
def forward(ctx, grid, shape, interpolation, bound, extrapolate):
190222

191223
bound = bound_to_nitorch(make_list(bound), as_type='int')
@@ -203,7 +235,7 @@ def forward(ctx, grid, shape, interpolation, bound, extrapolate):
203235
return output
204236

205237
@staticmethod
206-
@custom_bwd
238+
@custom_bwd(device_type='cuda')
207239
def backward(ctx, grad):
208240
var = ctx.saved_tensors
209241
opt = ctx.opt
@@ -216,7 +248,7 @@ def backward(ctx, grad):
216248
class GridGrad(torch.autograd.Function):
217249

218250
@staticmethod
219-
@custom_fwd(cast_inputs=torch.float32)
251+
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
220252
def forward(ctx, input, grid, interpolation, bound, extrapolate):
221253

222254
bound = bound_to_nitorch(make_list(bound), as_type='int')
@@ -234,7 +266,7 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate):
234266
return output
235267

236268
@staticmethod
237-
@custom_bwd
269+
@custom_bwd(device_type='cuda')
238270
def backward(ctx, grad):
239271
var = ctx.saved_tensors
240272
opt = ctx.opt
@@ -248,7 +280,7 @@ def backward(ctx, grad):
248280
class SplineCoeff(torch.autograd.Function):
249281

250282
@staticmethod
251-
@custom_fwd
283+
@custom_fwd(device_type='cuda')
252284
def forward(ctx, input, bound, interpolation, dim, inplace):
253285

254286
bound = bound_to_nitorch(make_list(bound)[0], as_type='int')
@@ -265,7 +297,7 @@ def forward(ctx, input, bound, interpolation, dim, inplace):
265297
return output
266298

267299
@staticmethod
268-
@custom_bwd
300+
@custom_bwd(device_type='cuda')
269301
def backward(ctx, grad):
270302
# symmetric filter -> backward == forward
271303
# (I don't know if I can write into grad, so inplace=False to be safe)
@@ -276,7 +308,7 @@ def backward(ctx, grad):
276308
class SplineCoeffND(torch.autograd.Function):
277309

278310
@staticmethod
279-
@custom_fwd
311+
@custom_fwd(device_type='cuda')
280312
def forward(ctx, input, bound, interpolation, dim, inplace):
281313

282314
bound = bound_to_nitorch(make_list(bound), as_type='int')
@@ -293,7 +325,7 @@ def forward(ctx, input, bound, interpolation, dim, inplace):
293325
return output
294326

295327
@staticmethod
296-
@custom_bwd
328+
@custom_bwd(device_type='cuda')
297329
def backward(ctx, grad):
298330
# symmetric filter -> backward == forward
299331
# (I don't know if I can write into grad, so inplace=False to be safe)

0 commit comments

Comments
 (0)