@@ -31,13 +31,7 @@ def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False):
3131 t2 = nx .linspace (0 , 1 , height , type_as = type_as )
3232 Y2 , X2 = nx .meshgrid (t2 , t2 )
3333 M2 = - ((X2 - Y2 ) ** 2 ) / reg
34-
35- # As M1 and M2 are computed first, we can use them to compute the convolution in log-domain
36- def convol_imgs (log_imgs ):
37- log_imgs = nx .logsumexp (M1 [:, :, None ] + log_imgs [None ], axis = 1 )
38- log_imgs = nx .logsumexp (M2 [:, :, None ] + log_imgs .T [None ], axis = 1 ).T
39- return log_imgs
40-
34+
4135 # If normal domain is selected, we can use M1 and M2 to compute the convolution
4236 if not log_domain :
4337 K1 , K2 = nx .exp (M1 ), nx .exp (M2 )
@@ -47,6 +41,13 @@ def convol_imgs(imgs):
4741 kxy = nx .einsum ("...ij,klj->kli" , K2 , kx )
4842 return kxy
4943
44+ # Else, we can use M1 and M2 to compute the convolution in log-domain
45+ else :
46+ def convol_imgs (log_imgs ):
47+ log_imgs = nx .logsumexp (M1 [:, :, None ] + log_imgs [None ], axis = 1 )
48+ log_imgs = nx .logsumexp (M2 [:, :, None ] + log_imgs .T [None ], axis = 1 ).T
49+ return log_imgs
50+
5051 return convol_imgs
5152
5253
0 commit comments