Skip to content

Commit c543307

Browse files
committed
[FIX] torchscript pushpull (backport from torch-interpol)
1 parent 8279577 commit c543307

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

nitorch/_C/_ts/utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -312,16 +312,27 @@ def inbounds_mask_1d(extrapolate: int, gx, nx: int) -> Optional[Tensor]:
312312
return mask
313313

314314

315+
@torch.jit.script
316+
def list_prod_tensor(x: List[Tensor]) -> Tensor:
317+
if len(x) == 0:
318+
empty: List[int] = []
319+
return torch.ones(empty)
320+
x0 = x[0]
321+
for x1 in x[1:]:
322+
x0 = x0 * x1
323+
return x0
324+
325+
315326
@torch.jit.script
316327
def make_sign(sign: List[Optional[Tensor]]) -> Optional[Tensor]:
317-
osign: Optional[Tensor] = None
328+
is_none: List[bool] = [s is None for s in sign]
329+
if list_all(is_none):
330+
return None
331+
filt_sign: List[Tensor] = []
318332
for s in sign:
319333
if s is not None:
320-
if osign is None:
321-
osign = s
322-
else:
323-
osign = osign * s
324-
return osign
334+
filt_sign.append(s)
335+
return list_prod_tensor(filt_sign)
325336

326337

327338
@torch.jit.script

0 commit comments

Comments
 (0)