You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to implement a decoder-only model with Flax. I used flax.linen.scan for building repeated layers and I want to use flax.linen.while_loop to loop the decoding steps. However I got an error like flax.errors.JaxTransformError: Jax transforms and Flax models cannot be mixed.. It's caused by using both flax.linen.scan and flax.linen.while_loop. If I change the loop to a common python while loop then there's no such error. How can I use both flax.linen.scan and flax.linen.while_loop for this purpose?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi team,
I'm trying to implement a decoder-only model with Flax. I used
flax.linen.scan
for building repeated layers and I want to useflax.linen.while_loop
to loop the decoding steps. However I got an error likeflax.errors.JaxTransformError: Jax transforms and Flax models cannot be mixed.
. It's caused by using bothflax.linen.scan
andflax.linen.while_loop
. If I change the loop to a common python while loop then there's no such error. How can I use bothflax.linen.scan
andflax.linen.while_loop
for this purpose?Also I've checked T5X implementation. T5X used a simple while loop for repeated layers(https://github.com/google-research/t5x/blob/36e5f02f87669e3c38a9699001a4a154b514a115/t5x/examples/decoder_only/network.py#LL197C9-L197C9). I thought scan should have better performance comparing to a simple while loop. What's the consideration here for T5X? Is it because T5X also used
lax.while_loop
for decoding(https://github.com/google-research/t5x/blob/36e5f02f87669e3c38a9699001a4a154b514a115/t5x/decoding.py#LL670C5-L670C5) so it can not use scan?Thank you!
Beta Was this translation helpful? Give feedback.
All reactions