File tree Expand file tree Collapse file tree 1 file changed +1
-10
lines changed
Expand file tree Collapse file tree 1 file changed +1
-10
lines changed Original file line number Diff line number Diff line change @@ -246,11 +246,6 @@ def _convolutional_barycenter2d_log(
246246 A = list_to_array (A )
247247
248248 nx = get_backend (A )
249- if nx .__name__ in ("jax" , "tf" ):
250- raise NotImplementedError (
251- "Log-domain functions are not yet implemented"
252- " for Jax and TF. Use numpy or torch arrays instead."
253- )
254249
255250 n_hists , width , height = A .shape
256251
@@ -483,11 +478,7 @@ def _convolutional_barycenter2d_debiased_log(
483478 A = list_to_array (A )
484479 n_hists , width , height = A .shape
485480 nx = get_backend (A )
486- if nx .__name__ in ("jax" , "tf" ):
487- raise NotImplementedError (
488- "Log-domain functions are not yet implemented"
489- " for Jax and TF. Use numpy or torch arrays instead."
490- )
481+
491482 if weights is None :
492483 weights = nx .ones ((n_hists ,), type_as = A ) / n_hists
493484 else :
You can’t perform that action at this time.
0 commit comments