Skip to content

Commit b6dc417

Browse files
committed
[FIX] ensure shapes are lists in TS code -- fixes things when PYTORCH_JIT=0
1 parent c90a40b commit b6dc417

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

nitorch/_C/_ts/iso0.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1):
3030
"""
3131
dim = 3
3232
boundx, boundy, boundz = bound
33-
oshape = g.shape[-dim-1:-1]
33+
oshape = list(g.shape[-dim-1:-1])
3434
g = g.reshape([g.shape[0], 1, -1, dim])
3535
gx, gy, gz = g.unbind(-1)
3636
batch = max(inp.shape[0], gx.shape[0])
@@ -47,7 +47,7 @@ def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1):
4747
gz, signz = get_indices(gz, nz, boundz)
4848

4949
# gather
50-
inp = inp.reshape(inp.shape[:2] + [-1])
50+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
5151
idx = sub2ind_list([gx, gy, gz], shape)
5252
idx = idx.expand([batch, channel, idx.shape[-1]])
5353
out = inp.gather(-1, idx)
@@ -56,7 +56,7 @@ def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1):
5656
out *= sign
5757
if mask is not None:
5858
out *= mask
59-
out = out.reshape(out.shape[:2] + oshape)
59+
out = out.reshape(list(out.shape[:2]) + oshape)
6060
return out
6161

6262

@@ -75,10 +75,10 @@ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound],
7575
boundx, boundy, boundz = bound
7676
if inp.shape[-dim:] != g.shape[-dim-1:-1]:
7777
raise ValueError('Input and grid should have the same spatial shape')
78-
ishape = inp.shape[-dim:]
78+
ishape = list(inp.shape[-dim:])
7979
g = g.reshape([g.shape[0], 1, -1, dim])
8080
gx, gy, gz = torch.unbind(g, -1)
81-
inp = inp.reshape(inp.shape[:2] + [-1])
81+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
8282
batch = max(inp.shape[0], gx.shape[0])
8383
channel = inp.shape[1]
8484

@@ -107,7 +107,7 @@ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound],
107107
inp *= mask
108108
out.scatter_add_(-1, idx, inp)
109109

110-
out = out.reshape(out.shape[:2] + shape)
110+
out = out.reshape(list(out.shape[:2]) + shape)
111111
return out
112112

113113

@@ -127,12 +127,12 @@ def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1):
127127
"""
128128
dim = 2
129129
boundx, boundy = bound
130-
oshape = g.shape[-dim-1:-1]
130+
oshape = list(g.shape[-dim-1:-1])
131131
g = g.reshape([g.shape[0], 1, -1, dim])
132132
gx, gy = g.unbind(-1)
133133
batch = max(inp.shape[0], gx.shape[0])
134134
channel = inp.shape[1]
135-
shape = inp.shape[-dim:]
135+
shape = list(inp.shape[-dim:])
136136
nx, ny = shape
137137

138138
# mask of inbounds voxels
@@ -171,10 +171,10 @@ def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound],
171171
boundx, boundy = bound
172172
if inp.shape[-dim:] != g.shape[-dim-1:-1]:
173173
raise ValueError('Input and grid should have the same spatial shape')
174-
ishape = inp.shape[-dim:]
174+
ishape = list(inp.shape[-dim:])
175175
g = g.reshape([g.shape[0], 1, -1, dim])
176176
gx, gy = torch.unbind(g, -1)
177-
inp = inp.reshape(inp.shape[:2] + [-1])
177+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
178178
batch = max(inp.shape[0], gx.shape[0])
179179
channel = inp.shape[1]
180180

@@ -202,7 +202,7 @@ def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound],
202202
inp *= mask
203203
out.scatter_add_(-1, idx, inp)
204204

205-
out = out.reshape(out.shape[:2] + shape)
205+
out = out.reshape(list(out.shape[:2]) + shape)
206206
return out
207207

208208

@@ -222,12 +222,12 @@ def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1):
222222
"""
223223
dim = 1
224224
boundx = bound[0]
225-
oshape = g.shape[-dim-1:-1]
225+
oshape = list(g.shape[-dim-1:-1])
226226
g = g.reshape([g.shape[0], 1, -1, dim])
227227
gx = g.squeeze(-1)
228228
batch = max(inp.shape[0], gx.shape[0])
229229
channel = inp.shape[1]
230-
shape = inp.shape[-dim:]
230+
shape = list(inp.shape[-dim:])
231231
nx = shape[0]
232232

233233
# mask of inbounds voxels
@@ -237,7 +237,7 @@ def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1):
237237
gx, signx = get_indices(gx, nx, boundx)
238238

239239
# gather
240-
inp = inp.reshape(inp.shape[:2] + [-1])
240+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
241241
idx = gx
242242
idx = idx.expand([batch, channel, idx.shape[-1]])
243243
out = inp.gather(-1, idx)
@@ -246,7 +246,7 @@ def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1):
246246
out *= sign
247247
if mask is not None:
248248
out *= mask
249-
out = out.reshape(out.shape[:2] + oshape)
249+
out = out.reshape(list(out.shape[:2]) + oshape)
250250
return out
251251

252252

@@ -265,10 +265,10 @@ def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound],
265265
boundx = bound[0]
266266
if inp.shape[-dim:] != g.shape[-dim-1:-1]:
267267
raise ValueError('Input and grid should have the same spatial shape')
268-
ishape = inp.shape[-dim:]
268+
ishape = list(inp.shape[-dim:])
269269
g = g.reshape([g.shape[0], 1, -1, dim])
270270
gx = g.squeeze(-1)
271-
inp = inp.reshape(inp.shape[:2] + [-1])
271+
inp = inp.reshape(list(inp.shape[:2]) + [-1])
272272
batch = max(inp.shape[0], gx.shape[0])
273273
channel = inp.shape[1]
274274

@@ -295,7 +295,7 @@ def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound],
295295
inp *= mask
296296
out.scatter_add_(-1, idx, inp)
297297

298-
out = out.reshape(out.shape[:2] + shape)
298+
out = out.reshape(list(out.shape[:2]) + shape)
299299
return out
300300

301301

@@ -336,7 +336,7 @@ def pushgrad(inp, g, shape: Optional[List[int]], bound: List[Bound],
336336
dim = g.shape[-1]
337337
if inp.shape[-dim-1:-1] != g.shape[-dim-1:-1]:
338338
raise ValueError('Input and grid should have the same spatial shape')
339-
ishape = inp.shape[-dim-1:-1]
339+
ishape = list(inp.shape[-dim-1:-1])
340340
batch = max(inp.shape[0], g.shape[0])
341341
channel = inp.shape[1]
342342

nitorch/_C/_ts/nd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,4 +440,4 @@ def hess(inp, grid, bound: List[Bound], spline: List[Spline],
440440
out.unbind(-1)[d2].unbind(-1)[d].copy_(out.unbind(-1)[d].unbind(-1)[d2])
441441

442442
out = out.reshape(list(out.shape[:2]) + oshape + list(out.shape[-2:]))
443-
return out
443+
return out

0 commit comments

Comments
 (0)