Skip to content

Model batching #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

adenzler-nvidia
Copy link
Collaborator

Sharing WIP for model batching. Closes #63.

This right now contains the following changes:

  • move nworld parameter to the model
  • add an "expand_fields" set to put_model such that we know which arrays to tile correctly.
  • all other arrays are stride 0 in the first dimension, which means a constant over all worlds.

The set of arrays that can be made per-world is very much a best guess, I tried not to include anything topology relevant but might have missed some.

@erikfrey
Copy link
Collaborator

erikfrey commented Apr 5, 2025

Whew this is gonna be a monster of a PR :-) thank you so much for taking this on.

Two design considerations:

  1. Model is pretty heavy - in JAX we deal with this by letting the user explicitly choose which arrays to expand, and the others are left unbatched. See here for example, we produce a Model with only 5 batched fields:

https://github.com/google-deepmind/mujoco_playground/blob/main/mujoco_playground/_src/locomotion/g1/randomize.py#L92

I think this one is important.

  1. Less important than 1, but: could we allow for a different Model batch size than nworld? I'd keep those two concepts separate. Let's say nworld is 4 and nmodel (or whatever we call it) is 2, then:
Model Id Data Id
0 0
1 1
0 2
1 3

WDYT of these two design factors?

@erikfrey
Copy link
Collaborator

erikfrey commented Apr 6, 2025

Oh! I just looked at the code and I see you have 1 handled already, very cool. So it's more about 2, let me know what you think.

@adenzler-nvidia
Copy link
Collaborator Author

Yeah I think we got 1 covered nicely with the stride 0 arrays - still a few wonky things that I need to iron out but I think the approach works.

For 2 - I didn't think of that but it's certainly possible. Do you have a JAX example somewhere handy about how you're usually doing this? I can see it being a bit weird for kernel writers, as there might be 2 different batching indices at that point, which might not be obvious from the get-go. I would prefer to avoid having an indirect lookup (like modelid = d.modelid[worldid]) in every kernel but if we can get the model index from the world index using a calculation it should be fine. Depends a bit on the requirements here.

@eric-heiden
Copy link
Collaborator

Can we assume nworld always to be a multiple of nmodel, i.e. there are always a constant number of N >= 1 states per model? In that case we could just have modelid = worldid % nmodel.

@erikfrey
Copy link
Collaborator

erikfrey commented Apr 7, 2025

@eric-heiden Sure, I think we could, but wouldn't modelid = worldid % nmodel work even if nworld is not a perfect multiple of nmodel?

@adenzler-nvidia One scenario I'm thinking of that I think we'll want quite soon is domain randomizing the objects in a scene, e.g. if we want to train a "grasp anything" type policy - actually forcing the user to have 8k objects on the model (or whatever nworld happens to be) might be prohibitive, not just for the user to supply but what we can populate on device.

Generally speaking, would it be safe to query the shape of the model array in question for the item to retrive? So something like:

