Description
Intro
Hi!
I am a student and I'm using Mujoco (especially MJX) for rigid body simulation and reinforcement learning.
My setup
I used both version 3.3.0 and the remote environment in the tutorial notebook in this link https://mujoco.readthedocs.io/en/stable/mjx.html.
What's happening? What did you expect?
I'm encountering GPU simulation distortion issues very frequently. When objects make contact with the ground, they either bounce abnormally or immediately result in NaN values in data.xpos and qpos. In any case, the results are significantly different from those of the CPU simulation. Reducing the time step or changing the integrator does not solve the problem.
Even a very simple example (even a simple capsule falling to the ground causes such anomalies) is able to produce this error.
Steps for reproduction
Just need to run the following code.
Minimal model for reproduction
xml = """
<?xml version="1.0" ?>
<mujoco>
<default>
<site rgba=".9 .9 .9 1"/>
</default>
<asset>
<texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3" rgb2=".2 .3 .4" width="300" height="300" mark="none"/>
<material name="grid" texture="grid" texrepeat="6 6" texuniform="true" reflectance=".2"/>
</asset>
<option timestep="0.0001">
<flag energy="disable" contact="enable" gravity="enable"/>
</option>
<worldbody>
<light name="light" pos="-.2 0 1"/>
<geom name="ground" type="plane" pos="0 0 -0.001" size="0 0 10" material="grid" zaxis="0 0 1" friction="0.0"/>
<camera name="cam" pos="0 0 7"/>
<light pos="0 0 1"/>
<light pos="-1 0 1"/>
<light pos="-2 0 1"/>
<light pos="-3 0 1"/>
<light pos="-4 0 1"/>
<light pos="-5 0 1"/>
<body name="b1" pos="0 0 0.05">
<joint name="j1" type="free"/>
<geom friction="0" name="g1" type="capsule" size="0.04" fromto="0 0 0 0.19855637987881708 0 0"/>
</body>
</worldbody>
<tendon/>
</mujoco>
"""
### Code required for reproduction
```Python
import time
import numpy as np
import jax
import jax.numpy as jp
import matplotlib.pyplot as plt
import mujoco
from mujoco import mjx
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
duration = 0.5 # (seconds)
framerate = 24 # (Hz)
jit_step = jax.jit(mjx.step)
frames = []
mujoco.mj_resetData(mj_model, mj_data)
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
while mjx_data.time < duration:
# Step.
mjx_data = jit_step(mjx_model, mjx_data)
print("xpos", jp.sum(mjx_data.xpos))
### Confirmations
- [x] I searched the [latest documentation](https://mujoco.readthedocs.io/en/latest/overview.html) thoroughly before posting.
- [x] I searched previous [Issues](https://github.com/google-deepmind/mujoco/issues) and [Discussions](https://github.com/google-deepmind/mujoco/discussions), I am certain this has not been raised before.