[MJX] Best method for using mjx.ray in RL multi-environment training without running out of memory #2668
Unanswered
AlexS28
asked this question in
Asking for Help
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Intro
Hi!
I am a PostDoct using MuJoCo for writing research papers in the area of manipulation.
My setup
Nvidia 4090 GPU
My question
I copy and pasted my code below which initializes a ray function which calls mjx.ray, and then I also copy and pasted the function which I call in my get observation function used in my RL pipeline (mjx), called compute_camera_ray_hits_mjx. Currently, my assumed resolution is low (40,40), and number of environments is 128 (also quite low). I flatten the dists variable and include that in my observation space. Note, even if I change my resolution to something dramatic such as (2,2) or lower number of environments to 2, I still get memory issues. I must be doing some wrong logic in my implementation. It does work if I make my number of environments equal to 1, so there's likely some vmap or jit issue? Any help would be greatly appreciated. I also tried without Jit and vmap, but it took too long to initialize so I am not sure if I would have run into memory issue.
2025-06-05 20:30:32.817821: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below 16.64GiB (17871886491 bytes) by rematerialization; only reduced to 26.34TiB (28956816731264 bytes), down from 26.34TiB (28958311494848 bytes) originally
Minimal model and/or code that explain my question
Confirmations
Beta Was this translation helpful? Give feedback.
All reactions