Skip to content

Commit e8d4de5

Browse files
author
Francisco Muñoz
committed
feat: delete not implemented error in convolutional module
1 parent 80d9542 commit e8d4de5

File tree

1 file changed

+1
-10
lines changed

1 file changed

+1
-10
lines changed

ot/bregman/_convolutional.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,6 @@ def _convolutional_barycenter2d_log(
246246
A = list_to_array(A)
247247

248248
nx = get_backend(A)
249-
if nx.__name__ in ("jax", "tf"):
250-
raise NotImplementedError(
251-
"Log-domain functions are not yet implemented"
252-
" for Jax and TF. Use numpy or torch arrays instead."
253-
)
254249

255250
n_hists, width, height = A.shape
256251

@@ -483,11 +478,7 @@ def _convolutional_barycenter2d_debiased_log(
483478
A = list_to_array(A)
484479
n_hists, width, height = A.shape
485480
nx = get_backend(A)
486-
if nx.__name__ in ("jax", "tf"):
487-
raise NotImplementedError(
488-
"Log-domain functions are not yet implemented"
489-
" for Jax and TF. Use numpy or torch arrays instead."
490-
)
481+
491482
if weights is None:
492483
weights = nx.ones((n_hists,), type_as=A) / n_hists
493484
else:

0 commit comments

Comments
 (0)