Replies: 1 comment 6 replies
-
Can you be more specific about what is 5x slower? How are you measuring this? Have you isolated compilation time from runtime? (see https://docs.jax.dev/en/latest/benchmarking.html#benchmarking-jax-code for some details). |
Beta Was this translation helpful? Give feedback.
6 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment

Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
tldr
DDPG is a value-based RL algorithm, which needs a reply buffer to store experiences (obs, action, reward, etc), which is implemented as a circular buffer in my case. Every step of training, one step experience is pushed into the buffer. Due to the pure functional nature of jax.jit, large reply buffer can lead to pushing experience time- and memory-expensive. So it is usefull to donate the buffer state arrays to ensure in-place modification operation to save a lot of runtime.
But I found jitting the outmost function
train_one_stepdoesn't save any runtime(almost the same), it seems that the compiler necglect the signal from donation decorators and decides to allocate new buffer state array and copy the old whole buffer.. I failed to validate it because I tried to see the jaxpr but no information of whether to copy large memory.What improve the performance significantly (5x) is to
train_one_stepintorollout_and_pushandupdate_modelSpecifically, donate the large
buffer_stateargument ofrollout_and_pushfunction, which is the critic point to ensure underhood in-place update of large buffer.So why jitting together degrades? I guess from XLA’s perspective:
code snippet
If you jit the outmost function
train_one_step, the training will be 5x slower!train_one_stepjitted: 37.45strain_one_stepjitted: 7.8sbenmark code:
function definition:
Complete code
https://github.com/zzhixin/jaxrl-learning/tree/b2a9d1ab64a413ed1cfcb03352f1c6c3a2b2be4f
Just run the benchmark_ddpg.py file.
Comments
Generally, one can expect no worse performance when jitting the outer function compared to jitting only the inner functions respectively. If I didn't do thing stupidly, this case shows sometimes jitting the outer function can lead to worse performance.
I suppose this is not a new discovery, but I the performance drop still surprised me. In theory, such performance degradation of jitting outer function strongly suggests you can always make a smarter jax compiler once the failure mode is identified. In other words, if, for certain case, the compiler can be smarter enough to automatically jitted seperately when one blindly jits the outer function, then the user experience can be more consistent, which means that one can obey the principle of "jit the outer function as much as possible" more confidently.
Questions
system info
Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 39 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 16 On-line CPU(s) list: 0-15 Vendor ID: GenuineIntel Model name: 11th Gen Intel(R) Core(TM) i7-11700KF @ 3.60GHz
Beta Was this translation helpful? Give feedback.
All reactions