Skip to content

Commit 1030815

Browse files
authored
refactor: change function _get_convol_img_fn for more clarity
1 parent dadb470 commit 1030815

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

ot/bregman/_convolutional.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,7 @@ def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False):
3131
t2 = nx.linspace(0, 1, height, type_as=type_as)
3232
Y2, X2 = nx.meshgrid(t2, t2)
3333
M2 = -((X2 - Y2) ** 2) / reg
34-
35-
# As M1 and M2 are computed first, we can use them to compute the convolution in log-domain
36-
def convol_imgs(log_imgs):
37-
log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1)
38-
log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T
39-
return log_imgs
40-
34+
4135
# If normal domain is selected, we can use M1 and M2 to compute the convolution
4236
if not log_domain:
4337
K1, K2 = nx.exp(M1), nx.exp(M2)
@@ -47,6 +41,13 @@ def convol_imgs(imgs):
4741
kxy = nx.einsum("...ij,klj->kli", K2, kx)
4842
return kxy
4943

44+
# Else, we can use M1 and M2 to compute the convolution in log-domain
45+
else:
46+
def convol_imgs(log_imgs):
47+
log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1)
48+
log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T
49+
return log_imgs
50+
5051
return convol_imgs
5152

5253

0 commit comments

Comments
 (0)