diff --git a/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py b/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py index 498377e4..87cb15b6 100644 --- a/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py +++ b/contrib/kernel_analyzer/kernel_analyzer/ast_analyzer_test.py @@ -136,8 +136,8 @@ def test_all_issues( @kernel def test_no_issues( # Model: - qpos0: wp.array(dtype=float), - geom_pos: wp.array(dtype=wp.vec3), + qpos0: wp.array2d(dtype=float), + geom_pos: wp.array2d(dtype=wp.vec3), # Data in: qpos_in: wp.array2d(dtype=float), qvel_in: wp.array2d(dtype=float), diff --git a/mujoco_warp/_src/collision_box.py b/mujoco_warp/_src/collision_box.py index 562cf6ff..5c4f2038 100644 --- a/mujoco_warp/_src/collision_box.py +++ b/mujoco_warp/_src/collision_box.py @@ -191,21 +191,21 @@ def _box_box( # Model: geom_type: wp.array(dtype=int), geom_condim: wp.array(dtype=int), - geom_priority: wp.array(dtype=int), - geom_solmix: wp.array(dtype=float), - geom_solref: wp.array(dtype=wp.vec2), - geom_solimp: wp.array(dtype=vec5), - geom_size: wp.array(dtype=wp.vec3), - geom_friction: wp.array(dtype=wp.vec3), - geom_margin: wp.array(dtype=float), - geom_gap: wp.array(dtype=float), + geom_priority: wp.array2d(dtype=int), + geom_solmix: wp.array2d(dtype=float), + geom_solref: wp.array2d(dtype=wp.vec2), + geom_solimp: wp.array2d(dtype=vec5), + geom_size: wp.array2d(dtype=wp.vec3), + geom_friction: wp.array2d(dtype=wp.vec3), + geom_margin: wp.array2d(dtype=float), + geom_gap: wp.array2d(dtype=float), pair_dim: wp.array(dtype=int), - pair_solref: wp.array(dtype=wp.vec2), - pair_solreffriction: wp.array(dtype=wp.vec2), - pair_solimp: wp.array(dtype=vec5), - pair_margin: wp.array(dtype=float), - pair_gap: wp.array(dtype=float), - pair_friction: wp.array(dtype=vec5), + pair_solref: wp.array2d(dtype=wp.vec2), + pair_solreffriction: wp.array2d(dtype=wp.vec2), + pair_solimp: wp.array2d(dtype=vec5), + pair_margin: wp.array2d(dtype=float), + pair_gap: wp.array2d(dtype=float), + pair_friction: wp.array2d(dtype=vec5), # Data in: nconmax_in: int, geom_xpos_in: wp.array2d(dtype=wp.vec3), @@ -262,6 +262,7 @@ def _box_box( collision_pair_in, collision_pairid_in, tid, + worldid, ) # transformations @@ -271,8 +272,8 @@ def _box_box( trans_atob = b_mat_inv @ (a_pos - b_pos) rot_atob = b_mat_inv @ a_mat - a_size = geom_size[ga] - b_size = geom_size[gb] + a_size = geom_size[worldid, ga] + b_size = geom_size[worldid, gb] a = box(rot_atob, trans_atob, a_size) b = box(wp.identity(3, wp.float32), wp.vec3(0.0), b_size) @@ -363,7 +364,7 @@ def _box_box( for i in range(4): pos[i] = pos[idx] - margin = wp.max(geom_margin[ga], geom_margin[gb]) + margin = wp.max(geom_margin[worldid, ga], geom_margin[worldid, gb]) for i in range(4): pos_glob = b_mat @ pos[i] + b_pos n_glob = b_mat @ sep_axis diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py index 6e8b5ce3..43e077e5 100644 --- a/mujoco_warp/_src/collision_convex.py +++ b/mujoco_warp/_src/collision_convex.py @@ -708,24 +708,24 @@ def gjk_epa_sparse( geom_type: wp.array(dtype=int), geom_condim: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), - geom_priority: wp.array(dtype=int), - geom_solmix: wp.array(dtype=float), - geom_solref: wp.array(dtype=wp.vec2), - geom_solimp: wp.array(dtype=vec5), - geom_size: wp.array(dtype=wp.vec3), - geom_friction: wp.array(dtype=wp.vec3), - geom_margin: wp.array(dtype=float), - geom_gap: wp.array(dtype=float), + geom_priority: wp.array2d(dtype=int), + geom_solmix: wp.array2d(dtype=float), + geom_solref: wp.array2d(dtype=wp.vec2), + geom_solimp: wp.array2d(dtype=vec5), + geom_size: wp.array2d(dtype=wp.vec3), + geom_friction: wp.array2d(dtype=wp.vec3), + geom_margin: wp.array2d(dtype=float), + geom_gap: wp.array2d(dtype=float), mesh_vertadr: wp.array(dtype=int), mesh_vertnum: wp.array(dtype=int), mesh_vert: wp.array(dtype=wp.vec3), pair_dim: wp.array(dtype=int), - pair_solref: wp.array(dtype=wp.vec2), - pair_solreffriction: wp.array(dtype=wp.vec2), - pair_solimp: wp.array(dtype=vec5), - pair_margin: wp.array(dtype=float), - pair_gap: wp.array(dtype=float), - pair_friction: wp.array(dtype=vec5), + pair_solref: wp.array2d(dtype=wp.vec2), + pair_solreffriction: wp.array2d(dtype=wp.vec2), + pair_solimp: wp.array2d(dtype=vec5), + pair_margin: wp.array2d(dtype=float), + pair_gap: wp.array2d(dtype=float), + pair_friction: wp.array2d(dtype=vec5), # Data in: nconmax_in: int, geom_xpos_in: wp.array2d(dtype=wp.vec3), @@ -772,6 +772,7 @@ def gjk_epa_sparse( collision_pair_in, collision_pairid_in, tid, + worldid, ) g1 = geoms[0] @@ -783,7 +784,7 @@ def gjk_epa_sparse( geom1 = _geom( geom_type, geom_dataid, - geom_size, + geom_size[worldid], mesh_vertadr, mesh_vertnum, mesh_vert, @@ -796,7 +797,7 @@ def gjk_epa_sparse( geom2 = _geom( geom_type, geom_dataid, - geom_size, + geom_size[worldid], mesh_vertadr, mesh_vertnum, mesh_vert, @@ -806,7 +807,7 @@ def gjk_epa_sparse( g2, ) - margin = wp.max(geom_margin[g1], geom_margin[g2]) + margin = wp.max(geom_margin[worldid, g1], geom_margin[worldid, g2]) simplex, normal = _gjk(mesh_vert, geom1, geom2) diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py index 799051cc..be3a4732 100644 --- a/mujoco_warp/_src/collision_driver.py +++ b/mujoco_warp/_src/collision_driver.py @@ -32,8 +32,8 @@ @wp.func def _sphere_filter( # Model: - geom_rbound: wp.array(dtype=float), - geom_margin: wp.array(dtype=float), + geom_rbound: wp.array2d(dtype=float), + geom_margin: wp.array2d(dtype=float), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), geom_xmat_in: wp.array2d(dtype=wp.mat33), @@ -42,12 +42,12 @@ def _sphere_filter( geom2: int, worldid: int, ) -> bool: - margin1 = geom_margin[geom1] - margin2 = geom_margin[geom2] + margin1 = geom_margin[worldid, geom1] + margin2 = geom_margin[worldid, geom2] pos1 = geom_xpos_in[worldid, geom1] pos2 = geom_xpos_in[worldid, geom2] - size1 = geom_rbound[geom1] - size2 = geom_rbound[geom2] + size1 = geom_rbound[worldid, geom1] + size2 = geom_rbound[worldid, geom2] bound = size1 + size2 + wp.max(margin1, margin2) dif = pos2 - pos1 @@ -124,8 +124,8 @@ def _upper_tri_index(n: int, i: int, j: int) -> int: @wp.kernel def _sap_project( # Model: - geom_rbound: wp.array(dtype=float), - geom_margin: wp.array(dtype=float), + geom_rbound: wp.array2d(dtype=float), + geom_margin: wp.array2d(dtype=float), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), # In: @@ -138,13 +138,13 @@ def _sap_project( worldid, geomid = wp.tid() xpos = geom_xpos_in[worldid, geomid] - rbound = geom_rbound[geomid] + rbound = geom_rbound[worldid, geomid] if rbound == 0.0: # geom is a plane rbound = MJ_MAXVAL - radius = rbound + geom_margin[geomid] + radius = rbound + geom_margin[worldid, geomid] center = wp.dot(direction_in, xpos) sap_projection_lower_out[worldid, geomid] = center - radius @@ -182,8 +182,8 @@ def _sap_broadphase( # Model: ngeom: int, geom_type: wp.array(dtype=int), - geom_rbound: wp.array(dtype=float), - geom_margin: wp.array(dtype=float), + geom_rbound: wp.array2d(dtype=float), + geom_margin: wp.array2d(dtype=float), nxn_pairid: wp.array(dtype=int), # Data in: nworld_in: int, @@ -343,8 +343,8 @@ def sap_broadphase(m: Model, d: Data): def _nxn_broadphase( # Model: geom_type: wp.array(dtype=int), - geom_rbound: wp.array(dtype=float), - geom_margin: wp.array(dtype=float), + geom_rbound: wp.array2d(dtype=float), + geom_margin: wp.array2d(dtype=float), nxn_geom_pair: wp.array(dtype=wp.vec2i), nxn_pairid: wp.array(dtype=int), # Data in: diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 9c7ffd77..cab4f503 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -44,7 +44,7 @@ def _geom( # Model: geom_type: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), - geom_size: wp.array(dtype=wp.vec3), + geom_size: wp.array2d(dtype=wp.vec3), mesh_vertadr: wp.array(dtype=int), mesh_vertnum: wp.array(dtype=int), mesh_vert: wp.array(dtype=wp.vec3), @@ -59,7 +59,7 @@ def _geom( geom.pos = geom_xpos_in[worldid, gid] rot = geom_xmat_in[worldid, gid] geom.rot = rot - geom.size = geom_size[gid] + geom.size = geom_size[worldid, gid] geom.normal = wp.vec3(rot[0, 2], rot[1, 2], rot[2, 2]) # plane dataid = geom_dataid[gid] @@ -1255,46 +1255,47 @@ def plane_cylinder( def contact_params( # Model: geom_condim: wp.array(dtype=int), - geom_priority: wp.array(dtype=int), - geom_solmix: wp.array(dtype=float), - geom_solref: wp.array(dtype=wp.vec2), - geom_solimp: wp.array(dtype=vec5), - geom_friction: wp.array(dtype=wp.vec3), - geom_margin: wp.array(dtype=float), - geom_gap: wp.array(dtype=float), + geom_priority: wp.array2d(dtype=int), + geom_solmix: wp.array2d(dtype=float), + geom_solref: wp.array2d(dtype=wp.vec2), + geom_solimp: wp.array2d(dtype=vec5), + geom_friction: wp.array2d(dtype=wp.vec3), + geom_margin: wp.array2d(dtype=float), + geom_gap: wp.array2d(dtype=float), pair_dim: wp.array(dtype=int), - pair_solref: wp.array(dtype=wp.vec2), - pair_solreffriction: wp.array(dtype=wp.vec2), - pair_solimp: wp.array(dtype=vec5), - pair_margin: wp.array(dtype=float), - pair_gap: wp.array(dtype=float), - pair_friction: wp.array(dtype=vec5), + pair_solref: wp.array2d(dtype=wp.vec2), + pair_solreffriction: wp.array2d(dtype=wp.vec2), + pair_solimp: wp.array2d(dtype=vec5), + pair_margin: wp.array2d(dtype=float), + pair_gap: wp.array2d(dtype=float), + pair_friction: wp.array2d(dtype=vec5), # Data in: collision_pair_in: wp.array(dtype=wp.vec2i), collision_pairid_in: wp.array(dtype=int), # In: cid: int, + worldid: int, ): geoms = collision_pair_in[cid] pairid = collision_pairid_in[cid] if pairid > -1: - margin = pair_margin[pairid] - gap = pair_gap[pairid] + margin = pair_margin[worldid, pairid] + gap = pair_gap[worldid, pairid] condim = pair_dim[pairid] - friction = pair_friction[pairid] - solref = pair_solref[pairid] - solreffriction = pair_solreffriction[pairid] - solimp = pair_solimp[pairid] + friction = pair_friction[worldid, pairid] + solref = pair_solref[worldid, pairid] + solreffriction = pair_solreffriction[worldid, pairid] + solimp = pair_solimp[worldid, pairid] else: g1 = geoms[0] g2 = geoms[1] - p1 = geom_priority[g1] - p2 = geom_priority[g2] + p1 = geom_priority[worldid, g1] + p2 = geom_priority[worldid, g2] - solmix1 = geom_solmix[g1] - solmix2 = geom_solmix[g2] + solmix1 = geom_solmix[worldid, g1] + solmix2 = geom_solmix[worldid, g2] mix = solmix1 / (solmix1 + solmix2) mix = wp.where((solmix1 < MJ_MINVAL) and (solmix2 < MJ_MINVAL), 0.5, mix) @@ -1302,14 +1303,14 @@ def contact_params( mix = wp.where((solmix1 >= MJ_MINVAL) and (solmix2 < MJ_MINVAL), 1.0, mix) mix = wp.where(p1 == p2, mix, wp.where(p1 > p2, 1.0, 0.0)) - margin = wp.max(geom_margin[g1], geom_margin[g2]) - gap = wp.max(geom_gap[g1], geom_gap[g2]) + margin = wp.max(geom_margin[worldid, g1], geom_margin[worldid, g2]) + gap = wp.max(geom_gap[worldid, g1], geom_gap[worldid, g2]) condim1 = geom_condim[g1] condim2 = geom_condim[g2] condim = wp.where(p1 == p2, wp.max(condim1, condim2), wp.where(p1 > p2, condim1, condim2)) - max_geom_friction = wp.max(geom_friction[g1], geom_friction[g2]) + max_geom_friction = wp.max(geom_friction[worldid, g1], geom_friction[worldid, g2]) friction = vec5( max_geom_friction[0], max_geom_friction[0], @@ -1318,14 +1319,14 @@ def contact_params( max_geom_friction[2], ) - if geom_solref[g1].x > 0.0 and geom_solref[g2].x > 0.0: - solref = mix * geom_solref[g1] + (1.0 - mix) * geom_solref[g2] + if geom_solref[worldid, g1].x > 0.0 and geom_solref[worldid, g2].x > 0.0: + solref = mix * geom_solref[worldid, g1] + (1.0 - mix) * geom_solref[worldid, g2] else: - solref = wp.min(geom_solref[g1], geom_solref[g2]) + solref = wp.min(geom_solref[worldid, g1], geom_solref[worldid, g2]) solreffriction = wp.vec2(0.0, 0.0) - solimp = mix * geom_solimp[g1] + (1.0 - mix) * geom_solimp[g2] + solimp = mix * geom_solimp[worldid, g1] + (1.0 - mix) * geom_solimp[worldid, g2] return geoms, margin, gap, condim, friction, solref, solreffriction, solimp @@ -1876,24 +1877,24 @@ def _primitive_narrowphase( geom_type: wp.array(dtype=int), geom_condim: wp.array(dtype=int), geom_dataid: wp.array(dtype=int), - geom_priority: wp.array(dtype=int), - geom_solmix: wp.array(dtype=float), - geom_solref: wp.array(dtype=wp.vec2), - geom_solimp: wp.array(dtype=vec5), - geom_size: wp.array(dtype=wp.vec3), - geom_friction: wp.array(dtype=wp.vec3), - geom_margin: wp.array(dtype=float), - geom_gap: wp.array(dtype=float), + geom_priority: wp.array2d(dtype=int), + geom_solmix: wp.array2d(dtype=float), + geom_solref: wp.array2d(dtype=wp.vec2), + geom_solimp: wp.array2d(dtype=vec5), + geom_size: wp.array2d(dtype=wp.vec3), + geom_friction: wp.array2d(dtype=wp.vec3), + geom_margin: wp.array2d(dtype=float), + geom_gap: wp.array2d(dtype=float), mesh_vertadr: wp.array(dtype=int), mesh_vertnum: wp.array(dtype=int), mesh_vert: wp.array(dtype=wp.vec3), pair_dim: wp.array(dtype=int), - pair_solref: wp.array(dtype=wp.vec2), - pair_solreffriction: wp.array(dtype=wp.vec2), - pair_solimp: wp.array(dtype=vec5), - pair_margin: wp.array(dtype=float), - pair_gap: wp.array(dtype=float), - pair_friction: wp.array(dtype=vec5), + pair_solref: wp.array2d(dtype=wp.vec2), + pair_solreffriction: wp.array2d(dtype=wp.vec2), + pair_solimp: wp.array2d(dtype=vec5), + pair_margin: wp.array2d(dtype=float), + pair_gap: wp.array2d(dtype=float), + pair_friction: wp.array2d(dtype=vec5), # Data in: nconmax_in: int, geom_xpos_in: wp.array2d(dtype=wp.vec3), @@ -1921,6 +1922,8 @@ def _primitive_narrowphase( if tid >= ncollision_in[0]: return + worldid = collision_worldid_in[tid] + geoms, margin, gap, condim, friction, solref, solreffriction, solimp = contact_params( geom_condim, geom_priority, @@ -1940,12 +1943,11 @@ def _primitive_narrowphase( collision_pair_in, collision_pairid_in, tid, + worldid, ) g1 = geoms[0] g2 = geoms[1] - worldid = collision_worldid_in[tid] - geom1 = _geom( geom_type, geom_dataid, diff --git a/mujoco_warp/_src/constraint.py b/mujoco_warp/_src/constraint.py index f7b5481b..4341f505 100644 --- a/mujoco_warp/_src/constraint.py +++ b/mujoco_warp/_src/constraint.py @@ -99,14 +99,14 @@ def _efc_equality_connect( opt_timestep: float, body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - body_invweight0: wp.array2d(dtype=float), + body_invweight0: wp.array3d(dtype=float), dof_bodyid: wp.array(dtype=int), site_bodyid: wp.array(dtype=int), eq_obj1id: wp.array(dtype=int), eq_obj2id: wp.array(dtype=int), eq_objtype: wp.array(dtype=int), - eq_solref: wp.array(dtype=wp.vec2), - eq_solimp: wp.array(dtype=vec5), + eq_solref: wp.array2d(dtype=wp.vec2), + eq_solimp: wp.array2d(dtype=vec5), eq_data: wp.array(dtype=vec11), eq_connect_adr: wp.array(dtype=int), # Data in: @@ -195,11 +195,11 @@ def _efc_equality_connect( efc_J_out[efcid + 2, dofid] = j1mj2[2] Jqvel += j1mj2 * qvel_in[worldid, dofid] - invweight = body_invweight0[body1id, 0] + body_invweight0[body2id, 0] + invweight = body_invweight0[worldid, body1id, 0] + body_invweight0[worldid, body2id, 0] pos_imp = wp.length(pos) - solref = eq_solref[i_eq] - solimp = eq_solimp[i_eq] + solref = eq_solref[worldid, i_eq] + solimp = eq_solimp[worldid, i_eq] for i in range(3): efcidi = efcid + i @@ -231,14 +231,14 @@ def _efc_equality_connect( def _efc_equality_joint( # Model: opt_timestep: float, - qpos0: wp.array(dtype=float), + qpos0: wp.array2d(dtype=float), jnt_qposadr: wp.array(dtype=int), jnt_dofadr: wp.array(dtype=int), - dof_invweight0: wp.array(dtype=float), + dof_invweight0: wp.array2d(dtype=float), eq_obj1id: wp.array(dtype=int), eq_obj2id: wp.array(dtype=int), - eq_solref: wp.array(dtype=wp.vec2), - eq_solimp: wp.array(dtype=vec5), + eq_solref: wp.array2d(dtype=wp.vec2), + eq_solimp: wp.array2d(dtype=vec5), eq_data: wp.array(dtype=vec11), eq_jnt_adr: wp.array(dtype=int), # Data in: @@ -281,22 +281,22 @@ def _efc_equality_joint( # Two joint constraint qposadr2 = jnt_qposadr[jntid_2] dofadr2 = jnt_dofadr[jntid_2] - dif = qpos_in[worldid, qposadr2] - qpos0[qposadr2] + dif = qpos_in[worldid, qposadr2] - qpos0[worldid, qposadr2] # Horner's method for polynomials rhs = data[0] + dif * (data[1] + dif * (data[2] + dif * (data[3] + dif * data[4]))) deriv_2 = data[1] + dif * (2.0 * data[2] + dif * (3.0 * data[3] + dif * 4.0 * data[4])) - pos = qpos_in[worldid, qposadr1] - qpos0[qposadr1] - rhs + pos = qpos_in[worldid, qposadr1] - qpos0[worldid, qposadr1] - rhs Jqvel = qvel_in[worldid, dofadr1] - qvel_in[worldid, dofadr2] * deriv_2 - invweight = dof_invweight0[dofadr1] + dof_invweight0[dofadr2] + invweight = dof_invweight0[worldid, dofadr1] + dof_invweight0[worldid, dofadr2] efc_J_out[efcid, dofadr2] = -deriv_2 else: # Single joint constraint - pos = qpos_in[worldid, qposadr1] - qpos0[qposadr1] - data[0] + pos = qpos_in[worldid, qposadr1] - qpos0[worldid, qposadr1] - data[0] Jqvel = qvel_in[worldid, dofadr1] - invweight = dof_invweight0[dofadr1] + invweight = dof_invweight0[worldid, dofadr1] # Update constraint parameters _update_efc_row( @@ -306,8 +306,8 @@ def _efc_equality_joint( pos, pos, invweight, - eq_solref[i_eq], - eq_solimp[i_eq], + eq_solref[worldid, i_eq], + eq_solimp[worldid, i_eq], 0.0, Jqvel, 0.0, @@ -328,12 +328,12 @@ def _efc_equality_tendon( opt_timestep: float, eq_obj1id: wp.array(dtype=int), eq_obj2id: wp.array(dtype=int), - eq_solref: wp.array(dtype=wp.vec2), - eq_solimp: wp.array(dtype=vec5), + eq_solref: wp.array2d(dtype=wp.vec2), + eq_solimp: wp.array2d(dtype=vec5), eq_data: wp.array(dtype=vec11), eq_ten_adr: wp.array(dtype=int), - tendon_length0: wp.array(dtype=float), - tendon_invweight0: wp.array(dtype=float), + tendon_length0: wp.array2d(dtype=float), + tendon_invweight0: wp.array2d(dtype=float), # Data in: ne_connect_in: wp.array(dtype=int), ne_weld_in: wp.array(dtype=int), @@ -369,15 +369,15 @@ def _efc_equality_tendon( obj1id = eq_obj1id[eqid] obj2id = eq_obj2id[eqid] data = eq_data[eqid] - solref = eq_solref[eqid] - solimp = eq_solimp[eqid] - pos1 = ten_length_in[worldid, obj1id] - tendon_length0[obj1id] - pos2 = ten_length_in[worldid, obj2id] - tendon_length0[obj2id] + solref = eq_solref[worldid, eqid] + solimp = eq_solimp[worldid, eqid] + pos1 = ten_length_in[worldid, obj1id] - tendon_length0[worldid, obj1id] + pos2 = ten_length_in[worldid, obj2id] - tendon_length0[worldid, obj2id] jac1 = ten_J_in[worldid, obj1id] jac2 = ten_J_in[worldid, obj2id] if obj2id > -1: - invweight = tendon_invweight0[obj1id] + tendon_invweight0[obj2id] + invweight = tendon_invweight0[worldid, obj1id] + tendon_invweight0[worldid, obj2id] dif = pos2 dif2 = dif * dif @@ -387,7 +387,7 @@ def _efc_equality_tendon( pos = pos1 - (data[0] + data[1] * dif + data[2] * dif2 + data[3] * dif3 + data[4] * dif4) deriv = data[1] + 2.0 * data[2] * dif + 3.0 * data[3] * dif2 + 4.0 * data[4] * dif3 else: - invweight = tendon_invweight0[obj1id] + invweight = tendon_invweight0[worldid, obj1id] pos = pos1 - data[0] deriv = 0.0 @@ -426,10 +426,10 @@ def _efc_equality_tendon( def _efc_friction( # Model: opt_timestep: float, - dof_invweight0: wp.array(dtype=float), - dof_frictionloss: wp.array(dtype=float), - dof_solimp: wp.array(dtype=vec5), - dof_solref: wp.array(dtype=wp.vec2), + dof_invweight0: wp.array2d(dtype=float), + dof_frictionloss: wp.array2d(dtype=float), + dof_solimp: wp.array2d(dtype=vec5), + dof_solref: wp.array2d(dtype=wp.vec2), # Data in: qvel_in: wp.array2d(dtype=float), # In: @@ -449,7 +449,7 @@ def _efc_friction( # TODO(team): tendon worldid, dofid = wp.tid() - if dof_frictionloss[dofid] <= 0.0: + if dof_frictionloss[worldid, dofid] <= 0.0: return efcid = wp.atomic_add(nefc_out, 0, 1) @@ -465,12 +465,12 @@ def _efc_friction( efcid, 0.0, 0.0, - dof_invweight0[dofid], - dof_solref[dofid], - dof_solimp[dofid], + dof_invweight0[worldid, dofid], + dof_solref[worldid, dofid], + dof_solimp[worldid, dofid], 0.0, Jqvel, - dof_frictionloss[dofid], + dof_frictionloss[worldid, dofid], dofid, efc_id_out, efc_pos_out, @@ -489,15 +489,15 @@ def _efc_equality_weld( opt_timestep: float, body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - body_invweight0: wp.array2d(dtype=float), + body_invweight0: wp.array3d(dtype=float), dof_bodyid: wp.array(dtype=int), site_bodyid: wp.array(dtype=int), - site_quat: wp.array(dtype=wp.quat), + site_quat: wp.array2d(dtype=wp.quat), eq_obj1id: wp.array(dtype=int), eq_obj2id: wp.array(dtype=int), eq_objtype: wp.array(dtype=int), - eq_solref: wp.array(dtype=wp.vec2), - eq_solimp: wp.array(dtype=vec5), + eq_solref: wp.array2d(dtype=wp.vec2), + eq_solimp: wp.array2d(dtype=vec5), eq_data: wp.array(dtype=vec11), eq_wld_adr: wp.array(dtype=int), # Data in: @@ -552,8 +552,8 @@ def _efc_equality_weld( pos1 = site_xpos_in[worldid, obj1id] pos2 = site_xpos_in[worldid, obj2id] - quat = math.mul_quat(xquat_in[worldid, body1id], site_quat[obj1id]) - quat1 = math.quat_inv(math.mul_quat(xquat_in[worldid, body2id], site_quat[obj2id])) + quat = math.mul_quat(xquat_in[worldid, body1id], site_quat[worldid, obj1id]) + quat1 = math.quat_inv(math.mul_quat(xquat_in[worldid, body2id], site_quat[worldid, obj2id])) else: body1id = obj1id @@ -612,12 +612,12 @@ def _efc_equality_weld( crotq = math.mul_quat(quat1, quat) # copy axis components crot = wp.vec3(crotq[1], crotq[2], crotq[3]) * torquescale - invweight_t = body_invweight0[body1id, 0] + body_invweight0[body2id, 0] + invweight_t = body_invweight0[worldid, body1id, 0] + body_invweight0[worldid, body2id, 0] pos_imp = wp.sqrt(wp.length_sq(cpos) + wp.length_sq(crot)) - solref = eq_solref[i_eq] - solimp = eq_solimp[i_eq] + solref = eq_solref[worldid, i_eq] + solimp = eq_solimp[worldid, i_eq] for i in range(3): _update_efc_row( @@ -641,7 +641,7 @@ def _efc_equality_weld( efc_frictionloss_out, ) - invweight_r = body_invweight0[body1id, 1] + body_invweight0[body2id, 1] + invweight_r = body_invweight0[worldid, body1id, 1] + body_invweight0[worldid, body2id, 1] for i in range(3): _update_efc_row( @@ -672,12 +672,12 @@ def _efc_limit_slide_hinge( opt_timestep: float, jnt_qposadr: wp.array(dtype=int), jnt_dofadr: wp.array(dtype=int), - jnt_solref: wp.array(dtype=wp.vec2), - jnt_solimp: wp.array(dtype=vec5), - jnt_range: wp.array2d(dtype=float), - jnt_margin: wp.array(dtype=float), + jnt_solref: wp.array2d(dtype=wp.vec2), + jnt_solimp: wp.array2d(dtype=vec5), + jnt_range: wp.array3d(dtype=float), + jnt_margin: wp.array2d(dtype=float), jnt_limited_slide_hinge_adr: wp.array(dtype=int), - dof_invweight0: wp.array(dtype=float), + dof_invweight0: wp.array2d(dtype=float), # Data in: nefc_in: wp.array(dtype=int), qpos_in: wp.array2d(dtype=float), @@ -697,10 +697,10 @@ def _efc_limit_slide_hinge( ): worldid, jntlimitedid = wp.tid() jntid = jnt_limited_slide_hinge_adr[jntlimitedid] - jntrange = jnt_range[jntid] + jntrange = jnt_range[worldid, jntid] qpos = qpos_in[worldid, jnt_qposadr[jntid]] - jntmargin = jnt_margin[jntid] + jntmargin = jnt_margin[worldid, jntid] dist_min, dist_max = qpos - jntrange[0], jntrange[1] - qpos pos = wp.min(dist_min, dist_max) - jntmargin active = pos < 0 @@ -722,9 +722,9 @@ def _efc_limit_slide_hinge( efcid, pos, pos, - dof_invweight0[dofadr], - jnt_solref[jntid], - jnt_solimp[jntid], + dof_invweight0[worldid, dofadr], + jnt_solref[worldid, jntid], + jnt_solimp[worldid, jntid], jntmargin, Jqvel, 0.0, @@ -744,12 +744,12 @@ def _efc_limit_ball( opt_timestep: float, jnt_qposadr: wp.array(dtype=int), jnt_dofadr: wp.array(dtype=int), - jnt_solref: wp.array(dtype=wp.vec2), - jnt_solimp: wp.array(dtype=vec5), - jnt_range: wp.array2d(dtype=float), - jnt_margin: wp.array(dtype=float), + jnt_solref: wp.array2d(dtype=wp.vec2), + jnt_solimp: wp.array2d(dtype=vec5), + jnt_range: wp.array3d(dtype=float), + jnt_margin: wp.array2d(dtype=float), jnt_limited_ball_adr: wp.array(dtype=int), - dof_invweight0: wp.array(dtype=float), + dof_invweight0: wp.array2d(dtype=float), # Data in: nefc_in: wp.array(dtype=int), qpos_in: wp.array2d(dtype=float), @@ -774,9 +774,9 @@ def _efc_limit_ball( qpos = qpos_in[worldid] jnt_quat = wp.quat(qpos[qposadr + 0], qpos[qposadr + 1], qpos[qposadr + 2], qpos[qposadr + 3]) axis_angle = math.quat_to_vel(jnt_quat) - jntrange = jnt_range[jntid] + jntrange = jnt_range[worldid, jntid] axis, angle = math.normalize_with_norm(axis_angle) - jntmargin = jnt_margin[jntid] + jntmargin = jnt_margin[worldid, jntid] pos = wp.max(jntrange[0], jntrange[1]) - angle - jntmargin active = pos < 0 @@ -802,9 +802,9 @@ def _efc_limit_ball( efcid, pos, pos, - dof_invweight0[dofadr], - jnt_solref[jntid], - jnt_solimp[jntid], + dof_invweight0[worldid, dofadr], + jnt_solref[worldid, jntid], + jnt_solimp[worldid, jntid], jntmargin, Jqvel, 0.0, @@ -827,11 +827,11 @@ def _efc_limit_tendon( tendon_adr: wp.array(dtype=int), tendon_num: wp.array(dtype=int), tendon_limited_adr: wp.array(dtype=int), - tendon_solref_lim: wp.array(dtype=wp.vec2), - tendon_solimp_lim: wp.array(dtype=vec5), - tendon_range: wp.array(dtype=wp.vec2), - tendon_margin: wp.array(dtype=float), - tendon_invweight0: wp.array(dtype=float), + tendon_solref_lim: wp.array2d(dtype=wp.vec2), + tendon_solimp_lim: wp.array2d(dtype=vec5), + tendon_range: wp.array2d(dtype=wp.vec2), + tendon_margin: wp.array2d(dtype=float), + tendon_invweight0: wp.array2d(dtype=float), wrap_objid: wp.array(dtype=int), wrap_type: wp.array(dtype=int), # Data in: @@ -855,10 +855,10 @@ def _efc_limit_tendon( worldid, tenlimitedid = wp.tid() tenid = tendon_limited_adr[tenlimitedid] - tenrange = tendon_range[tenid] + tenrange = tendon_range[worldid, tenid] length = ten_length_in[worldid, tenid] dist_min, dist_max = length - tenrange[0], tenrange[1] - length - tenmargin = tendon_margin[tenid] + tenmargin = tendon_margin[worldid, tenid] pos = wp.min(dist_min, dist_max) - tenmargin active = pos < 0 @@ -890,9 +890,9 @@ def _efc_limit_tendon( efcid, pos, pos, - tendon_invweight0[tenid], - tendon_solref_lim[tenid], - tendon_solimp_lim[tenid], + tendon_invweight0[worldid, tenid], + tendon_solref_lim[worldid, tenid], + tendon_solimp_lim[worldid, tenid], tenmargin, Jqvel, 0.0, @@ -914,7 +914,7 @@ def _efc_contact_pyramidal( opt_impratio: float, body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - body_invweight0: wp.array2d(dtype=float), + body_invweight0: wp.array3d(dtype=float), dof_bodyid: wp.array(dtype=int), geom_bodyid: wp.array(dtype=int), # Data in: @@ -974,7 +974,7 @@ def _efc_contact_pyramidal( frame = frame_in[conid] # pyramidal has common invweight across all edges - invweight = body_invweight0[body1, 0] + body_invweight0[body2, 0] + invweight = body_invweight0[worldid, body1, 0] + body_invweight0[worldid, body2, 0] if condim > 1: dimid2 = dimid / 2 + 1 @@ -1060,7 +1060,7 @@ def _efc_contact_elliptic( opt_impratio: float, body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - body_invweight0: wp.array2d(dtype=float), + body_invweight0: wp.array3d(dtype=float), dof_bodyid: wp.array(dtype=int), geom_bodyid: wp.array(dtype=int), # Data in: @@ -1158,7 +1158,7 @@ def _efc_contact_elliptic( efc_J_out[efcid, i] = J Jqvel += J * qvel_in[worldid, i] - invweight = body_invweight0[body1, 0] + body_invweight0[body2, 0] + invweight = body_invweight0[worldid, body1, 0] + body_invweight0[worldid, body2, 0] ref = solref_in[conid] pos_aref = pos diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 0b104c75..c3c9cb7e 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -141,8 +141,8 @@ def _next_activation( opt_timestep: float, actuator_dyntype: wp.array(dtype=int), actuator_actlimited: wp.array(dtype=bool), - actuator_dynprm: wp.array(dtype=vec10f), - actuator_actrange: wp.array(dtype=wp.vec2), + actuator_dynprm: wp.array2d(dtype=vec10f), + actuator_actrange: wp.array2d(dtype=wp.vec2), # Data in: act_in: wp.array2d(dtype=float), act_dot_in: wp.array2d(dtype=float), @@ -159,7 +159,7 @@ def _next_activation( # advance the actuation if actuator_dyntype[actid] == wp.static(DynType.FILTEREXACT.value): - dyn_prm = actuator_dynprm[actid] + dyn_prm = actuator_dynprm[worldid, actid] tau = wp.max(MJ_MINVAL, dyn_prm[0]) act += act_dot_scale_in * act_dot * tau * (1.0 - wp.exp(-opt_timestep / tau)) else: @@ -167,7 +167,7 @@ def _next_activation( # clamp to actrange if limit and actuator_actlimited[actid]: - actrange = actuator_actrange[actid] + actrange = actuator_actrange[worldid, actid] act = wp.clamp(act, actrange[0], actrange[1]) act_out[worldid, actid] = act @@ -267,7 +267,7 @@ def _euler_damp_qfrc_sparse( # Model: opt_timestep: float, dof_Madr: wp.array(dtype=int), - dof_damping: wp.array(dtype=float), + dof_damping: wp.array2d(dtype=float), # Data in: qfrc_smooth_in: wp.array2d(dtype=float), qfrc_constraint_in: wp.array2d(dtype=float), @@ -278,7 +278,7 @@ def _euler_damp_qfrc_sparse( worldid, tid = wp.tid() adr = dof_Madr[tid] - qM_integration_out[worldid, 0, adr] += opt_timestep * dof_damping[tid] + qM_integration_out[worldid, 0, adr] += opt_timestep * dof_damping[worldid, tid] qfrc_integration_out[worldid, tid] = qfrc_smooth_in[worldid, tid] + qfrc_constraint_in[worldid, tid] @@ -314,7 +314,7 @@ def _tile_euler_dense(tile: TileSet): @nested_kernel def euler_dense( # Model: - dof_damping: wp.array(dtype=float), + dof_damping: wp.array2d(dtype=float), opt_timestep: float, # Data in: qM_in: wp.array3d(dtype=float), @@ -330,7 +330,7 @@ def euler_dense( dofid = adr_in[nodeid] M_tile = wp.tile_load(qM_in[worldid], shape=(TILE_SIZE, TILE_SIZE), offset=(dofid, dofid)) - damping_tile = wp.tile_load(dof_damping, shape=(TILE_SIZE,), offset=(dofid,)) + damping_tile = wp.tile_load(dof_damping[worldid], shape=(TILE_SIZE,), offset=(dofid,)) damping_scaled = damping_tile * opt_timestep qm_integration_tile = wp.tile_diag_add(M_tile, damping_scaled) @@ -481,8 +481,8 @@ def _implicit_actuator_bias_gain_vel( actuator_dyntype: wp.array(dtype=int), actuator_gaintype: wp.array(dtype=int), actuator_biastype: wp.array(dtype=int), - actuator_gainprm: wp.array(dtype=vec10f), - actuator_biasprm: wp.array(dtype=vec10f), + actuator_gainprm: wp.array2d(dtype=vec10f), + actuator_biasprm: wp.array2d(dtype=vec10f), # Data in: act_in: wp.array2d(dtype=float), ctrl_in: wp.array2d(dtype=float), @@ -492,12 +492,12 @@ def _implicit_actuator_bias_gain_vel( worldid, actid = wp.tid() if actuator_biastype[actid] == wp.static(BiasType.AFFINE.value): - bias_vel = actuator_biasprm[actid][2] + bias_vel = actuator_biasprm[worldid, actid][2] else: bias_vel = 0.0 if actuator_gaintype[actid] == wp.static(GainType.AFFINE.value): - gain_vel = actuator_gainprm[actid][2] + gain_vel = actuator_gainprm[worldid, actid][2] else: gain_vel = 0.0 @@ -523,7 +523,7 @@ def subtract_multiply(x: float, y: float): @nested_kernel def implicit_actuator_qderiv( # Model: - dof_damping: wp.array(dtype=float), + dof_damping: wp.array2d(dtype=float), # Data in: actuator_moment_in: wp.array3d(dtype=float), qM_in: wp.array3d(dtype=float), @@ -562,7 +562,7 @@ def implicit_actuator_qderiv( qderiv_tile = wp.tile_zeros(shape=(TILE_NV_SIZE, TILE_NV_SIZE), dtype=wp.float32) if wp.static(passive_enabled): - dof_damping_tile = wp.tile_load(dof_damping, shape=TILE_NV_SIZE, offset=offset_nv) + dof_damping_tile = wp.tile_load(dof_damping[worldid], shape=TILE_NV_SIZE, offset=offset_nv) negative = wp.neg(dof_damping_tile) qderiv_tile = wp.tile_diag_add(qderiv_tile, negative) @@ -818,11 +818,11 @@ def _actuator_force( actuator_actnum: wp.array(dtype=int), actuator_ctrllimited: wp.array(dtype=bool), actuator_forcelimited: wp.array(dtype=bool), - actuator_dynprm: wp.array(dtype=vec10f), - actuator_gainprm: wp.array(dtype=vec10f), - actuator_biasprm: wp.array(dtype=vec10f), - actuator_ctrlrange: wp.array(dtype=wp.vec2), - actuator_forcerange: wp.array(dtype=wp.vec2), + actuator_dynprm: wp.array2d(dtype=vec10f), + actuator_gainprm: wp.array2d(dtype=vec10f), + actuator_biasprm: wp.array2d(dtype=vec10f), + actuator_ctrlrange: wp.array2d(dtype=wp.vec2), + actuator_forcerange: wp.array2d(dtype=wp.vec2), # Data in: act_in: wp.array2d(dtype=float), ctrl_in: wp.array2d(dtype=float), @@ -839,7 +839,7 @@ def _actuator_force( ctrl = ctrl_in[worldid, uid] if actuator_ctrllimited[uid] and not dsbl_clampctrl: - ctrlrange = actuator_ctrlrange[uid] + ctrlrange = actuator_ctrlrange[worldid, uid] ctrl = wp.clamp(ctrl, ctrlrange[0], ctrlrange[1]) if na: @@ -848,7 +848,7 @@ def _actuator_force( if dyntype == int(DynType.INTEGRATOR.value): act_dot_out[worldid, actuator_actadr[uid]] = ctrl elif dyntype == int(DynType.FILTER.value) or dyntype == int(DynType.FILTEREXACT.value): - dynprm = actuator_dynprm[uid] + dynprm = actuator_dynprm[worldid, uid] actadr = actuator_actadr[uid] act = act_in[worldid, actadr] act_dot_out[worldid, actadr] = (ctrl - act) / wp.max(dynprm[0], MJ_MINVAL) @@ -867,7 +867,7 @@ def _actuator_force( # gain gaintype = actuator_gaintype[uid] - gainprm = actuator_gainprm[uid] + gainprm = actuator_gainprm[worldid, uid] gain = 0.0 if gaintype == int(GainType.FIXED.value): @@ -879,7 +879,7 @@ def _actuator_force( # bias biastype = actuator_biastype[uid] - biasprm = actuator_biasprm[uid] + biasprm = actuator_biasprm[worldid, uid] bias = 0.0 # BiasType.NONE if biastype == int(BiasType.AFFINE.value): @@ -892,7 +892,7 @@ def _actuator_force( # TODO(team): tendon total force clamping if actuator_forcelimited[uid]: - forcerange = actuator_forcerange[uid] + forcerange = actuator_forcerange[worldid, uid] force = wp.clamp(force, forcerange[0], forcerange[1]) actuator_force_out[worldid, uid] = force @@ -904,7 +904,7 @@ def _qfrc_actuator_sparse( nu: int, ngravcomp: int, jnt_actfrclimited: wp.array(dtype=bool), - jnt_actfrcrange: wp.array(dtype=wp.vec2), + jnt_actfrcrange: wp.array2d(dtype=wp.vec2), jnt_actgravcomp: wp.array(dtype=int), dof_jntid: wp.array(dtype=int), # Data in: @@ -928,7 +928,7 @@ def _qfrc_actuator_sparse( qfrc += qfrc_gravcomp_in[worldid, dofid] if jnt_actfrclimited[jntid]: - frcrange = jnt_actfrcrange[jntid] + frcrange = jnt_actfrcrange[worldid, jntid] qfrc = wp.clamp(qfrc, frcrange[0], frcrange[1]) qfrc_actuator_out[worldid, dofid] = qfrc @@ -939,7 +939,7 @@ def _qfrc_actuator_limited( # Model: ngravcomp: int, jnt_actfrclimited: wp.array(dtype=bool), - jnt_actfrcrange: wp.array(dtype=wp.vec2), + jnt_actfrcrange: wp.array2d(dtype=wp.vec2), jnt_actgravcomp: wp.array(dtype=int), dof_jntid: wp.array(dtype=int), # Data in: @@ -957,7 +957,7 @@ def _qfrc_actuator_limited( qfrc_dof += qfrc_gravcomp_in[worldid, dofid] if jnt_actfrclimited[jntid]: - frcrange = jnt_actfrcrange[jntid] + frcrange = jnt_actfrcrange[worldid, jntid] qfrc_dof = wp.clamp(qfrc_dof, frcrange[0], frcrange[1]) qfrc_actuator_out[worldid, dofid] = qfrc_dof diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 6e828415..f508240f 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -250,6 +250,13 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: nxn_geom_pair.append((geom1, geom2)) nxn_pairid.append(pairid) + def create_nmodel_batched_array(mjm_array, dtype): + array = wp.array(mjm_array, dtype=dtype) + array.ndim += 1 + array.shape = (1,) + array.shape + array.strides = (0,) + array.strides + return array + return types.Model( nq=mjm.nq, nv=mjm.nv, @@ -294,8 +301,8 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: stat=types.Statistic( meaninertia=mjm.stat.meaninertia, ), - qpos0=wp.array(mjm.qpos0, dtype=float), - qpos_spring=wp.array(mjm.qpos_spring, dtype=float), + qpos0=create_nmodel_batched_array(mjm.qpos0, dtype=float), + qpos_spring=create_nmodel_batched_array(mjm.qpos_spring, dtype=float), qM_fullm_i=wp.array(qM_fullm_i, dtype=int), qM_fullm_j=wp.array(qM_fullm_j, dtype=int), qM_mulm_i=wp.array(qM_mulm_i, dtype=int), @@ -319,15 +326,15 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: body_dofadr=wp.array(mjm.body_dofadr, dtype=int), body_geomnum=wp.array(mjm.body_geomnum, dtype=int), body_geomadr=wp.array(mjm.body_geomadr, dtype=int), - body_pos=wp.array(mjm.body_pos, dtype=wp.vec3), - body_quat=wp.array(mjm.body_quat, dtype=wp.quat), - body_ipos=wp.array(mjm.body_ipos, dtype=wp.vec3), - body_iquat=wp.array(mjm.body_iquat, dtype=wp.quat), - body_mass=wp.array(mjm.body_mass, dtype=float), - body_subtreemass=wp.array(mjm.body_subtreemass, dtype=float), - subtree_mass=wp.array(subtree_mass, dtype=float), - body_inertia=wp.array(mjm.body_inertia, dtype=wp.vec3), - body_invweight0=wp.array(mjm.body_invweight0, dtype=float), + body_pos=create_nmodel_batched_array(mjm.body_pos, dtype=wp.vec3), + body_quat=create_nmodel_batched_array(mjm.body_quat, dtype=wp.quat), + body_ipos=create_nmodel_batched_array(mjm.body_ipos, dtype=wp.vec3), + body_iquat=create_nmodel_batched_array(mjm.body_iquat, dtype=wp.quat), + body_mass=create_nmodel_batched_array(mjm.body_mass, dtype=float), + body_subtreemass=create_nmodel_batched_array(mjm.body_subtreemass, dtype=float), + subtree_mass=create_nmodel_batched_array(subtree_mass, dtype=float), + body_inertia=create_nmodel_batched_array(mjm.body_inertia, dtype=wp.vec3), + body_invweight0=create_nmodel_batched_array(mjm.body_invweight0, dtype=float), body_contype=wp.array(mjm.body_contype, dtype=int), body_conaffinity=wp.array(mjm.body_conaffinity, dtype=int), body_gravcomp=wp.array(mjm.body_gravcomp, dtype=float), @@ -337,14 +344,14 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: jnt_bodyid=wp.array(mjm.jnt_bodyid, dtype=int), jnt_limited=wp.array(mjm.jnt_limited, dtype=int), jnt_actfrclimited=wp.array(mjm.jnt_actfrclimited, dtype=bool), - jnt_solref=wp.array(mjm.jnt_solref, dtype=wp.vec2), - jnt_solimp=wp.array(mjm.jnt_solimp, dtype=types.vec5), - jnt_pos=wp.array(mjm.jnt_pos, dtype=wp.vec3), + jnt_solref=create_nmodel_batched_array(mjm.jnt_solref, dtype=wp.vec2), + jnt_solimp=create_nmodel_batched_array(mjm.jnt_solimp, dtype=types.vec5), + jnt_pos=create_nmodel_batched_array(mjm.jnt_pos, dtype=wp.vec3), jnt_axis=wp.array(mjm.jnt_axis, dtype=wp.vec3), - jnt_stiffness=wp.array(mjm.jnt_stiffness, dtype=float), - jnt_range=wp.array(mjm.jnt_range, dtype=float), - jnt_actfrcrange=wp.array(mjm.jnt_actfrcrange, dtype=wp.vec2), - jnt_margin=wp.array(mjm.jnt_margin, dtype=float), + jnt_stiffness=create_nmodel_batched_array(mjm.jnt_stiffness, dtype=float), + jnt_range=create_nmodel_batched_array(mjm.jnt_range, dtype=float), + jnt_actfrcrange=create_nmodel_batched_array(mjm.jnt_actfrcrange, dtype=wp.vec2), + jnt_margin=create_nmodel_batched_array(mjm.jnt_margin, dtype=float), # these jnt_limited adrs are used in constraint.py jnt_limited_slide_hinge_adr=wp.array( np.nonzero( @@ -361,12 +368,12 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: dof_jntid=wp.array(mjm.dof_jntid, dtype=int), dof_parentid=wp.array(mjm.dof_parentid, dtype=int), dof_Madr=wp.array(mjm.dof_Madr, dtype=int), - dof_armature=wp.array(mjm.dof_armature, dtype=float), - dof_damping=wp.array(mjm.dof_damping, dtype=float), - dof_invweight0=wp.array(mjm.dof_invweight0, dtype=float), - dof_frictionloss=wp.array(mjm.dof_frictionloss, dtype=float), - dof_solimp=wp.array(mjm.dof_solimp, dtype=types.vec5), - dof_solref=wp.array(mjm.dof_solref, dtype=wp.vec2), + dof_armature=create_nmodel_batched_array(mjm.dof_armature, dtype=float), + dof_damping=create_nmodel_batched_array(mjm.dof_damping, dtype=float), + dof_invweight0=create_nmodel_batched_array(mjm.dof_invweight0, dtype=float), + dof_frictionloss=create_nmodel_batched_array(mjm.dof_frictionloss, dtype=float), + dof_solimp=create_nmodel_batched_array(mjm.dof_solimp, dtype=types.vec5), + dof_solref=create_nmodel_batched_array(mjm.dof_solref, dtype=wp.vec2), dof_tri_row=wp.array(dof_tri_row, dtype=int), dof_tri_col=wp.array(dof_tri_col, dtype=int), geom_type=wp.array(mjm.geom_type, dtype=int), @@ -375,28 +382,28 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: geom_condim=wp.array(mjm.geom_condim, dtype=int), geom_bodyid=wp.array(mjm.geom_bodyid, dtype=int), geom_dataid=wp.array(mjm.geom_dataid, dtype=int), - geom_priority=wp.array(mjm.geom_priority, dtype=int), - geom_solmix=wp.array(mjm.geom_solmix, dtype=float), - geom_solref=wp.array(mjm.geom_solref, dtype=wp.vec2), - geom_solimp=wp.array(mjm.geom_solimp, dtype=types.vec5), - geom_size=wp.array(mjm.geom_size, dtype=wp.vec3), - geom_aabb=wp.array(mjm.geom_aabb, dtype=wp.vec3), - geom_rbound=wp.array(mjm.geom_rbound, dtype=float), - geom_pos=wp.array(mjm.geom_pos, dtype=wp.vec3), - geom_quat=wp.array(mjm.geom_quat, dtype=wp.quat), - geom_friction=wp.array(mjm.geom_friction, dtype=wp.vec3), - geom_margin=wp.array(mjm.geom_margin, dtype=float), - geom_gap=wp.array(mjm.geom_gap, dtype=float), + geom_priority=create_nmodel_batched_array(mjm.geom_priority, dtype=int), + geom_solmix=create_nmodel_batched_array(mjm.geom_solmix, dtype=float), + geom_solref=create_nmodel_batched_array(mjm.geom_solref, dtype=wp.vec2), + geom_solimp=create_nmodel_batched_array(mjm.geom_solimp, dtype=types.vec5), + geom_size=create_nmodel_batched_array(mjm.geom_size, dtype=wp.vec3), + geom_aabb=create_nmodel_batched_array(mjm.geom_aabb, dtype=wp.vec3), + geom_rbound=create_nmodel_batched_array(mjm.geom_rbound, dtype=float), + geom_pos=create_nmodel_batched_array(mjm.geom_pos, dtype=wp.vec3), + geom_quat=create_nmodel_batched_array(mjm.geom_quat, dtype=wp.quat), + geom_friction=create_nmodel_batched_array(mjm.geom_friction, dtype=wp.vec3), + geom_margin=create_nmodel_batched_array(mjm.geom_margin, dtype=float), + geom_gap=create_nmodel_batched_array(mjm.geom_gap, dtype=float), site_bodyid=wp.array(mjm.site_bodyid, dtype=int), - site_pos=wp.array(mjm.site_pos, dtype=wp.vec3), - site_quat=wp.array(mjm.site_quat, dtype=wp.quat), + site_pos=create_nmodel_batched_array(mjm.site_pos, dtype=wp.vec3), + site_quat=create_nmodel_batched_array(mjm.site_quat, dtype=wp.quat), cam_mode=wp.array(mjm.cam_mode, dtype=int), cam_bodyid=wp.array(mjm.cam_bodyid, dtype=int), cam_targetbodyid=wp.array(mjm.cam_targetbodyid, dtype=int), - cam_pos=wp.array(mjm.cam_pos, dtype=wp.vec3), - cam_quat=wp.array(mjm.cam_quat, dtype=wp.quat), - cam_poscom0=wp.array(mjm.cam_poscom0, dtype=wp.vec3), - cam_pos0=wp.array(mjm.cam_pos0, dtype=wp.vec3), + cam_pos=create_nmodel_batched_array(mjm.cam_pos, dtype=wp.vec3), + cam_quat=create_nmodel_batched_array(mjm.cam_quat, dtype=wp.quat), + cam_poscom0=create_nmodel_batched_array(mjm.cam_poscom0, dtype=wp.vec3), + cam_pos0=create_nmodel_batched_array(mjm.cam_pos0, dtype=wp.vec3), cam_fovy=wp.array(mjm.cam_fovy, dtype=float), cam_resolution=wp.array(mjm.cam_resolution, dtype=wp.vec2i), cam_sensorsize=wp.array(mjm.cam_sensorsize, dtype=wp.vec2), @@ -404,10 +411,10 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: light_mode=wp.array(mjm.light_mode, dtype=int), light_bodyid=wp.array(mjm.light_bodyid, dtype=int), light_targetbodyid=wp.array(mjm.light_targetbodyid, dtype=int), - light_pos=wp.array(mjm.light_pos, dtype=wp.vec3), - light_dir=wp.array(mjm.light_dir, dtype=wp.vec3), - light_poscom0=wp.array(mjm.light_poscom0, dtype=wp.vec3), - light_pos0=wp.array(mjm.light_pos0, dtype=wp.vec3), + light_pos=create_nmodel_batched_array(mjm.light_pos, dtype=wp.vec3), + light_dir=create_nmodel_batched_array(mjm.light_dir, dtype=wp.vec3), + light_poscom0=create_nmodel_batched_array(mjm.light_poscom0, dtype=wp.vec3), + light_pos0=create_nmodel_batched_array(mjm.light_pos0, dtype=wp.vec3), mesh_vertadr=wp.array(mjm.mesh_vertadr, dtype=int), mesh_vertnum=wp.array(mjm.mesh_vertnum, dtype=int), mesh_vert=wp.array(mjm.mesh_vert, dtype=wp.vec3), @@ -416,8 +423,8 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: eq_obj2id=wp.array(mjm.eq_obj2id, dtype=int), eq_objtype=wp.array(mjm.eq_objtype, dtype=int), eq_active0=wp.array(mjm.eq_active0, dtype=bool), - eq_solref=wp.array(mjm.eq_solref, dtype=wp.vec2), - eq_solimp=wp.array(mjm.eq_solimp, dtype=types.vec5), + eq_solref=create_nmodel_batched_array(mjm.eq_solref, dtype=wp.vec2), + eq_solimp=create_nmodel_batched_array(mjm.eq_solimp, dtype=types.vec5), eq_data=wp.array(mjm.eq_data, dtype=types.vec11), # pre-compute indices of equality constraints eq_connect_adr=wp.array(np.nonzero(mjm.eq_type == types.EqType.CONNECT.value)[0], dtype=int), @@ -436,13 +443,13 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: actuator_ctrllimited=wp.array(mjm.actuator_ctrllimited, dtype=bool), actuator_forcelimited=wp.array(mjm.actuator_forcelimited, dtype=bool), actuator_actlimited=wp.array(mjm.actuator_actlimited, dtype=bool), - actuator_dynprm=wp.array(mjm.actuator_dynprm, dtype=types.vec10f), - actuator_gainprm=wp.array(mjm.actuator_gainprm, dtype=types.vec10f), - actuator_biasprm=wp.array(mjm.actuator_biasprm, dtype=types.vec10f), - actuator_ctrlrange=wp.array(mjm.actuator_ctrlrange, dtype=wp.vec2), - actuator_forcerange=wp.array(mjm.actuator_forcerange, dtype=wp.vec2), - actuator_actrange=wp.array(mjm.actuator_actrange, dtype=wp.vec2), - actuator_gear=wp.array(mjm.actuator_gear, dtype=wp.spatial_vector), + actuator_dynprm=create_nmodel_batched_array(mjm.actuator_dynprm, dtype=types.vec10f), + actuator_gainprm=create_nmodel_batched_array(mjm.actuator_gainprm, dtype=types.vec10f), + actuator_biasprm=create_nmodel_batched_array(mjm.actuator_biasprm, dtype=types.vec10f), + actuator_ctrlrange=create_nmodel_batched_array(mjm.actuator_ctrlrange, dtype=wp.vec2), + actuator_forcerange=create_nmodel_batched_array(mjm.actuator_forcerange, dtype=wp.vec2), + actuator_actrange=create_nmodel_batched_array(mjm.actuator_actrange, dtype=wp.vec2), + actuator_gear=create_nmodel_batched_array(mjm.actuator_gear, dtype=wp.spatial_vector), exclude_signature=wp.array(mjm.exclude_signature, dtype=int), # short-circuiting here allows us to skip a lot of code in implicit integration actuator_affine_bias_gain=bool( @@ -454,25 +461,25 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: pair_dim=wp.array(mjm.pair_dim, dtype=int), pair_geom1=wp.array(mjm.pair_geom1, dtype=int), pair_geom2=wp.array(mjm.pair_geom2, dtype=int), - pair_solref=wp.array(mjm.pair_solref, dtype=wp.vec2), - pair_solreffriction=wp.array(mjm.pair_solreffriction, dtype=wp.vec2), - pair_solimp=wp.array(mjm.pair_solimp, dtype=types.vec5), - pair_margin=wp.array(mjm.pair_margin, dtype=float), - pair_gap=wp.array(mjm.pair_gap, dtype=float), - pair_friction=wp.array(mjm.pair_friction, dtype=types.vec5), + pair_solref=create_nmodel_batched_array(mjm.pair_solref, dtype=wp.vec2), + pair_solreffriction=create_nmodel_batched_array(mjm.pair_solreffriction, dtype=wp.vec2), + pair_solimp=create_nmodel_batched_array(mjm.pair_solimp, dtype=types.vec5), + pair_margin=create_nmodel_batched_array(mjm.pair_margin, dtype=float), + pair_gap=create_nmodel_batched_array(mjm.pair_gap, dtype=float), + pair_friction=create_nmodel_batched_array(mjm.pair_friction, dtype=types.vec5), condim_max=np.max(mjm.pair_dim) if mjm.npair else np.max(mjm.geom_condim), # TODO(team): get max after filtering, tendon_adr=wp.array(mjm.tendon_adr, dtype=int), tendon_num=wp.array(mjm.tendon_num, dtype=int), tendon_limited=wp.array(mjm.tendon_limited, dtype=int), tendon_limited_adr=wp.array(np.nonzero(mjm.tendon_limited)[0], dtype=wp.int32, ndim=1), - tendon_solref_lim=wp.array(mjm.tendon_solref_lim, dtype=wp.vec2f), - tendon_solimp_lim=wp.array(mjm.tendon_solimp_lim, dtype=types.vec5), - tendon_range=wp.array(mjm.tendon_range, dtype=wp.vec2f), - tendon_margin=wp.array(mjm.tendon_margin, dtype=float), - tendon_length0=wp.array(mjm.tendon_length0, dtype=float), - tendon_invweight0=wp.array(mjm.tendon_invweight0, dtype=float), + tendon_solref_lim=create_nmodel_batched_array(mjm.tendon_solref_lim, dtype=wp.vec2f), + tendon_solimp_lim=create_nmodel_batched_array(mjm.tendon_solimp_lim, dtype=types.vec5), + tendon_range=create_nmodel_batched_array(mjm.tendon_range, dtype=wp.vec2f), + tendon_margin=create_nmodel_batched_array(mjm.tendon_margin, dtype=float), + tendon_length0=create_nmodel_batched_array(mjm.tendon_length0, dtype=float), + tendon_invweight0=create_nmodel_batched_array(mjm.tendon_invweight0, dtype=float), wrap_objid=wp.array(mjm.wrap_objid, dtype=int), - wrap_prm=wp.array(mjm.wrap_prm, dtype=float), + wrap_prm=create_nmodel_batched_array(mjm.wrap_prm, dtype=float), wrap_type=wp.array(mjm.wrap_type, dtype=int), tendon_jnt_adr=wp.array(tendon_jnt_adr, dtype=int), tendon_site_adr=wp.array(tendon_site_adr, dtype=int), @@ -490,7 +497,7 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: sensor_refid=wp.array(mjm.sensor_refid, dtype=int), sensor_dim=wp.array(mjm.sensor_dim, dtype=int), sensor_adr=wp.array(mjm.sensor_adr, dtype=int), - sensor_cutoff=wp.array(mjm.sensor_cutoff, dtype=float), + sensor_cutoff=create_nmodel_batched_array(mjm.sensor_cutoff, dtype=float), sensor_pos_adr=wp.array( np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_POS)[0], dtype=int, diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index c382716d..51c7b325 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -27,18 +27,18 @@ @wp.kernel def _spring_passive( # Model: - qpos_spring: wp.array(dtype=float), + qpos_spring: wp.array2d(dtype=float), jnt_type: wp.array(dtype=int), jnt_qposadr: wp.array(dtype=int), jnt_dofadr: wp.array(dtype=int), - jnt_stiffness: wp.array(dtype=float), + jnt_stiffness: wp.array2d(dtype=float), # Data in: qpos_in: wp.array2d(dtype=float), # Data out: qfrc_spring_out: wp.array2d(dtype=float), ): worldid, jntid = wp.tid() - stiffness = jnt_stiffness[jntid] + stiffness = jnt_stiffness[worldid, jntid] dofid = jnt_dofadr[jntid] if stiffness == 0.0: @@ -49,9 +49,9 @@ def _spring_passive( if jnttype == wp.static(JointType.FREE.value): dif = wp.vec3( - qpos_in[worldid, qposid + 0] - qpos_spring[qposid + 0], - qpos_in[worldid, qposid + 1] - qpos_spring[qposid + 1], - qpos_in[worldid, qposid + 2] - qpos_spring[qposid + 2], + qpos_in[worldid, qposid + 0] - qpos_spring[worldid, qposid + 0], + qpos_in[worldid, qposid + 1] - qpos_spring[worldid, qposid + 1], + qpos_in[worldid, qposid + 2] - qpos_spring[worldid, qposid + 2], ) qfrc_spring_out[worldid, dofid + 0] = -stiffness * dif[0] qfrc_spring_out[worldid, dofid + 1] = -stiffness * dif[1] @@ -63,10 +63,10 @@ def _spring_passive( qpos_in[worldid, qposid + 6], ) ref = wp.quat( - qpos_spring[qposid + 3], - qpos_spring[qposid + 4], - qpos_spring[qposid + 5], - qpos_spring[qposid + 6], + qpos_spring[worldid, qposid + 3], + qpos_spring[worldid, qposid + 4], + qpos_spring[worldid, qposid + 5], + qpos_spring[worldid, qposid + 6], ) dif = math.quat_sub(rot, ref) qfrc_spring_out[worldid, dofid + 3] = -stiffness * dif[0] @@ -80,24 +80,24 @@ def _spring_passive( qpos_in[worldid, qposid + 3], ) ref = wp.quat( - qpos_spring[qposid + 0], - qpos_spring[qposid + 1], - qpos_spring[qposid + 2], - qpos_spring[qposid + 3], + qpos_spring[worldid, qposid + 0], + qpos_spring[worldid, qposid + 1], + qpos_spring[worldid, qposid + 2], + qpos_spring[worldid, qposid + 3], ) dif = math.quat_sub(rot, ref) qfrc_spring_out[worldid, dofid + 0] = -stiffness * dif[0] qfrc_spring_out[worldid, dofid + 1] = -stiffness * dif[1] qfrc_spring_out[worldid, dofid + 2] = -stiffness * dif[2] else: # mjJNT_SLIDE, mjJNT_HINGE - fdif = qpos_in[worldid, qposid] - qpos_spring[qposid] + fdif = qpos_in[worldid, qposid] - qpos_spring[worldid, qposid] qfrc_spring_out[worldid, dofid] = -stiffness * fdif @wp.kernel def _damper_passive( # Model: - dof_damping: wp.array(dtype=float), + dof_damping: wp.array2d(dtype=float), # Data in: qvel_in: wp.array2d(dtype=float), qfrc_spring_in: wp.array2d(dtype=float), @@ -107,7 +107,7 @@ def _damper_passive( ): worldid, dofid = wp.tid() - qfrc_damper = -dof_damping[dofid] * qvel_in[worldid, dofid] + qfrc_damper = -dof_damping[worldid, dofid] * qvel_in[worldid, dofid] qfrc_damper_out[worldid, dofid] = qfrc_damper qfrc_passive_out[worldid, dofid] = qfrc_damper + qfrc_spring_in[worldid, dofid] @@ -119,7 +119,7 @@ def _gravity_force( opt_gravity: wp.vec3, body_parentid: wp.array(dtype=int), body_rootid: wp.array(dtype=int), - body_mass: wp.array(dtype=float), + body_mass: wp.array2d(dtype=float), body_gravcomp: wp.array(dtype=float), dof_bodyid: wp.array(dtype=int), # Data in: @@ -134,7 +134,7 @@ def _gravity_force( gravcomp = body_gravcomp[bodyid] if gravcomp: - force = -opt_gravity * body_mass[bodyid] * gravcomp + force = -opt_gravity * body_mass[worldid, bodyid] * gravcomp pos = xipos_in[worldid, bodyid] jac, _ = support.jac(body_parentid, body_rootid, dof_bodyid, subtree_com_in, cdof_in, pos, bodyid, dofid, worldid) diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py index a1e3a385..660a836a 100644 --- a/mujoco_warp/_src/sensor.py +++ b/mujoco_warp/_src/sensor.py @@ -34,15 +34,16 @@ def _write_scalar( # Model: sensor_datatype: wp.array(dtype=int), sensor_adr: wp.array(dtype=int), - sensor_cutoff: wp.array(dtype=float), + sensor_cutoff: wp.array2d(dtype=float), # In: sensorid: int, sensor: Any, + worldid: int, # Out: out: wp.array(dtype=float), ): adr = sensor_adr[sensorid] - cutoff = sensor_cutoff[sensorid] + cutoff = sensor_cutoff[worldid, sensorid] if cutoff > 0.0: datatype = sensor_datatype[sensorid] @@ -59,16 +60,17 @@ def _write_vector( # Model: sensor_datatype: wp.array(dtype=int), sensor_adr: wp.array(dtype=int), - sensor_cutoff: wp.array(dtype=float), + sensor_cutoff: wp.array2d(dtype=float), # In: sensorid: int, sensordim: int, sensor: Any, + worldid: int, # Out: out: wp.array(dtype=float), ): adr = sensor_adr[sensorid] - cutoff = sensor_cutoff[sensorid] + cutoff = sensor_cutoff[worldid, sensorid] if cutoff > 0.0: datatype = sensor_datatype[sensorid] @@ -292,13 +294,13 @@ def _frame_axis( @wp.func def _frame_quat( # Model: - body_iquat: wp.array(dtype=wp.quat), + body_iquat: wp.array2d(dtype=wp.quat), geom_bodyid: wp.array(dtype=int), - geom_quat: wp.array(dtype=wp.quat), + geom_quat: wp.array2d(dtype=wp.quat), site_bodyid: wp.array(dtype=int), - site_quat: wp.array(dtype=wp.quat), + site_quat: wp.array2d(dtype=wp.quat), cam_bodyid: wp.array(dtype=int), - cam_quat: wp.array(dtype=wp.quat), + cam_quat: wp.array2d(dtype=wp.quat), # Data in: xquat_in: wp.array2d(dtype=wp.quat), # In: @@ -309,15 +311,15 @@ def _frame_quat( reftype: int, ) -> wp.quat: if objtype == int(ObjType.BODY.value): - quat = math.mul_quat(xquat_in[worldid, objid], body_iquat[objid]) + quat = math.mul_quat(xquat_in[worldid, objid], body_iquat[worldid, objid]) elif objtype == int(ObjType.XBODY.value): quat = xquat_in[worldid, objid] elif objtype == int(ObjType.GEOM.value): - quat = math.mul_quat(xquat_in[worldid, geom_bodyid[objid]], geom_quat[objid]) + quat = math.mul_quat(xquat_in[worldid, geom_bodyid[objid]], geom_quat[worldid, objid]) elif objtype == int(ObjType.SITE.value): - quat = math.mul_quat(xquat_in[worldid, site_bodyid[objid]], site_quat[objid]) + quat = math.mul_quat(xquat_in[worldid, site_bodyid[objid]], site_quat[worldid, objid]) elif objtype == int(ObjType.CAMERA.value): - quat = math.mul_quat(xquat_in[worldid, cam_bodyid[objid]], cam_quat[objid]) + quat = math.mul_quat(xquat_in[worldid, cam_bodyid[objid]], cam_quat[worldid, objid]) else: # UNKNOWN quat = wp.quat(1.0, 0.0, 0.0, 0.0) @@ -325,15 +327,15 @@ def _frame_quat( return quat if reftype == int(ObjType.BODY.value): - refquat = math.mul_quat(xquat_in[worldid, refid], body_iquat[refid]) + refquat = math.mul_quat(xquat_in[worldid, refid], body_iquat[worldid, refid]) elif reftype == int(ObjType.XBODY.value): refquat = xquat_in[worldid, refid] elif reftype == int(ObjType.GEOM.value): - refquat = math.mul_quat(xquat_in[worldid, geom_bodyid[refid]], geom_quat[refid]) + refquat = math.mul_quat(xquat_in[worldid, geom_bodyid[refid]], geom_quat[worldid, refid]) elif reftype == int(ObjType.SITE.value): - refquat = math.mul_quat(xquat_in[worldid, site_bodyid[refid]], site_quat[refid]) + refquat = math.mul_quat(xquat_in[worldid, site_bodyid[refid]], site_quat[worldid, refid]) elif reftype == int(ObjType.CAMERA.value): - refquat = math.mul_quat(xquat_in[worldid, cam_bodyid[refid]], cam_quat[refid]) + refquat = math.mul_quat(xquat_in[worldid, cam_bodyid[refid]], cam_quat[worldid, refid]) else: # UNKNOWN refquat = wp.quat(1.0, 0.0, 0.0, 0.0) @@ -353,14 +355,14 @@ def _clock(time_in: wp.array(dtype=float), worldid: int) -> float: @wp.kernel def _sensor_pos( # Model: - body_iquat: wp.array(dtype=wp.quat), + body_iquat: wp.array2d(dtype=wp.quat), jnt_qposadr: wp.array(dtype=int), geom_bodyid: wp.array(dtype=int), - geom_quat: wp.array(dtype=wp.quat), + geom_quat: wp.array2d(dtype=wp.quat), site_bodyid: wp.array(dtype=int), - site_quat: wp.array(dtype=wp.quat), + site_quat: wp.array2d(dtype=wp.quat), cam_bodyid: wp.array(dtype=int), - cam_quat: wp.array(dtype=wp.quat), + cam_quat: wp.array2d(dtype=wp.quat), cam_fovy: wp.array(dtype=float), cam_resolution: wp.array(dtype=wp.vec2i), cam_sensorsize: wp.array(dtype=wp.vec2), @@ -372,7 +374,7 @@ def _sensor_pos( sensor_reftype: wp.array(dtype=int), sensor_refid: wp.array(dtype=int), sensor_adr: wp.array(dtype=int), - sensor_cutoff: wp.array(dtype=float), + sensor_cutoff: wp.array2d(dtype=float), sensor_pos_adr: wp.array(dtype=int), # Data in: time_in: wp.array(dtype=float), @@ -405,19 +407,19 @@ def _sensor_pos( vec2 = _cam_projection( cam_fovy, cam_resolution, cam_sensorsize, cam_intrinsic, site_xpos_in, cam_xpos_in, cam_xmat_in, worldid, objid, refid ) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 2, vec2, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 2, vec2, worldid, out) elif sensortype == int(SensorType.JOINTPOS.value): val = _joint_pos(jnt_qposadr, qpos_in, worldid, objid) - _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, out) + _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, worldid, out) elif sensortype == int(SensorType.TENDONPOS.value): val = _tendon_pos(ten_length_in, worldid, objid) - _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, out) + _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, worldid, out) elif sensortype == int(SensorType.ACTUATORPOS.value): val = _actuator_pos(actuator_length_in, worldid, objid) - _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, out) + _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, worldid, out) elif sensortype == int(SensorType.BALLQUAT.value): quat = _ball_quat(jnt_qposadr, qpos_in, worldid, objid) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 4, quat, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 4, quat, worldid, out) elif sensortype == int(SensorType.FRAMEPOS.value): objtype = sensor_objtype[sensorid] refid = sensor_refid[sensorid] @@ -439,7 +441,7 @@ def _sensor_pos( refid, reftype, ) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, worldid, out) elif ( sensortype == int(SensorType.FRAMEXAXIS.value) or sensortype == int(SensorType.FRAMEYAXIS.value) @@ -457,7 +459,7 @@ def _sensor_pos( vec3 = _frame_axis( ximat_in, xmat_in, geom_xmat_in, site_xmat_in, cam_xmat_in, worldid, objid, objtype, refid, reftype, axis ) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, worldid, out) elif sensortype == int(SensorType.FRAMEQUAT.value): objtype = sensor_objtype[sensorid] refid = sensor_refid[sensorid] @@ -477,13 +479,13 @@ def _sensor_pos( refid, reftype, ) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 4, quat, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 4, quat, worldid, out) elif sensortype == int(SensorType.SUBTREECOM.value): vec3 = _subtree_com(subtree_com_in, worldid, objid) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, worldid, out) elif sensortype == int(SensorType.CLOCK.value): val = _clock(time_in, worldid) - _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, out) + _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, worldid, out) @event_scope @@ -858,7 +860,7 @@ def _sensor_vel( sensor_reftype: wp.array(dtype=int), sensor_refid: wp.array(dtype=int), sensor_adr: wp.array(dtype=int), - sensor_cutoff: wp.array(dtype=float), + sensor_cutoff: wp.array2d(dtype=float), sensor_vel_adr: wp.array(dtype=int), # Data in: qvel_in: wp.array2d(dtype=float), @@ -889,22 +891,22 @@ def _sensor_vel( if sensortype == int(SensorType.VELOCIMETER.value): vec3 = _velocimeter(body_rootid, site_bodyid, site_xpos_in, site_xmat_in, subtree_com_in, cvel_in, worldid, objid) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, worldid, out) elif sensortype == int(SensorType.GYRO.value): vec3 = _gyro(site_bodyid, site_xmat_in, cvel_in, worldid, objid) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, worldid, out) elif sensortype == int(SensorType.JOINTVEL.value): val = _joint_vel(jnt_dofadr, qvel_in, worldid, objid) - _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, out) + _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, worldid, out) elif sensortype == int(SensorType.TENDONVEL.value): val = _tendon_vel(ten_velocity_in, worldid, objid) - _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, out) + _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, worldid, out) elif sensortype == int(SensorType.ACTUATORVEL.value): val = _actuator_vel(actuator_velocity_in, worldid, objid) - _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, out) + _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, worldid, out) elif sensortype == int(SensorType.BALLANGVEL.value): vec3 = _ball_ang_vel(jnt_dofadr, qvel_in, worldid, objid) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, worldid, out) elif sensortype == int(SensorType.FRAMELINVEL.value): objtype = sensor_objtype[sensorid] refid = sensor_refid[sensorid] @@ -932,7 +934,7 @@ def _sensor_vel( refid, reftype, ) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, frame_linvel, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, frame_linvel, worldid, out) elif sensortype == int(SensorType.FRAMEANGVEL.value): objtype = sensor_objtype[sensorid] refid = sensor_refid[sensorid] @@ -960,13 +962,13 @@ def _sensor_vel( refid, reftype, ) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, frame_angvel, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, frame_angvel, worldid, out) elif sensortype == int(SensorType.SUBTREELINVEL.value): vec3 = _subtree_linvel(subtree_linvel_in, worldid, objid) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, worldid, out) elif sensortype == int(SensorType.SUBTREEANGMOM.value): vec3 = _subtree_angmom(subtree_angmom_in, worldid, objid) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, worldid, out) @event_scope @@ -1198,7 +1200,7 @@ def _sensor_acc( sensor_objtype: wp.array(dtype=int), sensor_objid: wp.array(dtype=int), sensor_adr: wp.array(dtype=int), - sensor_cutoff: wp.array(dtype=float), + sensor_cutoff: wp.array2d(dtype=float), sensor_acc_adr: wp.array(dtype=int), # Data in: xpos_in: wp.array2d(dtype=wp.vec3), @@ -1226,19 +1228,19 @@ def _sensor_acc( vec3 = _accelerometer( body_rootid, site_bodyid, site_xpos_in, site_xmat_in, subtree_com_in, cvel_in, cacc_in, worldid, objid ) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, worldid, out) elif sensortype == int(SensorType.FORCE.value): vec3 = _force(site_bodyid, site_xmat_in, cfrc_int_in, worldid, objid) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, worldid, out) elif sensortype == int(SensorType.TORQUE.value): vec3 = _torque(body_rootid, site_bodyid, site_xpos_in, site_xmat_in, subtree_com_in, cfrc_int_in, worldid, objid) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, worldid, out) elif sensortype == int(SensorType.ACTUATORFRC.value): val = _actuator_force(actuator_force_in, worldid, objid) - _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, out) + _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, worldid, out) elif sensortype == int(SensorType.JOINTACTFRC.value): val = _joint_actuator_force(jnt_dofadr, qfrc_actuator_in, worldid, objid) - _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, out) + _write_scalar(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, val, worldid, out) elif sensortype == int(SensorType.FRAMELINACC.value): objtype = sensor_objtype[sensorid] vec3 = _framelinacc( @@ -1258,7 +1260,7 @@ def _sensor_acc( objid, objtype, ) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, worldid, out) elif sensortype == int(SensorType.FRAMEANGACC.value): objtype = sensor_objtype[sensorid] vec3 = _frameangacc( @@ -1270,7 +1272,7 @@ def _sensor_acc( objid, objtype, ) - _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, out) + _write_vector(sensor_datatype, sensor_adr, sensor_cutoff, sensorid, 3, vec3, worldid, out) @event_scope diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index 06f6dfa1..87b2e96d 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -60,17 +60,17 @@ def _kinematics_root( @wp.kernel def _kinematics_level( # Model: - qpos0: wp.array(dtype=float), + qpos0: wp.array2d(dtype=float), body_parentid: wp.array(dtype=int), body_jntnum: wp.array(dtype=int), body_jntadr: wp.array(dtype=int), - body_pos: wp.array(dtype=wp.vec3), - body_quat: wp.array(dtype=wp.quat), - body_ipos: wp.array(dtype=wp.vec3), - body_iquat: wp.array(dtype=wp.quat), + body_pos: wp.array2d(dtype=wp.vec3), + body_quat: wp.array2d(dtype=wp.quat), + body_ipos: wp.array2d(dtype=wp.vec3), + body_iquat: wp.array2d(dtype=wp.quat), jnt_type: wp.array(dtype=int), jnt_qposadr: wp.array(dtype=int), - jnt_pos: wp.array(dtype=wp.vec3), + jnt_pos: wp.array2d(dtype=wp.vec3), jnt_axis: wp.array(dtype=wp.vec3), # Data in: qpos_in: wp.array2d(dtype=float), @@ -97,8 +97,8 @@ def _kinematics_level( if jntnum == 0: # no joints - apply fixed translation and rotation relative to parent pid = body_parentid[bodyid] - xpos = (xmat_in[worldid, pid] * body_pos[bodyid]) + xpos_in[worldid, pid] - xquat = math.mul_quat(xquat_in[worldid, pid], body_quat[bodyid]) + xpos = (xmat_in[worldid, pid] * body_pos[worldid, bodyid]) + xpos_in[worldid, pid] + xquat = math.mul_quat(xquat_in[worldid, pid], body_quat[worldid, bodyid]) elif jntnum == 1 and jnt_type[jntadr] == wp.static(JointType.FREE.value): # free joint qadr = jnt_qposadr[jntadr] @@ -110,14 +110,14 @@ def _kinematics_level( # regular or no joints # apply fixed translation and rotation relative to parent pid = body_parentid[bodyid] - xpos = (xmat_in[worldid, pid] * body_pos[bodyid]) + xpos_in[worldid, pid] - xquat = math.mul_quat(xquat_in[worldid, pid], body_quat[bodyid]) + xpos = (xmat_in[worldid, pid] * body_pos[worldid, bodyid]) + xpos_in[worldid, pid] + xquat = math.mul_quat(xquat_in[worldid, pid], body_quat[worldid, bodyid]) for _ in range(jntnum): qadr = jnt_qposadr[jntadr] jnt_type_ = jnt_type[jntadr] jnt_axis_ = jnt_axis[jntadr] - xanchor = math.rot_vec_quat(jnt_pos[jntadr], xquat) + xpos + xanchor = math.rot_vec_quat(jnt_pos[worldid, jntadr], xquat) + xpos xaxis = math.rot_vec_quat(jnt_axis_, xquat) if jnt_type_ == wp.static(JointType.BALL.value): @@ -129,15 +129,15 @@ def _kinematics_level( ) xquat = math.mul_quat(xquat, qloc) # correct for off-center rotation - xpos = xanchor - math.rot_vec_quat(jnt_pos[jntadr], xquat) + xpos = xanchor - math.rot_vec_quat(jnt_pos[worldid, jntadr], xquat) elif jnt_type_ == wp.static(JointType.SLIDE.value): - xpos += xaxis * (qpos[qadr] - qpos0[qadr]) + xpos += xaxis * (qpos[qadr] - qpos0[worldid, qadr]) elif jnt_type_ == wp.static(JointType.HINGE.value): - qpos0_ = qpos0[qadr] + qpos0_ = qpos0[worldid, qadr] qloc_ = math.axis_angle_to_quat(jnt_axis_, qpos[qadr] - qpos0_) xquat = math.mul_quat(xquat, qloc_) # correct for off-center rotation - xpos = xanchor - math.rot_vec_quat(jnt_pos[jntadr], xquat) + xpos = xanchor - math.rot_vec_quat(jnt_pos[worldid, jntadr], xquat) xanchor_out[worldid, jntadr] = xanchor xaxis_out[worldid, jntadr] = xaxis @@ -146,16 +146,16 @@ def _kinematics_level( xpos_out[worldid, bodyid] = xpos xquat_out[worldid, bodyid] = wp.normalize(xquat) xmat_out[worldid, bodyid] = math.quat_to_mat(xquat) - xipos_out[worldid, bodyid] = xpos + math.rot_vec_quat(body_ipos[bodyid], xquat) - ximat_out[worldid, bodyid] = math.quat_to_mat(math.mul_quat(xquat, body_iquat[bodyid])) + xipos_out[worldid, bodyid] = xpos + math.rot_vec_quat(body_ipos[worldid, bodyid], xquat) + ximat_out[worldid, bodyid] = math.quat_to_mat(math.mul_quat(xquat, body_iquat[worldid, bodyid])) @wp.kernel def _geom_local_to_global( # Model: geom_bodyid: wp.array(dtype=int), - geom_pos: wp.array(dtype=wp.vec3), - geom_quat: wp.array(dtype=wp.quat), + geom_pos: wp.array2d(dtype=wp.vec3), + geom_quat: wp.array2d(dtype=wp.quat), # Data in: xpos_in: wp.array2d(dtype=wp.vec3), xquat_in: wp.array2d(dtype=wp.quat), @@ -167,16 +167,16 @@ def _geom_local_to_global( bodyid = geom_bodyid[geomid] xpos = xpos_in[worldid, bodyid] xquat = xquat_in[worldid, bodyid] - geom_xpos_out[worldid, geomid] = xpos + math.rot_vec_quat(geom_pos[geomid], xquat) - geom_xmat_out[worldid, geomid] = math.quat_to_mat(math.mul_quat(xquat, geom_quat[geomid])) + geom_xpos_out[worldid, geomid] = xpos + math.rot_vec_quat(geom_pos[worldid, geomid], xquat) + geom_xmat_out[worldid, geomid] = math.quat_to_mat(math.mul_quat(xquat, geom_quat[worldid, geomid])) @wp.kernel def _site_local_to_global( # Model: site_bodyid: wp.array(dtype=int), - site_pos: wp.array(dtype=wp.vec3), - site_quat: wp.array(dtype=wp.quat), + site_pos: wp.array2d(dtype=wp.vec3), + site_quat: wp.array2d(dtype=wp.quat), # Data in: xpos_in: wp.array2d(dtype=wp.vec3), xquat_in: wp.array2d(dtype=wp.quat), @@ -188,15 +188,15 @@ def _site_local_to_global( bodyid = site_bodyid[siteid] xpos = xpos_in[worldid, bodyid] xquat = xquat_in[worldid, bodyid] - site_xpos_out[worldid, siteid] = xpos + math.rot_vec_quat(site_pos[siteid], xquat) - site_xmat_out[worldid, siteid] = math.quat_to_mat(math.mul_quat(xquat, site_quat[siteid])) + site_xpos_out[worldid, siteid] = xpos + math.rot_vec_quat(site_pos[worldid, siteid], xquat) + site_xmat_out[worldid, siteid] = math.quat_to_mat(math.mul_quat(xquat, site_quat[worldid, siteid])) @wp.kernel def _mocap( # Model: - body_ipos: wp.array(dtype=wp.vec3), - body_iquat: wp.array(dtype=wp.quat), + body_ipos: wp.array2d(dtype=wp.vec3), + body_iquat: wp.array2d(dtype=wp.quat), mocap_bodyid: wp.array(dtype=int), # Data in: mocap_pos_in: wp.array2d(dtype=wp.vec3), @@ -215,8 +215,8 @@ def _mocap( xpos_out[worldid, bodyid] = xpos xquat_out[worldid, bodyid] = mocap_quat xmat_out[worldid, bodyid] = math.quat_to_mat(mocap_quat) - xipos_out[worldid, bodyid] = xpos + math.rot_vec_quat(body_ipos[bodyid], mocap_quat) - ximat_out[worldid, bodyid] = math.quat_to_mat(math.mul_quat(mocap_quat, body_iquat[bodyid])) + xipos_out[worldid, bodyid] = xpos + math.rot_vec_quat(body_ipos[worldid, bodyid], mocap_quat) + ximat_out[worldid, bodyid] = math.quat_to_mat(math.mul_quat(mocap_quat, body_iquat[worldid, bodyid])) @event_scope @@ -285,14 +285,14 @@ def kinematics(m: Model, d: Data): @wp.kernel def _subtree_com_init( # Model: - body_mass: wp.array(dtype=float), + body_mass: wp.array2d(dtype=float), # Data in: xipos_in: wp.array2d(dtype=wp.vec3), # Data out: xipos_out: wp.array2d(dtype=wp.vec3), ): worldid, bodyid = wp.tid() - xipos_out[worldid, bodyid] = xipos_in[worldid, bodyid] * body_mass[bodyid] + xipos_out[worldid, bodyid] = xipos_in[worldid, bodyid] * body_mass[worldid, bodyid] @wp.kernel @@ -315,20 +315,20 @@ def _subtree_com_acc( @wp.kernel def _subtree_div( # Model: - subtree_mass: wp.array(dtype=float), + subtree_mass: wp.array2d(dtype=float), # Data out: subtree_com_out: wp.array2d(dtype=wp.vec3), ): worldid, bodyid = wp.tid() - subtree_com_out[worldid, bodyid] /= subtree_mass[bodyid] + subtree_com_out[worldid, bodyid] /= subtree_mass[worldid, bodyid] @wp.kernel def _cinert( # Model: body_rootid: wp.array(dtype=int), - body_mass: wp.array(dtype=float), - body_inertia: wp.array(dtype=wp.vec3), + body_mass: wp.array2d(dtype=float), + body_inertia: wp.array2d(dtype=wp.vec3), # Data in: xipos_in: wp.array2d(dtype=wp.vec3), ximat_in: wp.array2d(dtype=wp.mat33), @@ -338,8 +338,8 @@ def _cinert( ): worldid, bodyid = wp.tid() mat = ximat_in[worldid, bodyid] - inert = body_inertia[bodyid] - mass = body_mass[bodyid] + inert = body_inertia[worldid, bodyid] + mass = body_mass[worldid, bodyid] dif = xipos_in[worldid, bodyid] - subtree_com_in[worldid, body_rootid[bodyid]] # express inertia in com-based frame (mju_inertCom) @@ -472,8 +472,8 @@ def com_pos(m: Model, d: Data): def _cam_local_to_global( # Model: cam_bodyid: wp.array(dtype=int), - cam_pos: wp.array(dtype=wp.vec3), - cam_quat: wp.array(dtype=wp.quat), + cam_pos: wp.array2d(dtype=wp.vec3), + cam_quat: wp.array2d(dtype=wp.quat), # Data in: xpos_in: wp.array2d(dtype=wp.vec3), xquat_in: wp.array2d(dtype=wp.quat), @@ -486,8 +486,8 @@ def _cam_local_to_global( bodyid = cam_bodyid[camid] xpos = xpos_in[worldid, bodyid] xquat = xquat_in[worldid, bodyid] - cam_xpos_out[worldid, camid] = xpos + math.rot_vec_quat(cam_pos[camid], xquat) - cam_xmat_out[worldid, camid] = math.quat_to_mat(math.mul_quat(xquat, cam_quat[camid])) + cam_xpos_out[worldid, camid] = xpos + math.rot_vec_quat(cam_pos[worldid, camid], xquat) + cam_xmat_out[worldid, camid] = math.quat_to_mat(math.mul_quat(xquat, cam_quat[worldid, camid])) @wp.kernel @@ -496,8 +496,8 @@ def _cam_fn( cam_mode: wp.array(dtype=int), cam_bodyid: wp.array(dtype=int), cam_targetbodyid: wp.array(dtype=int), - cam_poscom0: wp.array(dtype=wp.vec3), - cam_pos0: wp.array(dtype=wp.vec3), + cam_poscom0: wp.array2d(dtype=wp.vec3), + cam_pos0: wp.array2d(dtype=wp.vec3), # Data in: xpos_in: wp.array2d(dtype=wp.vec3), subtree_com_in: wp.array2d(dtype=wp.vec3), @@ -514,9 +514,9 @@ def _cam_fn( return elif cam_mode[camid] == wp.static(CamLightType.TRACK.value): body_xpos = xpos_in[worldid, cam_bodyid[camid]] - cam_xpos_out[worldid, camid] = body_xpos + cam_pos0[camid] + cam_xpos_out[worldid, camid] = body_xpos + cam_pos0[worldid, camid] elif cam_mode[camid] == wp.static(CamLightType.TRACKCOM.value): - cam_xpos_out[worldid, camid] = subtree_com_in[worldid, cam_bodyid[camid]] + cam_poscom0[camid] + cam_xpos_out[worldid, camid] = subtree_com_in[worldid, cam_bodyid[camid]] + cam_poscom0[worldid, camid] elif cam_mode[camid] == wp.static(CamLightType.TARGETBODY.value) or cam_mode[camid] == wp.static( CamLightType.TARGETBODYCOM.value ): @@ -541,8 +541,8 @@ def _cam_fn( def _light_local_to_global( # Model: light_bodyid: wp.array(dtype=int), - light_pos: wp.array(dtype=wp.vec3), - light_dir: wp.array(dtype=wp.vec3), + light_pos: wp.array2d(dtype=wp.vec3), + light_dir: wp.array2d(dtype=wp.vec3), # Data in: xpos_in: wp.array2d(dtype=wp.vec3), xquat_in: wp.array2d(dtype=wp.quat), @@ -555,8 +555,8 @@ def _light_local_to_global( bodyid = light_bodyid[lightid] xpos = xpos_in[worldid, bodyid] xquat = xquat_in[worldid, bodyid] - light_xpos_out[worldid, lightid] = xpos + math.rot_vec_quat(light_pos[lightid], xquat) - light_xdir_out[worldid, lightid] = math.rot_vec_quat(light_dir[lightid], xquat) + light_xpos_out[worldid, lightid] = xpos + math.rot_vec_quat(light_pos[worldid, lightid], xquat) + light_xdir_out[worldid, lightid] = math.rot_vec_quat(light_dir[worldid, lightid], xquat) @wp.kernel @@ -565,8 +565,8 @@ def _light_fn( light_mode: wp.array(dtype=int), light_bodyid: wp.array(dtype=int), light_targetbodyid: wp.array(dtype=int), - light_poscom0: wp.array(dtype=wp.vec3), - light_pos0: wp.array(dtype=wp.vec3), + light_poscom0: wp.array2d(dtype=wp.vec3), + light_pos0: wp.array2d(dtype=wp.vec3), # Data in: xpos_in: wp.array2d(dtype=wp.vec3), light_xpos_in: wp.array2d(dtype=wp.vec3), @@ -584,9 +584,9 @@ def _light_fn( return elif light_mode[lightid] == wp.static(CamLightType.TRACK.value): body_xpos = xpos_in[worldid, light_bodyid[lightid]] - light_xpos_out[worldid, lightid] = body_xpos + light_pos0[lightid] + light_xpos_out[worldid, lightid] = body_xpos + light_pos0[worldid, lightid] elif light_mode[lightid] == wp.static(CamLightType.TRACKCOM.value): - light_xpos_out[worldid, lightid] = subtree_com_in[worldid, light_bodyid[lightid]] + light_poscom0[lightid] + light_xpos_out[worldid, lightid] = subtree_com_in[worldid, light_bodyid[lightid]] + light_poscom0[worldid, lightid] elif light_mode[lightid] == wp.static(CamLightType.TARGETBODY.value) or light_mode[lightid] == wp.static( CamLightType.TARGETBODYCOM.value ): @@ -670,7 +670,7 @@ def _qM_sparse( dof_bodyid: wp.array(dtype=int), dof_parentid: wp.array(dtype=int), dof_Madr: wp.array(dtype=int), - dof_armature: wp.array(dtype=float), + dof_armature: wp.array2d(dtype=float), # Data in: cdof_in: wp.array2d(dtype=wp.spatial_vector), crb_in: wp.array2d(dtype=vec10), @@ -682,7 +682,7 @@ def _qM_sparse( bodyid = dof_bodyid[dofid] # init M(i,i) with armature inertia - qM_out[worldid, 0, madr_ij] = dof_armature[dofid] + qM_out[worldid, 0, madr_ij] = dof_armature[worldid, dofid] # precompute buf = crb_body_i * cdof_i buf = math.inert_vec(crb_in[worldid, bodyid], cdof_in[worldid, dofid]) @@ -699,7 +699,7 @@ def _qM_dense( # Model: dof_bodyid: wp.array(dtype=int), dof_parentid: wp.array(dtype=int), - dof_armature: wp.array(dtype=float), + dof_armature: wp.array2d(dtype=float), # Data in: cdof_in: wp.array2d(dtype=wp.spatial_vector), crb_in: wp.array2d(dtype=vec10), @@ -709,7 +709,7 @@ def _qM_dense( worldid, dofid = wp.tid() bodyid = dof_bodyid[dofid] # init M(i,i) with armature inertia - M = dof_armature[dofid] + M = dof_armature[worldid, dofid] # precompute buf = crb_body_i * cdof_i buf = math.inert_vec(crb_in[worldid, bodyid], cdof_in[worldid, dofid]) @@ -1140,7 +1140,7 @@ def _cfrc_ext_equality( # Model: body_rootid: wp.array(dtype=int), site_bodyid: wp.array(dtype=int), - site_pos: wp.array(dtype=wp.vec3), + site_pos: wp.array2d(dtype=wp.vec3), eq_obj1id: wp.array(dtype=int), eq_obj2id: wp.array(dtype=int), eq_objtype: wp.array(dtype=int), @@ -1203,7 +1203,7 @@ def _cfrc_ext_equality( else: offset = wp.vec3(eq_data_[3], eq_data_[4], eq_data_[5]) else: - offset = site_pos[obj1] + offset = site_pos[worldid, obj1] # transform point on body1: local -> global pos = xmat_in[worldid, bodyid1] @ offset + xpos_in[worldid, bodyid1] @@ -1225,7 +1225,7 @@ def _cfrc_ext_equality( else: offset = wp.vec3(eq_data_[0], eq_data_[1], eq_data_[2]) else: - offset = site_pos[obj2] + offset = site_pos[worldid, obj2] # transform point on body2: local -> global pos = xmat_in[worldid, bodyid2] @ offset + xpos_in[worldid, bodyid2] @@ -1480,7 +1480,7 @@ def _transmission( jnt_dofadr: wp.array(dtype=int), actuator_trntype: wp.array(dtype=int), actuator_trnid: wp.array(dtype=wp.vec2i), - actuator_gear: wp.array(dtype=wp.spatial_vector), + actuator_gear: wp.array2d(dtype=wp.spatial_vector), tendon_adr: wp.array(dtype=int), tendon_num: wp.array(dtype=int), wrap_objid: wp.array(dtype=int), @@ -1495,7 +1495,7 @@ def _transmission( ): worldid, actid = wp.tid() trntype = actuator_trntype[actid] - gear = actuator_gear[actid] + gear = actuator_gear[worldid, actid] if trntype == wp.static(TrnType.JOINT.value) or trntype == wp.static(TrnType.JOINTINPARENT.value): qpos = qpos_in[worldid] jntid = actuator_trnid[actid][0] @@ -1747,8 +1747,8 @@ def factor_solve_i(m, d, M, L, D, x, y): def _subtree_vel_forward( # Model: body_rootid: wp.array(dtype=int), - body_mass: wp.array(dtype=float), - body_inertia: wp.array(dtype=wp.vec3), + body_mass: wp.array2d(dtype=float), + body_inertia: wp.array2d(dtype=wp.vec3), # Data in: xipos_in: wp.array2d(dtype=wp.vec3), ximat_in: wp.array2d(dtype=wp.mat33), @@ -1771,11 +1771,11 @@ def _subtree_vel_forward( # update linear velocity lin -= wp.cross(xipos - subtree_com_root, ang) - subtree_linvel_out[worldid, bodyid] = body_mass[bodyid] * lin + subtree_linvel_out[worldid, bodyid] = body_mass[worldid, bodyid] * lin dv = wp.transpose(ximat) @ ang - dv[0] *= body_inertia[bodyid][0] - dv[1] *= body_inertia[bodyid][1] - dv[2] *= body_inertia[bodyid][2] + dv[0] *= body_inertia[worldid, bodyid][0] + dv[1] *= body_inertia[worldid, bodyid][1] + dv[2] *= body_inertia[worldid, bodyid][2] subtree_angmom_out[worldid, bodyid] = ximat @ dv subtree_bodyvel_out[worldid, bodyid] = wp.spatial_vector(ang, lin) @@ -1784,7 +1784,7 @@ def _subtree_vel_forward( def _linear_momentum( # Model: body_parentid: wp.array(dtype=int), - body_subtreemass: wp.array(dtype=float), + body_subtreemass: wp.array2d(dtype=float), # Data in: subtree_linvel_in: wp.array2d(dtype=wp.vec3), # In: @@ -1797,15 +1797,15 @@ def _linear_momentum( if bodyid: pid = body_parentid[bodyid] wp.atomic_add(subtree_linvel_out[worldid], pid, subtree_linvel_in[worldid, bodyid]) - subtree_linvel_out[worldid, bodyid] /= wp.max(MJ_MINVAL, body_subtreemass[bodyid]) + subtree_linvel_out[worldid, bodyid] /= wp.max(MJ_MINVAL, body_subtreemass[worldid, bodyid]) @wp.kernel def _angular_momentum( # Model: body_parentid: wp.array(dtype=int), - body_mass: wp.array(dtype=float), - body_subtreemass: wp.array(dtype=float), + body_mass: wp.array2d(dtype=float), + body_subtreemass: wp.array2d(dtype=float), # Data in: xipos_in: wp.array2d(dtype=wp.vec3), subtree_com_in: wp.array2d(dtype=wp.vec3), @@ -1830,8 +1830,8 @@ def _angular_momentum( vel = subtree_bodyvel_in[worldid, bodyid] linvel = subtree_linvel_in[worldid, bodyid] linvel_parent = subtree_linvel_in[worldid, pid] - mass = body_mass[bodyid] - subtreemass = body_subtreemass[bodyid] + mass = body_mass[worldid, bodyid] + subtreemass = body_subtreemass[worldid, bodyid] # momentum wrt body i dx = xipos - com @@ -1905,7 +1905,7 @@ def _joint_tendon( jnt_qposadr: wp.array(dtype=int), jnt_dofadr: wp.array(dtype=int), wrap_objid: wp.array(dtype=int), - wrap_prm: wp.array(dtype=float), + wrap_prm: wp.array2d(dtype=float), tendon_jnt_adr: wp.array(dtype=int), wrap_jnt_adr: wp.array(dtype=int), # Data in: @@ -1920,7 +1920,7 @@ def _joint_tendon( wrap_jnt_adr_ = wrap_jnt_adr[wrapid] wrap_objid_ = wrap_objid[wrap_jnt_adr_] - prm = wrap_prm[wrap_jnt_adr_] + prm = wrap_prm[worldid, wrap_jnt_adr_] # add to length L = prm * qpos_in[worldid, jnt_qposadr[wrap_objid_]] diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 4482f25e..f448ae22 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -587,8 +587,8 @@ class Model: nlsp: number of step sizes for parallel linsearch () opt: physics options stat: model statistics - qpos0: qpos values at default pose (nq,) - qpos_spring: reference pose for springs (nq,) + qpos0: qpos values at default pose (nworld, nq) + qpos_spring: reference pose for springs (nworld, nq) qM_fullm_i: sparse mass matrix addressing qM_fullm_j: sparse mass matrix addressing qM_mulm_i: sparse mass matrix addressing @@ -612,15 +612,15 @@ class Model: body_dofadr: start addr of dofs; -1: no dofs (nbody,) body_geomnum: number of geoms (nbody,) body_geomadr: start addr of geoms; -1: no geoms (nbody,) - body_pos: position offset rel. to parent body (nbody, 3) - body_quat: orientation offset rel. to parent body (nbody, 4) - body_ipos: local position of center of mass (nbody, 3) - body_iquat: local orientation of inertia ellipsoid (nbody, 4) - body_mass: mass (nbody,) - body_subtreemass: mass of subtree starting at this body (nbody,) - subtree_mass: mass of subtree (nbody,) - body_inertia: diagonal inertia in ipos/iquat frame (nbody, 3) - body_invweight0: mean inv inert in qpos0 (trn, rot) (nbody, 2) + body_pos: position offset rel. to parent body (nworld, nbody, 3) + body_quat: orientation offset rel. to parent body (nworld, nbody, 4) + body_ipos: local position of center of mass (nworld, nbody, 3) + body_iquat: local orientation of inertia ellipsoid (nworld, nbody, 4) + body_mass: mass (nworld, nbody,) + body_subtreemass: mass of subtree starting at this body (nworld, nbody,) + subtree_mass: mass of subtree (nworld, nbody,) + body_inertia: diagonal inertia in ipos/iquat frame (nworld, nbody, 3) + body_invweight0: mean inv inert in qpos0 (trn, rot) (nworld, nbody, 2) body_contype: OR over all geom contypes (nbody,) body_conaffinity: OR over all geom conaffinities (nbody,) body_gravcomp: antigravity force, units of body weight (nbody,) @@ -630,14 +630,14 @@ class Model: jnt_bodyid: id of joint's body (njnt,) jnt_limited: does joint have limits (njnt,) jnt_actfrclimited: does joint have actuator force limits (njnt,) - jnt_solref: constraint solver reference: limit (njnt, mjNREF) - jnt_solimp: constraint solver impedance: limit (njnt, mjNIMP) - jnt_pos: local anchor position (njnt, 3) + jnt_solref: constraint solver reference: limit (nworld, njnt, mjNREF) + jnt_solimp: constraint solver impedance: limit (nworld, njnt, mjNIMP) + jnt_pos: local anchor position (nworld,njnt, 3) jnt_axis: local joint axis (njnt, 3) - jnt_stiffness: stiffness coefficient (njnt,) - jnt_range: joint limits (njnt, 2) - jnt_actfrcrange: range of total actuator force (njnt, 2) - jnt_margin: min distance for limit detection (njnt,) + jnt_stiffness: stiffness coefficient (nworld, njnt) + jnt_range: joint limits (nworld, njnt, 2) + jnt_actfrcrange: range of total actuator force (nworldnjnt, 2) + jnt_margin: min distance for limit detection (nworld, njnt) jnt_limited_slide_hinge_adr: limited/slide/hinge jntadr jnt_limited_ball_adr: limited/ball jntadr jnt_actgravcomp: is gravcomp force applied via actuators (njnt,) @@ -645,12 +645,12 @@ class Model: dof_jntid: id of dof's joint (nv,) dof_parentid: id of dof's parent; -1: none (nv,) dof_Madr: dof address in M-diagonal (nv,) - dof_armature: dof armature inertia/mass (nv,) - dof_damping: damping coefficient (nv,) - dof_invweight0: diag. inverse inertia in qpos0 (nv,) - dof_frictionloss: dof friction loss (nv,) - dof_solimp: constraint solver impedance: frictionloss (nv, NIMP) - dof_solref: constraint solver reference: frictionloss (nv, NREF) + dof_armature: dof armature inertia/mass (nworld, nv) + dof_damping: damping coefficient (nworld, nv) + dof_invweight0: diag. inverse inertia in qpos0 (nworld, nv) + dof_frictionloss: dof friction loss (nworld, nv) + dof_solimp: constraint solver impedance: frictionloss (nworld, nv, NIMP) + dof_solref: constraint solver reference: frictionloss (nworld,nv, NREF) dof_tri_row: np.tril_indices (mjm.nv)[0] dof_tri_col: np.tril_indices (mjm.nv)[1] geom_type: geometric type (mjtGeom) (ngeom,) @@ -659,28 +659,28 @@ class Model: geom_condim: contact dimensionality (1, 3, 4, 6) (ngeom,) geom_bodyid: id of geom's body (ngeom,) geom_dataid: id of geom's mesh/hfield; -1: none (ngeom,) - geom_priority: geom contact priority (ngeom,) - geom_solmix: mixing coef for solref/imp in geom pair (ngeom,) - geom_solref: constraint solver reference: contact (ngeom, mjNREF) - geom_solimp: constraint solver impedance: contact (ngeom, mjNIMP) - geom_size: geom-specific size parameters (ngeom, 3) - geom_aabb: bounding box, (center, size) (ngeom, 6) - geom_rbound: radius of bounding sphere (ngeom,) - geom_pos: local position offset rel. to body (ngeom, 3) - geom_quat: local orientation offset rel. to body (ngeom, 4) - geom_friction: friction for (slide, spin, roll) (ngeom, 3) - geom_margin: detect contact if dist