Skip to content

Why does jitting fail when passing model as argument? #3520

Answered by chiamp
Rasmuskh asked this question in Q&A
Discussion options

You must be logged in to vote

jax.jit won't work on a function with a Module argument type. We can demonstrate this by making a dummy function with the same argument signature:

@jax.jit
def infer_v2(params, m, x):
    return

key1, key2 = jax.random.PRNGKey(102), jax.random.PRNGKey(774)
x = jax.random.normal(key1, shape=([1, 3]), dtype=jax.numpy.float32)
model = flax.linen.Dense(2)
params = model.init(key2, x)['params']

y2 = infer_v2(params, model, x)

We get an error:

<class 'flax.linen.linear.Dense'> is not a valid JAX type.

One solution to this is you can specify that that argument is static (check out the API documentation on static_argnums here) via:

def infer_v2(params, m, x):
    y = m.apply({'params': params…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Rasmuskh
Comment options

Answer selected by Rasmuskh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants