Skip to content

Commit 99fdbf7

Browse files
committed
Fix JAX vectorize.
1 parent 4919ea1 commit 99fdbf7

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

keras/src/backend/jax/numpy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,8 @@ def vstack(xs):
10241024

10251025

10261026
def vectorize(pyfunc, *, excluded=None, signature=None):
1027+
if excluded is None:
1028+
excluded = set()
10271029
return jnp.vectorize(pyfunc, excluded=excluded, signature=signature)
10281030

10291031

0 commit comments

Comments
 (0)