-
Notifications
You must be signed in to change notification settings - Fork 41
Device Management in Multi-GPU systems, v2 #182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Device Management in Multi-GPU systems, v2 #182
Conversation
@adenzler-nvidia sorry to ask a really dumb question but what is the use case here? At first glance, I would have expected something like wp.ScopedDevice() to be a user concept, e.g. I would expect the user (not us) to do something like: device = ...
m = io.put_model(..., device=device)
d = io.put_data(..., device=device)
with wp.ScopedDevice(device):
mjwarp.step(m, d) |
valid question - we could push this entirely to the user for sure. Our experience as a user ourselves tells us that it's very easy to forget this though, and random weird stuff happens once you have systems with multiple GPUs and other parts of your workflow also doing GPU work. The trickyness is mostly about how warp is selecting the default device. This default device can change if any other part of your system changes the currently bound CUDA context, and at that point you either pay the price for data migration/remote access and some stuff even stops working. We also need to make sure we're not messing with the currently bound context of other users on the system by restoring the existing state once we're done. Given that we are likely going to be using MjWarp with in tandem with ML workloads, rendering, user code that also uses warp, etc this introduces a bit of safety for everyone. We definitely want to reconsider this once we really thing about multi-GPU with MjWarp. Right now this just follows best practices we use on all of our other warp code. That being said - I'm also not a fan of enforcing ScopedDevice on all interface functions, and the API could be better by only having to specify the device for the model, and inferring it for data. But that would mean passing the model to put_data. Curious to hear opinions though. |
@adenzler-nvidia what's holding things back from considering multi-GPU in the shorter term? Coming from JAX and some of our recent work in MuJoCo land, multi-GPU has been seamless to use and critically important for our research velocity |
Nothing specific - happy to talk about multi-GPU. Interested in how you guys have been using it so far. I'm seeing this PR as a first stepping stone to make sure we're getting the device right for 1 GPU, but we can go for more immediately after that. The main point is to go from an implicit device selection to something explicit. |
Plan after offline discussion with @erikfrey:
|
replaces #130 after discussions about best practices.
All the API functions now use wp.ScopedDevice. This has the benefit of resetting to the old active context after the function, which can avoid issues if you're doing other stuff on your GPU in-between MJWarp simulations.
We use the device use in put_model as the source of truth. So this does not enable any fancy multi-GPU running, it only makes sure that you use the same GPU for everything MjWarp.
I opted for a test that checks if all the API functions have a scopedDevice block, I didn't figure out a better way to test this.
Also, I'm really sorry for whoever has to review this, but bascially all the changes are intendation changes.