@@ -30,7 +30,7 @@ def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1):
3030 """
3131 dim = 3
3232 boundx , boundy , boundz = bound
33- oshape = g .shape [- dim - 1 :- 1 ]
33+ oshape = list ( g .shape [- dim - 1 :- 1 ])
3434 g = g .reshape ([g .shape [0 ], 1 , - 1 , dim ])
3535 gx , gy , gz = g .unbind (- 1 )
3636 batch = max (inp .shape [0 ], gx .shape [0 ])
@@ -47,7 +47,7 @@ def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1):
4747 gz , signz = get_indices (gz , nz , boundz )
4848
4949 # gather
50- inp = inp .reshape (inp .shape [:2 ] + [- 1 ])
50+ inp = inp .reshape (list ( inp .shape [:2 ]) + [- 1 ])
5151 idx = sub2ind_list ([gx , gy , gz ], shape )
5252 idx = idx .expand ([batch , channel , idx .shape [- 1 ]])
5353 out = inp .gather (- 1 , idx )
@@ -56,7 +56,7 @@ def pull3d(inp, g, bound: List[Bound], extrapolate: int = 1):
5656 out *= sign
5757 if mask is not None :
5858 out *= mask
59- out = out .reshape (out .shape [:2 ] + oshape )
59+ out = out .reshape (list ( out .shape [:2 ]) + oshape )
6060 return out
6161
6262
@@ -75,10 +75,10 @@ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound],
7575 boundx , boundy , boundz = bound
7676 if inp .shape [- dim :] != g .shape [- dim - 1 :- 1 ]:
7777 raise ValueError ('Input and grid should have the same spatial shape' )
78- ishape = inp .shape [- dim :]
78+ ishape = list ( inp .shape [- dim :])
7979 g = g .reshape ([g .shape [0 ], 1 , - 1 , dim ])
8080 gx , gy , gz = torch .unbind (g , - 1 )
81- inp = inp .reshape (inp .shape [:2 ] + [- 1 ])
81+ inp = inp .reshape (list ( inp .shape [:2 ]) + [- 1 ])
8282 batch = max (inp .shape [0 ], gx .shape [0 ])
8383 channel = inp .shape [1 ]
8484
@@ -107,7 +107,7 @@ def push3d(inp, g, shape: Optional[List[int]], bound: List[Bound],
107107 inp *= mask
108108 out .scatter_add_ (- 1 , idx , inp )
109109
110- out = out .reshape (out .shape [:2 ] + shape )
110+ out = out .reshape (list ( out .shape [:2 ]) + shape )
111111 return out
112112
113113
@@ -127,12 +127,12 @@ def pull2d(inp, g, bound: List[Bound], extrapolate: int = 1):
127127 """
128128 dim = 2
129129 boundx , boundy = bound
130- oshape = g .shape [- dim - 1 :- 1 ]
130+ oshape = list ( g .shape [- dim - 1 :- 1 ])
131131 g = g .reshape ([g .shape [0 ], 1 , - 1 , dim ])
132132 gx , gy = g .unbind (- 1 )
133133 batch = max (inp .shape [0 ], gx .shape [0 ])
134134 channel = inp .shape [1 ]
135- shape = inp .shape [- dim :]
135+ shape = list ( inp .shape [- dim :])
136136 nx , ny = shape
137137
138138 # mask of inbounds voxels
@@ -171,10 +171,10 @@ def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound],
171171 boundx , boundy = bound
172172 if inp .shape [- dim :] != g .shape [- dim - 1 :- 1 ]:
173173 raise ValueError ('Input and grid should have the same spatial shape' )
174- ishape = inp .shape [- dim :]
174+ ishape = list ( inp .shape [- dim :])
175175 g = g .reshape ([g .shape [0 ], 1 , - 1 , dim ])
176176 gx , gy = torch .unbind (g , - 1 )
177- inp = inp .reshape (inp .shape [:2 ] + [- 1 ])
177+ inp = inp .reshape (list ( inp .shape [:2 ]) + [- 1 ])
178178 batch = max (inp .shape [0 ], gx .shape [0 ])
179179 channel = inp .shape [1 ]
180180
@@ -202,7 +202,7 @@ def push2d(inp, g, shape: Optional[List[int]], bound: List[Bound],
202202 inp *= mask
203203 out .scatter_add_ (- 1 , idx , inp )
204204
205- out = out .reshape (out .shape [:2 ] + shape )
205+ out = out .reshape (list ( out .shape [:2 ]) + shape )
206206 return out
207207
208208
@@ -222,12 +222,12 @@ def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1):
222222 """
223223 dim = 1
224224 boundx = bound [0 ]
225- oshape = g .shape [- dim - 1 :- 1 ]
225+ oshape = list ( g .shape [- dim - 1 :- 1 ])
226226 g = g .reshape ([g .shape [0 ], 1 , - 1 , dim ])
227227 gx = g .squeeze (- 1 )
228228 batch = max (inp .shape [0 ], gx .shape [0 ])
229229 channel = inp .shape [1 ]
230- shape = inp .shape [- dim :]
230+ shape = list ( inp .shape [- dim :])
231231 nx = shape [0 ]
232232
233233 # mask of inbounds voxels
@@ -237,7 +237,7 @@ def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1):
237237 gx , signx = get_indices (gx , nx , boundx )
238238
239239 # gather
240- inp = inp .reshape (inp .shape [:2 ] + [- 1 ])
240+ inp = inp .reshape (list ( inp .shape [:2 ]) + [- 1 ])
241241 idx = gx
242242 idx = idx .expand ([batch , channel , idx .shape [- 1 ]])
243243 out = inp .gather (- 1 , idx )
@@ -246,7 +246,7 @@ def pull1d(inp, g, bound: List[Bound], extrapolate: int = 1):
246246 out *= sign
247247 if mask is not None :
248248 out *= mask
249- out = out .reshape (out .shape [:2 ] + oshape )
249+ out = out .reshape (list ( out .shape [:2 ]) + oshape )
250250 return out
251251
252252
@@ -265,10 +265,10 @@ def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound],
265265 boundx = bound [0 ]
266266 if inp .shape [- dim :] != g .shape [- dim - 1 :- 1 ]:
267267 raise ValueError ('Input and grid should have the same spatial shape' )
268- ishape = inp .shape [- dim :]
268+ ishape = list ( inp .shape [- dim :])
269269 g = g .reshape ([g .shape [0 ], 1 , - 1 , dim ])
270270 gx = g .squeeze (- 1 )
271- inp = inp .reshape (inp .shape [:2 ] + [- 1 ])
271+ inp = inp .reshape (list ( inp .shape [:2 ]) + [- 1 ])
272272 batch = max (inp .shape [0 ], gx .shape [0 ])
273273 channel = inp .shape [1 ]
274274
@@ -295,7 +295,7 @@ def push1d(inp, g, shape: Optional[List[int]], bound: List[Bound],
295295 inp *= mask
296296 out .scatter_add_ (- 1 , idx , inp )
297297
298- out = out .reshape (out .shape [:2 ] + shape )
298+ out = out .reshape (list ( out .shape [:2 ]) + shape )
299299 return out
300300
301301
@@ -336,7 +336,7 @@ def pushgrad(inp, g, shape: Optional[List[int]], bound: List[Bound],
336336 dim = g .shape [- 1 ]
337337 if inp .shape [- dim - 1 :- 1 ] != g .shape [- dim - 1 :- 1 ]:
338338 raise ValueError ('Input and grid should have the same spatial shape' )
339- ishape = inp .shape [- dim - 1 :- 1 ]
339+ ishape = list ( inp .shape [- dim - 1 :- 1 ])
340340 batch = max (inp .shape [0 ], g .shape [0 ])
341341 channel = inp .shape [1 ]
342342
0 commit comments