Skip to content

Device Management in Multi-GPU systems, v2 #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 41 commits into
base: main
Choose a base branch
from

Conversation

adenzler-nvidia
Copy link
Collaborator

replaces #130 after discussions about best practices.

All the API functions now use wp.ScopedDevice. This has the benefit of resetting to the old active context after the function, which can avoid issues if you're doing other stuff on your GPU in-between MJWarp simulations.

We use the device use in put_model as the source of truth. So this does not enable any fancy multi-GPU running, it only makes sure that you use the same GPU for everything MjWarp.

I opted for a test that checks if all the API functions have a scopedDevice block, I didn't figure out a better way to test this.

Also, I'm really sorry for whoever has to review this, but bascially all the changes are intendation changes.

@erikfrey
Copy link
Collaborator

@adenzler-nvidia sorry to ask a really dumb question but what is the use case here?

At first glance, I would have expected something like wp.ScopedDevice() to be a user concept, e.g. I would expect the user (not us) to do something like:

device = ...
m = io.put_model(..., device=device)
d = io.put_data(..., device=device)
with wp.ScopedDevice(device):
   mjwarp.step(m, d)

@adenzler-nvidia
Copy link
Collaborator Author

valid question - we could push this entirely to the user for sure. Our experience as a user ourselves tells us that it's very easy to forget this though, and random weird stuff happens once you have systems with multiple GPUs and other parts of your workflow also doing GPU work.

The trickyness is mostly about how warp is selecting the default device. This default device can change if any other part of your system changes the currently bound CUDA context, and at that point you either pay the price for data migration/remote access and some stuff even stops working. We also need to make sure we're not messing with the currently bound context of other users on the system by restoring the existing state once we're done.

Given that we are likely going to be using MjWarp with in tandem with ML workloads, rendering, user code that also uses warp, etc this introduces a bit of safety for everyone.

We definitely want to reconsider this once we really thing about multi-GPU with MjWarp. Right now this just follows best practices we use on all of our other warp code. That being said - I'm also not a fan of enforcing ScopedDevice on all interface functions, and the API could be better by only having to specify the device for the model, and inferring it for data. But that would mean passing the model to put_data.

Curious to hear opinions though.

@btaba
Copy link
Collaborator

btaba commented Apr 25, 2025

@adenzler-nvidia what's holding things back from considering multi-GPU in the shorter term? Coming from JAX and some of our recent work in MuJoCo land, multi-GPU has been seamless to use and critically important for our research velocity

@adenzler-nvidia
Copy link
Collaborator Author

Nothing specific - happy to talk about multi-GPU. Interested in how you guys have been using it so far.

I'm seeing this PR as a first stepping stone to make sure we're getting the device right for 1 GPU, but we can go for more immediately after that. The main point is to go from an implicit device selection to something explicit.

@adenzler-nvidia
Copy link
Collaborator Author

Plan after offline discussion with @erikfrey:

  • let's abandon this and gather requirements about multi-GPU use-cases immediately. Goal is not to implement it in beta, but get a head-start on the changes/additions we need in warp to make sure we're ready for this ASAP.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants