-
Hi, I noticed that jit compilation fails when my model is passed as a parameter to a function, but not when it is just used from within the outer scope. Here is a minimal working example:
jit compiling works fine for If I remove (Additional context: I normally create a flax |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
We get an error:
One solution to this is you can specify that that argument is static (check out the API documentation on
But Alternatively, you could pass in a string to the
Flax also provides "lifted transformed" versions of JAX transformations, which would allow you to use But I think for your situation, one of the above solutions should suffice. |
Beta Was this translation helpful? Give feedback.
jax.jit
won't work on a function with aModule
argument type. We can demonstrate this by making a dummy function with the same argument signature:We get an error:
One solution to this is you can specify that that argument is static (check out the API documentation on
static_argnums
here) via: