You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to use float32 to reduce memory usage. What do I need to do to change the default dtype to float32? I try "jax.config.update('jax_enable_x64', False)" in the beginning of my code but it seems not working.
I'm trying to use float32 to reduce memory usage. What do I need to do to change the default dtype to float32? I try "jax.config.update('jax_enable_x64', False)" in the beginning of my code but it seems not working.