inference function block with global variables consumes less gpu ram and faster than purely functional code without global usage #3252
Unanswered
TheSeriousProgrammer
asked this question in
Q&A
Replies: 1 comment
-
|
Hey @TheSeriousProgrammer, you should be very careful when using globals in JAX programs as you can easily leak tracers. The problem is that this code @jax.jit
def infer(x):
global output
output = model.apply(params, x)is compiling to a constant function that returns |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
Linux 6.1.38-1-MANJARO # 1 SMP PREEMPT_DYNAMIC Wed Jul 5 23:49:30 UTC 2023 x86_64 GNU/Linux
pip show flax jax jaxlib:NVIDIA GeForce RTX 3070
12.2
Problem you have encountered:
I have the following snippet from a free time research
Before proceeding with model I wanted to know the throughput of the new model
For the same I used the below code
The output was
I was a bit skeptical and tried another snippet where the output wont be returned by the infer fn but stored in a global variable
And the output was
The difference between the 2 benchmarking code blocks is that the first one returns the output whereas the other one stores the output in a global variable. The difference in the speeds are a ridiculous 36 fps and 1228fps respectively. Am I doing something wrong here?
Beta Was this translation helpful? Give feedback.
All reactions