diff --git a/contrib/jax_unroll.py b/contrib/jax_unroll.py index ec6db438..f29d68d5 100644 --- a/contrib/jax_unroll.py +++ b/contrib/jax_unroll.py @@ -33,7 +33,7 @@ mjd.qvel = np.random.uniform(-0.01, 0.01, mjm.nv) mujoco.mj_step(mjm, mjd, 3) # let dynamics get state significantly non-zero mujoco.mj_forward(mjm, mjd) -m = mjwarp.put_model(mjm) +m = mjwarp.put_model(mjm, nworld=NWORLDS) d = mjwarp.put_data(mjm, mjd, nworld=NWORLDS, nconmax=131012, njmax=131012 * 4) diff --git a/mujoco_warp/_src/broad_phase_test.py b/mujoco_warp/_src/broad_phase_test.py index 08febab4..d1b949bf 100644 --- a/mujoco_warp/_src/broad_phase_test.py +++ b/mujoco_warp/_src/broad_phase_test.py @@ -208,6 +208,7 @@ def test_nxn_broadphase(self): np.testing.assert_allclose(d2.collision_pair.numpy()[2][1], 2) # two worlds and four collisions + m3 = mjwarp.put_model(mjm, nworld=2) d3 = mjwarp.make_data(mjm, nworld=2) d3.geom_xpos = wp.array( np.vstack( @@ -216,7 +217,7 @@ def test_nxn_broadphase(self): dtype=wp.vec3, ) - collision_driver.nxn_broadphase(m, d3) + collision_driver.nxn_broadphase(m3, d3) np.testing.assert_allclose(d3.ncollision.numpy()[0], 4) np.testing.assert_allclose(d3.collision_pair.numpy()[0][0], 0) np.testing.assert_allclose(d3.collision_pair.numpy()[0][1], 1) diff --git a/mujoco_warp/_src/collision_box.py b/mujoco_warp/_src/collision_box.py index a7ba6270..a318c926 100644 --- a/mujoco_warp/_src/collision_box.py +++ b/mujoco_warp/_src/collision_box.py @@ -306,7 +306,7 @@ def box_box_kernel( for i in range(4): pos[i] = pos[idx] - margin = wp.max(m.geom_margin[ga], m.geom_margin[gb]) + margin = wp.max(m.geom_margin[worldid, ga], m.geom_margin[worldid, gb]) for i in range(4): pos_glob = b_mat @ pos[i] + b_pos n_glob = b_mat @ sep_axis diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py index ffc81e4b..cb0bdf9e 100644 --- a/mujoco_warp/_src/collision_driver.py +++ b/mujoco_warp/_src/collision_driver.py @@ -26,13 +26,15 @@ @wp.func -def _geom_filter(m: Model, geom1: int, geom2: int, filterparent: bool) -> bool: +def _geom_filter( + m: Model, geom1: int, geom2: int, filterparent: bool, worldid: int +) -> bool: bodyid1 = m.geom_bodyid[geom1] bodyid2 = m.geom_bodyid[geom2] - contype1 = m.geom_contype[geom1] - contype2 = m.geom_contype[geom2] - conaffinity1 = m.geom_conaffinity[geom1] - conaffinity2 = m.geom_conaffinity[geom2] + contype1 = m.geom_contype[worldid, geom1] + contype2 = m.geom_contype[worldid, geom2] + conaffinity1 = m.geom_conaffinity[worldid, geom1] + conaffinity2 = m.geom_conaffinity[worldid, geom2] weldid1 = m.body_weldid[bodyid1] weldid2 = m.body_weldid[bodyid2] weld_parentid1 = m.body_weldid[m.body_parentid[weldid1]] @@ -82,7 +84,7 @@ def broadphase_project_spheres_onto_sweep_direction_kernel( if r == 0.0: # current geom is a plane r = 1000000000.0 - sphere_radius = r + m.geom_margin[i] + sphere_radius = r + m.geom_margin[worldid, i] center = wp.dot(direction, c) f = center - sphere_radius @@ -146,7 +148,7 @@ def reorder_bounding_spheres_kernel( # Get the bounding volume c = d.geom_xpos[worldid, mapped] r = m.geom_rbound[mapped] - margin = m.geom_margin[mapped] + margin = m.geom_margin[worldid, mapped] # Reorder the box into the sorted array if r == 0.0: @@ -291,7 +293,7 @@ def sap_broadphase_kernel(m: Model, d: Data, num_threads: int, filter_parent: bo idx1 = d.sap_sort_index[worldid, i] idx2 = d.sap_sort_index[worldid, j] - if not _geom_filter(m, idx1, idx2, filter_parent): + if not _geom_filter(m, idx1, idx2, filter_parent, worldid): threadId += num_threads continue @@ -317,17 +319,19 @@ def get_contact_solver_params_kernel( g1 = geoms.x g2 = geoms.y - margin = wp.max(m.geom_margin[g1], m.geom_margin[g2]) - gap = wp.max(m.geom_gap[g1], m.geom_gap[g2]) - solmix1 = m.geom_solmix[g1] - solmix2 = m.geom_solmix[g2] + worldid = d.contact.worldid[tid] + + margin = wp.max(m.geom_margin[worldid, g1], m.geom_margin[worldid, g2]) + gap = wp.max(m.geom_gap[worldid, g1], m.geom_gap[worldid, g2]) + solmix1 = m.geom_solmix[worldid, g1] + solmix2 = m.geom_solmix[worldid, g2] mix = solmix1 / (solmix1 + solmix2) mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 < MJ_MINVAL), 0.5, mix) mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 >= MJ_MINVAL), 0.0, mix) mix = wp.where((solmix1 >= MJ_MINVAL) and (solmix2 < MJ_MINVAL), 1.0, mix) - p1 = m.geom_priority[g1] - p2 = m.geom_priority[g2] + p1 = m.geom_priority[worldid, g1] + p2 = m.geom_priority[worldid, g2] mix = wp.where(p1 == p2, mix, wp.where(p1 > p2, 1.0, 0.0)) condim1 = m.geom_condim[g1] @@ -337,15 +341,21 @@ def get_contact_solver_params_kernel( ) d.contact.dim[tid] = condim - if m.geom_solref[g1].x > 0.0 and m.geom_solref[g2].x > 0.0: - d.contact.solref[tid] = mix * m.geom_solref[g1] + (1.0 - mix) * m.geom_solref[g2] + if m.geom_solref[worldid, g1].x > 0.0 and m.geom_solref[worldid, g2].x > 0.0: + d.contact.solref[tid] = ( + mix * m.geom_solref[worldid, g1] + (1.0 - mix) * m.geom_solref[worldid, g2] + ) else: - d.contact.solref[tid] = wp.min(m.geom_solref[g1], m.geom_solref[g2]) + d.contact.solref[tid] = wp.min( + m.geom_solref[worldid, g1], m.geom_solref[worldid, g2] + ) d.contact.includemargin[tid] = margin - gap - friction_ = wp.max(m.geom_friction[g1], m.geom_friction[g2]) + friction_ = wp.max(m.geom_friction[worldid, g1], m.geom_friction[worldid, g2]) friction5 = vec5(friction_[0], friction_[0], friction_[1], friction_[2], friction_[2]) d.contact.friction[tid] = friction5 - d.contact.solimp[tid] = mix * m.geom_solimp[g1] + (1.0 - mix) * m.geom_solimp[g2] + d.contact.solimp[tid] = ( + mix * m.geom_solimp[worldid, g1] + (1.0 - mix) * m.geom_solimp[worldid, g2] + ) def sap_broadphase(m: Model, d: Data): @@ -357,7 +367,7 @@ def sap_broadphase(m: Model, d: Data): wp.launch( kernel=broadphase_project_spheres_onto_sweep_direction_kernel, - dim=(d.nworld, m.ngeom), + dim=(m.nworld, m.ngeom), inputs=[m, d, direction], ) @@ -367,19 +377,19 @@ def sap_broadphase(m: Model, d: Data): if tile_sort_available: segmented_sort_kernel = create_segmented_sort_kernel(m.ngeom) wp.launch_tiled( - kernel=segmented_sort_kernel, dim=(d.nworld), inputs=[m, d], block_dim=128 + kernel=segmented_sort_kernel, dim=(m.nworld), inputs=[m, d], block_dim=128 ) print("tile sort available") elif segmented_sort_available: wp.utils.segmented_sort_pairs( d.sap_projection_lower, d.sap_sort_index, - m.ngeom * d.nworld, + m.ngeom * m.nworld, d.sap_segment_index, ) else: # Sort each world's segment separately - for world_id in range(d.nworld): + for world_id in range(m.nworld): start_idx = world_id * m.ngeom # Create temporary arrays for sorting @@ -431,13 +441,13 @@ def sap_broadphase(m: Model, d: Data): wp.launch( kernel=reorder_bounding_spheres_kernel, - dim=(d.nworld, m.ngeom), + dim=(m.nworld, m.ngeom), inputs=[m, d], ) wp.launch( kernel=sap_broadphase_prepare_kernel, - dim=(d.nworld, m.ngeom), + dim=(m.nworld, m.ngeom), inputs=[m, d], ) @@ -445,7 +455,7 @@ def sap_broadphase(m: Model, d: Data): wp.utils.array_scan(d.sap_range.reshape(-1), d.sap_cumulative_sum, True) # Estimate how many overlap checks need to be done - assumes each box has to be compared to 5 other boxes (and batched over all worlds) - num_sweep_threads = 5 * d.nworld * m.ngeom + num_sweep_threads = 5 * m.nworld * m.ngeom filter_parent = not m.opt.disableflags & DisableBit.FILTERPARENT.value wp.launch( kernel=sap_broadphase_kernel, @@ -478,8 +488,8 @@ def _nxn_broadphase(m: Model, d: Data): + (m.ngeom - geom1) * ((m.ngeom - geom1) - 1) // 2 ) - margin1 = m.geom_margin[geom1] - margin2 = m.geom_margin[geom2] + margin1 = m.geom_margin[worldid, geom1] + margin2 = m.geom_margin[worldid, geom2] pos1 = d.geom_xpos[worldid, geom1] pos2 = d.geom_xpos[worldid, geom2] size1 = m.geom_rbound[geom1] @@ -503,13 +513,13 @@ def _nxn_broadphase(m: Model, d: Data): dist = wp.dot(-dif, wp.vec3(xmat2[0, 2], xmat2[1, 2], xmat2[2, 2])) bounds_filter = dist <= bound - geom_filter = _geom_filter(m, geom1, geom2, filterparent) + geom_filter = _geom_filter(m, geom1, geom2, filterparent, worldid) if bounds_filter and geom_filter: _add_geom_pair(m, d, geom1, geom2, worldid) wp.launch( - _nxn_broadphase, dim=(d.nworld, m.ngeom * (m.ngeom - 1) // 2), inputs=[m, d] + _nxn_broadphase, dim=(m.nworld, m.ngeom * (m.ngeom - 1) // 2), inputs=[m, d] ) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index e9831732..b8e006dd 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -253,7 +253,7 @@ def _primitive_narrowphase( geom1 = _geom(g1, m, d.geom_xpos[worldid], d.geom_xmat[worldid]) geom2 = _geom(g2, m, d.geom_xpos[worldid], d.geom_xmat[worldid]) - margin = wp.max(m.geom_margin[g1], m.geom_margin[g2]) + margin = wp.max(m.geom_margin[worldid, g1], m.geom_margin[worldid, g2]) # TODO(team): static loop unrolling to remove unnecessary branching if type1 == int(GeomType.PLANE.value) and type2 == int(GeomType.SPHERE.value): diff --git a/mujoco_warp/_src/constraint.py b/mujoco_warp/_src/constraint.py index 7ebb87f8..50060373 100644 --- a/mujoco_warp/_src/constraint.py +++ b/mujoco_warp/_src/constraint.py @@ -109,9 +109,9 @@ def _efc_limit_slide_hinge( jntid = m.jnt_limited_slide_hinge_adr[jntlimitedid] qpos = d.qpos[worldid, m.jnt_qposadr[jntid]] - jnt_range = m.jnt_range[jntid] + jnt_range = m.jnt_range[worldid, jntid] dist_min, dist_max = qpos - jnt_range[0], jnt_range[1] - qpos - pos = wp.min(dist_min, dist_max) - m.jnt_margin[jntid] + pos = wp.min(dist_min, dist_max) - m.jnt_margin[worldid, jntid] active = pos < 0 if active: @@ -130,10 +130,10 @@ def _efc_limit_slide_hinge( efcid, pos, pos, - m.dof_invweight0[dofadr], - m.jnt_solref[jntid], - m.jnt_solimp[jntid], - m.jnt_margin[jntid], + m.dof_invweight0[worldid, dofadr], + m.jnt_solref[worldid, jntid], + m.jnt_solimp[worldid, jntid], + m.jnt_margin[worldid, jntid], refsafe, Jqvel, ) @@ -179,7 +179,9 @@ def _efc_contact_pyramidal( frame = d.contact.frame[conid] # pyramidal has common invweight across all edges - invweight = m.body_invweight0[body1, 0] + m.body_invweight0[body2, 0] + invweight = ( + m.body_invweight0[worldid, body1, 0] + m.body_invweight0[worldid, body2, 0] + ) if condim > 1: dimid2 = dimid / 2 + 1 @@ -285,7 +287,9 @@ def _efc_contact_elliptic( d.efc.J[efcid, i] = J Jqvel += J * d.qvel[worldid, i] - invweight = m.body_invweight0[body1, 0] + m.body_invweight0[body2, 0] + invweight = ( + m.body_invweight0[worldid, body1, 0] + m.body_invweight0[worldid, body2, 0] + ) ref = d.contact.solref[conid] pos_aref = pos @@ -340,7 +344,7 @@ def make_constraint(m: types.Model, d: types.Data): ): wp.launch( _efc_limit_slide_hinge, - dim=(d.nworld, m.jnt_limited_slide_hinge_adr.size), + dim=(m.nworld, m.jnt_limited_slide_hinge_adr.size), inputs=[m, d, refsafe], ) diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 6e8ada75..3ae7e80e 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -58,8 +58,8 @@ def next_activation( # get the high/low range for each actuator state limited = m.actuator_actlimited[actid] - range_low = wp.where(limited, m.actuator_actrange[actid][0], -wp.inf) - range_high = wp.where(limited, m.actuator_actrange[actid][1], wp.inf) + range_low = wp.where(limited, m.actuator_actrange[worldid, actid][0], -wp.inf) + range_high = wp.where(limited, m.actuator_actrange[worldid, actid][1], wp.inf) # get the actual actuation - skip if -1 (means stateless actuator) act_adr = m.actuator_actadr[actid] @@ -74,7 +74,7 @@ def next_activation( # check dynType dyn_type = m.actuator_dyntype[actid] - dyn_prm = m.actuator_dynprm[actid][0] + dyn_prm = m.actuator_dynprm[worldid, actid][0] # advance the actuation if dyn_type == wp.static(DynType.FILTEREXACT.value): @@ -148,9 +148,9 @@ def integrate_joint_positions(m: Model, d: Data, qvel_in: array2df): # skip if no stateful actuators. if m.na: - wp.launch(next_activation, dim=(d.nworld, m.nu), inputs=[m, d, act_dot]) + wp.launch(next_activation, dim=(m.nworld, m.nu), inputs=[m, d, act_dot]) - wp.launch(advance_velocities, dim=(d.nworld, m.nv), inputs=[m, d, qacc]) + wp.launch(advance_velocities, dim=(m.nworld, m.nv), inputs=[m, d, qacc]) # advance positions with qvel if given, d.qvel otherwise (semi-implicit) if qvel is not None: @@ -158,7 +158,7 @@ def integrate_joint_positions(m: Model, d: Data, qvel_in: array2df): else: qvel_in = d.qvel - wp.launch(integrate_joint_positions, dim=(d.nworld, m.njnt), inputs=[m, d, qvel_in]) + wp.launch(integrate_joint_positions, dim=(m.nworld, m.njnt), inputs=[m, d, qvel_in]) d.time = d.time + m.opt.timestep @@ -175,14 +175,16 @@ def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data): worldid, tid = wp.tid() dof_Madr = m.dof_Madr[tid] - d.qM_integration[worldid, 0, dof_Madr] += m.opt.timestep * m.dof_damping[tid] + d.qM_integration[worldid, 0, dof_Madr] += ( + m.opt.timestep * m.dof_damping[worldid, tid] + ) d.qfrc_integration[worldid, tid] = ( d.qfrc_smooth[worldid, tid] + d.qfrc_constraint[worldid, tid] ) kernel_copy(d.qM_integration, d.qM) - wp.launch(add_damping_sum_qfrc_kernel_sparse, dim=(d.nworld, m.nv), inputs=[m, d]) + wp.launch(add_damping_sum_qfrc_kernel_sparse, dim=(m.nworld, m.nv), inputs=[m, d]) smooth.factor_solve_i( m, d, @@ -196,15 +198,15 @@ def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data): def eulerdamp_fused_dense(m: Model, d: Data): def tile_eulerdamp(adr: int, size: int, tilesize: int): @kernel - def eulerdamp( - m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int - ): + def eulerdamp(m: Model, d: Data, damping: array2df, leveladr: int): worldid, nodeid = wp.tid() dofid = m.qLD_tile[leveladr + nodeid] M_tile = wp.tile_load( d.qM[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid) ) - damping_tile = wp.tile_load(damping, shape=(tilesize,), offset=(dofid,)) + damping_tile = wp.tile_load( + damping[worldid], shape=(tilesize,), offset=(dofid,) + ) damping_scaled = damping_tile * m.opt.timestep qm_integration_tile = wp.tile_diag_add(M_tile, damping_scaled) @@ -222,7 +224,7 @@ def eulerdamp( wp.tile_store(d.qacc_integration[worldid], qacc_tile, offset=(dofid)) wp.launch_tiled( - eulerdamp, dim=(d.nworld, size), inputs=[m, d, m.dof_damping, adr], block_dim=32 + eulerdamp, dim=(m.nworld, size), inputs=[m, d, m.dof_damping, adr], block_dim=32 ) qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() @@ -287,10 +289,10 @@ def actuator_bias_gain_vel(m: Model, d: Data): actuator_dyntype = m.actuator_dyntype[actid] if actuator_biastype == wp.static(BiasType.AFFINE.value): - bias_vel = m.actuator_biasprm[actid, 2] + bias_vel = m.actuator_biasprm[worldid, actid, 2] if actuator_gaintype == wp.static(GainType.AFFINE.value): - gain_vel = m.actuator_gainprm[actid, 2] + gain_vel = m.actuator_gainprm[worldid, actid, 2] ctrl = d.ctrl[worldid, actid] @@ -299,9 +301,7 @@ def actuator_bias_gain_vel(m: Model, d: Data): d.act_vel_integration[worldid, actid] = bias_vel + gain_vel * ctrl - def qderiv_actuator_damping_fused( - m: Model, d: Data, damping: wp.array(dtype=wp.float32) - ): + def qderiv_actuator_damping_fused(m: Model, d: Data, damping: array2df): if actuation_enabled: block_dim = 64 else: @@ -316,7 +316,7 @@ def qderiv_actuator_damping_tiled( ): @kernel def qderiv_actuator_fused_kernel( - m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int + m: Model, d: Data, damping: array2df, leveladr: int ): worldid, nodeid = wp.tid() offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid] @@ -343,7 +343,9 @@ def qderiv_actuator_fused_kernel( ) if wp.static(passive_enabled): - dof_damping = wp.tile_load(damping, shape=tilesize_nv, offset=offset_nv) + dof_damping = wp.tile_load( + damping[worldid], shape=tilesize_nv, offset=offset_nv + ) negative = wp.neg(dof_damping) qderiv_tile = wp.tile_diag_add(qderiv_tile, negative) @@ -368,7 +370,7 @@ def qderiv_actuator_fused_kernel( wp.launch_tiled( qderiv_actuator_fused_kernel, - dim=(d.nworld, size), + dim=(m.nworld, size), inputs=[m, d, damping, adr], block_dim=block_dim, ) @@ -391,7 +393,7 @@ def qderiv_actuator_fused_kernel( if actuation_enabled: wp.launch( actuator_bias_gain_vel, - dim=(d.nworld, m.nu), + dim=(m.nworld, m.nu), inputs=[m, d], ) @@ -436,7 +438,7 @@ def _actuator_velocity(d: Data): qvel = d.qvel[worldid] wp.atomic_add(d.actuator_velocity[worldid], actid, moment[dofid] * qvel[dofid]) - wp.launch(_actuator_velocity, dim=(d.nworld, m.nu, m.nv), inputs=[d]) + wp.launch(_actuator_velocity, dim=(m.nworld, m.nu, m.nv), inputs=[d]) else: def actuator_velocity( @@ -466,7 +468,7 @@ def _actuator_velocity( wp.launch_tiled( _actuator_velocity, - dim=(d.nworld, size), + dim=(m.nworld, size), inputs=[ m, d, @@ -523,22 +525,22 @@ def _force( actuator_length = d.actuator_length[worldid, uid] actuator_velocity = d.actuator_velocity[worldid, uid] - gain = m.actuator_gainprm[uid, 0] - gain += m.actuator_gainprm[uid, 1] * actuator_length - gain += m.actuator_gainprm[uid, 2] * actuator_velocity + gain = m.actuator_gainprm[worldid, uid, 0] + gain += m.actuator_gainprm[worldid, uid, 1] * actuator_length + gain += m.actuator_gainprm[worldid, uid, 2] * actuator_velocity - bias = m.actuator_biasprm[uid, 0] - bias += m.actuator_biasprm[uid, 1] * actuator_length - bias += m.actuator_biasprm[uid, 2] * actuator_velocity + bias = m.actuator_biasprm[worldid, uid, 0] + bias += m.actuator_biasprm[worldid, uid, 1] * actuator_length + bias += m.actuator_biasprm[worldid, uid, 2] * actuator_velocity ctrl = d.ctrl[worldid, uid] disable_clampctrl = m.opt.disableflags & wp.static(DisableBit.CLAMPCTRL.value) if m.actuator_ctrllimited[uid] and not disable_clampctrl: - r = m.actuator_ctrlrange[uid] + r = m.actuator_ctrlrange[worldid, uid] ctrl = wp.clamp(ctrl, r[0], r[1]) f = gain * ctrl + bias if m.actuator_forcelimited[uid]: - r = m.actuator_forcerange[uid] + r = m.actuator_forcerange[worldid, uid] f = wp.clamp(f, r[0], r[1]) force[worldid, uid] = f @@ -549,8 +551,8 @@ def _qfrc_limited(m: Model, d: Data): if m.jnt_actfrclimited[jntid]: d.qfrc_actuator[worldid, dofid] = wp.clamp( d.qfrc_actuator[worldid, dofid], - m.jnt_actfrcrange[jntid][0], - m.jnt_actfrcrange[jntid][1], + m.jnt_actfrcrange[worldid, jntid][0], + m.jnt_actfrcrange[worldid, jntid][1], ) if m.opt.is_sparse: @@ -565,18 +567,18 @@ def _qfrc(m: Model, moment: array3df, force: array2df, qfrc: array2df): s += moment[worldid, uid, vid] * force[worldid, uid] jntid = m.dof_jntid[vid] if m.jnt_actfrclimited[jntid]: - r = m.jnt_actfrcrange[jntid] + r = m.jnt_actfrcrange[worldid, jntid] s = wp.clamp(s, r[0], r[1]) qfrc[worldid, vid] = s - wp.launch(_force, dim=[d.nworld, m.nu], inputs=[m, d], outputs=[d.actuator_force]) + wp.launch(_force, dim=[m.nworld, m.nu], inputs=[m, d], outputs=[d.actuator_force]) if m.opt.is_sparse: # TODO(team): sparse version wp.launch( _qfrc, - dim=(d.nworld, m.nv), + dim=(m.nworld, m.nv), inputs=[m, d.actuator_moment, d.actuator_force], outputs=[d.qfrc_actuator], ) @@ -611,7 +613,7 @@ def qfrc_actuator_kernel( wp.launch_tiled( qfrc_actuator_kernel, - dim=(d.nworld, size), + dim=(m.nworld, size), inputs=[ m, d, @@ -636,7 +638,7 @@ def qfrc_actuator_kernel( beg, end - beg, int(qderiv_tilesize_nu[i]), int(qderiv_tilesize_nv[i]) ) - wp.launch(_qfrc_limited, dim=(d.nworld, m.nv), inputs=[m, d]) + wp.launch(_qfrc_limited, dim=(m.nworld, m.nv), inputs=[m, d]) # TODO actuator-level gravity compensation, skip if added as passive force @@ -655,7 +657,7 @@ def _qfrc_smooth(d: Data): + d.qfrc_applied[worldid, dofid] ) - wp.launch(_qfrc_smooth, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_qfrc_smooth, dim=(m.nworld, m.nv), inputs=[d]) xfrc_accumulate(m, d, d.qfrc_smooth) smooth.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 1012c780..9f056fd3 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -24,7 +24,14 @@ from . import types -def put_model(mjm: mujoco.MjModel) -> types.Model: +# TODO(erikfrey): would it be better to tile on the gpu? +def tile(x, nworld): + return np.tile(x, (nworld,) + (1,) * len(x.shape)) + + +def put_model( + mjm: mujoco.MjModel, nworld: int = 1, expand_fields: set[str] = set() +) -> types.Model: # check supported features for field, field_types, field_str in ( (mjm.actuator_trntype, types.TrnType, "Actuator transmission type"), @@ -104,9 +111,22 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m.opt.is_sparse = support.is_sparse(mjm) m.opt.ls_parallel = False m.stat.meaninertia = mjm.stat.meaninertia + m.nworld = nworld - m.qpos0 = wp.array(mjm.qpos0, dtype=wp.float32, ndim=1) - m.qpos_spring = wp.array(mjm.qpos_spring, dtype=wp.float32, ndim=1) + def create_nworld_array(mjm_array, dtype, expand): + if expand: + array = wp.array(tile(mjm_array, nworld), dtype=dtype) + else: + array = wp.array(mjm_array, dtype=dtype) + array.ndim += 1 + array.shape = (nworld,) + array.shape + array.strides = (0,) + array.strides + return array + + m.qpos0 = create_nworld_array(mjm.qpos0, wp.float32, "qpos0" in expand_fields) + m.qpos_spring = create_nworld_array( + mjm.qpos_spring, wp.float32, "qpos_spring" in expand_fields + ) # dof lower triangle row and column indices dof_tri_row, dof_tri_col = np.tril_indices(mjm.nv) @@ -155,8 +175,8 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: tree_off = [0] + [len(bodies[i]) for i in range(len(bodies))] body_treeadr = np.cumsum(tree_off)[:-1] - m.body_tree = wp.array(body_tree, dtype=wp.int32, ndim=1) - m.body_treeadr = wp.array(body_treeadr, dtype=wp.int32, ndim=1, device="cpu") + m.body_tree = wp.array(body_tree, dtype=wp.int32) + m.body_treeadr = wp.array(body_treeadr, dtype=wp.int32, device="cpu") qLD_update_tree = np.empty(shape=(0, 3), dtype=int) qLD_update_treeadr = np.empty(shape=(0,), dtype=int) @@ -268,143 +288,229 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: [int(a[1]) for a in sorted_keys] ) # for this level - m.qM_fullm_i = wp.array(qM_fullm_i, dtype=wp.int32, ndim=1) - m.qM_fullm_j = wp.array(qM_fullm_j, dtype=wp.int32, ndim=1) - m.qM_mulm_i = wp.array(qM_mulm_i, dtype=wp.int32, ndim=1) - m.qM_mulm_j = wp.array(qM_mulm_j, dtype=wp.int32, ndim=1) - m.qM_madr_ij = wp.array(qM_madr_ij, dtype=wp.int32, ndim=1) + m.qM_fullm_i = wp.array(qM_fullm_i, dtype=wp.int32) + m.qM_fullm_j = wp.array(qM_fullm_j, dtype=wp.int32) + m.qM_mulm_i = wp.array(qM_mulm_i, dtype=wp.int32) + m.qM_mulm_j = wp.array(qM_mulm_j, dtype=wp.int32) + m.qM_madr_ij = wp.array(qM_madr_ij, dtype=wp.int32) m.qLD_update_tree = wp.array(qLD_update_tree, dtype=wp.vec3i, ndim=1) m.qLD_update_treeadr = wp.array( qLD_update_treeadr, dtype=wp.int32, ndim=1, device="cpu" ) - m.qLD_tile = wp.array(qLD_tile, dtype=wp.int32, ndim=1) - m.qLD_tileadr = wp.array(qLD_tileadr, dtype=wp.int32, ndim=1, device="cpu") - m.qLD_tilesize = wp.array(qLD_tilesize, dtype=wp.int32, ndim=1, device="cpu") - m.actuator_moment_offset_nv = wp.array( - actuator_moment_offset_nv, dtype=wp.int32, ndim=1 - ) - m.actuator_moment_offset_nu = wp.array( - actuator_moment_offset_nu, dtype=wp.int32, ndim=1 - ) + m.qLD_tile = wp.array(qLD_tile, dtype=wp.int32) + m.qLD_tileadr = wp.array(qLD_tileadr, dtype=wp.int32, device="cpu") + m.qLD_tilesize = wp.array(qLD_tilesize, dtype=wp.int32, device="cpu") + m.actuator_moment_offset_nv = wp.array(actuator_moment_offset_nv, dtype=wp.int32) + m.actuator_moment_offset_nu = wp.array(actuator_moment_offset_nu, dtype=wp.int32) m.actuator_moment_tileadr = wp.array( - actuator_moment_tileadr, dtype=wp.int32, ndim=1, device="cpu" + actuator_moment_tileadr, dtype=wp.int32, device="cpu" ) m.actuator_moment_tilesize_nv = wp.array( - actuator_moment_tilesize_nv, dtype=wp.int32, ndim=1, device="cpu" + actuator_moment_tilesize_nv, dtype=wp.int32, device="cpu" ) m.actuator_moment_tilesize_nu = wp.array( - actuator_moment_tilesize_nu, dtype=wp.int32, ndim=1, device="cpu" + actuator_moment_tilesize_nu, dtype=wp.int32, device="cpu" ) m.alpha_candidate = wp.array(np.linspace(0.0, 1.0, m.nlsp), dtype=wp.float32) - m.body_dofadr = wp.array(mjm.body_dofadr, dtype=wp.int32, ndim=1) - m.body_dofnum = wp.array(mjm.body_dofnum, dtype=wp.int32, ndim=1) - m.body_jntadr = wp.array(mjm.body_jntadr, dtype=wp.int32, ndim=1) - m.body_jntnum = wp.array(mjm.body_jntnum, dtype=wp.int32, ndim=1) - m.body_parentid = wp.array(mjm.body_parentid, dtype=wp.int32, ndim=1) - m.body_mocapid = wp.array(mjm.body_mocapid, dtype=wp.int32, ndim=1) - m.body_weldid = wp.array(mjm.body_weldid, dtype=wp.int32, ndim=1) - m.body_pos = wp.array(mjm.body_pos, dtype=wp.vec3, ndim=1) - m.body_quat = wp.array(mjm.body_quat, dtype=wp.quat, ndim=1) - m.body_ipos = wp.array(mjm.body_ipos, dtype=wp.vec3, ndim=1) - m.body_iquat = wp.array(mjm.body_iquat, dtype=wp.quat, ndim=1) - m.body_rootid = wp.array(mjm.body_rootid, dtype=wp.int32, ndim=1) - m.body_inertia = wp.array(mjm.body_inertia, dtype=wp.vec3, ndim=1) - m.body_mass = wp.array(mjm.body_mass, dtype=wp.float32, ndim=1) + m.body_dofadr = wp.array(mjm.body_dofadr, dtype=wp.int32) + m.body_dofnum = wp.array(mjm.body_dofnum, dtype=wp.int32) + m.body_jntadr = wp.array(mjm.body_jntadr, dtype=wp.int32) + m.body_jntnum = wp.array(mjm.body_jntnum, dtype=wp.int32) + m.body_parentid = wp.array(mjm.body_parentid, dtype=wp.int32) + m.body_mocapid = wp.array(mjm.body_mocapid, dtype=wp.int32) + m.body_weldid = wp.array(mjm.body_weldid, dtype=wp.int32) + m.body_pos = create_nworld_array(mjm.body_pos, wp.vec3, "body_pos" in expand_fields) + m.body_quat = create_nworld_array( + mjm.body_quat, wp.quat, "body_quat" in expand_fields + ) + m.body_ipos = create_nworld_array( + mjm.body_ipos, wp.vec3, "body_ipos" in expand_fields + ) + m.body_iquat = create_nworld_array( + mjm.body_iquat, wp.quat, "body_iquat" in expand_fields + ) + m.body_rootid = wp.array(mjm.body_rootid, dtype=wp.int32) + m.body_inertia = create_nworld_array( + mjm.body_inertia, wp.vec3, "body_inertia" in expand_fields + ) + m.body_mass = create_nworld_array( + mjm.body_mass, wp.float32, "body_mass" in expand_fields + ) subtree_mass = np.copy(mjm.body_mass) # TODO(team): should this be [mjm.nbody - 1, 0) ? for i in range(mjm.nbody - 1, -1, -1): subtree_mass[mjm.body_parentid[i]] += subtree_mass[i] - m.subtree_mass = wp.array(subtree_mass, dtype=wp.float32, ndim=1) - m.body_invweight0 = wp.array(mjm.body_invweight0, dtype=wp.float32, ndim=2) - m.body_geomnum = wp.array(mjm.body_geomnum, dtype=wp.int32, ndim=1) - m.body_geomadr = wp.array(mjm.body_geomadr, dtype=wp.int32, ndim=1) - m.body_contype = wp.array(mjm.body_contype, dtype=wp.int32, ndim=1) - m.body_conaffinity = wp.array(mjm.body_conaffinity, dtype=wp.int32, ndim=1) - m.jnt_bodyid = wp.array(mjm.jnt_bodyid, dtype=wp.int32, ndim=1) - m.jnt_limited = wp.array(mjm.jnt_limited, dtype=wp.int32, ndim=1) - m.jnt_limited_slide_hinge_adr = wp.array( - jnt_limited_slide_hinge_adr, dtype=wp.int32, ndim=1 - ) - m.jnt_type = wp.array(mjm.jnt_type, dtype=wp.int32, ndim=1) - m.jnt_solref = wp.array(mjm.jnt_solref, dtype=wp.vec2f, ndim=1) - m.jnt_solimp = wp.array(mjm.jnt_solimp, dtype=types.vec5, ndim=1) - m.jnt_qposadr = wp.array(mjm.jnt_qposadr, dtype=wp.int32, ndim=1) - m.jnt_dofadr = wp.array(mjm.jnt_dofadr, dtype=wp.int32, ndim=1) - m.jnt_axis = wp.array(mjm.jnt_axis, dtype=wp.vec3, ndim=1) - m.jnt_pos = wp.array(mjm.jnt_pos, dtype=wp.vec3, ndim=1) - m.jnt_range = wp.array(mjm.jnt_range, dtype=wp.float32, ndim=2) - m.jnt_margin = wp.array(mjm.jnt_margin, dtype=wp.float32, ndim=1) - m.jnt_stiffness = wp.array(mjm.jnt_stiffness, dtype=wp.float32, ndim=1) - m.jnt_actfrclimited = wp.array(mjm.jnt_actfrclimited, dtype=wp.bool, ndim=1) - m.jnt_actfrcrange = wp.array(mjm.jnt_actfrcrange, dtype=wp.vec2, ndim=1) - m.geom_type = wp.array(mjm.geom_type, dtype=wp.int32, ndim=1) - m.geom_bodyid = wp.array(mjm.geom_bodyid, dtype=wp.int32, ndim=1) - m.geom_conaffinity = wp.array(mjm.geom_conaffinity, dtype=wp.int32, ndim=1) - m.geom_contype = wp.array(mjm.geom_contype, dtype=wp.int32, ndim=1) - m.geom_condim = wp.array(mjm.geom_condim, dtype=wp.int32, ndim=1) - m.geom_pos = wp.array(mjm.geom_pos, dtype=wp.vec3, ndim=1) - m.geom_quat = wp.array(mjm.geom_quat, dtype=wp.quat, ndim=1) - m.geom_size = wp.array(mjm.geom_size, dtype=wp.vec3, ndim=1) - m.geom_priority = wp.array(mjm.geom_priority, dtype=wp.int32, ndim=1) - m.geom_solmix = wp.array(mjm.geom_solmix, dtype=wp.float32, ndim=1) - m.geom_solref = wp.array(mjm.geom_solref, dtype=wp.vec2, ndim=1) - m.geom_solimp = wp.array(mjm.geom_solimp, dtype=types.vec5, ndim=1) - m.geom_friction = wp.array(mjm.geom_friction, dtype=wp.vec3, ndim=1) - m.geom_margin = wp.array(mjm.geom_margin, dtype=wp.float32, ndim=1) - m.geom_gap = wp.array(mjm.geom_gap, dtype=wp.float32, ndim=1) - m.geom_aabb = wp.array(mjm.geom_aabb, dtype=wp.vec3, ndim=3) - m.geom_rbound = wp.array(mjm.geom_rbound, dtype=wp.float32, ndim=1) - m.geom_dataid = wp.array(mjm.geom_dataid, dtype=wp.int32, ndim=1) - m.mesh_vertadr = wp.array(mjm.mesh_vertadr, dtype=wp.int32, ndim=1) - m.mesh_vertnum = wp.array(mjm.mesh_vertnum, dtype=wp.int32, ndim=1) - m.mesh_vert = wp.array(mjm.mesh_vert, dtype=wp.vec3, ndim=1) - m.site_pos = wp.array(mjm.site_pos, dtype=wp.vec3, ndim=1) - m.site_quat = wp.array(mjm.site_quat, dtype=wp.quat, ndim=1) - m.site_bodyid = wp.array(mjm.site_bodyid, dtype=wp.int32, ndim=1) - m.cam_mode = wp.array(mjm.cam_mode, dtype=wp.int32, ndim=1) - m.cam_bodyid = wp.array(mjm.cam_bodyid, dtype=wp.int32, ndim=1) - m.cam_targetbodyid = wp.array(mjm.cam_targetbodyid, dtype=wp.int32, ndim=1) - m.cam_pos = wp.array(mjm.cam_pos, dtype=wp.vec3, ndim=1) - m.cam_quat = wp.array(mjm.cam_quat, dtype=wp.quat, ndim=1) - m.cam_poscom0 = wp.array(mjm.cam_poscom0, dtype=wp.vec3, ndim=1) - m.cam_pos0 = wp.array(mjm.cam_pos0, dtype=wp.vec3, ndim=1) - m.cam_mat0 = wp.array(mjm.cam_mat0.reshape(-1, 3, 3), dtype=wp.mat33, ndim=1) - m.light_mode = wp.array(mjm.light_mode, dtype=wp.int32, ndim=1) - m.light_bodyid = wp.array(mjm.light_bodyid, dtype=wp.int32, ndim=1) - m.light_targetbodyid = wp.array(mjm.light_targetbodyid, dtype=wp.int32, ndim=1) - m.light_pos = wp.array(mjm.light_pos, dtype=wp.vec3, ndim=1) - m.light_dir = wp.array(mjm.light_dir, dtype=wp.vec3, ndim=1) - m.light_poscom0 = wp.array(mjm.light_poscom0, dtype=wp.vec3, ndim=1) - m.light_pos0 = wp.array(mjm.light_pos0, dtype=wp.vec3, ndim=1) - m.light_dir0 = wp.array(mjm.light_dir0, dtype=wp.vec3, ndim=1) - m.dof_bodyid = wp.array(mjm.dof_bodyid, dtype=wp.int32, ndim=1) - m.dof_jntid = wp.array(mjm.dof_jntid, dtype=wp.int32, ndim=1) - m.dof_parentid = wp.array(mjm.dof_parentid, dtype=wp.int32, ndim=1) - m.dof_Madr = wp.array(mjm.dof_Madr, dtype=wp.int32, ndim=1) - m.dof_armature = wp.array(mjm.dof_armature, dtype=wp.float32, ndim=1) - m.dof_damping = wp.array(mjm.dof_damping, dtype=wp.float32, ndim=1) + m.subtree_mass = create_nworld_array( + subtree_mass, wp.float32, "subtree_mass" in expand_fields + ) # how should we hande this dependency on body_mass being potentially expanded? + m.body_invweight0 = create_nworld_array( + mjm.body_invweight0, wp.float32, "body_invweight0" in expand_fields + ) + m.body_geomnum = wp.array(mjm.body_geomnum, dtype=wp.int32) + m.body_geomadr = wp.array(mjm.body_geomadr, dtype=wp.int32) + m.jnt_bodyid = wp.array(mjm.jnt_bodyid, dtype=wp.int32) + m.jnt_limited = wp.array( + mjm.jnt_limited, dtype=wp.int32 + ) # should that be varying per-world? + m.jnt_limited_slide_hinge_adr = wp.array(jnt_limited_slide_hinge_adr, dtype=wp.int32) + m.jnt_type = wp.array(mjm.jnt_type, dtype=wp.int32) + m.jnt_solref = create_nworld_array( + mjm.jnt_solref, wp.vec2f, "jnt_solref" in expand_fields + ) + m.jnt_solimp = create_nworld_array( + mjm.jnt_solimp, types.vec5, "jnt_solimp" in expand_fields + ) + m.jnt_qposadr = wp.array(mjm.jnt_qposadr, dtype=wp.int32) + m.jnt_dofadr = wp.array(mjm.jnt_dofadr, dtype=wp.int32) + m.jnt_axis = wp.array( + mjm.jnt_axis, dtype=wp.vec3 + ) # should that be varying per-world? + m.jnt_pos = wp.array(mjm.jnt_pos, dtype=wp.vec3) # should that be varying per-world? + m.jnt_range = create_nworld_array( + mjm.jnt_range, wp.float32, "jnt_range" in expand_fields + ) + m.jnt_margin = create_nworld_array( + mjm.jnt_margin, wp.float32, "jnt_margin" in expand_fields + ) + m.jnt_stiffness = create_nworld_array( + mjm.jnt_stiffness, wp.float32, "jnt_stiffness" in expand_fields + ) + m.jnt_actfrclimited = wp.array(mjm.jnt_actfrclimited, dtype=wp.bool) + m.jnt_actfrcrange = create_nworld_array( + mjm.jnt_actfrcrange, wp.vec2, "jnt_actfrcrange" in expand_fields + ) + m.geom_type = wp.array(mjm.geom_type, dtype=wp.int32) + m.geom_bodyid = wp.array(mjm.geom_bodyid, dtype=wp.int32) + m.geom_conaffinity = create_nworld_array( + mjm.geom_conaffinity, wp.int32, "geom_conaffinity" in expand_fields + ) + m.geom_contype = create_nworld_array( + mjm.geom_contype, wp.int32, "geom_contype" in expand_fields + ) + m.geom_condim = wp.array( + mjm.geom_condim, dtype=wp.int32 + ) # should that be varying per-world? + m.geom_pos = create_nworld_array(mjm.geom_pos, wp.vec3, "geom_pos" in expand_fields) + m.geom_quat = create_nworld_array( + mjm.geom_quat, wp.quat, "geom_quat" in expand_fields + ) + m.geom_size = wp.array( + mjm.geom_size, dtype=wp.vec3 + ) # should that be varying per-world? + m.geom_priority = create_nworld_array( + mjm.geom_priority, wp.int32, "geom_priority" in expand_fields + ) + m.geom_solmix = create_nworld_array( + mjm.geom_solmix, wp.float32, "geom_solmix" in expand_fields + ) + m.geom_solref = create_nworld_array( + mjm.geom_solref, wp.vec2, "geom_solref" in expand_fields + ) + m.geom_solimp = create_nworld_array( + mjm.geom_solimp, types.vec5, "geom_solimp" in expand_fields + ) + m.geom_friction = create_nworld_array( + mjm.geom_friction, wp.vec3, "geom_friction" in expand_fields + ) + m.geom_margin = create_nworld_array( + mjm.geom_margin, wp.float32, "geom_margin" in expand_fields + ) + m.geom_gap = create_nworld_array( + mjm.geom_gap, wp.float32, "geom_gap" in expand_fields + ) + m.geom_aabb = wp.array( + mjm.geom_aabb, dtype=wp.vec3 + ) # should that be varying per-world? + m.geom_rbound = wp.array( + mjm.geom_rbound, dtype=wp.float32 + ) # should that be varying per-world? + m.geom_dataid = wp.array(mjm.geom_dataid, dtype=wp.int32) + m.mesh_vertadr = wp.array(mjm.mesh_vertadr, dtype=wp.int32) + m.mesh_vertnum = wp.array(mjm.mesh_vertnum, dtype=wp.int32) + m.mesh_vert = wp.array(mjm.mesh_vert, dtype=wp.vec3) + m.site_pos = create_nworld_array(mjm.site_pos, wp.vec3, "site_pos" in expand_fields) + m.site_quat = create_nworld_array( + mjm.site_quat, wp.quat, "site_quat" in expand_fields + ) + m.site_bodyid = wp.array(mjm.site_bodyid, dtype=wp.int32) + m.cam_mode = wp.array(mjm.cam_mode, dtype=wp.int32) + m.cam_bodyid = wp.array(mjm.cam_bodyid, dtype=wp.int32) + m.cam_targetbodyid = wp.array(mjm.cam_targetbodyid, dtype=wp.int32) + m.cam_pos = create_nworld_array(mjm.cam_pos, wp.vec3, "cam_pos" in expand_fields) + m.cam_quat = create_nworld_array(mjm.cam_quat, wp.quat, "cam_quat" in expand_fields) + m.cam_poscom0 = create_nworld_array( + mjm.cam_poscom0, wp.vec3, "cam_poscom0" in expand_fields + ) + m.cam_pos0 = create_nworld_array(mjm.cam_pos0, wp.vec3, "cam_pos0" in expand_fields) + m.cam_mat0 = create_nworld_array( + mjm.cam_mat0.reshape(-1, 3, 3), wp.mat33, "cam_mat0" in expand_fields + ) + m.light_mode = wp.array(mjm.light_mode, dtype=wp.int32) + m.light_bodyid = wp.array(mjm.light_bodyid, dtype=wp.int32) + m.light_targetbodyid = wp.array(mjm.light_targetbodyid, dtype=wp.int32) + m.light_pos = create_nworld_array( + mjm.light_pos, wp.vec3, "light_pos" in expand_fields + ) + m.light_dir = create_nworld_array( + mjm.light_dir, wp.vec3, "light_dir" in expand_fields + ) + m.light_poscom0 = create_nworld_array( + mjm.light_poscom0, wp.vec3, "light_poscom0" in expand_fields + ) + m.light_pos0 = create_nworld_array( + mjm.light_pos0, wp.vec3, "light_pos0" in expand_fields + ) + m.light_dir0 = create_nworld_array( + mjm.light_dir0, wp.vec3, "light_dir0" in expand_fields + ) + m.dof_bodyid = wp.array(mjm.dof_bodyid, dtype=wp.int32) + m.dof_jntid = wp.array(mjm.dof_jntid, dtype=wp.int32) + m.dof_parentid = wp.array(mjm.dof_parentid, dtype=wp.int32) + m.dof_Madr = wp.array(mjm.dof_Madr, dtype=wp.int32) + m.dof_armature = create_nworld_array( + mjm.dof_armature, wp.float32, "dof_armature" in expand_fields + ) + m.dof_damping = create_nworld_array( + mjm.dof_damping, wp.float32, "dof_damping" in expand_fields + ) m.dof_tri_row = wp.from_numpy(dof_tri_row, dtype=wp.int32) m.dof_tri_col = wp.from_numpy(dof_tri_col, dtype=wp.int32) - m.dof_invweight0 = wp.array(mjm.dof_invweight0, dtype=wp.float32, ndim=1) - m.actuator_trntype = wp.array(mjm.actuator_trntype, dtype=wp.int32, ndim=1) - m.actuator_trnid = wp.array(mjm.actuator_trnid, dtype=wp.int32, ndim=2) - m.actuator_ctrllimited = wp.array(mjm.actuator_ctrllimited, dtype=wp.bool, ndim=1) - m.actuator_ctrlrange = wp.array(mjm.actuator_ctrlrange, dtype=wp.vec2, ndim=1) - m.actuator_forcelimited = wp.array(mjm.actuator_forcelimited, dtype=wp.bool, ndim=1) - m.actuator_forcerange = wp.array(mjm.actuator_forcerange, dtype=wp.vec2, ndim=1) - m.actuator_gaintype = wp.array(mjm.actuator_gaintype, dtype=wp.int32, ndim=1) - m.actuator_gainprm = wp.array(mjm.actuator_gainprm, dtype=wp.float32, ndim=2) - m.actuator_biastype = wp.array(mjm.actuator_biastype, dtype=wp.int32, ndim=1) - m.actuator_biasprm = wp.array(mjm.actuator_biasprm, dtype=wp.float32, ndim=2) - m.actuator_gear = wp.array(mjm.actuator_gear, dtype=wp.spatial_vector, ndim=1) - m.actuator_actlimited = wp.array(mjm.actuator_actlimited, dtype=wp.bool, ndim=1) - m.actuator_actrange = wp.array(mjm.actuator_actrange, dtype=wp.vec2, ndim=1) - m.actuator_actadr = wp.array(mjm.actuator_actadr, dtype=wp.int32, ndim=1) - m.actuator_dyntype = wp.array(mjm.actuator_dyntype, dtype=wp.int32, ndim=1) - m.actuator_dynprm = wp.array(mjm.actuator_dynprm, dtype=types.vec10f, ndim=1) - m.exclude_signature = wp.array(mjm.exclude_signature, dtype=wp.int32, ndim=1) + m.dof_invweight0 = create_nworld_array( + mjm.dof_invweight0, wp.float32, "dof_invweight0" in expand_fields + ) + m.actuator_trntype = wp.array(mjm.actuator_trntype, dtype=wp.int32) + m.actuator_trnid = wp.array(mjm.actuator_trnid, dtype=wp.int32) + m.actuator_ctrllimited = wp.array(mjm.actuator_ctrllimited, dtype=wp.bool) + m.actuator_ctrlrange = create_nworld_array( + mjm.actuator_ctrlrange, wp.vec2, "actuator_ctrlrange" in expand_fields + ) + m.actuator_forcelimited = wp.array(mjm.actuator_forcelimited, dtype=wp.bool) + m.actuator_forcerange = create_nworld_array( + mjm.actuator_forcerange, wp.vec2, "actuator_forcerange" in expand_fields + ) + m.actuator_gaintype = wp.array(mjm.actuator_gaintype, dtype=wp.int32) + m.actuator_gainprm = create_nworld_array( + mjm.actuator_gainprm, wp.float32, "actuator_gainprm" in expand_fields + ) + m.actuator_biastype = wp.array(mjm.actuator_biastype, dtype=wp.int32) + m.actuator_biasprm = create_nworld_array( + mjm.actuator_biasprm, wp.float32, "actuator_biasprm" in expand_fields + ) + m.actuator_gear = create_nworld_array( + mjm.actuator_gear, wp.spatial_vector, "actuator_gear" in expand_fields + ) + m.actuator_actlimited = wp.array(mjm.actuator_actlimited, dtype=wp.bool) + m.actuator_actrange = create_nworld_array( + mjm.actuator_actrange, wp.vec2, "actuator_actrange" in expand_fields + ) + m.actuator_actadr = wp.array(mjm.actuator_actadr, dtype=wp.int32) + m.actuator_dyntype = wp.array(mjm.actuator_dyntype, dtype=wp.int32) + m.actuator_dynprm = create_nworld_array( + mjm.actuator_dynprm, types.vec10, "actuator_dynprm" in expand_fields + ) + m.exclude_signature = wp.array(mjm.exclude_signature, dtype=wp.int32) # short-circuiting here allows us to skip a lot of code in implicit integration m.actuator_affine_bias_gain = bool( @@ -413,11 +519,13 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: ) # tendon - m.tendon_adr = wp.array(mjm.tendon_adr, dtype=wp.int32, ndim=1) - m.tendon_num = wp.array(mjm.tendon_num, dtype=wp.int32, ndim=1) - m.wrap_objid = wp.array(mjm.wrap_objid, dtype=wp.int32, ndim=1) - m.wrap_prm = wp.array(mjm.wrap_prm, dtype=wp.float32, ndim=1) - m.wrap_type = wp.array(mjm.wrap_type, dtype=wp.int32, ndim=1) + m.tendon_adr = wp.array(mjm.tendon_adr, dtype=wp.int32) + m.tendon_num = wp.array(mjm.tendon_num, dtype=wp.int32) + m.wrap_objid = wp.array(mjm.wrap_objid, dtype=wp.int32) + m.wrap_prm = create_nworld_array( + mjm.wrap_prm, wp.float32, "wrap_prm" in expand_fields + ) + m.wrap_type = wp.array(mjm.wrap_type, dtype=wp.int32) tendon_jnt_adr = [] wrap_jnt_adr = [] @@ -429,33 +537,27 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: tendon_jnt_adr.append(i) wrap_jnt_adr.append(adr + j) - m.tendon_jnt_adr = wp.array(tendon_jnt_adr, dtype=wp.int32, ndim=1) - m.wrap_jnt_adr = wp.array(wrap_jnt_adr, dtype=wp.int32, ndim=1) + m.tendon_jnt_adr = wp.array(tendon_jnt_adr, dtype=wp.int32) + m.wrap_jnt_adr = wp.array(wrap_jnt_adr, dtype=wp.int32) # sensors - m.sensor_type = wp.array(mjm.sensor_type, dtype=wp.int32, ndim=1) - m.sensor_datatype = wp.array(mjm.sensor_datatype, dtype=wp.int32, ndim=1) - m.sensor_objtype = wp.array(mjm.sensor_objtype, dtype=wp.int32, ndim=1) - m.sensor_objid = wp.array(mjm.sensor_objid, dtype=wp.int32, ndim=1) - m.sensor_reftype = wp.array(mjm.sensor_reftype, dtype=wp.int32, ndim=1) - m.sensor_refid = wp.array(mjm.sensor_refid, dtype=wp.int32, ndim=1) - m.sensor_dim = wp.array(mjm.sensor_dim, dtype=wp.int32, ndim=1) - m.sensor_adr = wp.array(mjm.sensor_adr, dtype=wp.int32, ndim=1) - m.sensor_cutoff = wp.array(mjm.sensor_cutoff, dtype=wp.float32, ndim=1) + m.sensor_type = wp.array(mjm.sensor_type, dtype=wp.int32) + m.sensor_datatype = wp.array(mjm.sensor_datatype, dtype=wp.int32) + m.sensor_objtype = wp.array(mjm.sensor_objtype, dtype=wp.int32) + m.sensor_objid = wp.array(mjm.sensor_objid, dtype=wp.int32) + m.sensor_reftype = wp.array(mjm.sensor_reftype, dtype=wp.int32) + m.sensor_refid = wp.array(mjm.sensor_refid, dtype=wp.int32) + m.sensor_dim = wp.array(mjm.sensor_dim, dtype=wp.int32) + m.sensor_adr = wp.array(mjm.sensor_adr, dtype=wp.int32) + m.sensor_cutoff = wp.array(mjm.sensor_cutoff, dtype=wp.float32) m.sensor_pos_adr = wp.array( - np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_POS)[0], - dtype=wp.int32, - ndim=1, + np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_POS)[0], dtype=wp.int32 ) m.sensor_vel_adr = wp.array( - np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_VEL)[0], - dtype=wp.int32, - ndim=1, + np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_VEL)[0], dtype=wp.int32 ) m.sensor_acc_adr = wp.array( - np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_ACC)[0], - dtype=wp.int32, - ndim=1, + np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_ACC)[0], dtype=wp.int32 ) return m @@ -523,7 +625,6 @@ def make_data( mjm: mujoco.MjModel, nworld: int = 1, nconmax: int = -1, njmax: int = -1 ) -> types.Data: d = types.Data() - d.nworld = nworld # TODO(team): move to Model? if nconmax == -1: @@ -534,7 +635,6 @@ def make_data( # TODO(team): heuristic for njmax njmax = 512 d.njmax = njmax - d.ncon = wp.zeros(1, dtype=wp.int32) d.nefc = wp.zeros(1, dtype=wp.int32, ndim=1) d.nl = 0 @@ -597,7 +697,7 @@ def make_data( d.contact.geom = wp.zeros((nconmax,), dtype=wp.vec2i) d.contact.efc_address = wp.zeros((nconmax, np.max(mjm.geom_condim)), dtype=wp.int32) d.contact.worldid = wp.zeros((nconmax,), dtype=wp.int32) - d.efc = _constraint(mjm, d.nworld, d.njmax) + d.efc = _constraint(mjm, nworld, d.njmax) d.qfrc_passive = wp.zeros((nworld, mjm.nv), dtype=wp.float32) d.qfrc_spring = wp.zeros((nworld, mjm.nv), dtype=wp.float32) d.qfrc_damper = wp.zeros((nworld, mjm.nv), dtype=wp.float32) @@ -606,8 +706,8 @@ def make_data( d.qfrc_constraint = wp.zeros((nworld, mjm.nv), dtype=wp.float32) d.qacc_smooth = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.rne_cacc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector) - d.rne_cfrc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector) + d.rne_cacc = wp.zeros(shape=(nworld, mjm.nbody), dtype=wp.spatial_vector) + d.rne_cfrc = wp.zeros(shape=(nworld, mjm.nbody), dtype=wp.spatial_vector) d.xfrc_applied = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector) @@ -674,7 +774,6 @@ def put_data( if nworld * mjd.nefc > njmax: raise ValueError(f"njmax overflow (njmax must be >= {nworld * mjd.nefc})") - d.nworld = nworld # TODO(team): move nconmax and njmax to Model? d.nconmax = nconmax d.njmax = njmax @@ -684,10 +783,6 @@ def put_data( d.nefc = wp.array([mjd.nefc * nworld], dtype=wp.int32, ndim=1) d.time = mjd.time - # TODO(erikfrey): would it be better to tile on the gpu? - def tile(x): - return np.tile(x, (nworld,) + (1,) * len(x.shape)) - if support.is_sparse(mjm): qM = np.expand_dims(mjd.qM, axis=0) qLD = np.expand_dims(mjd.qLD, axis=0) @@ -711,52 +806,64 @@ def tile(x): mjd.moment_colind, ) - d.qpos = wp.array(tile(mjd.qpos), dtype=wp.float32, ndim=2) - d.qvel = wp.array(tile(mjd.qvel), dtype=wp.float32, ndim=2) - d.qacc_warmstart = wp.array(tile(mjd.qacc_warmstart), dtype=wp.float32, ndim=2) - d.qfrc_applied = wp.array(tile(mjd.qfrc_applied), dtype=wp.float32, ndim=2) - d.mocap_pos = wp.array(tile(mjd.mocap_pos), dtype=wp.vec3, ndim=2) - d.mocap_quat = wp.array(tile(mjd.mocap_quat), dtype=wp.quat, ndim=2) - d.qacc = wp.array(tile(mjd.qacc), dtype=wp.float32, ndim=2) - d.xanchor = wp.array(tile(mjd.xanchor), dtype=wp.vec3, ndim=2) - d.xaxis = wp.array(tile(mjd.xaxis), dtype=wp.vec3, ndim=2) - d.xmat = wp.array(tile(mjd.xmat), dtype=wp.mat33, ndim=2) - d.xpos = wp.array(tile(mjd.xpos), dtype=wp.vec3, ndim=2) - d.xquat = wp.array(tile(mjd.xquat), dtype=wp.quat, ndim=2) - d.xipos = wp.array(tile(mjd.xipos), dtype=wp.vec3, ndim=2) - d.ximat = wp.array(tile(mjd.ximat), dtype=wp.mat33, ndim=2) - d.subtree_com = wp.array(tile(mjd.subtree_com), dtype=wp.vec3, ndim=2) - d.geom_xpos = wp.array(tile(mjd.geom_xpos), dtype=wp.vec3, ndim=2) - d.geom_xmat = wp.array(tile(mjd.geom_xmat), dtype=wp.mat33, ndim=2) - d.site_xpos = wp.array(tile(mjd.site_xpos), dtype=wp.vec3, ndim=2) - d.site_xmat = wp.array(tile(mjd.site_xmat), dtype=wp.mat33, ndim=2) - d.cam_xpos = wp.array(tile(mjd.cam_xpos), dtype=wp.vec3, ndim=2) - d.cam_xmat = wp.array(tile(mjd.cam_xmat.reshape(-1, 3, 3)), dtype=wp.mat33, ndim=2) - d.light_xpos = wp.array(tile(mjd.light_xpos), dtype=wp.vec3, ndim=2) - d.light_xdir = wp.array(tile(mjd.light_xdir), dtype=wp.vec3, ndim=2) - d.cinert = wp.array(tile(mjd.cinert), dtype=types.vec10, ndim=2) - d.cdof = wp.array(tile(mjd.cdof), dtype=wp.spatial_vector, ndim=2) - d.crb = wp.array(tile(mjd.crb), dtype=types.vec10, ndim=2) - d.qM = wp.array(tile(qM), dtype=wp.float32, ndim=3) - d.qLD = wp.array(tile(qLD), dtype=wp.float32, ndim=3) - d.qLDiagInv = wp.array(tile(mjd.qLDiagInv), dtype=wp.float32, ndim=2) - d.ctrl = wp.array(tile(mjd.ctrl), dtype=wp.float32, ndim=2) - d.actuator_velocity = wp.array(tile(mjd.actuator_velocity), dtype=wp.float32, ndim=2) - d.actuator_force = wp.array(tile(mjd.actuator_force), dtype=wp.float32, ndim=2) - d.actuator_length = wp.array(tile(mjd.actuator_length), dtype=wp.float32, ndim=2) - d.actuator_moment = wp.array(tile(actuator_moment), dtype=wp.float32, ndim=3) - d.cvel = wp.array(tile(mjd.cvel), dtype=wp.spatial_vector, ndim=2) - d.cdof_dot = wp.array(tile(mjd.cdof_dot), dtype=wp.spatial_vector, ndim=2) - d.qfrc_bias = wp.array(tile(mjd.qfrc_bias), dtype=wp.float32, ndim=2) - d.qfrc_passive = wp.array(tile(mjd.qfrc_passive), dtype=wp.float32, ndim=2) - d.qfrc_spring = wp.array(tile(mjd.qfrc_spring), dtype=wp.float32, ndim=2) - d.qfrc_damper = wp.array(tile(mjd.qfrc_damper), dtype=wp.float32, ndim=2) - d.qfrc_actuator = wp.array(tile(mjd.qfrc_actuator), dtype=wp.float32, ndim=2) - d.qfrc_smooth = wp.array(tile(mjd.qfrc_smooth), dtype=wp.float32, ndim=2) - d.qfrc_constraint = wp.array(tile(mjd.qfrc_constraint), dtype=wp.float32, ndim=2) - d.qacc_smooth = wp.array(tile(mjd.qacc_smooth), dtype=wp.float32, ndim=2) - d.act = wp.array(tile(mjd.act), dtype=wp.float32, ndim=2) - d.act_dot = wp.array(tile(mjd.act_dot), dtype=wp.float32, ndim=2) + d.qpos = wp.array(tile(mjd.qpos, nworld), dtype=wp.float32, ndim=2) + d.qvel = wp.array(tile(mjd.qvel, nworld), dtype=wp.float32, ndim=2) + d.qacc_warmstart = wp.array( + tile(mjd.qacc_warmstart, nworld), dtype=wp.float32, ndim=2 + ) + d.qfrc_applied = wp.array(tile(mjd.qfrc_applied, nworld), dtype=wp.float32, ndim=2) + d.mocap_pos = wp.array(tile(mjd.mocap_pos, nworld), dtype=wp.vec3, ndim=2) + d.mocap_quat = wp.array(tile(mjd.mocap_quat, nworld), dtype=wp.quat, ndim=2) + d.qacc = wp.array(tile(mjd.qacc, nworld), dtype=wp.float32, ndim=2) + d.xanchor = wp.array(tile(mjd.xanchor, nworld), dtype=wp.vec3, ndim=2) + d.xaxis = wp.array(tile(mjd.xaxis, nworld), dtype=wp.vec3, ndim=2) + d.xmat = wp.array(tile(mjd.xmat, nworld), dtype=wp.mat33, ndim=2) + d.xpos = wp.array(tile(mjd.xpos, nworld), dtype=wp.vec3, ndim=2) + d.xquat = wp.array(tile(mjd.xquat, nworld), dtype=wp.quat, ndim=2) + d.xipos = wp.array(tile(mjd.xipos, nworld), dtype=wp.vec3, ndim=2) + d.ximat = wp.array(tile(mjd.ximat, nworld), dtype=wp.mat33, ndim=2) + d.subtree_com = wp.array(tile(mjd.subtree_com, nworld), dtype=wp.vec3, ndim=2) + d.geom_xpos = wp.array(tile(mjd.geom_xpos, nworld), dtype=wp.vec3, ndim=2) + d.geom_xmat = wp.array(tile(mjd.geom_xmat, nworld), dtype=wp.mat33, ndim=2) + d.site_xpos = wp.array(tile(mjd.site_xpos, nworld), dtype=wp.vec3, ndim=2) + d.site_xmat = wp.array(tile(mjd.site_xmat, nworld), dtype=wp.mat33, ndim=2) + d.cam_xpos = wp.array(tile(mjd.cam_xpos, nworld), dtype=wp.vec3, ndim=2) + d.cam_xmat = wp.array( + tile(mjd.cam_xmat.reshape(-1, 3, 3), nworld), dtype=wp.mat33, ndim=2 + ) + d.light_xpos = wp.array(tile(mjd.light_xpos, nworld), dtype=wp.vec3, ndim=2) + d.light_xdir = wp.array(tile(mjd.light_xdir, nworld), dtype=wp.vec3, ndim=2) + d.cinert = wp.array(tile(mjd.cinert, nworld), dtype=types.vec10, ndim=2) + d.cdof = wp.array(tile(mjd.cdof, nworld), dtype=wp.spatial_vector, ndim=2) + d.crb = wp.array(tile(mjd.crb, nworld), dtype=types.vec10, ndim=2) + d.qM = wp.array(tile(qM, nworld), dtype=wp.float32, ndim=3) + d.qLD = wp.array(tile(qLD, nworld), dtype=wp.float32, ndim=3) + d.qLDiagInv = wp.array(tile(mjd.qLDiagInv, nworld), dtype=wp.float32, ndim=2) + d.ctrl = wp.array(tile(mjd.ctrl, nworld), dtype=wp.float32, ndim=2) + d.actuator_velocity = wp.array( + tile(mjd.actuator_velocity, nworld), dtype=wp.float32, ndim=2 + ) + d.actuator_force = wp.array( + tile(mjd.actuator_force, nworld), dtype=wp.float32, ndim=2 + ) + d.actuator_length = wp.array( + tile(mjd.actuator_length, nworld), dtype=wp.float32, ndim=2 + ) + d.actuator_moment = wp.array(tile(actuator_moment, nworld), dtype=wp.float32, ndim=3) + d.cvel = wp.array(tile(mjd.cvel, nworld), dtype=wp.spatial_vector, ndim=2) + d.cdof_dot = wp.array(tile(mjd.cdof_dot, nworld), dtype=wp.spatial_vector, ndim=2) + d.qfrc_bias = wp.array(tile(mjd.qfrc_bias, nworld), dtype=wp.float32, ndim=2) + d.qfrc_passive = wp.array(tile(mjd.qfrc_passive, nworld), dtype=wp.float32, ndim=2) + d.qfrc_spring = wp.array(tile(mjd.qfrc_spring, nworld), dtype=wp.float32, ndim=2) + d.qfrc_damper = wp.array(tile(mjd.qfrc_damper, nworld), dtype=wp.float32, ndim=2) + d.qfrc_actuator = wp.array(tile(mjd.qfrc_actuator, nworld), dtype=wp.float32, ndim=2) + d.qfrc_smooth = wp.array(tile(mjd.qfrc_smooth, nworld), dtype=wp.float32, ndim=2) + d.qfrc_constraint = wp.array( + tile(mjd.qfrc_constraint, nworld), dtype=wp.float32, ndim=2 + ) + d.qacc_smooth = wp.array(tile(mjd.qacc_smooth, nworld), dtype=wp.float32, ndim=2) + d.act = wp.array(tile(mjd.act, nworld), dtype=wp.float32, ndim=2) + d.act_dot = wp.array(tile(mjd.act_dot, nworld), dtype=wp.float32, ndim=2) nefc = mjd.nefc efc_worldid = np.zeros(njmax, dtype=int) @@ -845,10 +952,10 @@ def tile(x): d.contact.efc_address = wp.array(con_efc_address_fill, dtype=wp.int32, ndim=2) d.contact.worldid = wp.array(con_worldid, dtype=wp.int32, ndim=1) - d.rne_cacc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector) - d.rne_cfrc = wp.zeros(shape=(d.nworld, mjm.nbody), dtype=wp.spatial_vector) + d.rne_cacc = wp.zeros(shape=(nworld, mjm.nbody), dtype=wp.spatial_vector) + d.rne_cfrc = wp.zeros(shape=(nworld, mjm.nbody), dtype=wp.spatial_vector) - d.efc = _constraint(mjm, d.nworld, d.njmax) + d.efc = _constraint(mjm, nworld, d.njmax) d.efc.J = wp.array(efc_J_fill, dtype=wp.float32, ndim=2) d.efc.D = wp.array(efc_D_fill, dtype=wp.float32, ndim=1) d.efc.pos = wp.array(efc_pos_fill, dtype=wp.float32, ndim=1) @@ -857,7 +964,9 @@ def tile(x): d.efc.margin = wp.array(efc_margin_fill, dtype=wp.float32, ndim=1) d.efc.worldid = wp.from_numpy(efc_worldid, dtype=wp.int32) - d.xfrc_applied = wp.array(tile(mjd.xfrc_applied), dtype=wp.spatial_vector, ndim=2) + d.xfrc_applied = wp.array( + tile(mjd.xfrc_applied, nworld), dtype=wp.spatial_vector, ndim=2 + ) # internal tmp arrays d.qfrc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) @@ -883,7 +992,7 @@ def tile(x): d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1) # tendon - d.ten_length = wp.array(tile(mjd.ten_length), dtype=wp.float32, ndim=2) + d.ten_length = wp.array(tile(mjd.ten_length, nworld), dtype=wp.float32, ndim=2) if support.is_sparse(mjm) and mjm.ntendon: ten_J = np.zeros((mjm.ntendon, mjm.nv)) @@ -893,10 +1002,10 @@ def tile(x): else: ten_J = mjd.ten_J.reshape((mjm.ntendon, mjm.nv)) - d.ten_J = wp.array(tile(ten_J), dtype=wp.float32, ndim=3) + d.ten_J = wp.array(tile(ten_J, nworld), dtype=wp.float32, ndim=3) # sensors - d.sensordata = wp.array(tile(mjd.sensordata), dtype=wp.float32, ndim=2) + d.sensordata = wp.array(tile(mjd.sensordata, nworld), dtype=wp.float32, ndim=2) return d @@ -905,9 +1014,10 @@ def get_data_into( result: mujoco.MjData, mjm: mujoco.MjModel, d: types.Data, + nworld: int = 1, ): """Gets Data from a device into an existing mujoco.MjData.""" - if d.nworld > 1: + if nworld > 1: raise NotImplementedError("only nworld == 1 supported for now") ncon = d.ncon.numpy()[0] diff --git a/mujoco_warp/_src/io_test.py b/mujoco_warp/_src/io_test.py index 4f4f6c32..8dbc1ad8 100644 --- a/mujoco_warp/_src/io_test.py +++ b/mujoco_warp/_src/io_test.py @@ -23,6 +23,18 @@ import mujoco_warp as mjwarp +from . import test_util + +# tolerance for difference between MuJoCo and MJWarp smooth calculations - mostly +# due to float precision +_TOLERANCE = 5e-5 + + +def _assert_eq(a, b, name): + tol = _TOLERANCE * 10 # avoid test noise + err_msg = f"mismatch: {name}" + np.testing.assert_allclose(a, b, err_msg=err_msg, atol=tol, rtol=tol) + class IOTest(absltest.TestCase): def test_equality(self): @@ -277,6 +289,33 @@ def test_option_physical_constants(self): with self.assertRaises(NotImplementedError): mjwarp.put_model(mjm) + def test_model_batching(self): + mjm, mjd, _, _ = test_util.fixture("humanoid/humanoid.xml") + + m = mjwarp.put_model(mjm, nworld=2, expand_fields={"dof_damping"}) + d = mjwarp.put_data(mjm, mjd, nworld=2) + + self.assertEqual(m.nworld, 2) + + # randomize dof_damping + dof_damping = m.dof_damping.numpy() + dof_damping[1, :] *= 0.5 + m.dof_damping = wp.from_numpy(dof_damping, dtype=wp.float32) + + mjwarp.passive(m, d) + + # mujoco reference, just have 2 separate model/data strctures + mujoco.mj_passive(mjm, mjd) + + m2, d2, _, _ = test_util.fixture("humanoid/humanoid.xml") + d2.qvel = mjd.qvel # need to copy qvel because of randomization + m2.dof_damping *= 0.5 + + mujoco.mj_passive(m2, d2) + + _assert_eq(d.qfrc_damper.numpy()[0, :], mjd.qfrc_damper, "qfrc_damper") + _assert_eq(d.qfrc_damper.numpy()[1, :], d2.qfrc_damper, "qfrc_damper") + if __name__ == "__main__": wp.init() diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index c8cbc30b..353c2fc2 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -35,7 +35,7 @@ def passive(m: Model, d: Data): @kernel def _spring(m: Model, d: Data): worldid, jntid = wp.tid() - stiffness = m.jnt_stiffness[jntid] + stiffness = m.jnt_stiffness[worldid, jntid] dofid = m.jnt_dofadr[jntid] if stiffness == 0.0: @@ -46,9 +46,9 @@ def _spring(m: Model, d: Data): if jnt_type == wp.static(JointType.FREE.value): dif = wp.vec3( - d.qpos[worldid, qposid + 0] - m.qpos_spring[qposid + 0], - d.qpos[worldid, qposid + 1] - m.qpos_spring[qposid + 1], - d.qpos[worldid, qposid + 2] - m.qpos_spring[qposid + 2], + d.qpos[worldid, qposid + 0] - m.qpos_spring[worldid, qposid + 0], + d.qpos[worldid, qposid + 1] - m.qpos_spring[worldid, qposid + 1], + d.qpos[worldid, qposid + 2] - m.qpos_spring[worldid, qposid + 2], ) d.qfrc_spring[worldid, dofid + 0] = -stiffness * dif[0] d.qfrc_spring[worldid, dofid + 1] = -stiffness * dif[1] @@ -60,10 +60,10 @@ def _spring(m: Model, d: Data): d.qpos[worldid, qposid + 6], ) ref = wp.quat( - m.qpos_spring[qposid + 3], - m.qpos_spring[qposid + 4], - m.qpos_spring[qposid + 5], - m.qpos_spring[qposid + 6], + m.qpos_spring[worldid, qposid + 3], + m.qpos_spring[worldid, qposid + 4], + m.qpos_spring[worldid, qposid + 5], + m.qpos_spring[worldid, qposid + 6], ) dif = math.quat_sub(rot, ref) d.qfrc_spring[worldid, dofid + 3] = -stiffness * dif[0] @@ -77,23 +77,23 @@ def _spring(m: Model, d: Data): d.qpos[worldid, qposid + 3], ) ref = wp.quat( - m.qpos_spring[qposid + 0], - m.qpos_spring[qposid + 1], - m.qpos_spring[qposid + 2], - m.qpos_spring[qposid + 3], + m.qpos_spring[worldid, qposid + 0], + m.qpos_spring[worldid, qposid + 1], + m.qpos_spring[worldid, qposid + 2], + m.qpos_spring[worldid, qposid + 3], ) dif = math.quat_sub(rot, ref) d.qfrc_spring[worldid, dofid + 0] = -stiffness * dif[0] d.qfrc_spring[worldid, dofid + 1] = -stiffness * dif[1] d.qfrc_spring[worldid, dofid + 2] = -stiffness * dif[2] else: # mjJNT_SLIDE, mjJNT_HINGE - fdif = d.qpos[worldid, qposid] - m.qpos_spring[qposid] + fdif = d.qpos[worldid, qposid] - m.qpos_spring[worldid, qposid] d.qfrc_spring[worldid, dofid] = -stiffness * fdif @kernel def _damper_passive(m: Model, d: Data): worldid, dofid = wp.tid() - damping = m.dof_damping[dofid] + damping = m.dof_damping[worldid, dofid] qfrc_damper = -damping * d.qvel[worldid, dofid] d.qfrc_damper[worldid, dofid] = qfrc_damper @@ -104,5 +104,5 @@ def _damper_passive(m: Model, d: Data): # TODO(team): mj_inertiaBoxFluidModell d.qfrc_spring.zero_() - wp.launch(_spring, dim=(d.nworld, m.njnt), inputs=[m, d]) - wp.launch(_damper_passive, dim=(d.nworld, m.nv), inputs=[m, d]) + wp.launch(_spring, dim=(m.nworld, m.njnt), inputs=[m, d]) + wp.launch(_damper_passive, dim=(m.nworld, m.nv), inputs=[m, d]) diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py index 8d496bb2..541385d5 100644 --- a/mujoco_warp/_src/sensor.py +++ b/mujoco_warp/_src/sensor.py @@ -46,7 +46,7 @@ def _sensor_pos(m: Model, d: Data): if (m.sensor_pos_adr.size == 0) or (m.opt.disableflags & DisableBit.SENSOR): return - wp.launch(_sensor_pos, dim=(d.nworld, m.sensor_pos_adr.size), inputs=[m, d]) + wp.launch(_sensor_pos, dim=(m.nworld, m.sensor_pos_adr.size), inputs=[m, d]) @wp.func @@ -72,7 +72,7 @@ def _sensor_vel(m: Model, d: Data): if (m.sensor_vel_adr.size == 0) or (m.opt.disableflags & DisableBit.SENSOR): return - wp.launch(_sensor_vel, dim=(d.nworld, m.sensor_vel_adr.size), inputs=[m, d]) + wp.launch(_sensor_vel, dim=(m.nworld, m.sensor_vel_adr.size), inputs=[m, d]) @wp.func @@ -98,4 +98,4 @@ def _sensor_acc(m: Model, d: Data): if (m.sensor_acc_adr.size == 0) or (m.opt.disableflags & DisableBit.SENSOR): return - wp.launch(_sensor_acc, dim=(d.nworld, m.sensor_acc_adr.size), inputs=[m, d]) + wp.launch(_sensor_acc, dim=(m.nworld, m.sensor_acc_adr.size), inputs=[m, d]) diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index cf88fc88..adb7a303 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -56,8 +56,8 @@ def _level(m: Model, d: Data, leveladr: int): if jntnum == 0: # no joints - apply fixed translation and rotation relative to parent pid = m.body_parentid[bodyid] - xpos = (d.xmat[worldid, pid] * m.body_pos[bodyid]) + d.xpos[worldid, pid] - xquat = math.mul_quat(d.xquat[worldid, pid], m.body_quat[bodyid]) + xpos = (d.xmat[worldid, pid] * m.body_pos[worldid, bodyid]) + d.xpos[worldid, pid] + xquat = math.mul_quat(d.xquat[worldid, pid], m.body_quat[worldid, bodyid]) elif jntnum == 1 and m.jnt_type[jntadr] == wp.static(JointType.FREE.value): # free joint qadr = m.jnt_qposadr[jntadr] @@ -69,8 +69,8 @@ def _level(m: Model, d: Data, leveladr: int): # regular or no joints # apply fixed translation and rotation relative to parent pid = m.body_parentid[bodyid] - xpos = (d.xmat[worldid, pid] * m.body_pos[bodyid]) + d.xpos[worldid, pid] - xquat = math.mul_quat(d.xquat[worldid, pid], m.body_quat[bodyid]) + xpos = (d.xmat[worldid, pid] * m.body_pos[worldid, bodyid]) + d.xpos[worldid, pid] + xquat = math.mul_quat(d.xquat[worldid, pid], m.body_quat[worldid, bodyid]) for _ in range(jntnum): qadr = m.jnt_qposadr[jntadr] @@ -90,9 +90,9 @@ def _level(m: Model, d: Data, leveladr: int): # correct for off-center rotation xpos = xanchor - math.rot_vec_quat(m.jnt_pos[jntadr], xquat) elif jnt_type == wp.static(JointType.SLIDE.value): - xpos += xaxis * (qpos[qadr] - m.qpos0[qadr]) + xpos += xaxis * (qpos[qadr] - m.qpos0[worldid, qadr]) elif jnt_type == wp.static(JointType.HINGE.value): - qpos0 = m.qpos0[qadr] + qpos0 = m.qpos0[worldid, qadr] qloc = math.axis_angle_to_quat(jnt_axis, qpos[qadr] - qpos0) xquat = math.mul_quat(xquat, qloc) # correct for off-center rotation @@ -106,9 +106,11 @@ def _level(m: Model, d: Data, leveladr: int): xquat = wp.normalize(xquat) d.xquat[worldid, bodyid] = xquat d.xmat[worldid, bodyid] = math.quat_to_mat(xquat) - d.xipos[worldid, bodyid] = xpos + math.rot_vec_quat(m.body_ipos[bodyid], xquat) + d.xipos[worldid, bodyid] = xpos + math.rot_vec_quat( + m.body_ipos[worldid, bodyid], xquat + ) d.ximat[worldid, bodyid] = math.quat_to_mat( - math.mul_quat(xquat, m.body_iquat[bodyid]) + math.mul_quat(xquat, m.body_iquat[worldid, bodyid]) ) @kernel @@ -117,9 +119,11 @@ def geom_local_to_global(m: Model, d: Data): bodyid = m.geom_bodyid[geomid] xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] - d.geom_xpos[worldid, geomid] = xpos + math.rot_vec_quat(m.geom_pos[geomid], xquat) + d.geom_xpos[worldid, geomid] = xpos + math.rot_vec_quat( + m.geom_pos[worldid, geomid], xquat + ) d.geom_xmat[worldid, geomid] = math.quat_to_mat( - math.mul_quat(xquat, m.geom_quat[geomid]) + math.mul_quat(xquat, m.geom_quat[worldid, geomid]) ) @kernel @@ -128,24 +132,26 @@ def site_local_to_global(m: Model, d: Data): bodyid = m.site_bodyid[siteid] xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] - d.site_xpos[worldid, siteid] = xpos + math.rot_vec_quat(m.site_pos[siteid], xquat) + d.site_xpos[worldid, siteid] = xpos + math.rot_vec_quat( + m.site_pos[worldid, siteid], xquat + ) d.site_xmat[worldid, siteid] = math.quat_to_mat( - math.mul_quat(xquat, m.site_quat[siteid]) + math.mul_quat(xquat, m.site_quat[worldid, siteid]) ) - wp.launch(_root, dim=(d.nworld), inputs=[m, d]) + wp.launch(_root, dim=(m.nworld), inputs=[m, d]) body_treeadr = m.body_treeadr.numpy() for i in range(1, len(body_treeadr)): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(_level, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + wp.launch(_level, dim=(m.nworld, end - beg), inputs=[m, d, beg]) if m.ngeom: - wp.launch(geom_local_to_global, dim=(d.nworld, m.ngeom), inputs=[m, d]) + wp.launch(geom_local_to_global, dim=(m.nworld, m.ngeom), inputs=[m, d]) if m.nsite: - wp.launch(site_local_to_global, dim=(d.nworld, m.nsite), inputs=[m, d]) + wp.launch(site_local_to_global, dim=(m.nworld, m.nsite), inputs=[m, d]) @event_scope @@ -155,7 +161,9 @@ def com_pos(m: Model, d: Data): @kernel def subtree_com_init(m: Model, d: Data): worldid, bodyid = wp.tid() - d.subtree_com[worldid, bodyid] = d.xipos[worldid, bodyid] * m.body_mass[bodyid] + d.subtree_com[worldid, bodyid] = ( + d.xipos[worldid, bodyid] * m.body_mass[worldid, bodyid] + ) @kernel def subtree_com_acc(m: Model, d: Data, leveladr: int): @@ -167,14 +175,14 @@ def subtree_com_acc(m: Model, d: Data, leveladr: int): @kernel def subtree_div(m: Model, d: Data): worldid, bodyid = wp.tid() - d.subtree_com[worldid, bodyid] /= m.subtree_mass[bodyid] + d.subtree_com[worldid, bodyid] /= m.subtree_mass[worldid, bodyid] @kernel def cinert(m: Model, d: Data): worldid, bodyid = wp.tid() mat = d.ximat[worldid, bodyid] - inert = m.body_inertia[bodyid] - mass = m.body_mass[bodyid] + inert = m.body_inertia[worldid, bodyid] + mass = m.body_mass[worldid, bodyid] dif = d.xipos[worldid, bodyid] - d.subtree_com[worldid, m.body_rootid[bodyid]] # express inertia in com-based frame (mju_inertCom) @@ -234,18 +242,18 @@ def cdof(m: Model, d: Data): elif jnt_type == wp.static(JointType.HINGE.value): # hinge res[dofid] = wp.spatial_vector(xaxis, wp.cross(xaxis, offset)) - wp.launch(subtree_com_init, dim=(d.nworld, m.nbody), inputs=[m, d]) + wp.launch(subtree_com_init, dim=(m.nworld, m.nbody), inputs=[m, d]) body_treeadr = m.body_treeadr.numpy() for i in reversed(range(len(body_treeadr))): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(subtree_com_acc, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + wp.launch(subtree_com_acc, dim=(m.nworld, end - beg), inputs=[m, d, beg]) - wp.launch(subtree_div, dim=(d.nworld, m.nbody), inputs=[m, d]) - wp.launch(cinert, dim=(d.nworld, m.nbody), inputs=[m, d]) - wp.launch(cdof, dim=(d.nworld, m.njnt), inputs=[m, d]) + wp.launch(subtree_div, dim=(m.nworld, m.nbody), inputs=[m, d]) + wp.launch(cinert, dim=(m.nworld, m.nbody), inputs=[m, d]) + wp.launch(cdof, dim=(m.nworld, m.njnt), inputs=[m, d]) @event_scope @@ -259,9 +267,11 @@ def cam_local_to_global(m: Model, d: Data): bodyid = m.cam_bodyid[camid] xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] - d.cam_xpos[worldid, camid] = xpos + math.rot_vec_quat(m.cam_pos[camid], xquat) + d.cam_xpos[worldid, camid] = xpos + math.rot_vec_quat( + m.cam_pos[worldid, camid], xquat + ) d.cam_xmat[worldid, camid] = math.quat_to_mat( - math.mul_quat(xquat, m.cam_quat[camid]) + math.mul_quat(xquat, m.cam_quat[worldid, camid]) ) @kernel @@ -275,10 +285,10 @@ def cam_fn(m: Model, d: Data): return elif m.cam_mode[camid] == wp.static(CamLightType.TRACK.value): body_xpos = d.xpos[worldid, m.cam_bodyid[camid]] - d.cam_xpos[worldid, camid] = body_xpos + m.cam_pos0[camid] + d.cam_xpos[worldid, camid] = body_xpos + m.cam_pos0[worldid, camid] elif m.cam_mode[camid] == wp.static(CamLightType.TRACKCOM.value): d.cam_xpos[worldid, camid] = ( - d.subtree_com[worldid, m.cam_bodyid[camid]] + m.cam_poscom0[camid] + d.subtree_com[worldid, m.cam_bodyid[camid]] + m.cam_poscom0[worldid, camid] ) elif m.cam_mode[camid] == wp.static(CamLightType.TARGETBODY.value) or m.cam_mode[ camid @@ -307,9 +317,11 @@ def light_local_to_global(m: Model, d: Data): xpos = d.xpos[worldid, bodyid] xquat = d.xquat[worldid, bodyid] d.light_xpos[worldid, lightid] = xpos + math.rot_vec_quat( - m.light_pos[lightid], xquat + m.light_pos[worldid, lightid], xquat + ) + d.light_xdir[worldid, lightid] = math.rot_vec_quat( + m.light_dir[worldid, lightid], xquat ) - d.light_xdir[worldid, lightid] = math.rot_vec_quat(m.light_dir[lightid], xquat) @kernel def light_fn(m: Model, d: Data): @@ -322,10 +334,11 @@ def light_fn(m: Model, d: Data): return elif m.light_mode[lightid] == wp.static(CamLightType.TRACK.value): body_xpos = d.xpos[worldid, m.light_bodyid[lightid]] - d.light_xpos[worldid, lightid] = body_xpos + m.light_pos0[lightid] + d.light_xpos[worldid, lightid] = body_xpos + m.light_pos0[worldid, lightid] elif m.light_mode[lightid] == wp.static(CamLightType.TRACKCOM.value): d.light_xpos[worldid, lightid] = ( - d.subtree_com[worldid, m.light_bodyid[lightid]] + m.light_poscom0[lightid] + d.subtree_com[worldid, m.light_bodyid[lightid]] + + m.light_poscom0[worldid, lightid] ) elif m.light_mode[lightid] == wp.static( CamLightType.TARGETBODY.value @@ -337,11 +350,11 @@ def light_fn(m: Model, d: Data): d.light_xdir[worldid, lightid] = wp.normalize(d.light_xdir[worldid, lightid]) if m.ncam > 0: - wp.launch(cam_local_to_global, dim=(d.nworld, m.ncam), inputs=[m, d]) - wp.launch(cam_fn, dim=(d.nworld, m.ncam), inputs=[m, d]) + wp.launch(cam_local_to_global, dim=(m.nworld, m.ncam), inputs=[m, d]) + wp.launch(cam_fn, dim=(m.nworld, m.ncam), inputs=[m, d]) if m.nlight > 0: - wp.launch(light_local_to_global, dim=(d.nworld, m.nlight), inputs=[m, d]) - wp.launch(light_fn, dim=(d.nworld, m.nlight), inputs=[m, d]) + wp.launch(light_local_to_global, dim=(m.nworld, m.nlight), inputs=[m, d]) + wp.launch(light_fn, dim=(m.nworld, m.nlight), inputs=[m, d]) @event_scope @@ -366,7 +379,7 @@ def qM_sparse(m: Model, d: Data): bodyid = m.dof_bodyid[dofid] # init M(i,i) with armature inertia - d.qM[worldid, 0, madr_ij] = m.dof_armature[dofid] + d.qM[worldid, 0, madr_ij] = m.dof_armature[worldid, dofid] # precompute buf = crb_body_i * cdof_i buf = math.inert_vec(d.crb[worldid, bodyid], d.cdof[worldid, dofid]) @@ -383,7 +396,7 @@ def qM_dense(m: Model, d: Data): bodyid = m.dof_bodyid[dofid] # init M(i,i) with armature inertia - M = m.dof_armature[dofid] + M = m.dof_armature[worldid, dofid] # precompute buf = crb_body_i * cdof_i buf = math.inert_vec(d.crb[worldid, bodyid], d.cdof[worldid, dofid]) @@ -404,13 +417,13 @@ def qM_dense(m: Model, d: Data): for i in reversed(range(len(body_treeadr))): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(crb_accumulate, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + wp.launch(crb_accumulate, dim=(m.nworld, end - beg), inputs=[m, d, beg]) d.qM.zero_() if m.opt.is_sparse: - wp.launch(qM_sparse, dim=(d.nworld, m.nv), inputs=[m, d]) + wp.launch(qM_sparse, dim=(m.nworld, m.nv), inputs=[m, d]) else: - wp.launch(qM_dense, dim=(d.nworld, m.nv), inputs=[m, d]) + wp.launch(qM_dense, dim=(m.nworld, m.nv), inputs=[m, d]) def _factor_i_sparse_legacy(m: Model, d: Data, M: array3df, L: array3df, D: array2df): @@ -444,9 +457,9 @@ def qLDiag_div(m: Model, L: array3df, D: array2df): beg, end = qLD_update_treeadr[i], m.qLD_update_tree.shape[0] else: beg, end = qLD_update_treeadr[i], qLD_update_treeadr[i + 1] - wp.launch(qLD_acc, dim=(d.nworld, end - beg), inputs=[m, beg, L]) + wp.launch(qLD_acc, dim=(m.nworld, end - beg), inputs=[m, beg, L]) - wp.launch(qLDiag_div, dim=(d.nworld, m.nv), inputs=[m, L, D]) + wp.launch(qLDiag_div, dim=(m.nworld, m.nv), inputs=[m, L, D]) def _factor_i_sparse(m: Model, d: Data, M: array3df, L: array3df, D: array2df): @@ -482,7 +495,7 @@ def copy_CSR(L: array3df, M: array3df, mapM2M: wp.array(dtype=wp.int32, ndim=1)) worldid, ind = wp.tid() L[worldid, 0, ind] = M[worldid, 0, mapM2M[ind]] - wp.launch(copy_CSR, dim=(d.nworld, m.nM), inputs=[L, M, m.mapM2M]) + wp.launch(copy_CSR, dim=(m.nworld, m.nM), inputs=[L, M, m.mapM2M]) qLD_update_treeadr = m.qLD_update_treeadr.numpy() @@ -491,9 +504,9 @@ def copy_CSR(L: array3df, M: array3df, mapM2M: wp.array(dtype=wp.int32, ndim=1)) beg, end = qLD_update_treeadr[i], m.qLD_update_tree.shape[0] else: beg, end = qLD_update_treeadr[i], qLD_update_treeadr[i + 1] - wp.launch(qLD_acc, dim=(d.nworld, end - beg), inputs=[m, beg, L]) + wp.launch(qLD_acc, dim=(m.nworld, end - beg), inputs=[m, beg, L]) - wp.launch(qLDiag_div, dim=(d.nworld, m.nv), inputs=[m, L, D]) + wp.launch(qLDiag_div, dim=(m.nworld, m.nv), inputs=[m, L, D]) def _factor_i_dense(m: Model, d: Data, M: wp.array, L: wp.array): @@ -514,7 +527,7 @@ def cholesky(m: Model, leveladr: int, M: array3df, L: array3df): wp.tile_store(L[worldid], L_tile, offset=(dofid, dofid)) wp.launch_tiled( - cholesky, dim=(d.nworld, size), inputs=[m, adr, M, L], block_dim=block_dim + cholesky, dim=(m.nworld, size), inputs=[m, adr, M, L], block_dim=block_dim ) qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() @@ -597,22 +610,22 @@ def qfrc_bias(m: Model, d: Data): d.cdof[worldid, dofid], d.rne_cfrc[worldid, bodyid] ) - wp.launch(cacc_world, dim=[d.nworld], inputs=[m, d]) + wp.launch(cacc_world, dim=[m.nworld], inputs=[m, d]) body_treeadr = m.body_treeadr.numpy() for i in range(len(body_treeadr)): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(cacc_level, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + wp.launch(cacc_level, dim=(m.nworld, end - beg), inputs=[m, d, beg]) - wp.launch(frc_fn, dim=[d.nworld, m.nbody], inputs=[d]) + wp.launch(frc_fn, dim=[m.nworld, m.nbody], inputs=[d]) for i in reversed(range(len(body_treeadr))): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(cfrc_fn, dim=[d.nworld, end - beg], inputs=[m, d, beg]) + wp.launch(cfrc_fn, dim=[m.nworld, end - beg], inputs=[m, d, beg]) - wp.launch(qfrc_bias, dim=[d.nworld, m.nv], inputs=[m, d]) + wp.launch(qfrc_bias, dim=[m.nworld, m.nv], inputs=[m, d]) @event_scope @@ -636,7 +649,7 @@ def _transmission( qadr = m.jnt_qposadr[jntid] vadr = m.jnt_dofadr[jntid] trntype = m.actuator_trntype[actid] - gear = m.actuator_gear[actid] + gear = m.actuator_gear[worldid, actid] if trntype == wp.static(TrnType.JOINT.value) or trntype == wp.static( TrnType.JOINTINPARENT.value ): @@ -679,7 +692,7 @@ def _transmission( wp.launch( _transmission, - dim=[d.nworld, m.nu], + dim=[m.nworld, m.nu], inputs=[m, d], outputs=[d.actuator_length, d.actuator_moment], ) @@ -746,13 +759,13 @@ def _level(m: Model, d: Data, leveladr: int): d.cvel[worldid, bodyid] = cvel - wp.launch(_root, dim=(d.nworld, 6), inputs=[d]) + wp.launch(_root, dim=(m.nworld, 6), inputs=[d]) body_treeadr = m.body_treeadr.numpy() for i in range(1, len(body_treeadr)): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(_level, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + wp.launch(_level, dim=(m.nworld, end - beg), inputs=[m, d, beg]) def _solve_LD_sparse( @@ -788,16 +801,16 @@ def x_acc_down(m: Model, L: array3df, x: array2df, leveladr: int): beg, end = qLD_update_treeadr[i], m.qLD_update_tree.shape[0] else: beg, end = qLD_update_treeadr[i], qLD_update_treeadr[i + 1] - wp.launch(x_acc_up, dim=(d.nworld, end - beg), inputs=[m, L, x, beg]) + wp.launch(x_acc_up, dim=(m.nworld, end - beg), inputs=[m, L, x, beg]) - wp.launch(qLDiag_mul, dim=(d.nworld, m.nv), inputs=[D, x]) + wp.launch(qLDiag_mul, dim=(m.nworld, m.nv), inputs=[D, x]) for i in range(len(qLD_update_treeadr)): if i == len(qLD_update_treeadr) - 1: beg, end = qLD_update_treeadr[i], m.qLD_update_tree.shape[0] else: beg, end = qLD_update_treeadr[i], qLD_update_treeadr[i + 1] - wp.launch(x_acc_down, dim=(d.nworld, end - beg), inputs=[m, L, x, beg]) + wp.launch(x_acc_down, dim=(m.nworld, end - beg), inputs=[m, L, x, beg]) def _solve_LD_dense(m: Model, d: Data, L: array3df, x: array2df, y: array2df): @@ -819,7 +832,7 @@ def cho_solve(m: Model, L: array3df, x: array2df, y: array2df, leveladr: int): wp.tile_store(x[worldid], x_slice, offset=(dofid,)) wp.launch_tiled( - cho_solve, dim=(d.nworld, size), inputs=[m, L, x, y, adr], block_dim=block_dim + cho_solve, dim=(m.nworld, size), inputs=[m, L, x, y, adr], block_dim=block_dim ) qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() @@ -864,7 +877,7 @@ def cholesky(m: Model, leveladr: int, M: array3df, x: array2df, y: array2df): wp.tile_store(x[worldid], x_slice, offset=(dofid,)) wp.launch_tiled( - cholesky, dim=(d.nworld, size), inputs=[m, adr, M, x, y], block_dim=block_dim + cholesky, dim=(m.nworld, size), inputs=[m, adr, M, x, y], block_dim=block_dim ) qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() @@ -903,7 +916,7 @@ def _joint_tendon(m: Model, d: Data): wrap_jnt_adr = m.wrap_jnt_adr[wrapid] wrap_objid = m.wrap_objid[wrap_jnt_adr] - prm = m.wrap_prm[wrap_jnt_adr] + prm = m.wrap_prm[worldid, wrap_jnt_adr] # add to length L = prm * d.qpos[worldid, m.jnt_qposadr[wrap_objid]] @@ -913,6 +926,6 @@ def _joint_tendon(m: Model, d: Data): # add to moment d.ten_J[worldid, tendon_jnt_adr, m.jnt_dofadr[wrap_jnt_adr]] = prm - wp.launch(_joint_tendon, dim=(d.nworld, m.wrap_jnt_adr.size), inputs=[m, d]) + wp.launch(_joint_tendon, dim=(m.nworld, m.wrap_jnt_adr.size), inputs=[m, d]) # TODO(team): spatial diff --git a/mujoco_warp/_src/smooth_test.py b/mujoco_warp/_src/smooth_test.py index 2150c16d..4ab516ea 100644 --- a/mujoco_warp/_src/smooth_test.py +++ b/mujoco_warp/_src/smooth_test.py @@ -29,7 +29,7 @@ # tolerance for difference between MuJoCo and MJWarp smooth calculations - mostly # due to float precision -_TOLERANCE = 5e-5 +_TOLERANCE = 5e-4 def _assert_eq(a, b, name): diff --git a/mujoco_warp/_src/solver.py b/mujoco_warp/_src/solver.py index 4220cebd..3063e327 100644 --- a/mujoco_warp/_src/solver.py +++ b/mujoco_warp/_src/solver.py @@ -54,7 +54,7 @@ def _search(d: types.Data): d.efc.search[worldid, dofid] = search wp.atomic_add(d.efc.search_dot, worldid, search * search) - wp.launch(_init_context, dim=(d.nworld), inputs=[d]) + wp.launch(_init_context, dim=(m.nworld), inputs=[d]) # jaref = d.efc_J @ d.qacc - d.efc_aref d.efc.Jaref.zero_() @@ -69,7 +69,7 @@ def _search(d: types.Data): _update_gradient(m, d) # search = -Mgrad - wp.launch(_search, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_search, dim=(m.nworld, m.nv), inputs=[d]) def _update_constraint(m: types.Model, d: types.Data): @@ -159,18 +159,18 @@ def _gauss(d: types.Data): wp.atomic_add(d.efc.gauss, worldid, gauss_cost) wp.atomic_add(d.efc.cost, worldid, gauss_cost) - wp.launch(_init_cost, dim=(d.nworld), inputs=[d]) + wp.launch(_init_cost, dim=(m.nworld), inputs=[d]) wp.launch(_efc_kernel, dim=(d.njmax,), inputs=[d]) # qfrc_constraint = efc_J.T @ efc_force - wp.launch(_zero_qfrc_constraint, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_zero_qfrc_constraint, dim=(m.nworld, m.nv), inputs=[d]) wp.launch(_qfrc_constraint, dim=(m.nv, d.njmax), inputs=[d]) # gauss = 0.5 * (Ma - qfrc_smooth).T @ (qacc - qacc_smooth) - wp.launch(_gauss, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_gauss, dim=(m.nworld, m.nv), inputs=[d]) def _update_gradient(m: types.Model, d: types.Data): @@ -293,22 +293,22 @@ def _cholesky(d: types.Data): wp.tile_store(d.efc.Mgrad[worldid], output_tile) # grad = Ma - qfrc_smooth - qfrc_constraint - wp.launch(_zero_grad_dot, dim=(d.nworld), inputs=[d]) + wp.launch(_zero_grad_dot, dim=(m.nworld), inputs=[d]) - wp.launch(_grad, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_grad, dim=(m.nworld, m.nv), inputs=[d]) if m.opt.solver == types.SolverType.CG: smooth.solve_m(m, d, d.efc.Mgrad, d.efc.grad) elif m.opt.solver == types.SolverType.NEWTON: # h = qM + (efc_J.T * efc_D * active) @ efc_J if m.opt.is_sparse: - wp.launch(_zero_h_lower, dim=(d.nworld, m.dof_tri_row.size), inputs=[m, d]) + wp.launch(_zero_h_lower, dim=(m.nworld, m.dof_tri_row.size), inputs=[m, d]) wp.launch( - _set_h_qM_lower_sparse, dim=(d.nworld, m.qM_fullm_i.size), inputs=[m, d] + _set_h_qM_lower_sparse, dim=(m.nworld, m.qM_fullm_i.size), inputs=[m, d] ) else: - wp.launch(_copy_lower_triangle, dim=(d.nworld, m.dof_tri_row.size), inputs=[m, d]) + wp.launch(_copy_lower_triangle, dim=(m.nworld, m.dof_tri_row.size), inputs=[m, d]) # Optimization: launching _JTDAJ with limited number of blocks on a GPU. # Profiling suggests that only a fraction of blocks out of the original @@ -333,7 +333,7 @@ def _cholesky(d: types.Data): inputs=[m, d, int((d.njmax + dim_x - 1) / dim_x), dim_x], ) - wp.launch_tiled(_cholesky, dim=(d.nworld,), inputs=[d], block_dim=32) + wp.launch_tiled(_cholesky, dim=(m.nworld,), inputs=[d], block_dim=32) @wp.func @@ -640,7 +640,7 @@ def _swap( alpha = wp.where(improved and not plo_better, phi_alpha, alpha) d.efc.alpha[worldid] = alpha - wp.launch(_gtol, dim=(d.nworld,), inputs=[m, d]) + wp.launch(_gtol, dim=(m.nworld,), inputs=[m, d]) # linesearch points done = d.efc.ls_done @@ -659,17 +659,17 @@ def _swap( # initialize interval - wp.launch(_init_p0_gauss, dim=(d.nworld,), inputs=[p0, d]) + wp.launch(_init_p0_gauss, dim=(m.nworld,), inputs=[p0, d]) wp.launch(_init_p0, dim=(d.njmax,), inputs=[p0, d]) - wp.launch(_init_lo_gauss, dim=(d.nworld,), inputs=[p0, lo, lo_alpha, d]) + wp.launch(_init_lo_gauss, dim=(m.nworld,), inputs=[p0, lo, lo_alpha, d]) wp.launch(_init_lo, dim=(d.njmax,), inputs=[lo, lo_alpha, d]) # set the lo/hi interval bounds - wp.launch(_init_bounds, dim=(d.nworld,), inputs=[p0, lo, lo_alpha, hi, hi_alpha, d]) + wp.launch(_init_bounds, dim=(m.nworld,), inputs=[p0, lo, lo_alpha, hi, hi_alpha, d]) for _ in range(m.opt.ls_iterations): # note: we always launch ls_iterations kernels, but the kernels may early exit if done is true @@ -677,7 +677,7 @@ def _swap( # of extra launches inputs = [done, lo, lo_alpha, hi, hi_alpha, lo_next, lo_next_alpha, hi_next] inputs += [hi_next_alpha, mid, mid_alpha, d] - wp.launch(_next_alpha_gauss, dim=(d.nworld,), inputs=inputs) + wp.launch(_next_alpha_gauss, dim=(m.nworld,), inputs=inputs) inputs = [done, lo_next, lo_next_alpha, hi_next, hi_next_alpha, mid, mid_alpha] inputs += [d] @@ -685,7 +685,7 @@ def _swap( inputs = [done, p0, lo, lo_alpha, hi, hi_alpha, lo_next, lo_next_alpha, hi_next] inputs += [hi_next_alpha, mid, mid_alpha, d] - wp.launch(_swap, dim=(d.nworld,), inputs=inputs) + wp.launch(_swap, dim=(m.nworld,), inputs=inputs) def _linesearch_parallel(m: types.Model, d: types.Data): @@ -752,10 +752,10 @@ def _best_alpha(d: types.Data): bestid = wp.argmin(d.efc.cost_candidate[worldid]) d.efc.alpha[worldid] = m.alpha_candidate[bestid] - wp.launch(_quad_total, dim=(d.nworld, m.nlsp), inputs=[m, d]) + wp.launch(_quad_total, dim=(m.nworld, m.nlsp), inputs=[m, d]) wp.launch(_quad_total_candidate, dim=(d.njmax, m.nlsp), inputs=[m, d]) - wp.launch(_cost_alpha, dim=(d.nworld, m.nlsp), inputs=[m, d]) - wp.launch(_best_alpha, dim=(d.nworld), inputs=[d]) + wp.launch(_cost_alpha, dim=(m.nworld, m.nlsp), inputs=[m, d]) + wp.launch(_best_alpha, dim=(m.nworld), inputs=[d]) @event_scope @@ -878,9 +878,9 @@ def _jaref(d: types.Data): # prepare quadratics # quad_gauss = [gauss, search.T @ Ma - search.T @ qfrc_smooth, 0.5 * search.T @ mv] - wp.launch(_zero_quad_gauss, dim=(d.nworld), inputs=[d]) + wp.launch(_zero_quad_gauss, dim=(m.nworld), inputs=[d]) - wp.launch(_init_quad_gauss, dim=(d.nworld, m.nv), inputs=[m, d]) + wp.launch(_init_quad_gauss, dim=(m.nworld, m.nv), inputs=[m, d]) # quad = [0.5 * Jaref * Jaref * efc_D, jv * Jaref * efc_D, 0.5 * jv * jv * efc_D] @@ -891,7 +891,7 @@ def _jaref(d: types.Data): else: _linesearch_iterative(m, d) - wp.launch(_qacc_ma, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_qacc_ma, dim=(m.nworld, m.nv), inputs=[d]) wp.launch(_jaref, dim=(d.njmax,), inputs=[d]) @@ -1005,23 +1005,23 @@ def _beta(d: types.Data): _linesearch(m, d) if m.opt.solver == types.SolverType.CG: - wp.launch(_prev_grad_Mgrad, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_prev_grad_Mgrad, dim=(m.nworld, m.nv), inputs=[d]) _update_constraint(m, d) _update_gradient(m, d) # polak-ribiere if m.opt.solver == types.SolverType.CG: - wp.launch(_zero_beta_num_den, dim=(d.nworld), inputs=[d]) + wp.launch(_zero_beta_num_den, dim=(m.nworld), inputs=[d]) - wp.launch(_beta_num_den, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_beta_num_den, dim=(m.nworld, m.nv), inputs=[d]) - wp.launch(_beta, dim=(d.nworld,), inputs=[d]) + wp.launch(_beta, dim=(m.nworld,), inputs=[d]) - wp.launch(_zero_search_dot, dim=(d.nworld), inputs=[d]) + wp.launch(_zero_search_dot, dim=(m.nworld), inputs=[d]) - wp.launch(_search_update, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_search_update, dim=(m.nworld, m.nv), inputs=[d]) - wp.launch(_done, dim=(d.nworld,), inputs=[m, d, i]) + wp.launch(_done, dim=(m.nworld,), inputs=[m, d, i]) kernel_copy(d.qacc_warmstart, d.qacc) diff --git a/mujoco_warp/_src/solver_test.py b/mujoco_warp/_src/solver_test.py index 2c82151e..30a5e00c 100644 --- a/mujoco_warp/_src/solver_test.py +++ b/mujoco_warp/_src/solver_test.py @@ -62,7 +62,7 @@ def _load( mjd = mujoco.MjData(mjm) mujoco.mj_resetDataKeyframe(mjm, mjd, keyframe) mujoco.mj_step(mjm, mjd) - m = mjwarp.put_model(mjm) + m = mjwarp.put_model(mjm, nworld=nworld) m.opt.ls_parallel = ls_parallel d = mjwarp.put_data(mjm, mjd, nworld=nworld, njmax=njmax) return mjm, mjd, m, d diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py index 80e4cab9..7d7c0ee2 100644 --- a/mujoco_warp/_src/support.py +++ b/mujoco_warp/_src/support.py @@ -71,7 +71,7 @@ def mul( wp.launch_tiled( mul, - dim=(d.nworld, size), + dim=(m.nworld, size), inputs=[ m, d, @@ -130,10 +130,10 @@ def _mul_m_sparse_ij( wp.atomic_add(res[worldid], i, qM * vec[worldid, j]) wp.atomic_add(res[worldid], j, qM * vec[worldid, i]) - wp.launch(_mul_m_sparse_diag, dim=(d.nworld, m.nv), inputs=[m, d, res, vec, skip]) + wp.launch(_mul_m_sparse_diag, dim=(m.nworld, m.nv), inputs=[m, d, res, vec, skip]) wp.launch( - _mul_m_sparse_ij, dim=(d.nworld, m.qM_madr_ij.size), inputs=[m, d, res, vec, skip] + _mul_m_sparse_ij, dim=(m.nworld, m.qM_madr_ij.size), inputs=[m, d, res, vec, skip] ) @@ -169,7 +169,7 @@ def _accumulate(m: Model, d: Data, qfrc: array2df): qfrc[worldid, dofid] += accumul - wp.launch(kernel=_accumulate, dim=(d.nworld, m.nv), inputs=[m, d, qfrc]) + wp.launch(kernel=_accumulate, dim=(m.nworld, m.nv), inputs=[m, d, qfrc]) @wp.func diff --git a/mujoco_warp/_src/support_test.py b/mujoco_warp/_src/support_test.py index 799892c8..215cbf95 100644 --- a/mujoco_warp/_src/support_test.py +++ b/mujoco_warp/_src/support_test.py @@ -50,7 +50,7 @@ def test_mul_m(self, sparse): res = wp.zeros((1, mjm.nv), dtype=wp.float32) vec = wp.from_numpy(np.expand_dims(mj_vec, axis=0), dtype=wp.float32) - skip = wp.zeros((d.nworld), dtype=bool) + skip = wp.zeros((m.nworld), dtype=bool) mjwarp.mul_m(m, d, res, vec, skip) _assert_eq(res.numpy()[0], mj_res, f"mul_m ({'sparse' if sparse else 'dense'})") diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index f0e8b5be..3b725cf4 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -423,10 +423,11 @@ class Model: nsensor: number of sensors () nsensordata: number of elements in sensor data vector () nlsp: number of step sizes for parallel linsearch () + nworld: number of parallel worlds () opt: physics options stat: model statistics - qpos0: qpos values at default pose (nq,) - qpos_spring: reference pose for springs (nq,) + qpos0: qpos values at default pose (nworld, nq) + qpos_spring: reference pose for springs (nworld, nq) body_tree: BFS ordering of body ids body_treeadr: starting index of each body tree level actuator_moment_offset_nv: tiling configuration @@ -459,75 +460,73 @@ class Model: body_dofadr: start addr of dofs; -1: no dofs (nbody,) body_geomnum: number of geoms (nbody,) body_geomadr: start addr of geoms; -1: no geoms (nbody,) - body_pos: position offset rel. to parent body (nbody, 3) - body_quat: orientation offset rel. to parent body (nbody, 4) - body_ipos: local position of center of mass (nbody, 3) - body_iquat: local orientation of inertia ellipsoid (nbody, 4) - body_mass: mass (nbody,) - subtree_mass: mass of subtree (nbody,) - body_inertia: diagonal inertia in ipos/iquat frame (nbody, 3) - body_invweight0: mean inv inert in qpos0 (trn, rot) (nbody, 2) - body_contype: OR over all geom contypes (nbody,) - body_conaffinity: OR over all geom conaffinities (nbody,) + body_pos: position offset rel. to parent body (nworld, nbody, 3) + body_quat: orientation offset rel. to parent body (nworld, nbody, 4) + body_ipos: local position of center of mass (nworld, nbody, 3) + body_iquat: local orientation of inertia ellipsoid (nworld, nbody, 4) + body_mass: mass (nworld, nbody) + subtree_mass: mass of subtree (nworld, nbody) + body_inertia: diagonal inertia in ipos/iquat frame (nworld, nbody, 3) + body_invweight0: mean inv inert in qpos0 (trn, rot) (nworld, nbody, 2) jnt_type: type of joint (mjtJoint) (njnt,) jnt_qposadr: start addr in 'qpos' for joint's data (njnt,) jnt_dofadr: start addr in 'qvel' for joint's data (njnt,) jnt_bodyid: id of joint's body (njnt,) jnt_limited: does joint have limits (njnt,) jnt_actfrclimited: does joint have actuator force limits (njnt,) - jnt_solref: constraint solver reference: limit (njnt, mjNREF) - jnt_solimp: constraint solver impedance: limit (njnt, mjNIMP) + jnt_solref: constraint solver reference: limit (nworld, njnt, mjNREF) + jnt_solimp: constraint solver impedance: limit (nworld, njnt, mjNIMP) jnt_pos: local anchor position (njnt, 3) jnt_axis: local joint axis (njnt, 3) - jnt_stiffness: stiffness coefficient (njnt,) - jnt_range: joint limits (njnt, 2) - jnt_actfrcrange: range of total actuator force (njnt, 2) - jnt_margin: min distance for limit detection (njnt,) + jnt_stiffness: stiffness coefficient (nworld, njnt) + jnt_range: joint limits (nworld, njnt, 2) + jnt_actfrcrange: range of total actuator force (nworld, njnt, 2) + jnt_margin: min distance for limit detection (nworld, njnt) jnt_limited_slide_hinge_adr: limited/slide/hinge jntadr dof_bodyid: id of dof's body (nv,) dof_jntid: id of dof's joint (nv,) dof_parentid: id of dof's parent; -1: none (nv,) dof_Madr: dof address in M-diagonal (nv,) - dof_armature: dof armature inertia/mass (nv,) - dof_damping: damping coefficient (nv,) - dof_invweight0: diag. inverse inertia in qpos0 (nv,) + dof_armature: dof armature inertia/mass (nworld, nv) + dof_damping: damping coefficient (nworld, nv) + dof_invweight0: diag. inverse inertia in qpos0 (nworld, nv) dof_tri_row: np.tril_indices (mjm.nv)[0] dof_tri_col: np.tril_indices (mjm.nv)[1] geom_type: geometric type (mjtGeom) (ngeom,) - geom_contype: geom contact type (ngeom,) - geom_conaffinity: geom contact affinity (ngeom,) + geom_contype: geom contact type (nworld, ngeom) + geom_conaffinity: geom contact affinity (nworld, ngeom) geom_condim: contact dimensionality (1, 3, 4, 6) (ngeom,) geom_bodyid: id of geom's body (ngeom,) geom_dataid: id of geom's mesh/hfield; -1: none (ngeom,) - geom_priority: geom contact priority (ngeom,) - geom_solmix: mixing coef for solref/imp in geom pair (ngeom,) - geom_solref: constraint solver reference: contact (ngeom, mjNREF) - geom_solimp: constraint solver impedance: contact (ngeom, mjNIMP) + geom_priority: geom contact priority (nworld, ngeom) + geom_solmix: mixing coef for solref/imp in geom pair (nworld, ngeom) + geom_solref: constraint solver reference: contact (nworld, ngeom, mjNREF) + geom_solimp: constraint solver impedance: contact (nworld, ngeom, mjNIMP) geom_size: geom-specific size parameters (ngeom, 3) geom_aabb: bounding box, (center, size) (ngeom, 6) geom_rbound: radius of bounding sphere (ngeom,) - geom_pos: local position offset rel. to body (ngeom, 3) - geom_quat: local orientation offset rel. to body (ngeom, 4) - geom_friction: friction for (slide, spin, roll) (ngeom, 3) - geom_margin: detect contact if dist