marginid = worldid % m.geom_margin.shape[0]
margin = m.geom_margin[marginid, ...

@adenzler-nvidia
Copy link
Collaborator Author

There is an obvious need for running heterogeneous environments for sure. I don't think it makes a lot of sense to implement that in this PR as we still need to figure out how to do that best when looking at performance. It also depends a lot on what exactly is randomized - if all the objects have the same tree topology, that is a different thing that suddenly having a different tree for each world. For example, just changing the collision geometry of a free-jointed object is going to be easy, but then having two different robots is a completely different beast.

We need to think about that not only in terms of worlds, but rather over what axis we should parallelize. I think if we go fully heterogeneous, there needs to be a compilation step that reorders some of the subtrees such that we can extract as much parallelism as possible. The tricky part at that point is how we remap all the parameters, whether that is an object id or a world id or a model id. And then we need to think about de-duplication, how to make sure we're limiting memory usage.

I think we can separate API and implementation level concerns here a bit, but we need to be clear on the requirements to avoid driving ourselves into a corner. On top of that, we need to come up with something that still makes it possible to stay sane while developing and debugging the engine.

@erikfrey
Copy link
Collaborator

erikfrey commented Apr 8, 2025

Oh definitely agree we should not implement it now. I think my suggestion is exactly to avoid driving ourselves into a corner as you put it - if we explicitly tie nworld to both Model and Data we may have to undo it later, possibly leading to wailing and gnashing of teeth of our users.

That's why I'm suggesting an option that seems to impose the fewest API assumptions that we may have to undo later, which is to just query the array shape for the Model field in question, and use that - is there somewhere that that may bite us?

@adenzler-nvidia
Copy link
Collaborator Author

Makes sense - I'll test drive the shape lookup today.

So in the end you would avoid having any world/modelid parameter on the model side alltogether, and just allow any size in the first dimension? I'll have to test drive how well this works from an API point of view.

My biggest concern here is that it might become unnecessarily hard for us developers, trying to figure out which part of the model has which size etc. Unless we can somehow get the model arrays to wrap around, that would be cool.

@erikfrey
Copy link
Collaborator

erikfrey commented Apr 9, 2025

So in the end you would avoid having any world/modelid parameter on the model side alltogether, and just allow any size in the first dimension? I'll have to test drive how well this works from an API point of view.

Just for now until there's more clarity on what parameters we'll want that encompass all our model batching use cases.

@adenzler-nvidia
Copy link
Collaborator Author

went ahead and implemented the modulo approach here: https://github.com/adenzler-nvidia/mujoco_warp/tree/dev/adenzler/modulo-experiments

Roughly, what I did is to have all the "expandable" model fields have shape 1 and stride 0 in the first dim. So that's equivalent to the proposal above. But then as a user you can replace that by an array of any shape in the first dim to get more or less arbitrary model-> world mappings and then just do the modulo calculation.

Interestingly, the modulo approach is quite a bit (~2%) faster on the humanoid. What is even more confusing is that the change in runtime is in the solver, which currently isn't even touched by any of the changes I did here. I would have expected it to be the other way around given that we do more computation during array indexing, which happens often and can be one of those silent performance killers. So right now my gut feeling is that I have a bug somewhere else that makes the solver terminate earlier or otherwise do less work. Need to figure that out.

That being said, I think the approach can work. My main reservations are:

  • it's going to be hard for devs to figure out when to do the modulo indexing. I'm pretty sure we're going to mess this up all the time
  • it's not having a big effect on perf right now, but we're bottlenecked/suffering the consequences of launch tons of tiny kernels with almost empty threads. I'm a bit worried that a change like this is something we cannot reverse on even if it starts becoming an issue, because it's part of the API. Likely we also won't really ever know that it will be an issue because it's not a big block showing up on a profile, but rather small inefficiencies scattered all over the place.

I'm a bit out of ideas but will try to explore some more. What's clear to me is that we need to find ways of not having to pay the price for the complexities of all kinds of batching if you don't need the complexity in the first place. Maybe some clever use of wp.static is important here.

@erikfrey
Copy link
Collaborator

I hear you that these changes introduce more opportunities for bugs. Maybe between this PR and #148 it's worth thinking through some helper interfaces to minimize the surface area and verbosity. What do you think?

@adenzler-nvidia
Copy link
Collaborator Author

Reading through this again, I think it's time to make a decision. It sounds clear to me that we do not want modelId to be tied to worldId, which makes a lot of sense.

So the remaining question is whether to allow different sizes for different model parameters. I'm torn on that one personally - I like the idea of having 1 modelid that we can calculate upfront and then use for all model fields, it's simple. On the other hand, I can see the benefits of having different sizes, but I'm worried about the price of looking up shapes all the time.

I think I can make both approaches work though, with some helpers. I think the goal should be to make the easy use-case (nmodel == nworld) usable without too many restrictions, I guess.

WDYT?

@btaba
Copy link
Collaborator

btaba commented Apr 25, 2025

@adenzler-nvidia pointed me to take a look at this PR. Here are my high-level thoughts:

  1. I strongly suspect expand_fields will not play nicely with our JAX workflow, but I could be wrong (I'd need to go through the motions). The modulo PR seems like the better approach.
  2. If there's a concern about maintainability with the modulo approach, we can override __getitem__ on model fields.

If there isn't a big rush to get this PR in, I would wait for a real-use case to battle test the impl (i.e. JAX interop with MJX and domain rando hooked up, which is more or less imminent)

@adenzler-nvidia
Copy link
Collaborator Author

heads-up: I'm currently working on the next version in a new branch, will retarget this MR as soon as I have all the tests passing. Going for a modulo version with helpers.

So the expand-fields approach is dead, I agree it's unlikely to play well with JAX. Happy to test-drive this with a real use-case, let me know when you have something ready.

@adenzler-nvidia
Copy link
Collaborator Author

New MR: #195

@adenzler-nvidia
Copy link
Collaborator Author

adenzler-nvidia commented Apr 28, 2025

Plan after offline discussion @erikfrey:

  • let's forget about the modulo indexing and just have all the batched fields be size nworlds, and have the user do repeat data for now. But only expands fields to have different data for different model batches.
  • use the kernel analyzer to enforce correctness around what field can be expanded and what can't.
  • wait with this until [WIP] API changes to address multiple issues. #148 is merged.

@adenzler-nvidia
Copy link
Collaborator Author

closing in favor of #231

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Model batching
4 participants