Skip to content

Jax transforms and modules cannot be mixed. #968

Answered by jheek
yang-song asked this question in General
Discussion options

You must be logged in to vote

We will add full fledged cond support that would allow module-in-cond.

For your particular case though I would advice not to use cond at all.
Cond has a high overhead on accelerators in particular GPU this is also why it hasn't been implemented yet. It is rarely a good idea to use it inside a model. The condition needs to be evaluated on host and the next op can only be scheduled afterward causing the GPU to idle in the meantime.

jnp.where will get you better performance here. It will execute both branches and pick the correct one avoiding a sync between GPU and host while also allowing for more optimization in XLA.

Alternatively, you could use a boolean first_row, first_col and use jit(s…

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
1 reply
@yang-song
Comment options

Comment options

You must be logged in to vote
2 replies
@jheek
Comment options

jheek Feb 1, 2021
Maintainer

@yang-song
Comment options

Answer selected by yang-song
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants