diff --git a/mujoco_warp/__init__.py b/mujoco_warp/__init__.py index 4bedba93..7fa5acac 100644 --- a/mujoco_warp/__init__.py +++ b/mujoco_warp/__init__.py @@ -51,7 +51,6 @@ from ._src.smooth import tendon as tendon from ._src.smooth import transmission as transmission from ._src.solver import solve as solve -from ._src.support import contact_force as contact_force from ._src.support import is_sparse as is_sparse from ._src.support import mul_m as mul_m from ._src.support import xfrc_accumulate as xfrc_accumulate diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py index 26b5e1ab..6723bb65 100644 --- a/mujoco_warp/_src/collision_driver.py +++ b/mujoco_warp/_src/collision_driver.py @@ -167,97 +167,99 @@ def _sap_broadphase(m: Model, d: Data, nsweep: int, filterparent: bool): def sap_broadphase(m: Model, d: Data): """Broadphase collision detection via sweep-and-prune.""" - nworldgeom = d.nworld * m.ngeom + with wp.ScopedDevice(m.qpos0.device): + nworldgeom = d.nworld * m.ngeom - # TODO(team): direction + # TODO(team): direction - # random fixed direction - direction = wp.vec3(0.5935, 0.7790, 0.1235) - direction = wp.normalize(direction) + # random fixed direction + direction = wp.vec3(0.5935, 0.7790, 0.1235) + direction = wp.normalize(direction) - wp.launch( - kernel=_sap_project, - dim=(d.nworld, m.ngeom), - inputs=[m, d, direction], - ) + wp.launch( + kernel=_sap_project, + dim=(d.nworld, m.ngeom), + inputs=[m, d, direction], + ) - # TODO(team): tile sort + # TODO(team): tile sort - wp.utils.segmented_sort_pairs( - d.sap_projection_lower, - d.sap_sort_index, - nworldgeom, - d.sap_segment_index, - ) + wp.utils.segmented_sort_pairs( + d.sap_projection_lower, + d.sap_sort_index, + nworldgeom, + d.sap_segment_index, + ) - wp.launch( - kernel=_sap_range, - dim=(d.nworld, m.ngeom), - inputs=[m, d], - ) + wp.launch( + kernel=_sap_range, + dim=(d.nworld, m.ngeom), + inputs=[m, d], + ) - # scan is used for load balancing among the threads - wp.utils.array_scan(d.sap_range.reshape(-1), d.sap_cumulative_sum, True) + # scan is used for load balancing among the threads + wp.utils.array_scan(d.sap_range.reshape(-1), d.sap_cumulative_sum, True) - # estimate number of overlap checks - assumes each geom has 5 other geoms (batched over all worlds) - nsweep = 5 * nworldgeom - filterparent = not m.opt.disableflags & DisableBit.FILTERPARENT.value - wp.launch( - kernel=_sap_broadphase, - dim=nsweep, - inputs=[m, d, nsweep, filterparent], - ) + # estimate number of overlap checks - assumes each geom has 5 other geoms (batched over all worlds) + nsweep = 5 * nworldgeom + filterparent = not m.opt.disableflags & DisableBit.FILTERPARENT.value + wp.launch( + kernel=_sap_broadphase, + dim=nsweep, + inputs=[m, d, nsweep, filterparent], + ) def nxn_broadphase(m: Model, d: Data): """Broadphase collision detective via brute-force search.""" - @wp.kernel - def _nxn_broadphase(m: Model, d: Data): - worldid, elementid = wp.tid() + with wp.ScopedDevice(m.qpos0.device): - # check for valid geom pair - if m.nxn_pairid[elementid] < -1: - return + @wp.kernel + def _nxn_broadphase(m: Model, d: Data): + worldid, elementid = wp.tid() - geom = m.nxn_geom_pair[elementid] - geom1 = geom[0] - geom2 = geom[1] + # check for valid geom pair + if m.nxn_pairid[elementid] < -1: + return - if _sphere_filter(m, d, geom1, geom2, worldid): - _add_geom_pair(m, d, geom1, geom2, worldid, elementid) + geom = m.nxn_geom_pair[elementid] + geom1 = geom[0] + geom2 = geom[1] - if m.nxn_geom_pair.shape[0]: - wp.launch(_nxn_broadphase, dim=(d.nworld, m.nxn_geom_pair.shape[0]), inputs=[m, d]) + if _sphere_filter(m, d, geom1, geom2, worldid): + _add_geom_pair(m, d, geom1, geom2, worldid, elementid) + + if m.nxn_geom_pair.shape[0]: + wp.launch( + _nxn_broadphase, dim=(d.nworld, m.nxn_geom_pair.shape[0]), inputs=[m, d] + ) @event_scope def collision(m: Model, d: Data): """Collision detection.""" - # AD: based on engine_collision_driver.py in Eric's warp fork/mjx-collisions-dev - # which is further based on the CUDA code here: - # https://github.com/btaba/mujoco/blob/warp-collisions/mjx/mujoco/mjx/_src/cuda/engine_collision_driver.cu.cc#L458-L583 - - d.ncollision.zero_() - d.ncon.zero_() + with wp.ScopedDevice(m.qpos0.device): + d.ncollision.zero_() + d.ncon.zero_() - if d.nconmax == 0: - return + if d.nconmax == 0: + return - dsbl_flgs = m.opt.disableflags - if (dsbl_flgs & DisableBit.CONSTRAINT) | (dsbl_flgs & DisableBit.CONTACT): - return + dsbl_flgs = m.opt.disableflags + if (dsbl_flgs & DisableBit.CONSTRAINT) | (dsbl_flgs & DisableBit.CONTACT): + return - # TODO(team): determine ngeom to switch from n^2 to sap - if m.ngeom <= 100: - nxn_broadphase(m, d) - else: - sap_broadphase(m, d) - - # TODO(team): we should reject far-away contacts in the narrowphase instead of constraint - # partitioning because we can move some pressure of the atomics - # TODO(team) switch between collision functions and GJK/EPA here - gjk_narrowphase(m, d) - primitive_narrowphase(m, d) - box_box_narrowphase(m, d) + # TODO(team): determine ngeom to switch from n^2 to sap + if m.ngeom <= 100: + nxn_broadphase(m, d) + else: + sap_broadphase(m, d) + + # TODO(team): we should reject far-away contacts in the narrowphase instead of constraint + # partitioning because we can move some pressure of the atomics + # TODO(team) switch between collision functions and GJK/EPA here + gjk_narrowphase(m, d) + primitive_narrowphase(m, d) + box_box_narrowphase(m, d) diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 426ec093..f606cf77 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -1021,6 +1021,7 @@ def _primitive_narrowphase( def primitive_narrowphase(m: Model, d: Data): - # we need to figure out how to keep the overhead of this small - not launching anything - # for pair types without collisions, as well as updating the launch dimensions. - wp.launch(_primitive_narrowphase, dim=d.nconmax, inputs=[m, d]) + with wp.ScopedDevice(m.qpos0.device): + # we need to figure out how to keep the overhead of this small - not launching anything + # for pair types without collisions, as well as updating the launch dimensions. + wp.launch(_primitive_narrowphase, dim=d.nconmax, inputs=[m, d]) diff --git a/mujoco_warp/_src/constraint.py b/mujoco_warp/_src/constraint.py index efaa02f9..b46f093f 100644 --- a/mujoco_warp/_src/constraint.py +++ b/mujoco_warp/_src/constraint.py @@ -704,79 +704,80 @@ def _update_nefc(d: types.Data): def make_constraint(m: types.Model, d: types.Data): """Creates constraint jacobians and other supporting data.""" - d.ne.zero_() - d.ne_connect.zero_() - d.ne_weld.zero_() - d.ne_jnt.zero_() - d.nefc.zero_() - d.nf.zero_() - d.nl.zero_() - - if not (m.opt.disableflags & types.DisableBit.CONSTRAINT.value): - d.efc.J.zero_() - - if not (m.opt.disableflags & types.DisableBit.EQUALITY.value): - wp.launch( - _efc_equality_connect, - dim=(d.nworld, m.eq_connect_adr.size), - inputs=[m, d], - ) - wp.launch( - _efc_equality_weld, - dim=(d.nworld, m.eq_wld_adr.size), - inputs=[m, d], - ) - wp.launch( - _efc_equality_joint, - dim=(d.nworld, m.eq_jnt_adr.size), - inputs=[m, d], - ) - - wp.launch(_num_equality, dim=(1,), inputs=[d]) - - if not (m.opt.disableflags & types.DisableBit.FRICTIONLOSS.value): - wp.launch( - _efc_friction, - dim=(d.nworld, m.nv), - inputs=[m, d], - ) - - # limit - if not (m.opt.disableflags & types.DisableBit.LIMIT.value): - limit_ball = m.jnt_limited_ball_adr.size > 0 - if limit_ball: + with wp.ScopedDevice(m.qpos0.device): + d.ne.zero_() + d.ne_connect.zero_() + d.ne_weld.zero_() + d.ne_jnt.zero_() + d.nefc.zero_() + d.nf.zero_() + d.nl.zero_() + + if not (m.opt.disableflags & types.DisableBit.CONSTRAINT.value): + d.efc.J.zero_() + + if not (m.opt.disableflags & types.DisableBit.EQUALITY.value): wp.launch( - _efc_limit_ball, - dim=(d.nworld, m.jnt_limited_ball_adr.size), + _efc_equality_connect, + dim=(d.nworld, m.eq_connect_adr.size), inputs=[m, d], ) - - limit_slide_hinge = m.jnt_limited_slide_hinge_adr.size > 0 - if limit_slide_hinge: wp.launch( - _efc_limit_slide_hinge, - dim=(d.nworld, m.jnt_limited_slide_hinge_adr.size), + _efc_equality_weld, + dim=(d.nworld, m.eq_wld_adr.size), inputs=[m, d], ) - - limit_tendon = m.tendon_limited_adr.size > 0 - if limit_tendon: wp.launch( - _efc_limit_tendon, - dim=(d.nworld, m.tendon_limited_adr.size), + _efc_equality_joint, + dim=(d.nworld, m.eq_jnt_adr.size), inputs=[m, d], ) - if limit_ball or limit_slide_hinge or limit_tendon: - wp.launch(_update_nefc, dim=(1,), inputs=[d]) + wp.launch(_num_equality, dim=(1,), inputs=[d]) - # contact - if not (m.opt.disableflags & types.DisableBit.CONTACT.value): - if m.opt.cone == types.ConeType.PYRAMIDAL.value: + if not (m.opt.disableflags & types.DisableBit.FRICTIONLOSS.value): wp.launch( - _efc_contact_pyramidal, - dim=(d.nconmax, 2 * (m.condim_max - 1) if m.condim_max > 1 else 1), + _efc_friction, + dim=(d.nworld, m.nv), inputs=[m, d], ) - elif m.opt.cone == types.ConeType.ELLIPTIC.value: - wp.launch(_efc_contact_elliptic, dim=(d.nconmax, m.condim_max), inputs=[m, d]) + + # limit + if not (m.opt.disableflags & types.DisableBit.LIMIT.value): + limit_ball = m.jnt_limited_ball_adr.size > 0 + if limit_ball: + wp.launch( + _efc_limit_ball, + dim=(d.nworld, m.jnt_limited_ball_adr.size), + inputs=[m, d], + ) + + limit_slide_hinge = m.jnt_limited_slide_hinge_adr.size > 0 + if limit_slide_hinge: + wp.launch( + _efc_limit_slide_hinge, + dim=(d.nworld, m.jnt_limited_slide_hinge_adr.size), + inputs=[m, d], + ) + + limit_tendon = m.tendon_limited_adr.size > 0 + if limit_tendon: + wp.launch( + _efc_limit_tendon, + dim=(d.nworld, m.tendon_limited_adr.size), + inputs=[m, d], + ) + + if limit_ball or limit_slide_hinge or limit_tendon: + wp.launch(_update_nefc, dim=(1,), inputs=[d]) + + # contact + if not (m.opt.disableflags & types.DisableBit.CONTACT.value): + if m.opt.cone == types.ConeType.PYRAMIDAL.value: + wp.launch( + _efc_contact_pyramidal, + dim=(d.nconmax, 2 * (m.condim_max - 1) if m.condim_max > 1 else 1), + inputs=[m, d], + ) + elif m.opt.cone == types.ConeType.ELLIPTIC.value: + wp.launch(_efc_contact_elliptic, dim=(d.nconmax, m.condim_max), inputs=[m, d]) diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 01fc6385..b9389e3c 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -185,654 +185,673 @@ def _time(m: Model, d: Data): def euler(m: Model, d: Data): """Euler integrator, semi-implicit in velocity.""" - # integrate damping implicitly + with wp.ScopedDevice(m.qpos0.device): + # integrate damping implicitly - def eulerdamp_sparse(m: Model, d: Data): - @kernel - def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data): - worldid, tid = wp.tid() + def eulerdamp_sparse(m: Model, d: Data): + @kernel + def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data): + worldid, tid = wp.tid() - dof_Madr = m.dof_Madr[tid] - d.qM_integration[worldid, 0, dof_Madr] += m.opt.timestep * m.dof_damping[tid] + dof_Madr = m.dof_Madr[tid] + d.qM_integration[worldid, 0, dof_Madr] += m.opt.timestep * m.dof_damping[tid] + + d.qfrc_integration[worldid, tid] = ( + d.qfrc_smooth[worldid, tid] + d.qfrc_constraint[worldid, tid] + ) - d.qfrc_integration[worldid, tid] = ( - d.qfrc_smooth[worldid, tid] + d.qfrc_constraint[worldid, tid] + wp.copy(d.qM_integration, d.qM) + wp.launch(add_damping_sum_qfrc_kernel_sparse, dim=(d.nworld, m.nv), inputs=[m, d]) + smooth.factor_solve_i( + m, + d, + d.qM_integration, + d.qLD_integration, + d.qLDiagInv_integration, + d.qacc_integration, + d.qfrc_integration, ) - wp.copy(d.qM_integration, d.qM) - wp.launch(add_damping_sum_qfrc_kernel_sparse, dim=(d.nworld, m.nv), inputs=[m, d]) - smooth.factor_solve_i( - m, - d, - d.qM_integration, - d.qLD_integration, - d.qLDiagInv_integration, - d.qacc_integration, - d.qfrc_integration, - ) + def eulerdamp_fused_dense(m: Model, d: Data): + def tile_eulerdamp(adr: int, size: int, tilesize: int): + @kernel + def eulerdamp( + m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int + ): + worldid, nodeid = wp.tid() + dofid = m.qLD_tile[leveladr + nodeid] + M_tile = wp.tile_load( + d.qM[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid) + ) + damping_tile = wp.tile_load(damping, shape=(tilesize,), offset=(dofid,)) + damping_scaled = damping_tile * m.opt.timestep + qm_integration_tile = wp.tile_diag_add(M_tile, damping_scaled) - def eulerdamp_fused_dense(m: Model, d: Data): - def tile_eulerdamp(adr: int, size: int, tilesize: int): - @kernel - def eulerdamp( - m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int - ): - worldid, nodeid = wp.tid() - dofid = m.qLD_tile[leveladr + nodeid] - M_tile = wp.tile_load( - d.qM[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid) - ) - damping_tile = wp.tile_load(damping, shape=(tilesize,), offset=(dofid,)) - damping_scaled = damping_tile * m.opt.timestep - qm_integration_tile = wp.tile_diag_add(M_tile, damping_scaled) + qfrc_smooth_tile = wp.tile_load( + d.qfrc_smooth[worldid], shape=(tilesize,), offset=(dofid,) + ) + qfrc_constraint_tile = wp.tile_load( + d.qfrc_constraint[worldid], shape=(tilesize,), offset=(dofid,) + ) - qfrc_smooth_tile = wp.tile_load( - d.qfrc_smooth[worldid], shape=(tilesize,), offset=(dofid,) - ) - qfrc_constraint_tile = wp.tile_load( - d.qfrc_constraint[worldid], shape=(tilesize,), offset=(dofid,) - ) + qfrc_tile = qfrc_smooth_tile + qfrc_constraint_tile - qfrc_tile = qfrc_smooth_tile + qfrc_constraint_tile + L_tile = wp.tile_cholesky(qm_integration_tile) + qacc_tile = wp.tile_cholesky_solve(L_tile, qfrc_tile) + wp.tile_store(d.qacc_integration[worldid], qacc_tile, offset=(dofid)) - L_tile = wp.tile_cholesky(qm_integration_tile) - qacc_tile = wp.tile_cholesky_solve(L_tile, qfrc_tile) - wp.tile_store(d.qacc_integration[worldid], qacc_tile, offset=(dofid)) + wp.launch_tiled( + eulerdamp, + dim=(d.nworld, size), + inputs=[m, d, m.dof_damping, adr], + block_dim=32, + ) - wp.launch_tiled( - eulerdamp, dim=(d.nworld, size), inputs=[m, d, m.dof_damping, adr], block_dim=32 - ) + qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() - qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() + for i in range(len(qLD_tileadr)): + beg = qLD_tileadr[i] + end = m.qLD_tile.shape[0] if i == len(qLD_tileadr) - 1 else qLD_tileadr[i + 1] + tile_eulerdamp(beg, end - beg, int(qLD_tilesize[i])) - for i in range(len(qLD_tileadr)): - beg = qLD_tileadr[i] - end = m.qLD_tile.shape[0] if i == len(qLD_tileadr) - 1 else qLD_tileadr[i + 1] - tile_eulerdamp(beg, end - beg, int(qLD_tilesize[i])) + if not m.opt.disableflags & DisableBit.EULERDAMP.value: + if m.opt.is_sparse: + eulerdamp_sparse(m, d) + else: + eulerdamp_fused_dense(m, d) - if not m.opt.disableflags & DisableBit.EULERDAMP.value: - if m.opt.is_sparse: - eulerdamp_sparse(m, d) + _advance(m, d, d.qacc_integration) else: - eulerdamp_fused_dense(m, d) - - _advance(m, d, d.qacc_integration) - else: - _advance(m, d, d.qacc) + _advance(m, d, d.qacc) @event_scope def rungekutta4(m: Model, d: Data): """Runge-Kutta explicit order 4 integrator.""" - wp.copy(d.qpos_t0, d.qpos) - wp.copy(d.qvel_t0, d.qvel) - if m.na: - wp.copy(d.act_t0, d.act) - - A, B = _RK4_A, _RK4_B + with wp.ScopedDevice(m.qpos0.device): + wp.copy(d.qpos_t0, d.qpos) + wp.copy(d.qvel_t0, d.qvel) + if m.na: + wp.copy(d.act_t0, d.act) - def rk_accumulate(d: Data, b: float): - """Computes one term of 1/6 k_1 + 1/3 k_2 + 1/3 k_3 + 1/6 k_4""" + A, B = _RK4_A, _RK4_B - @kernel - def _qvel_acc(d: Data, b: float): - worldid, tid = wp.tid() - d.qvel_rk[worldid, tid] += b * d.qvel[worldid, tid] - d.qacc_rk[worldid, tid] += b * d.qacc[worldid, tid] - - if m.na: + def rk_accumulate(d: Data, b: float): + """Computes one term of 1/6 k_1 + 1/3 k_2 + 1/3 k_3 + 1/6 k_4""" @kernel - def _act_dot(d: Data, b: float): + def _qvel_acc(d: Data, b: float): worldid, tid = wp.tid() - d.act_dot_rk[worldid, tid] += b * d.act_dot[worldid, tid] + d.qvel_rk[worldid, tid] += b * d.qvel[worldid, tid] + d.qacc_rk[worldid, tid] += b * d.qacc[worldid, tid] - wp.launch(_qvel_acc, dim=(d.nworld, m.nv), inputs=[d, b]) + if m.na: - if m.na: - wp.launch(_act_dot, dim=(d.nworld, m.na), inputs=[d, b]) + @kernel + def _act_dot(d: Data, b: float): + worldid, tid = wp.tid() + d.act_dot_rk[worldid, tid] += b * d.act_dot[worldid, tid] - def perturb_state(m: Model, d: Data, a: float): - @kernel - def _qpos(m: Model, d: Data): - """Integrate joint positions""" - worldid, jntId = wp.tid() - _integrate_pos(worldid, jntId, m, d.qpos, d.qpos_t0, d.qvel, qvel_scale=a) + wp.launch(_qvel_acc, dim=(d.nworld, m.nv), inputs=[d, b]) - if m.na: + if m.na: + wp.launch(_act_dot, dim=(d.nworld, m.na), inputs=[d, b]) + def perturb_state(m: Model, d: Data, a: float): @kernel - def _act(m: Model, d: Data): - worldid, tid = wp.tid() - dact_dot = a * d.act_dot[worldid, tid] - d.act[worldid, tid] = d.act_t0[worldid, tid] + dact_dot * m.opt.timestep + def _qpos(m: Model, d: Data): + """Integrate joint positions""" + worldid, jntId = wp.tid() + _integrate_pos(worldid, jntId, m, d.qpos, d.qpos_t0, d.qvel, qvel_scale=a) - @kernel - def _qvel(m: Model, d: Data): - worldid, tid = wp.tid() - dqacc = a * d.qacc[worldid, tid] - d.qvel[worldid, tid] = d.qvel_t0[worldid, tid] + dqacc * m.opt.timestep + if m.na: - wp.launch(_qpos, dim=(d.nworld, m.njnt), inputs=[m, d]) - if m.na: - wp.launch(_act, dim=(d.nworld, m.na), inputs=[m, d]) - wp.launch(_qvel, dim=(d.nworld, m.nv), inputs=[m, d]) + @kernel + def _act(m: Model, d: Data): + worldid, tid = wp.tid() + dact_dot = a * d.act_dot[worldid, tid] + d.act[worldid, tid] = d.act_t0[worldid, tid] + dact_dot * m.opt.timestep - rk_accumulate(d, B[0]) - for i in range(3): - a, b = float(A[i][i]), B[i + 1] - perturb_state(m, d, a) - forward(m, d) - rk_accumulate(d, b) - - wp.copy(d.qpos, d.qpos_t0) - wp.copy(d.qvel, d.qvel_t0) - if m.na: - wp.copy(d.act, d.act_t0) - wp.copy(d.act_dot, d.act_dot_rk) - _advance(m, d, d.qacc_rk, d.qvel_rk) + @kernel + def _qvel(m: Model, d: Data): + worldid, tid = wp.tid() + dqacc = a * d.qacc[worldid, tid] + d.qvel[worldid, tid] = d.qvel_t0[worldid, tid] + dqacc * m.opt.timestep + + wp.launch(_qpos, dim=(d.nworld, m.njnt), inputs=[m, d]) + if m.na: + wp.launch(_act, dim=(d.nworld, m.na), inputs=[m, d]) + wp.launch(_qvel, dim=(d.nworld, m.nv), inputs=[m, d]) + + rk_accumulate(d, B[0]) + for i in range(3): + a, b = float(A[i][i]), B[i + 1] + perturb_state(m, d, a) + forward(m, d) + rk_accumulate(d, b) + + wp.copy(d.qpos, d.qpos_t0) + wp.copy(d.qvel, d.qvel_t0) + if m.na: + wp.copy(d.act, d.act_t0) + wp.copy(d.act_dot, d.act_dot_rk) + _advance(m, d, d.qacc_rk, d.qvel_rk) @event_scope def implicit(m: Model, d: Data): """Integrates fully implicit in velocity.""" - # optimization comments (AD) - # I went from small kernels for every step to a relatively big single - # kernel using tile API because it kept improving performance - - # 30M to 50M FPS on an A6000. - # - # The main benefit is reduced global memory roundtrips, but I assume - # there is also some benefit to loading data as early as possible. - # - # I further tried fusing in the cholesky factor/solve but the high - # storage requirements led to low occupancy and thus worse performance. - # - # The actuator_bias_gain_vel kernel could theoretically be fused in as well, - # but it's pretty clean straight-line code that loads a lot of data but - # only stores one array, so I think the benefit of keeping that one on-chip - # is likely not worth it compared to the compromises we're making with tile API. - # It would also need a different data layout for the biasprm/gainprm arrays - # to be tileable. - - # assumptions - assert not m.opt.is_sparse # unsupported - # TODO(team): add sparse version - - # compile-time constants - passive_enabled = not m.opt.disableflags & DisableBit.PASSIVE.value - actuation_enabled = ( - not m.opt.disableflags & DisableBit.ACTUATION.value - ) and m.actuator_affine_bias_gain + with wp.ScopedDevice(m.qpos0.device): + # optimization comments (AD) + # I went from small kernels for every step to a relatively big single + # kernel using tile API because it kept improving performance - + # 30M to 50M FPS on an A6000. + # + # The main benefit is reduced global memory roundtrips, but I assume + # there is also some benefit to loading data as early as possible. + # + # I further tried fusing in the cholesky factor/solve but the high + # storage requirements led to low occupancy and thus worse performance. + # + # The actuator_bias_gain_vel kernel could theoretically be fused in as well, + # but it's pretty clean straight-line code that loads a lot of data but + # only stores one array, so I think the benefit of keeping that one on-chip + # is likely not worth it compared to the compromises we're making with tile API. + # It would also need a different data layout for the biasprm/gainprm arrays + # to be tileable. + + # assumptions + assert not m.opt.is_sparse # unsupported + # TODO(team): add sparse version + + # compile-time constants + passive_enabled = not m.opt.disableflags & DisableBit.PASSIVE.value + actuation_enabled = ( + not m.opt.disableflags & DisableBit.ACTUATION.value + ) and m.actuator_affine_bias_gain - @kernel - def actuator_bias_gain_vel(m: Model, d: Data): - worldid, actid = wp.tid() + @kernel + def actuator_bias_gain_vel(m: Model, d: Data): + worldid, actid = wp.tid() - bias_vel = 0.0 - gain_vel = 0.0 + bias_vel = 0.0 + gain_vel = 0.0 - actuator_biastype = m.actuator_biastype[actid] - actuator_gaintype = m.actuator_gaintype[actid] - actuator_dyntype = m.actuator_dyntype[actid] + actuator_biastype = m.actuator_biastype[actid] + actuator_gaintype = m.actuator_gaintype[actid] + actuator_dyntype = m.actuator_dyntype[actid] - if actuator_biastype == wp.static(BiasType.AFFINE.value): - bias_vel = m.actuator_biasprm[actid][2] + if actuator_biastype == wp.static(BiasType.AFFINE.value): + bias_vel = m.actuator_biasprm[actid][2] - if actuator_gaintype == wp.static(GainType.AFFINE.value): - gain_vel = m.actuator_gainprm[actid][2] + if actuator_gaintype == wp.static(GainType.AFFINE.value): + gain_vel = m.actuator_gainprm[actid][2] - ctrl = d.ctrl[worldid, actid] + ctrl = d.ctrl[worldid, actid] - if actuator_dyntype != wp.static(DynType.NONE.value): - ctrl = d.act[worldid, actid] + if actuator_dyntype != wp.static(DynType.NONE.value): + ctrl = d.act[worldid, actid] - d.act_vel_integration[worldid, actid] = bias_vel + gain_vel * ctrl + d.act_vel_integration[worldid, actid] = bias_vel + gain_vel * ctrl - def qderiv_actuator_damping_fused( - m: Model, d: Data, damping: wp.array(dtype=wp.float32) - ): - if actuation_enabled: - block_dim = 64 - else: - block_dim = 256 + def qderiv_actuator_damping_fused( + m: Model, d: Data, damping: wp.array(dtype=wp.float32) + ): + if actuation_enabled: + block_dim = 64 + else: + block_dim = 256 - @wp.func - def subtract_multiply(x: wp.float32, y: wp.float32): - return x - y * wp.static(m.opt.timestep) + @wp.func + def subtract_multiply(x: wp.float32, y: wp.float32): + return x - y * wp.static(m.opt.timestep) - def qderiv_actuator_damping_tiled( - adr: int, size: int, tilesize_nv: int, tilesize_nu: int - ): - @kernel - def qderiv_actuator_fused_kernel( - m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int + def qderiv_actuator_damping_tiled( + adr: int, size: int, tilesize_nv: int, tilesize_nu: int ): - worldid, nodeid = wp.tid() - offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid] - - # skip tree with no actuators. - if wp.static(actuation_enabled and tilesize_nu != 0): - offset_nu = m.actuator_moment_offset_nu[leveladr + nodeid] - actuator_moment_tile = wp.tile_load( - d.actuator_moment[worldid], - shape=(tilesize_nu, tilesize_nv), - offset=(offset_nu, offset_nv), + @kernel + def qderiv_actuator_fused_kernel( + m: Model, d: Data, damping: wp.array(dtype=wp.float32), leveladr: int + ): + worldid, nodeid = wp.tid() + offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid] + + # skip tree with no actuators. + if wp.static(actuation_enabled and tilesize_nu != 0): + offset_nu = m.actuator_moment_offset_nu[leveladr + nodeid] + actuator_moment_tile = wp.tile_load( + d.actuator_moment[worldid], + shape=(tilesize_nu, tilesize_nv), + offset=(offset_nu, offset_nv), + ) + zeros = wp.tile_zeros(shape=(tilesize_nu, tilesize_nu), dtype=wp.float32) + vel_tile = wp.tile_load( + d.act_vel_integration[worldid], shape=(tilesize_nu), offset=offset_nu + ) + diag = wp.tile_diag_add(zeros, vel_tile) + actuator_moment_T = wp.tile_transpose(actuator_moment_tile) + amTVel = wp.tile_matmul(actuator_moment_T, diag) + qderiv_tile = wp.tile_matmul(amTVel, actuator_moment_tile) + else: + qderiv_tile = wp.tile_zeros( + shape=(tilesize_nv, tilesize_nv), dtype=wp.float32 + ) + + if wp.static(passive_enabled): + dof_damping = wp.tile_load(damping, shape=tilesize_nv, offset=offset_nv) + negative = wp.neg(dof_damping) + qderiv_tile = wp.tile_diag_add(qderiv_tile, negative) + + # add to qM + qM_tile = wp.tile_load( + d.qM[worldid], + shape=(tilesize_nv, tilesize_nv), + offset=(offset_nv, offset_nv), + ) + qderiv_tile = wp.tile_map(subtract_multiply, qM_tile, qderiv_tile) + wp.tile_store( + d.qM_integration[worldid], qderiv_tile, offset=(offset_nv, offset_nv) ) - zeros = wp.tile_zeros(shape=(tilesize_nu, tilesize_nu), dtype=wp.float32) - vel_tile = wp.tile_load( - d.act_vel_integration[worldid], shape=(tilesize_nu), offset=offset_nu + + # sum qfrc + qfrc_smooth_tile = wp.tile_load( + d.qfrc_smooth[worldid], shape=tilesize_nv, offset=offset_nv ) - diag = wp.tile_diag_add(zeros, vel_tile) - actuator_moment_T = wp.tile_transpose(actuator_moment_tile) - amTVel = wp.tile_matmul(actuator_moment_T, diag) - qderiv_tile = wp.tile_matmul(amTVel, actuator_moment_tile) - else: - qderiv_tile = wp.tile_zeros( - shape=(tilesize_nv, tilesize_nv), dtype=wp.float32 + qfrc_constraint_tile = wp.tile_load( + d.qfrc_constraint[worldid], shape=tilesize_nv, offset=offset_nv ) + qfrc_combined = wp.add(qfrc_smooth_tile, qfrc_constraint_tile) + wp.tile_store(d.qfrc_integration[worldid], qfrc_combined, offset=offset_nv) + + wp.launch_tiled( + qderiv_actuator_fused_kernel, + dim=(d.nworld, size), + inputs=[m, d, damping, adr], + block_dim=block_dim, + ) - if wp.static(passive_enabled): - dof_damping = wp.tile_load(damping, shape=tilesize_nv, offset=offset_nv) - negative = wp.neg(dof_damping) - qderiv_tile = wp.tile_diag_add(qderiv_tile, negative) + qderiv_tilesize_nv = m.actuator_moment_tilesize_nv.numpy() + qderiv_tilesize_nu = m.actuator_moment_tilesize_nu.numpy() + qderiv_tileadr = m.actuator_moment_tileadr.numpy() - # add to qM - qM_tile = wp.tile_load( - d.qM[worldid], shape=(tilesize_nv, tilesize_nv), offset=(offset_nv, offset_nv) - ) - qderiv_tile = wp.tile_map(subtract_multiply, qM_tile, qderiv_tile) - wp.tile_store( - d.qM_integration[worldid], qderiv_tile, offset=(offset_nv, offset_nv) + for i in range(len(qderiv_tileadr)): + beg = qderiv_tileadr[i] + end = ( + m.qLD_tile.shape[0] if i == len(qderiv_tileadr) - 1 else qderiv_tileadr[i + 1] ) + if qderiv_tilesize_nv[i] != 0: + qderiv_actuator_damping_tiled( + beg, end - beg, int(qderiv_tilesize_nv[i]), int(qderiv_tilesize_nu[i]) + ) - # sum qfrc - qfrc_smooth_tile = wp.tile_load( - d.qfrc_smooth[worldid], shape=tilesize_nv, offset=offset_nv - ) - qfrc_constraint_tile = wp.tile_load( - d.qfrc_constraint[worldid], shape=tilesize_nv, offset=offset_nv + if passive_enabled or actuation_enabled: + if actuation_enabled: + wp.launch( + actuator_bias_gain_vel, + dim=(d.nworld, m.nu), + inputs=[m, d], ) - qfrc_combined = wp.add(qfrc_smooth_tile, qfrc_constraint_tile) - wp.tile_store(d.qfrc_integration[worldid], qfrc_combined, offset=offset_nv) - wp.launch_tiled( - qderiv_actuator_fused_kernel, - dim=(d.nworld, size), - inputs=[m, d, damping, adr], - block_dim=block_dim, - ) - - qderiv_tilesize_nv = m.actuator_moment_tilesize_nv.numpy() - qderiv_tilesize_nu = m.actuator_moment_tilesize_nu.numpy() - qderiv_tileadr = m.actuator_moment_tileadr.numpy() - - for i in range(len(qderiv_tileadr)): - beg = qderiv_tileadr[i] - end = ( - m.qLD_tile.shape[0] if i == len(qderiv_tileadr) - 1 else qderiv_tileadr[i + 1] - ) - if qderiv_tilesize_nv[i] != 0: - qderiv_actuator_damping_tiled( - beg, end - beg, int(qderiv_tilesize_nv[i]), int(qderiv_tilesize_nu[i]) - ) + qderiv_actuator_damping_fused(m, d, m.dof_damping) - if passive_enabled or actuation_enabled: - if actuation_enabled: - wp.launch( - actuator_bias_gain_vel, - dim=(d.nworld, m.nu), - inputs=[m, d], + smooth._factor_solve_i_dense( + m, d, d.qM_integration, d.qacc_integration, d.qfrc_integration ) - qderiv_actuator_damping_fused(m, d, m.dof_damping) - - smooth._factor_solve_i_dense( - m, d, d.qM_integration, d.qacc_integration, d.qfrc_integration - ) - - _advance(m, d, d.qacc_integration) - else: - _advance(m, d, d.qacc) + _advance(m, d, d.qacc_integration) + else: + _advance(m, d, d.qacc) @event_scope def fwd_position(m: Model, d: Data): """Position-dependent computations.""" - smooth.kinematics(m, d) - smooth.com_pos(m, d) - smooth.camlight(m, d) - smooth.tendon(m, d) - smooth.crb(m, d) - smooth.factor_m(m, d) - collision_driver.collision(m, d) - constraint.make_constraint(m, d) - smooth.transmission(m, d) + with wp.ScopedDevice(m.qpos0.device): + smooth.kinematics(m, d) + smooth.com_pos(m, d) + smooth.camlight(m, d) + smooth.tendon(m, d) + smooth.crb(m, d) + smooth.factor_m(m, d) + collision_driver.collision(m, d) + constraint.make_constraint(m, d) + smooth.transmission(m, d) @event_scope def fwd_velocity(m: Model, d: Data): """Velocity-dependent computations.""" - if m.opt.is_sparse: - # TODO(team): sparse version - NV = m.nv - - @kernel - def _actuator_velocity(d: Data): - worldid, actid = wp.tid() - moment_tile = wp.tile_load(d.actuator_moment[worldid, actid], shape=NV) - qvel_tile = wp.tile_load(d.qvel[worldid], shape=NV) - moment_qvel_tile = wp.tile_map(wp.mul, moment_tile, qvel_tile) - actuator_velocity_tile = wp.tile_reduce(wp.add, moment_qvel_tile) - wp.tile_store(d.actuator_velocity[worldid], actuator_velocity_tile) - - wp.launch_tiled(_actuator_velocity, dim=(d.nworld, m.nu), inputs=[d], block_dim=32) - else: + with wp.ScopedDevice(m.qpos0.device): + if m.opt.is_sparse: + # TODO(team): sparse version + NV = m.nv - def actuator_velocity( - adr: int, - size: int, - tilesize_nu: int, - tilesize_nv: int, - ): @kernel - def _actuator_velocity( - m: Model, d: Data, leveladr: int, velocity: array3df, qvel: array3df - ): - worldid, nodeid = wp.tid() - offset_nu = m.actuator_moment_offset_nu[leveladr + nodeid] - offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid] - actuator_moment_tile = wp.tile_load( - d.actuator_moment[worldid], - shape=(tilesize_nu, tilesize_nv), - offset=(offset_nu, offset_nv), - ) - qvel_tile = wp.tile_load( - qvel[worldid], shape=(tilesize_nv, 1), offset=(offset_nv, 0) - ) - velocity_tile = wp.tile_matmul(actuator_moment_tile, qvel_tile) - - wp.tile_store(velocity[worldid], velocity_tile, offset=(offset_nu, 0)) + def _actuator_velocity(d: Data): + worldid, actid = wp.tid() + moment_tile = wp.tile_load(d.actuator_moment[worldid, actid], shape=NV) + qvel_tile = wp.tile_load(d.qvel[worldid], shape=NV) + moment_qvel_tile = wp.tile_map(wp.mul, moment_tile, qvel_tile) + actuator_velocity_tile = wp.tile_reduce(wp.add, moment_qvel_tile) + wp.tile_store(d.actuator_velocity[worldid], actuator_velocity_tile) wp.launch_tiled( - _actuator_velocity, - dim=(d.nworld, size), - inputs=[ - m, - d, - adr, - d.actuator_velocity.reshape(d.actuator_velocity.shape + (1,)), - d.qvel.reshape(d.qvel.shape + (1,)), - ], - block_dim=32, + _actuator_velocity, dim=(d.nworld, m.nu), inputs=[d], block_dim=32 ) + else: - actuator_moment_tilesize_nu = m.actuator_moment_tilesize_nu.numpy() - actuator_moment_tilesize_nv = m.actuator_moment_tilesize_nv.numpy() - actuator_moment_tileadr = m.actuator_moment_tileadr.numpy() + def actuator_velocity( + adr: int, + size: int, + tilesize_nu: int, + tilesize_nv: int, + ): + @kernel + def _actuator_velocity( + m: Model, d: Data, leveladr: int, velocity: array3df, qvel: array3df + ): + worldid, nodeid = wp.tid() + offset_nu = m.actuator_moment_offset_nu[leveladr + nodeid] + offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid] + actuator_moment_tile = wp.tile_load( + d.actuator_moment[worldid], + shape=(tilesize_nu, tilesize_nv), + offset=(offset_nu, offset_nv), + ) + qvel_tile = wp.tile_load( + qvel[worldid], shape=(tilesize_nv, 1), offset=(offset_nv, 0) + ) + velocity_tile = wp.tile_matmul(actuator_moment_tile, qvel_tile) + + wp.tile_store(velocity[worldid], velocity_tile, offset=(offset_nu, 0)) + + wp.launch_tiled( + _actuator_velocity, + dim=(d.nworld, size), + inputs=[ + m, + d, + adr, + d.actuator_velocity.reshape(d.actuator_velocity.shape + (1,)), + d.qvel.reshape(d.qvel.shape + (1,)), + ], + block_dim=32, + ) - for i in range(len(actuator_moment_tileadr)): - beg = actuator_moment_tileadr[i] - end = ( - m.actuator_moment_tileadr.shape[0] - if i == len(actuator_moment_tileadr) - 1 - else actuator_moment_tileadr[i + 1] - ) - if actuator_moment_tilesize_nu[i] != 0 and actuator_moment_tilesize_nv[i] != 0: - actuator_velocity( - beg, - end - beg, - int(actuator_moment_tilesize_nu[i]), - int(actuator_moment_tilesize_nv[i]), + actuator_moment_tilesize_nu = m.actuator_moment_tilesize_nu.numpy() + actuator_moment_tilesize_nv = m.actuator_moment_tilesize_nv.numpy() + actuator_moment_tileadr = m.actuator_moment_tileadr.numpy() + + for i in range(len(actuator_moment_tileadr)): + beg = actuator_moment_tileadr[i] + end = ( + m.actuator_moment_tileadr.shape[0] + if i == len(actuator_moment_tileadr) - 1 + else actuator_moment_tileadr[i + 1] ) + if actuator_moment_tilesize_nu[i] != 0 and actuator_moment_tilesize_nv[i] != 0: + actuator_velocity( + beg, + end - beg, + int(actuator_moment_tilesize_nu[i]), + int(actuator_moment_tilesize_nv[i]), + ) - if m.ntendon > 0: - # TODO(team): sparse version - NV = m.nv + if m.ntendon > 0: + # TODO(team): sparse version + NV = m.nv - @kernel - def _tendon_velocity(d: Data): - worldid, tenid = wp.tid() - ten_J_tile = wp.tile_load(d.ten_J[worldid, tenid], shape=NV) - qvel_tile = wp.tile_load(d.qvel[worldid], shape=NV) - ten_J_qvel_tile = wp.tile_map(wp.mul, ten_J_tile, qvel_tile) - ten_velocity_tile = wp.tile_reduce(wp.add, ten_J_qvel_tile) - wp.tile_store(d.ten_velocity[worldid], ten_velocity_tile) - - wp.launch_tiled( - _tendon_velocity, dim=(d.nworld, m.ntendon), inputs=[d], block_dim=32 - ) + @kernel + def _tendon_velocity(d: Data): + worldid, tenid = wp.tid() + ten_J_tile = wp.tile_load(d.ten_J[worldid, tenid], shape=NV) + qvel_tile = wp.tile_load(d.qvel[worldid], shape=NV) + ten_J_qvel_tile = wp.tile_map(wp.mul, ten_J_tile, qvel_tile) + ten_velocity_tile = wp.tile_reduce(wp.add, ten_J_qvel_tile) + wp.tile_store(d.ten_velocity[worldid], ten_velocity_tile) - smooth.com_vel(m, d) - passive.passive(m, d) - smooth.rne(m, d) + wp.launch_tiled( + _tendon_velocity, dim=(d.nworld, m.ntendon), inputs=[d], block_dim=32 + ) + + smooth.com_vel(m, d) + passive.passive(m, d) + smooth.rne(m, d) @event_scope def fwd_actuation(m: Model, d: Data): """Actuation-dependent computations.""" - if not m.nu or (m.opt.disableflags & DisableBit.ACTUATION): - d.act_dot.zero_() - d.qfrc_actuator.zero_() - return - @kernel - def _force(m: Model, d: Data): - worldid, uid = wp.tid() + with wp.ScopedDevice(m.qpos0.device): + if not m.nu or (m.opt.disableflags & DisableBit.ACTUATION): + d.act_dot.zero_() + d.qfrc_actuator.zero_() + return - ctrl = d.ctrl[worldid, uid] - dsbl_clampctrl = m.opt.disableflags & wp.static(DisableBit.CLAMPCTRL.value) + @kernel + def _force(m: Model, d: Data): + worldid, uid = wp.tid() - if m.actuator_ctrllimited[uid] and not dsbl_clampctrl: - r = m.actuator_ctrlrange[uid] - ctrl = wp.clamp(ctrl, r[0], r[1]) + ctrl = d.ctrl[worldid, uid] + dsbl_clampctrl = m.opt.disableflags & wp.static(DisableBit.CLAMPCTRL.value) - if m.na: - dyntype = m.actuator_dyntype[uid] + if m.actuator_ctrllimited[uid] and not dsbl_clampctrl: + r = m.actuator_ctrlrange[uid] + ctrl = wp.clamp(ctrl, r[0], r[1]) - if dyntype == int(DynType.INTEGRATOR.value): - d.act_dot[worldid, m.actuator_actadr[uid]] = ctrl - elif dyntype == int(DynType.FILTER.value) or dyntype == int( - DynType.FILTEREXACT.value - ): - dynprm = m.actuator_dynprm[uid] - actadr = m.actuator_actadr[uid] - act = d.act[worldid, actadr] - d.act_dot[worldid, actadr] = (ctrl - act) / wp.max(dynprm[0], MJ_MINVAL) + if m.na: + dyntype = m.actuator_dyntype[uid] - # TODO(team): DynType.MUSCLE + if dyntype == int(DynType.INTEGRATOR.value): + d.act_dot[worldid, m.actuator_actadr[uid]] = ctrl + elif dyntype == int(DynType.FILTER.value) or dyntype == int( + DynType.FILTEREXACT.value + ): + dynprm = m.actuator_dynprm[uid] + actadr = m.actuator_actadr[uid] + act = d.act[worldid, actadr] + d.act_dot[worldid, actadr] = (ctrl - act) / wp.max(dynprm[0], MJ_MINVAL) - ctrl_act = ctrl - if m.na: - if m.actuator_actadr[uid] > -1: - ctrl_act = d.act[worldid, m.actuator_actadr[uid] + m.actuator_actnum[uid] - 1] + # TODO(team): DynType.MUSCLE - # TODO(team): actuator_actearly + ctrl_act = ctrl + if m.na: + if m.actuator_actadr[uid] > -1: + ctrl_act = d.act[worldid, m.actuator_actadr[uid] + m.actuator_actnum[uid] - 1] - length = d.actuator_length[worldid, uid] - velocity = d.actuator_velocity[worldid, uid] + # TODO(team): actuator_actearly - # gain - gaintype = m.actuator_gaintype[uid] - gainprm = m.actuator_gainprm[uid] + length = d.actuator_length[worldid, uid] + velocity = d.actuator_velocity[worldid, uid] - gain = 0.0 - if gaintype == int(GainType.FIXED.value): - gain = gainprm[0] - elif gaintype == int(GainType.AFFINE.value): - gain = gainprm[0] + gainprm[1] * length + gainprm[2] * velocity + # gain + gaintype = m.actuator_gaintype[uid] + gainprm = m.actuator_gainprm[uid] - # TODO(team): GainType.MUSCLE + gain = 0.0 + if gaintype == int(GainType.FIXED.value): + gain = gainprm[0] + elif gaintype == int(GainType.AFFINE.value): + gain = gainprm[0] + gainprm[1] * length + gainprm[2] * velocity - # bias - biastype = m.actuator_biastype[uid] - biasprm = m.actuator_biasprm[uid] + # TODO(team): GainType.MUSCLE - bias = 0.0 # BiasType.NONE - if biastype == int(BiasType.AFFINE.value): - bias = biasprm[0] + biasprm[1] * length + biasprm[2] * velocity + # bias + biastype = m.actuator_biastype[uid] + biasprm = m.actuator_biasprm[uid] - # TODO(team): BiasType.MUSCLE + bias = 0.0 # BiasType.NONE + if biastype == int(BiasType.AFFINE.value): + bias = biasprm[0] + biasprm[1] * length + biasprm[2] * velocity - f = gain * ctrl_act + bias + # TODO(team): BiasType.MUSCLE - # TODO(team): tendon total force clamping + f = gain * ctrl_act + bias - if m.actuator_forcelimited[uid]: - r = m.actuator_forcerange[uid] - f = wp.clamp(f, r[0], r[1]) - d.actuator_force[worldid, uid] = f + # TODO(team): tendon total force clamping - @kernel - def _qfrc_limited(m: Model, d: Data): - worldid, dofid = wp.tid() - jntid = m.dof_jntid[dofid] - if m.jnt_actfrclimited[jntid]: - d.qfrc_actuator[worldid, dofid] = wp.clamp( - d.qfrc_actuator[worldid, dofid], - m.jnt_actfrcrange[jntid][0], - m.jnt_actfrcrange[jntid][1], - ) + if m.actuator_forcelimited[uid]: + r = m.actuator_forcerange[uid] + f = wp.clamp(f, r[0], r[1]) + d.actuator_force[worldid, uid] = f - if m.opt.is_sparse: - # TODO(team): sparse version @kernel - def _qfrc(m: Model, moment: array3df, force: array2df, qfrc: array2df): - worldid, vid = wp.tid() - - s = float(0.0) - for uid in range(m.nu): - # TODO consider using Tile API or transpose moment for better access pattern - s += moment[worldid, uid, vid] * force[worldid, uid] - jntid = m.dof_jntid[vid] + def _qfrc_limited(m: Model, d: Data): + worldid, dofid = wp.tid() + jntid = m.dof_jntid[dofid] if m.jnt_actfrclimited[jntid]: - r = m.jnt_actfrcrange[jntid] - s = wp.clamp(s, r[0], r[1]) - qfrc[worldid, vid] = s + d.qfrc_actuator[worldid, dofid] = wp.clamp( + d.qfrc_actuator[worldid, dofid], + m.jnt_actfrcrange[jntid][0], + m.jnt_actfrcrange[jntid][1], + ) + + if m.opt.is_sparse: + # TODO(team): sparse version + @kernel + def _qfrc(m: Model, moment: array3df, force: array2df, qfrc: array2df): + worldid, vid = wp.tid() - wp.launch(_force, dim=[d.nworld, m.nu], inputs=[m, d]) + s = float(0.0) + for uid in range(m.nu): + # TODO consider using Tile API or transpose moment for better access pattern + s += moment[worldid, uid, vid] * force[worldid, uid] + jntid = m.dof_jntid[vid] + if m.jnt_actfrclimited[jntid]: + r = m.jnt_actfrcrange[jntid] + s = wp.clamp(s, r[0], r[1]) + qfrc[worldid, vid] = s - if m.opt.is_sparse: - # TODO(team): sparse version + wp.launch(_force, dim=[d.nworld, m.nu], inputs=[m, d]) - wp.launch( - _qfrc, - dim=(d.nworld, m.nv), - inputs=[m, d.actuator_moment, d.actuator_force], - outputs=[d.qfrc_actuator], - ) + if m.opt.is_sparse: + # TODO(team): sparse version - else: + wp.launch( + _qfrc, + dim=(d.nworld, m.nv), + inputs=[m, d.actuator_moment, d.actuator_force], + outputs=[d.qfrc_actuator], + ) - def qfrc_actuator(adr: int, size: int, tilesize_nu: int, tilesize_nv: int): - @kernel - def qfrc_actuator_kernel( - m: Model, - d: Data, - leveladr: int, - qfrc_actuator: array3df, - actuator_force: array3df, - ): - worldid, nodeid = wp.tid() - offset_nu = m.actuator_moment_offset_nu[leveladr + nodeid] - offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid] - - actuator_moment_tile = wp.tile_load( - d.actuator_moment[worldid], - shape=(tilesize_nu, tilesize_nv), - offset=(offset_nu, offset_nv), - ) - actuator_moment_T_tile = wp.tile_transpose(actuator_moment_tile) + else: - force_tile = wp.tile_load( - actuator_force[worldid], shape=(tilesize_nu, 1), offset=(offset_nu, 0) - ) - qfrc_tile = wp.tile_matmul(actuator_moment_T_tile, force_tile) - wp.tile_store(qfrc_actuator[worldid], qfrc_tile, offset=(offset_nv, 0)) + def qfrc_actuator(adr: int, size: int, tilesize_nu: int, tilesize_nv: int): + @kernel + def qfrc_actuator_kernel( + m: Model, + d: Data, + leveladr: int, + qfrc_actuator: array3df, + actuator_force: array3df, + ): + worldid, nodeid = wp.tid() + offset_nu = m.actuator_moment_offset_nu[leveladr + nodeid] + offset_nv = m.actuator_moment_offset_nv[leveladr + nodeid] - wp.launch_tiled( - qfrc_actuator_kernel, - dim=(d.nworld, size), - inputs=[ - m, - d, - adr, - d.qfrc_actuator.reshape(d.qfrc_actuator.shape + (1,)), - d.actuator_force.reshape(d.actuator_force.shape + (1,)), - ], - block_dim=32, - ) + actuator_moment_tile = wp.tile_load( + d.actuator_moment[worldid], + shape=(tilesize_nu, tilesize_nv), + offset=(offset_nu, offset_nv), + ) + actuator_moment_T_tile = wp.tile_transpose(actuator_moment_tile) - qderiv_tilesize_nu = m.actuator_moment_tilesize_nu.numpy() - qderiv_tilesize_nv = m.actuator_moment_tilesize_nv.numpy() - qderiv_tileadr = m.actuator_moment_tileadr.numpy() + force_tile = wp.tile_load( + actuator_force[worldid], shape=(tilesize_nu, 1), offset=(offset_nu, 0) + ) + qfrc_tile = wp.tile_matmul(actuator_moment_T_tile, force_tile) + wp.tile_store(qfrc_actuator[worldid], qfrc_tile, offset=(offset_nv, 0)) + + wp.launch_tiled( + qfrc_actuator_kernel, + dim=(d.nworld, size), + inputs=[ + m, + d, + adr, + d.qfrc_actuator.reshape(d.qfrc_actuator.shape + (1,)), + d.actuator_force.reshape(d.actuator_force.shape + (1,)), + ], + block_dim=32, + ) - for i in range(len(qderiv_tileadr)): - beg = qderiv_tileadr[i] - end = ( - m.qLD_tile.shape[0] if i == len(qderiv_tileadr) - 1 else qderiv_tileadr[i + 1] - ) - if qderiv_tilesize_nu[i] != 0 and qderiv_tilesize_nv[i] != 0: - qfrc_actuator( - beg, end - beg, int(qderiv_tilesize_nu[i]), int(qderiv_tilesize_nv[i]) + qderiv_tilesize_nu = m.actuator_moment_tilesize_nu.numpy() + qderiv_tilesize_nv = m.actuator_moment_tilesize_nv.numpy() + qderiv_tileadr = m.actuator_moment_tileadr.numpy() + + for i in range(len(qderiv_tileadr)): + beg = qderiv_tileadr[i] + end = ( + m.qLD_tile.shape[0] if i == len(qderiv_tileadr) - 1 else qderiv_tileadr[i + 1] ) + if qderiv_tilesize_nu[i] != 0 and qderiv_tilesize_nv[i] != 0: + qfrc_actuator( + beg, end - beg, int(qderiv_tilesize_nu[i]), int(qderiv_tilesize_nv[i]) + ) - wp.launch(_qfrc_limited, dim=(d.nworld, m.nv), inputs=[m, d]) + wp.launch(_qfrc_limited, dim=(d.nworld, m.nv), inputs=[m, d]) - # TODO actuator-level gravity compensation, skip if added as passive force + # TODO actuator-level gravity compensation, skip if added as passive force @event_scope def fwd_acceleration(m: Model, d: Data): """Add up all non-constraint forces, compute qacc_smooth.""" - @kernel - def _qfrc_smooth(d: Data): - worldid, dofid = wp.tid() - d.qfrc_smooth[worldid, dofid] = ( - d.qfrc_passive[worldid, dofid] - - d.qfrc_bias[worldid, dofid] - + d.qfrc_actuator[worldid, dofid] - + d.qfrc_applied[worldid, dofid] - ) + with wp.ScopedDevice(m.qpos0.device): + + @kernel + def _qfrc_smooth(d: Data): + worldid, dofid = wp.tid() + d.qfrc_smooth[worldid, dofid] = ( + d.qfrc_passive[worldid, dofid] + - d.qfrc_bias[worldid, dofid] + + d.qfrc_actuator[worldid, dofid] + + d.qfrc_applied[worldid, dofid] + ) - wp.launch(_qfrc_smooth, dim=(d.nworld, m.nv), inputs=[d]) - xfrc_accumulate(m, d, d.qfrc_smooth) + wp.launch(_qfrc_smooth, dim=(d.nworld, m.nv), inputs=[d]) + xfrc_accumulate(m, d, d.qfrc_smooth) - smooth.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth) + smooth.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth) @event_scope def forward(m: Model, d: Data): """Forward dynamics.""" - fwd_position(m, d) - sensor.sensor_pos(m, d) - fwd_velocity(m, d) - sensor.sensor_vel(m, d) - fwd_actuation(m, d) - fwd_acceleration(m, d) - sensor.sensor_acc(m, d) - - if d.njmax == 0: - wp.copy(d.qacc, d.qacc_smooth) - else: - solver.solve(m, d) + with wp.ScopedDevice(m.qpos0.device): + fwd_position(m, d) + sensor.sensor_pos(m, d) + fwd_velocity(m, d) + sensor.sensor_vel(m, d) + fwd_actuation(m, d) + fwd_acceleration(m, d) + sensor.sensor_acc(m, d) + + if d.njmax == 0: + wp.copy(d.qacc, d.qacc_smooth) + else: + solver.solve(m, d) @event_scope def step(m: Model, d: Data): """Advance simulation.""" - forward(m, d) - - if m.opt.integrator == mujoco.mjtIntegrator.mjINT_EULER: - euler(m, d) - elif m.opt.integrator == mujoco.mjtIntegrator.mjINT_RK4: - rungekutta4(m, d) - elif m.opt.integrator == mujoco.mjtIntegrator.mjINT_IMPLICITFAST: - implicit(m, d) - else: - raise NotImplementedError(f"integrator {m.opt.integrator} not implemented.") + + with wp.ScopedDevice(m.qpos0.device): + forward(m, d) + + if m.opt.integrator == mujoco.mjtIntegrator.mjINT_EULER: + euler(m, d) + elif m.opt.integrator == mujoco.mjtIntegrator.mjINT_RK4: + rungekutta4(m, d) + elif m.opt.integrator == mujoco.mjtIntegrator.mjINT_IMPLICITFAST: + implicit(m, d) + else: + raise NotImplementedError(f"integrator {m.opt.integrator} not implemented.") diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 299ce2b4..a15febea 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -70,539 +70,542 @@ def geom_pair(m: mujoco.MjModel) -> Tuple[np.array, np.array]: return np.array(geompairs), np.array(pairids) -def put_model(mjm: mujoco.MjModel) -> types.Model: - # check supported features - for field, field_types, field_str in ( - (mjm.actuator_trntype, types.TrnType, "Actuator transmission type"), - (mjm.actuator_dyntype, types.DynType, "Actuator dynamics type"), - (mjm.actuator_gaintype, types.GainType, "Gain type"), - (mjm.actuator_biastype, types.BiasType, "Bias type"), - (mjm.eq_type, types.EqType, "Equality constraint types"), - (mjm.geom_type, types.GeomType, "Geom type"), - (mjm.sensor_type, types.SensorType, "Sensor types"), - (mjm.wrap_type, types.WrapType, "Wrap types"), - ): - unsupported = ~np.isin(field, list(field_types)) - if unsupported.any(): - raise NotImplementedError(f"{field_str} {field[unsupported]} not supported.") - - if mjm.sensor_cutoff.any(): - raise NotImplementedError("Sensor cutoff is unsupported.") - - for n, msg in ( - (mjm.nplugin, "Plugins"), - (mjm.nflex, "Flexes"), - ): - if n > 0: - raise NotImplementedError(f"{msg} are unsupported.") - - if mjm.tendon_frictionloss.any(): - raise NotImplementedError("Tendon frictionloss is unsupported.") - - # check options - for opt, opt_types, msg in ( - (mjm.opt.integrator, types.IntegratorType, "Integrator"), - (mjm.opt.cone, types.ConeType, "Cone"), - (mjm.opt.solver, types.SolverType, "Solver"), - ): - if opt not in set(opt_types): - raise NotImplementedError(f"{msg} {opt} is unsupported.") - - if mjm.opt.wind.any(): - raise NotImplementedError("Wind is unsupported.") - - if mjm.opt.density > 0 or mjm.opt.viscosity > 0: - raise NotImplementedError("Fluid forces are unsupported.") - - # TODO(team): remove after solver._update_gradient for Newton solver utilizes tile operations for islands - nv_max = 60 - if mjm.nv > nv_max and (not mjm.opt.jacobian == mujoco.mjtJacobian.mjJAC_SPARSE): - raise ValueError(f"Dense is unsupported for nv > {nv_max} (nv = {mjm.nv}).") - - m = types.Model() - - m.nq = mjm.nq - m.nv = mjm.nv - m.na = mjm.na - m.nu = mjm.nu - m.nbody = mjm.nbody - m.njnt = mjm.njnt - m.ngeom = mjm.ngeom - m.nsite = mjm.nsite - m.ncam = mjm.ncam - m.nlight = mjm.nlight - m.nmocap = mjm.nmocap - m.nM = mjm.nM - m.ntendon = mjm.ntendon - m.nwrap = mjm.nwrap - m.nsensor = mjm.nsensor - m.nsensordata = mjm.nsensordata - m.nlsp = mjm.opt.ls_iterations # TODO(team): how to set nlsp? - m.npair = mjm.npair - m.nexclude = mjm.nexclude - m.neq = mjm.neq - m.opt.timestep = mjm.opt.timestep - m.opt.tolerance = mjm.opt.tolerance - m.opt.ls_tolerance = mjm.opt.ls_tolerance - m.opt.gravity = wp.vec3(mjm.opt.gravity) - m.opt.cone = mjm.opt.cone - m.opt.solver = mjm.opt.solver - m.opt.iterations = mjm.opt.iterations - m.opt.ls_iterations = mjm.opt.ls_iterations - m.opt.integrator = mjm.opt.integrator - m.opt.disableflags = mjm.opt.disableflags - m.opt.impratio = wp.float32(mjm.opt.impratio) - m.opt.is_sparse = support.is_sparse(mjm) - m.opt.ls_parallel = False - # TODO(team) Figure out good default parameters - m.opt.gjk_iteration_count = wp.int32(1) # warp only - m.opt.epa_iteration_count = wp.int32(12) # warp only - m.opt.epa_exact_neg_distance = wp.bool(False) # warp only - m.opt.depth_extension = wp.float32(0.1) # warp only - m.stat.meaninertia = mjm.stat.meaninertia - - m.qpos0 = wp.array(mjm.qpos0, dtype=wp.float32, ndim=1) - m.qpos_spring = wp.array(mjm.qpos_spring, dtype=wp.float32, ndim=1) - - # dof lower triangle row and column indices - dof_tri_row, dof_tri_col = np.tril_indices(mjm.nv) - - # indices for sparse qM full_m - is_, js = [], [] - for i in range(mjm.nv): - j = i - while j > -1: - is_.append(i) - js.append(j) - j = mjm.dof_parentid[j] - qM_fullm_i = is_ - qM_fullm_j = js - - # indices for sparse qM mul_m - is_, js, madr_ijs = [], [], [] - for i in range(mjm.nv): - madr_ij, j = mjm.dof_Madr[i], i - - while True: - madr_ij, j = madr_ij + 1, mjm.dof_parentid[j] - if j == -1: - break - is_, js, madr_ijs = is_ + [i], js + [j], madr_ijs + [madr_ij] - - qM_mulm_i, qM_mulm_j, qM_madr_ij = ( - np.array(x, dtype=np.int32) for x in (is_, js, madr_ijs) - ) - - jnt_limited_slide_hinge_adr = np.nonzero( - mjm.jnt_limited - & ( - (mjm.jnt_type == mujoco.mjtJoint.mjJNT_SLIDE) - | (mjm.jnt_type == mujoco.mjtJoint.mjJNT_HINGE) - ) - )[0] - - jnt_limited_ball_adr = np.nonzero( - mjm.jnt_limited & (mjm.jnt_type == mujoco.mjtJoint.mjJNT_BALL) - )[0] - - # body_tree is BFS ordering of body ids - # body_treeadr contains starting index of each body tree level - bodies, body_depth = {}, np.zeros(mjm.nbody, dtype=int) - 1 - for i in range(mjm.nbody): - body_depth[i] = body_depth[mjm.body_parentid[i]] + 1 - bodies.setdefault(body_depth[i], []).append(i) - body_tree = np.concatenate([bodies[i] for i in range(len(bodies))]) - tree_off = [0] + [len(bodies[i]) for i in range(len(bodies))] - body_treeadr = np.cumsum(tree_off)[:-1] - - m.body_tree = wp.array(body_tree, dtype=wp.int32, ndim=1) - m.body_treeadr = wp.array(body_treeadr, dtype=wp.int32, ndim=1, device="cpu") - - qLD_update_tree = np.empty(shape=(0, 3), dtype=int) - qLD_update_treeadr = np.empty(shape=(0,), dtype=int) - qLD_tile = np.empty(shape=(0,), dtype=int) - qLD_tileadr = np.empty(shape=(0,), dtype=int) - qLD_tilesize = np.empty(shape=(0,), dtype=int) +def put_model( + mjm: mujoco.MjModel, device: Optional[wp.context.Device] = None +) -> types.Model: + with wp.ScopedDevice(device): + # check supported features + for field, field_types, field_str in ( + (mjm.actuator_trntype, types.TrnType, "Actuator transmission type"), + (mjm.actuator_dyntype, types.DynType, "Actuator dynamics type"), + (mjm.actuator_gaintype, types.GainType, "Gain type"), + (mjm.actuator_biastype, types.BiasType, "Bias type"), + (mjm.eq_type, types.EqType, "Equality constraint types"), + (mjm.geom_type, types.GeomType, "Geom type"), + (mjm.sensor_type, types.SensorType, "Sensor types"), + (mjm.wrap_type, types.WrapType, "Wrap types"), + ): + unsupported = ~np.isin(field, list(field_types)) + if unsupported.any(): + raise NotImplementedError(f"{field_str} {field[unsupported]} not supported.") + + if mjm.sensor_cutoff.any(): + raise NotImplementedError("Sensor cutoff is unsupported.") + + for n, msg in ( + (mjm.nplugin, "Plugins"), + (mjm.nflex, "Flexes"), + ): + if n > 0: + raise NotImplementedError(f"{msg} are unsupported.") + + if mjm.tendon_frictionloss.any(): + raise NotImplementedError("Tendon frictionloss is unsupported.") + + # check options + for opt, opt_types, msg in ( + (mjm.opt.integrator, types.IntegratorType, "Integrator"), + (mjm.opt.cone, types.ConeType, "Cone"), + (mjm.opt.solver, types.SolverType, "Solver"), + ): + if opt not in set(opt_types): + raise NotImplementedError(f"{msg} {opt} is unsupported.") + + if mjm.opt.wind.any(): + raise NotImplementedError("Wind is unsupported.") + + if mjm.opt.density > 0 or mjm.opt.viscosity > 0: + raise NotImplementedError("Fluid forces are unsupported.") + + # TODO(team): remove after solver._update_gradient for Newton solver utilizes tile operations for islands + nv_max = 60 + if mjm.nv > nv_max and (not mjm.opt.jacobian == mujoco.mjtJacobian.mjJAC_SPARSE): + raise ValueError(f"Dense is unsupported for nv > {nv_max} (nv = {mjm.nv}).") + + m = types.Model() + + m.nq = mjm.nq + m.nv = mjm.nv + m.na = mjm.na + m.nu = mjm.nu + m.nbody = mjm.nbody + m.njnt = mjm.njnt + m.ngeom = mjm.ngeom + m.nsite = mjm.nsite + m.ncam = mjm.ncam + m.nlight = mjm.nlight + m.nmocap = mjm.nmocap + m.nM = mjm.nM + m.ntendon = mjm.ntendon + m.nwrap = mjm.nwrap + m.nsensor = mjm.nsensor + m.nsensordata = mjm.nsensordata + m.nlsp = mjm.opt.ls_iterations # TODO(team): how to set nlsp? + m.npair = mjm.npair + m.nexclude = mjm.nexclude + m.neq = mjm.neq + m.opt.timestep = mjm.opt.timestep + m.opt.tolerance = mjm.opt.tolerance + m.opt.ls_tolerance = mjm.opt.ls_tolerance + m.opt.gravity = wp.vec3(mjm.opt.gravity) + m.opt.cone = mjm.opt.cone + m.opt.solver = mjm.opt.solver + m.opt.iterations = mjm.opt.iterations + m.opt.ls_iterations = mjm.opt.ls_iterations + m.opt.integrator = mjm.opt.integrator + m.opt.disableflags = mjm.opt.disableflags + m.opt.impratio = wp.float32(mjm.opt.impratio) + m.opt.is_sparse = support.is_sparse(mjm) + m.opt.ls_parallel = False + # TODO(team) Figure out good default parameters + m.opt.gjk_iteration_count = wp.int32(1) # warp only + m.opt.epa_iteration_count = wp.int32(12) # warp only + m.opt.epa_exact_neg_distance = wp.bool(False) # warp only + m.opt.depth_extension = wp.float32(0.1) # warp only + m.stat.meaninertia = mjm.stat.meaninertia + + m.qpos0 = wp.array(mjm.qpos0, dtype=wp.float32, ndim=1) + m.qpos_spring = wp.array(mjm.qpos_spring, dtype=wp.float32, ndim=1) + + # dof lower triangle row and column indices + dof_tri_row, dof_tri_col = np.tril_indices(mjm.nv) + + # indices for sparse qM full_m + is_, js = [], [] + for i in range(mjm.nv): + j = i + while j > -1: + is_.append(i) + js.append(j) + j = mjm.dof_parentid[j] + qM_fullm_i = is_ + qM_fullm_j = js - if support.is_sparse(mjm): - # qLD_update_tree has dof tree ordering of qLD updates for sparse factor m - # qLD_update_treeadr contains starting index of each dof tree level - mjd = mujoco.MjData(mjm) - if version.parse(mujoco.__version__) > version.parse("3.2.7"): - m.M_rownnz = wp.array(mjd.M_rownnz, dtype=wp.int32, ndim=1) - m.M_rowadr = wp.array(mjd.M_rowadr, dtype=wp.int32, ndim=1) - m.M_colind = wp.array(mjd.M_colind, dtype=wp.int32, ndim=1) - m.mapM2M = wp.array(mjd.mapM2M, dtype=wp.int32, ndim=1) - qLD_updates, dof_depth = {}, np.zeros(mjm.nv, dtype=int) - 1 - - rownnz = mjd.M_rownnz - rowadr = mjd.M_rowadr - - for k in range(mjm.nv): - dof_depth[k] = dof_depth[mjm.dof_parentid[k]] + 1 - i = mjm.dof_parentid[k] - diag_k = rowadr[k] + rownnz[k] - 1 - Madr_ki = diag_k - 1 - while i > -1: - qLD_updates.setdefault(dof_depth[i], []).append((i, k, Madr_ki)) - i = mjm.dof_parentid[i] - Madr_ki -= 1 - - qLD_update_tree = np.concatenate( - [qLD_updates[i] for i in range(len(qLD_updates))] - ) - tree_off = [0] + [len(qLD_updates[i]) for i in range(len(qLD_updates))] - qLD_update_treeadr = np.cumsum(tree_off)[:-1] - else: - qLD_updates, dof_depth = {}, np.zeros(mjm.nv, dtype=int) - 1 - for k in range(mjm.nv): - dof_depth[k] = dof_depth[mjm.dof_parentid[k]] + 1 - i = mjm.dof_parentid[k] - Madr_ki = mjm.dof_Madr[k] + 1 - while i > -1: - qLD_updates.setdefault(dof_depth[i], []).append((i, k, Madr_ki)) - i = mjm.dof_parentid[i] - Madr_ki += 1 - - # qLD_treeadr contains starting indicies of each level of sparse updates - qLD_update_tree = np.concatenate( - [qLD_updates[i] for i in range(len(qLD_updates))] - ) - tree_off = [0] + [len(qLD_updates[i]) for i in range(len(qLD_updates))] - qLD_update_treeadr = np.cumsum(tree_off)[:-1] + # indices for sparse qM mul_m + is_, js, madr_ijs = [], [], [] + for i in range(mjm.nv): + madr_ij, j = mjm.dof_Madr[i], i - else: - # qLD_tile has the dof id of each tile in qLD for dense factor m - # qLD_tileadr contains starting index in qLD_tile of each tile group - # qLD_tilesize has the square tile size of each tile group - tile_corners = [i for i in range(mjm.nv) if mjm.dof_parentid[i] == -1] - tiles = {} - for i in range(len(tile_corners)): - tile_beg = tile_corners[i] - tile_end = mjm.nv if i == len(tile_corners) - 1 else tile_corners[i + 1] - tiles.setdefault(tile_end - tile_beg, []).append(tile_beg) - qLD_tile = np.concatenate([tiles[sz] for sz in sorted(tiles.keys())]) - tile_off = [0] + [len(tiles[sz]) for sz in sorted(tiles.keys())] - qLD_tileadr = np.cumsum(tile_off)[:-1] - qLD_tilesize = np.array(sorted(tiles.keys())) - - # tiles for actuator_moment - needs nu + nv tile size and offset - actuator_moment_offset_nv = np.empty(shape=(0,), dtype=int) - actuator_moment_offset_nu = np.empty(shape=(0,), dtype=int) - actuator_moment_tileadr = np.empty(shape=(0,), dtype=int) - actuator_moment_tilesize_nv = np.empty(shape=(0,), dtype=int) - actuator_moment_tilesize_nu = np.empty(shape=(0,), dtype=int) - - if not support.is_sparse(mjm): - # how many actuators for each tree - tile_corners = [i for i in range(mjm.nv) if mjm.dof_parentid[i] == -1] - tree_id = mjm.dof_treeid[tile_corners] - num_trees = int(np.max(tree_id)) - tree = mjm.body_treeid[mjm.jnt_bodyid[mjm.actuator_trnid[:, 0]]] - counts, ids = np.histogram(tree, bins=np.arange(0, num_trees + 2)) - acts_per_tree = dict(zip([int(i) for i in ids], [int(i) for i in counts])) - - tiles = {} - act_beg = 0 - for i in range(len(tile_corners)): - tile_beg = tile_corners[i] - tile_end = mjm.nv if i == len(tile_corners) - 1 else tile_corners[i + 1] - tree = int(tree_id[i]) - act_num = acts_per_tree[tree] - tiles.setdefault((tile_end - tile_beg, act_num), []).append((tile_beg, act_beg)) - act_beg += act_num - - sorted_keys = sorted(tiles.keys()) - actuator_moment_offset_nv = [ - t[0] for key in sorted_keys for t in tiles.get(key, []) - ] - actuator_moment_offset_nu = [ - t[1] for key in sorted_keys for t in tiles.get(key, []) - ] - tile_off = [0] + [len(tiles[sz]) for sz in sorted(tiles.keys())] - actuator_moment_tileadr = np.cumsum(tile_off)[:-1] # offset - actuator_moment_tilesize_nv = np.array( - [a[0] for a in sorted_keys] - ) # for this level - actuator_moment_tilesize_nu = np.array( - [int(a[1]) for a in sorted_keys] - ) # for this level - - m.qM_fullm_i = wp.array(qM_fullm_i, dtype=wp.int32, ndim=1) - m.qM_fullm_j = wp.array(qM_fullm_j, dtype=wp.int32, ndim=1) - m.qM_mulm_i = wp.array(qM_mulm_i, dtype=wp.int32, ndim=1) - m.qM_mulm_j = wp.array(qM_mulm_j, dtype=wp.int32, ndim=1) - m.qM_madr_ij = wp.array(qM_madr_ij, dtype=wp.int32, ndim=1) - m.qLD_update_tree = wp.array(qLD_update_tree, dtype=wp.vec3i, ndim=1) - m.qLD_update_treeadr = wp.array( - qLD_update_treeadr, dtype=wp.int32, ndim=1, device="cpu" - ) - m.qLD_tile = wp.array(qLD_tile, dtype=wp.int32, ndim=1) - m.qLD_tileadr = wp.array(qLD_tileadr, dtype=wp.int32, ndim=1, device="cpu") - m.qLD_tilesize = wp.array(qLD_tilesize, dtype=wp.int32, ndim=1, device="cpu") - m.actuator_moment_offset_nv = wp.array( - actuator_moment_offset_nv, dtype=wp.int32, ndim=1 - ) - m.actuator_moment_offset_nu = wp.array( - actuator_moment_offset_nu, dtype=wp.int32, ndim=1 - ) - m.actuator_moment_tileadr = wp.array( - actuator_moment_tileadr, dtype=wp.int32, ndim=1, device="cpu" - ) - m.actuator_moment_tilesize_nv = wp.array( - actuator_moment_tilesize_nv, dtype=wp.int32, ndim=1, device="cpu" - ) - m.actuator_moment_tilesize_nu = wp.array( - actuator_moment_tilesize_nu, dtype=wp.int32, ndim=1, device="cpu" - ) - m.alpha_candidate = wp.array(np.linspace(0.0, 1.0, m.nlsp), dtype=wp.float32) - m.body_dofadr = wp.array(mjm.body_dofadr, dtype=wp.int32, ndim=1) - m.body_dofnum = wp.array(mjm.body_dofnum, dtype=wp.int32, ndim=1) - m.body_jntadr = wp.array(mjm.body_jntadr, dtype=wp.int32, ndim=1) - m.body_jntnum = wp.array(mjm.body_jntnum, dtype=wp.int32, ndim=1) - m.body_parentid = wp.array(mjm.body_parentid, dtype=wp.int32, ndim=1) - m.body_mocapid = wp.array(mjm.body_mocapid, dtype=wp.int32, ndim=1) - m.body_weldid = wp.array(mjm.body_weldid, dtype=wp.int32, ndim=1) - m.body_pos = wp.array(mjm.body_pos, dtype=wp.vec3, ndim=1) - m.body_quat = wp.array(mjm.body_quat, dtype=wp.quat, ndim=1) - m.body_ipos = wp.array(mjm.body_ipos, dtype=wp.vec3, ndim=1) - m.body_iquat = wp.array(mjm.body_iquat, dtype=wp.quat, ndim=1) - m.body_rootid = wp.array(mjm.body_rootid, dtype=wp.int32, ndim=1) - m.body_inertia = wp.array(mjm.body_inertia, dtype=wp.vec3, ndim=1) - m.body_mass = wp.array(mjm.body_mass, dtype=wp.float32, ndim=1) - m.body_subtreemass = wp.array(mjm.body_subtreemass, dtype=wp.float32, ndim=1) - - subtree_mass = np.copy(mjm.body_mass) - # TODO(team): should this be [mjm.nbody - 1, 0) ? - for i in range(mjm.nbody - 1, -1, -1): - subtree_mass[mjm.body_parentid[i]] += subtree_mass[i] - - m.subtree_mass = wp.array(subtree_mass, dtype=wp.float32, ndim=1) - m.body_invweight0 = wp.array(mjm.body_invweight0, dtype=wp.float32, ndim=2) - m.body_geomnum = wp.array(mjm.body_geomnum, dtype=wp.int32, ndim=1) - m.body_geomadr = wp.array(mjm.body_geomadr, dtype=wp.int32, ndim=1) - m.body_contype = wp.array(mjm.body_contype, dtype=wp.int32, ndim=1) - m.body_conaffinity = wp.array(mjm.body_conaffinity, dtype=wp.int32, ndim=1) - m.jnt_bodyid = wp.array(mjm.jnt_bodyid, dtype=wp.int32, ndim=1) - m.jnt_limited = wp.array(mjm.jnt_limited, dtype=wp.int32, ndim=1) - m.jnt_limited_slide_hinge_adr = wp.array( - jnt_limited_slide_hinge_adr, dtype=wp.int32, ndim=1 - ) - m.jnt_limited_ball_adr = wp.array(jnt_limited_ball_adr, dtype=wp.int32, ndim=1) - m.jnt_type = wp.array(mjm.jnt_type, dtype=wp.int32, ndim=1) - m.jnt_solref = wp.array(mjm.jnt_solref, dtype=wp.vec2f, ndim=1) - m.jnt_solimp = wp.array(mjm.jnt_solimp, dtype=types.vec5, ndim=1) - m.jnt_qposadr = wp.array(mjm.jnt_qposadr, dtype=wp.int32, ndim=1) - m.jnt_dofadr = wp.array(mjm.jnt_dofadr, dtype=wp.int32, ndim=1) - m.jnt_axis = wp.array(mjm.jnt_axis, dtype=wp.vec3, ndim=1) - m.jnt_pos = wp.array(mjm.jnt_pos, dtype=wp.vec3, ndim=1) - m.jnt_range = wp.array(mjm.jnt_range, dtype=wp.float32, ndim=2) - m.jnt_margin = wp.array(mjm.jnt_margin, dtype=wp.float32, ndim=1) - m.jnt_stiffness = wp.array(mjm.jnt_stiffness, dtype=wp.float32, ndim=1) - m.jnt_actfrclimited = wp.array(mjm.jnt_actfrclimited, dtype=wp.bool, ndim=1) - m.jnt_actfrcrange = wp.array(mjm.jnt_actfrcrange, dtype=wp.vec2, ndim=1) - m.geom_type = wp.array(mjm.geom_type, dtype=wp.int32, ndim=1) - m.geom_bodyid = wp.array(mjm.geom_bodyid, dtype=wp.int32, ndim=1) - m.geom_conaffinity = wp.array(mjm.geom_conaffinity, dtype=wp.int32, ndim=1) - m.geom_contype = wp.array(mjm.geom_contype, dtype=wp.int32, ndim=1) - m.geom_condim = wp.array(mjm.geom_condim, dtype=wp.int32, ndim=1) - m.geom_pos = wp.array(mjm.geom_pos, dtype=wp.vec3, ndim=1) - m.geom_quat = wp.array(mjm.geom_quat, dtype=wp.quat, ndim=1) - m.geom_size = wp.array(mjm.geom_size, dtype=wp.vec3, ndim=1) - m.geom_priority = wp.array(mjm.geom_priority, dtype=wp.int32, ndim=1) - m.geom_solmix = wp.array(mjm.geom_solmix, dtype=wp.float32, ndim=1) - m.geom_solref = wp.array(mjm.geom_solref, dtype=wp.vec2, ndim=1) - m.geom_solimp = wp.array(mjm.geom_solimp, dtype=types.vec5, ndim=1) - m.geom_friction = wp.array(mjm.geom_friction, dtype=wp.vec3, ndim=1) - m.geom_margin = wp.array(mjm.geom_margin, dtype=wp.float32, ndim=1) - m.geom_gap = wp.array(mjm.geom_gap, dtype=wp.float32, ndim=1) - m.geom_aabb = wp.array(mjm.geom_aabb, dtype=wp.vec3, ndim=3) - m.geom_rbound = wp.array(mjm.geom_rbound, dtype=wp.float32, ndim=1) - m.geom_dataid = wp.array(mjm.geom_dataid, dtype=wp.int32, ndim=1) - m.mesh_vertadr = wp.array(mjm.mesh_vertadr, dtype=wp.int32, ndim=1) - m.mesh_vertnum = wp.array(mjm.mesh_vertnum, dtype=wp.int32, ndim=1) - m.mesh_vert = wp.array(mjm.mesh_vert, dtype=wp.vec3, ndim=1) - m.eq_type = wp.array(mjm.eq_type, dtype=wp.int32, ndim=1) - m.eq_obj1id = wp.array(mjm.eq_obj1id, dtype=wp.int32, ndim=1) - m.eq_obj2id = wp.array(mjm.eq_obj2id, dtype=wp.int32, ndim=1) - m.eq_objtype = wp.array(mjm.eq_objtype, dtype=wp.int32, ndim=1) - m.eq_active0 = wp.array(mjm.eq_active0, dtype=wp.bool, ndim=1) - m.eq_solref = wp.array(mjm.eq_solref, dtype=wp.vec2, ndim=1) - m.eq_solimp = wp.array(mjm.eq_solimp, dtype=types.vec5, ndim=1) - m.eq_data = wp.array(mjm.eq_data, dtype=types.vec11, ndim=1) - m.site_pos = wp.array(mjm.site_pos, dtype=wp.vec3, ndim=1) - m.site_quat = wp.array(mjm.site_quat, dtype=wp.quat, ndim=1) - m.site_bodyid = wp.array(mjm.site_bodyid, dtype=wp.int32, ndim=1) - m.cam_mode = wp.array(mjm.cam_mode, dtype=wp.int32, ndim=1) - m.cam_bodyid = wp.array(mjm.cam_bodyid, dtype=wp.int32, ndim=1) - m.cam_targetbodyid = wp.array(mjm.cam_targetbodyid, dtype=wp.int32, ndim=1) - m.cam_pos = wp.array(mjm.cam_pos, dtype=wp.vec3, ndim=1) - m.cam_quat = wp.array(mjm.cam_quat, dtype=wp.quat, ndim=1) - m.cam_poscom0 = wp.array(mjm.cam_poscom0, dtype=wp.vec3, ndim=1) - m.cam_pos0 = wp.array(mjm.cam_pos0, dtype=wp.vec3, ndim=1) - m.light_mode = wp.array(mjm.light_mode, dtype=wp.int32, ndim=1) - m.light_bodyid = wp.array(mjm.light_bodyid, dtype=wp.int32, ndim=1) - m.light_targetbodyid = wp.array(mjm.light_targetbodyid, dtype=wp.int32, ndim=1) - m.light_pos = wp.array(mjm.light_pos, dtype=wp.vec3, ndim=1) - m.light_dir = wp.array(mjm.light_dir, dtype=wp.vec3, ndim=1) - m.light_poscom0 = wp.array(mjm.light_poscom0, dtype=wp.vec3, ndim=1) - m.light_pos0 = wp.array(mjm.light_pos0, dtype=wp.vec3, ndim=1) - m.dof_bodyid = wp.array(mjm.dof_bodyid, dtype=wp.int32, ndim=1) - m.dof_jntid = wp.array(mjm.dof_jntid, dtype=wp.int32, ndim=1) - m.dof_parentid = wp.array(mjm.dof_parentid, dtype=wp.int32, ndim=1) - m.dof_Madr = wp.array(mjm.dof_Madr, dtype=wp.int32, ndim=1) - m.dof_armature = wp.array(mjm.dof_armature, dtype=wp.float32, ndim=1) - m.dof_damping = wp.array(mjm.dof_damping, dtype=wp.float32, ndim=1) - m.dof_frictionloss = wp.array(mjm.dof_frictionloss, dtype=wp.float32, ndim=1) - m.dof_solimp = wp.array(mjm.dof_solimp, dtype=types.vec5, ndim=1) - m.dof_solref = wp.array(mjm.dof_solref, dtype=wp.vec2, ndim=1) - m.dof_tri_row = wp.from_numpy(dof_tri_row, dtype=wp.int32) - m.dof_tri_col = wp.from_numpy(dof_tri_col, dtype=wp.int32) - m.dof_invweight0 = wp.array(mjm.dof_invweight0, dtype=wp.float32, ndim=1) - m.actuator_trntype = wp.array(mjm.actuator_trntype, dtype=wp.int32, ndim=1) - m.actuator_trnid = wp.array(mjm.actuator_trnid, dtype=wp.int32, ndim=2) - m.actuator_ctrllimited = wp.array(mjm.actuator_ctrllimited, dtype=wp.bool, ndim=1) - m.actuator_ctrlrange = wp.array(mjm.actuator_ctrlrange, dtype=wp.vec2, ndim=1) - m.actuator_forcelimited = wp.array(mjm.actuator_forcelimited, dtype=wp.bool, ndim=1) - m.actuator_forcerange = wp.array(mjm.actuator_forcerange, dtype=wp.vec2, ndim=1) - m.actuator_gaintype = wp.array(mjm.actuator_gaintype, dtype=wp.int32, ndim=1) - m.actuator_gainprm = wp.array(mjm.actuator_gainprm, dtype=types.vec10f, ndim=1) - m.actuator_biastype = wp.array(mjm.actuator_biastype, dtype=wp.int32, ndim=1) - m.actuator_biasprm = wp.array(mjm.actuator_biasprm, dtype=types.vec10f, ndim=1) - m.actuator_gear = wp.array(mjm.actuator_gear, dtype=wp.spatial_vector, ndim=1) - m.actuator_actlimited = wp.array(mjm.actuator_actlimited, dtype=wp.bool, ndim=1) - m.actuator_actrange = wp.array(mjm.actuator_actrange, dtype=wp.vec2, ndim=1) - m.actuator_actadr = wp.array(mjm.actuator_actadr, dtype=wp.int32, ndim=1) - m.actuator_actnum = wp.array(mjm.actuator_actnum, dtype=wp.int32, ndim=1) - m.actuator_dyntype = wp.array(mjm.actuator_dyntype, dtype=wp.int32, ndim=1) - m.actuator_dynprm = wp.array(mjm.actuator_dynprm, dtype=types.vec10f, ndim=1) - m.exclude_signature = wp.array(mjm.exclude_signature, dtype=wp.int32, ndim=1) - - # pre-compute indices of equality constraints - m.eq_connect_adr = wp.array( - np.nonzero(mjm.eq_type == types.EqType.CONNECT.value)[0], dtype=wp.int32, ndim=1 - ) - m.eq_wld_adr = wp.array( - np.nonzero(mjm.eq_type == types.EqType.WELD.value)[0], dtype=wp.int32, ndim=1 - ) - m.eq_jnt_adr = wp.array( - np.nonzero(mjm.eq_type == types.EqType.JOINT.value)[0], dtype=wp.int32, ndim=1 - ) + while True: + madr_ij, j = madr_ij + 1, mjm.dof_parentid[j] + if j == -1: + break + is_, js, madr_ijs = is_ + [i], js + [j], madr_ijs + [madr_ij] - # short-circuiting here allows us to skip a lot of code in implicit integration - m.actuator_affine_bias_gain = bool( - np.any(mjm.actuator_biastype == types.BiasType.AFFINE.value) - or np.any(mjm.actuator_gaintype == types.GainType.AFFINE.value) - ) + qM_mulm_i, qM_mulm_j, qM_madr_ij = ( + np.array(x, dtype=np.int32) for x in (is_, js, madr_ijs) + ) - geompair, pairid = geom_pair(mjm) - m.nxn_geom_pair = wp.array(geompair, dtype=wp.vec2i, ndim=1) - m.nxn_pairid = wp.array(pairid, dtype=wp.int32, ndim=1) - - # predefined collision pairs - m.pair_dim = wp.array(mjm.pair_dim, dtype=wp.int32, ndim=1) - m.pair_geom1 = wp.array(mjm.pair_geom1, dtype=wp.int32, ndim=1) - m.pair_geom2 = wp.array(mjm.pair_geom2, dtype=wp.int32, ndim=1) - m.pair_solref = wp.array(mjm.pair_solref, dtype=wp.vec2, ndim=1) - m.pair_solreffriction = wp.array(mjm.pair_solreffriction, dtype=wp.vec2, ndim=1) - m.pair_solimp = wp.array(mjm.pair_solimp, dtype=types.vec5, ndim=1) - m.pair_margin = wp.array(mjm.pair_margin, dtype=wp.float32, ndim=1) - m.pair_gap = wp.array(mjm.pair_gap, dtype=wp.float32, ndim=1) - m.pair_friction = wp.array(mjm.pair_friction, dtype=types.vec5, ndim=1) - m.condim_max = np.max(mjm.geom_condim) # TODO(team): get max after filtering + jnt_limited_slide_hinge_adr = np.nonzero( + mjm.jnt_limited + & ( + (mjm.jnt_type == mujoco.mjtJoint.mjJNT_SLIDE) + | (mjm.jnt_type == mujoco.mjtJoint.mjJNT_HINGE) + ) + )[0] + + jnt_limited_ball_adr = np.nonzero( + mjm.jnt_limited & (mjm.jnt_type == mujoco.mjtJoint.mjJNT_BALL) + )[0] + + # body_tree is BFS ordering of body ids + # body_treeadr contains starting index of each body tree level + bodies, body_depth = {}, np.zeros(mjm.nbody, dtype=int) - 1 + for i in range(mjm.nbody): + body_depth[i] = body_depth[mjm.body_parentid[i]] + 1 + bodies.setdefault(body_depth[i], []).append(i) + body_tree = np.concatenate([bodies[i] for i in range(len(bodies))]) + tree_off = [0] + [len(bodies[i]) for i in range(len(bodies))] + body_treeadr = np.cumsum(tree_off)[:-1] + + m.body_tree = wp.array(body_tree, dtype=wp.int32, ndim=1) + m.body_treeadr = wp.array(body_treeadr, dtype=wp.int32, ndim=1, device="cpu") + + qLD_update_tree = np.empty(shape=(0, 3), dtype=int) + qLD_update_treeadr = np.empty(shape=(0,), dtype=int) + qLD_tile = np.empty(shape=(0,), dtype=int) + qLD_tileadr = np.empty(shape=(0,), dtype=int) + qLD_tilesize = np.empty(shape=(0,), dtype=int) + + if support.is_sparse(mjm): + # qLD_update_tree has dof tree ordering of qLD updates for sparse factor m + # qLD_update_treeadr contains starting index of each dof tree level + mjd = mujoco.MjData(mjm) + if version.parse(mujoco.__version__) > version.parse("3.2.7"): + m.M_rownnz = wp.array(mjd.M_rownnz, dtype=wp.int32, ndim=1) + m.M_rowadr = wp.array(mjd.M_rowadr, dtype=wp.int32, ndim=1) + m.M_colind = wp.array(mjd.M_colind, dtype=wp.int32, ndim=1) + m.mapM2M = wp.array(mjd.mapM2M, dtype=wp.int32, ndim=1) + qLD_updates, dof_depth = {}, np.zeros(mjm.nv, dtype=int) - 1 + + rownnz = mjd.M_rownnz + rowadr = mjd.M_rowadr + + for k in range(mjm.nv): + dof_depth[k] = dof_depth[mjm.dof_parentid[k]] + 1 + i = mjm.dof_parentid[k] + diag_k = rowadr[k] + rownnz[k] - 1 + Madr_ki = diag_k - 1 + while i > -1: + qLD_updates.setdefault(dof_depth[i], []).append((i, k, Madr_ki)) + i = mjm.dof_parentid[i] + Madr_ki -= 1 + + qLD_update_tree = np.concatenate( + [qLD_updates[i] for i in range(len(qLD_updates))] + ) + tree_off = [0] + [len(qLD_updates[i]) for i in range(len(qLD_updates))] + qLD_update_treeadr = np.cumsum(tree_off)[:-1] + else: + qLD_updates, dof_depth = {}, np.zeros(mjm.nv, dtype=int) - 1 + for k in range(mjm.nv): + dof_depth[k] = dof_depth[mjm.dof_parentid[k]] + 1 + i = mjm.dof_parentid[k] + Madr_ki = mjm.dof_Madr[k] + 1 + while i > -1: + qLD_updates.setdefault(dof_depth[i], []).append((i, k, Madr_ki)) + i = mjm.dof_parentid[i] + Madr_ki += 1 + + # qLD_treeadr contains starting indicies of each level of sparse updates + qLD_update_tree = np.concatenate( + [qLD_updates[i] for i in range(len(qLD_updates))] + ) + tree_off = [0] + [len(qLD_updates[i]) for i in range(len(qLD_updates))] + qLD_update_treeadr = np.cumsum(tree_off)[:-1] - # tendon - m.tendon_adr = wp.array(mjm.tendon_adr, dtype=wp.int32, ndim=1) - m.tendon_num = wp.array(mjm.tendon_num, dtype=wp.int32, ndim=1) - m.tendon_limited = wp.array(mjm.tendon_limited, dtype=wp.int32, ndim=1) - m.tendon_limited_adr = wp.array( - np.nonzero(mjm.tendon_limited)[0], dtype=wp.int32, ndim=1 - ) - m.tendon_solref_lim = wp.array(mjm.tendon_solref_lim, dtype=wp.vec2f, ndim=1) - m.tendon_solimp_lim = wp.array(mjm.tendon_solimp_lim, dtype=types.vec5, ndim=1) - m.tendon_range = wp.array(mjm.tendon_range, dtype=wp.vec2f, ndim=1) - m.tendon_margin = wp.array(mjm.tendon_margin, dtype=wp.float32, ndim=1) - m.tendon_length0 = wp.array(mjm.tendon_length0, dtype=wp.float32, ndim=1) - m.tendon_invweight0 = wp.array(mjm.tendon_invweight0, dtype=wp.float32, ndim=1) - m.wrap_objid = wp.array(mjm.wrap_objid, dtype=wp.int32, ndim=1) - m.wrap_prm = wp.array(mjm.wrap_prm, dtype=wp.float32, ndim=1) - m.wrap_type = wp.array(mjm.wrap_type, dtype=wp.int32, ndim=1) - - # fixed tendon - tendon_jnt_adr = [] - wrap_jnt_adr = [] - for i in range(mjm.ntendon): - adr = mjm.tendon_adr[i] - if mjm.wrap_type[adr] == mujoco.mjtWrap.mjWRAP_JOINT: - tendon_num = mjm.tendon_num[i] - for j in range(tendon_num): - tendon_jnt_adr.append(i) - wrap_jnt_adr.append(adr + j) - - m.tendon_jnt_adr = wp.array(tendon_jnt_adr, dtype=wp.int32, ndim=1) - m.wrap_jnt_adr = wp.array(wrap_jnt_adr, dtype=wp.int32, ndim=1) - - # spatial tendon - tendon_site_adr = [] - tendon_site_pair_adr = [] - ten_wrapadr_site = [0] - ten_wrapnum_site = [] - for i, tendon_num in enumerate(mjm.tendon_num): - adr = mjm.tendon_adr[i] - if (mjm.wrap_type[adr : adr + tendon_num] == mujoco.mjtWrap.mjWRAP_SITE).all(): - if i < mjm.ntendon: - ten_wrapadr_site.append(ten_wrapadr_site[-1] + tendon_num) - ten_wrapnum_site.append(tendon_num) - for j in range(tendon_num): - if j < tendon_num - 1: - tendon_site_pair_adr.append(i) - tendon_site_adr.append(i) else: - if i < mjm.ntendon: - ten_wrapadr_site.append(ten_wrapadr_site[-1]) - ten_wrapnum_site.append(0) - - tendon_site_adr = np.array(tendon_site_adr) - tendon_site_pair_adr = np.array(tendon_site_pair_adr) - wrap_site_adr = np.nonzero(mjm.wrap_type == mujoco.mjtWrap.mjWRAP_SITE)[0] - wrap_site_pair_adr = np.setdiff1d( - wrap_site_adr[np.nonzero(np.diff(wrap_site_adr) == 1)[0]], mjm.tendon_adr[1:] - 1 - ) + # qLD_tile has the dof id of each tile in qLD for dense factor m + # qLD_tileadr contains starting index in qLD_tile of each tile group + # qLD_tilesize has the square tile size of each tile group + tile_corners = [i for i in range(mjm.nv) if mjm.dof_parentid[i] == -1] + tiles = {} + for i in range(len(tile_corners)): + tile_beg = tile_corners[i] + tile_end = mjm.nv if i == len(tile_corners) - 1 else tile_corners[i + 1] + tiles.setdefault(tile_end - tile_beg, []).append(tile_beg) + qLD_tile = np.concatenate([tiles[sz] for sz in sorted(tiles.keys())]) + tile_off = [0] + [len(tiles[sz]) for sz in sorted(tiles.keys())] + qLD_tileadr = np.cumsum(tile_off)[:-1] + qLD_tilesize = np.array(sorted(tiles.keys())) + + # tiles for actuator_moment - needs nu + nv tile size and offset + actuator_moment_offset_nv = np.empty(shape=(0,), dtype=int) + actuator_moment_offset_nu = np.empty(shape=(0,), dtype=int) + actuator_moment_tileadr = np.empty(shape=(0,), dtype=int) + actuator_moment_tilesize_nv = np.empty(shape=(0,), dtype=int) + actuator_moment_tilesize_nu = np.empty(shape=(0,), dtype=int) + + if not support.is_sparse(mjm): + # how many actuators for each tree + tile_corners = [i for i in range(mjm.nv) if mjm.dof_parentid[i] == -1] + tree_id = mjm.dof_treeid[tile_corners] + num_trees = int(np.max(tree_id)) + tree = mjm.body_treeid[mjm.jnt_bodyid[mjm.actuator_trnid[:, 0]]] + counts, ids = np.histogram(tree, bins=np.arange(0, num_trees + 2)) + acts_per_tree = dict(zip([int(i) for i in ids], [int(i) for i in counts])) + + tiles = {} + act_beg = 0 + for i in range(len(tile_corners)): + tile_beg = tile_corners[i] + tile_end = mjm.nv if i == len(tile_corners) - 1 else tile_corners[i + 1] + tree = int(tree_id[i]) + act_num = acts_per_tree[tree] + tiles.setdefault((tile_end - tile_beg, act_num), []).append((tile_beg, act_beg)) + act_beg += act_num + + sorted_keys = sorted(tiles.keys()) + actuator_moment_offset_nv = [ + t[0] for key in sorted_keys for t in tiles.get(key, []) + ] + actuator_moment_offset_nu = [ + t[1] for key in sorted_keys for t in tiles.get(key, []) + ] + tile_off = [0] + [len(tiles[sz]) for sz in sorted(tiles.keys())] + actuator_moment_tileadr = np.cumsum(tile_off)[:-1] # offset + actuator_moment_tilesize_nv = np.array( + [a[0] for a in sorted_keys] + ) # for this level + actuator_moment_tilesize_nu = np.array( + [int(a[1]) for a in sorted_keys] + ) # for this level + + m.qM_fullm_i = wp.array(qM_fullm_i, dtype=wp.int32, ndim=1) + m.qM_fullm_j = wp.array(qM_fullm_j, dtype=wp.int32, ndim=1) + m.qM_mulm_i = wp.array(qM_mulm_i, dtype=wp.int32, ndim=1) + m.qM_mulm_j = wp.array(qM_mulm_j, dtype=wp.int32, ndim=1) + m.qM_madr_ij = wp.array(qM_madr_ij, dtype=wp.int32, ndim=1) + m.qLD_update_tree = wp.array(qLD_update_tree, dtype=wp.vec3i, ndim=1) + m.qLD_update_treeadr = wp.array( + qLD_update_treeadr, dtype=wp.int32, ndim=1, device="cpu" + ) + m.qLD_tile = wp.array(qLD_tile, dtype=wp.int32, ndim=1) + m.qLD_tileadr = wp.array(qLD_tileadr, dtype=wp.int32, ndim=1, device="cpu") + m.qLD_tilesize = wp.array(qLD_tilesize, dtype=wp.int32, ndim=1, device="cpu") + m.actuator_moment_offset_nv = wp.array( + actuator_moment_offset_nv, dtype=wp.int32, ndim=1 + ) + m.actuator_moment_offset_nu = wp.array( + actuator_moment_offset_nu, dtype=wp.int32, ndim=1 + ) + m.actuator_moment_tileadr = wp.array( + actuator_moment_tileadr, dtype=wp.int32, ndim=1, device="cpu" + ) + m.actuator_moment_tilesize_nv = wp.array( + actuator_moment_tilesize_nv, dtype=wp.int32, ndim=1, device="cpu" + ) + m.actuator_moment_tilesize_nu = wp.array( + actuator_moment_tilesize_nu, dtype=wp.int32, ndim=1, device="cpu" + ) + m.alpha_candidate = wp.array(np.linspace(0.0, 1.0, m.nlsp), dtype=wp.float32) + m.body_dofadr = wp.array(mjm.body_dofadr, dtype=wp.int32, ndim=1) + m.body_dofnum = wp.array(mjm.body_dofnum, dtype=wp.int32, ndim=1) + m.body_jntadr = wp.array(mjm.body_jntadr, dtype=wp.int32, ndim=1) + m.body_jntnum = wp.array(mjm.body_jntnum, dtype=wp.int32, ndim=1) + m.body_parentid = wp.array(mjm.body_parentid, dtype=wp.int32, ndim=1) + m.body_mocapid = wp.array(mjm.body_mocapid, dtype=wp.int32, ndim=1) + m.body_weldid = wp.array(mjm.body_weldid, dtype=wp.int32, ndim=1) + m.body_pos = wp.array(mjm.body_pos, dtype=wp.vec3, ndim=1) + m.body_quat = wp.array(mjm.body_quat, dtype=wp.quat, ndim=1) + m.body_ipos = wp.array(mjm.body_ipos, dtype=wp.vec3, ndim=1) + m.body_iquat = wp.array(mjm.body_iquat, dtype=wp.quat, ndim=1) + m.body_rootid = wp.array(mjm.body_rootid, dtype=wp.int32, ndim=1) + m.body_inertia = wp.array(mjm.body_inertia, dtype=wp.vec3, ndim=1) + m.body_mass = wp.array(mjm.body_mass, dtype=wp.float32, ndim=1) + m.body_subtreemass = wp.array(mjm.body_subtreemass, dtype=wp.float32, ndim=1) + + subtree_mass = np.copy(mjm.body_mass) + # TODO(team): should this be [mjm.nbody - 1, 0) ? + for i in range(mjm.nbody - 1, -1, -1): + subtree_mass[mjm.body_parentid[i]] += subtree_mass[i] + + m.subtree_mass = wp.array(subtree_mass, dtype=wp.float32, ndim=1) + m.body_invweight0 = wp.array(mjm.body_invweight0, dtype=wp.float32, ndim=2) + m.body_geomnum = wp.array(mjm.body_geomnum, dtype=wp.int32, ndim=1) + m.body_geomadr = wp.array(mjm.body_geomadr, dtype=wp.int32, ndim=1) + m.body_contype = wp.array(mjm.body_contype, dtype=wp.int32, ndim=1) + m.body_conaffinity = wp.array(mjm.body_conaffinity, dtype=wp.int32, ndim=1) + m.jnt_bodyid = wp.array(mjm.jnt_bodyid, dtype=wp.int32, ndim=1) + m.jnt_limited = wp.array(mjm.jnt_limited, dtype=wp.int32, ndim=1) + m.jnt_limited_slide_hinge_adr = wp.array( + jnt_limited_slide_hinge_adr, dtype=wp.int32, ndim=1 + ) + m.jnt_limited_ball_adr = wp.array(jnt_limited_ball_adr, dtype=wp.int32, ndim=1) + m.jnt_type = wp.array(mjm.jnt_type, dtype=wp.int32, ndim=1) + m.jnt_solref = wp.array(mjm.jnt_solref, dtype=wp.vec2f, ndim=1) + m.jnt_solimp = wp.array(mjm.jnt_solimp, dtype=types.vec5, ndim=1) + m.jnt_qposadr = wp.array(mjm.jnt_qposadr, dtype=wp.int32, ndim=1) + m.jnt_dofadr = wp.array(mjm.jnt_dofadr, dtype=wp.int32, ndim=1) + m.jnt_axis = wp.array(mjm.jnt_axis, dtype=wp.vec3, ndim=1) + m.jnt_pos = wp.array(mjm.jnt_pos, dtype=wp.vec3, ndim=1) + m.jnt_range = wp.array(mjm.jnt_range, dtype=wp.float32, ndim=2) + m.jnt_margin = wp.array(mjm.jnt_margin, dtype=wp.float32, ndim=1) + m.jnt_stiffness = wp.array(mjm.jnt_stiffness, dtype=wp.float32, ndim=1) + m.jnt_actfrclimited = wp.array(mjm.jnt_actfrclimited, dtype=wp.bool, ndim=1) + m.jnt_actfrcrange = wp.array(mjm.jnt_actfrcrange, dtype=wp.vec2, ndim=1) + m.geom_type = wp.array(mjm.geom_type, dtype=wp.int32, ndim=1) + m.geom_bodyid = wp.array(mjm.geom_bodyid, dtype=wp.int32, ndim=1) + m.geom_conaffinity = wp.array(mjm.geom_conaffinity, dtype=wp.int32, ndim=1) + m.geom_contype = wp.array(mjm.geom_contype, dtype=wp.int32, ndim=1) + m.geom_condim = wp.array(mjm.geom_condim, dtype=wp.int32, ndim=1) + m.geom_pos = wp.array(mjm.geom_pos, dtype=wp.vec3, ndim=1) + m.geom_quat = wp.array(mjm.geom_quat, dtype=wp.quat, ndim=1) + m.geom_size = wp.array(mjm.geom_size, dtype=wp.vec3, ndim=1) + m.geom_priority = wp.array(mjm.geom_priority, dtype=wp.int32, ndim=1) + m.geom_solmix = wp.array(mjm.geom_solmix, dtype=wp.float32, ndim=1) + m.geom_solref = wp.array(mjm.geom_solref, dtype=wp.vec2, ndim=1) + m.geom_solimp = wp.array(mjm.geom_solimp, dtype=types.vec5, ndim=1) + m.geom_friction = wp.array(mjm.geom_friction, dtype=wp.vec3, ndim=1) + m.geom_margin = wp.array(mjm.geom_margin, dtype=wp.float32, ndim=1) + m.geom_gap = wp.array(mjm.geom_gap, dtype=wp.float32, ndim=1) + m.geom_aabb = wp.array(mjm.geom_aabb, dtype=wp.vec3, ndim=3) + m.geom_rbound = wp.array(mjm.geom_rbound, dtype=wp.float32, ndim=1) + m.geom_dataid = wp.array(mjm.geom_dataid, dtype=wp.int32, ndim=1) + m.mesh_vertadr = wp.array(mjm.mesh_vertadr, dtype=wp.int32, ndim=1) + m.mesh_vertnum = wp.array(mjm.mesh_vertnum, dtype=wp.int32, ndim=1) + m.mesh_vert = wp.array(mjm.mesh_vert, dtype=wp.vec3, ndim=1) + m.eq_type = wp.array(mjm.eq_type, dtype=wp.int32, ndim=1) + m.eq_obj1id = wp.array(mjm.eq_obj1id, dtype=wp.int32, ndim=1) + m.eq_obj2id = wp.array(mjm.eq_obj2id, dtype=wp.int32, ndim=1) + m.eq_objtype = wp.array(mjm.eq_objtype, dtype=wp.int32, ndim=1) + m.eq_active0 = wp.array(mjm.eq_active0, dtype=wp.bool, ndim=1) + m.eq_solref = wp.array(mjm.eq_solref, dtype=wp.vec2, ndim=1) + m.eq_solimp = wp.array(mjm.eq_solimp, dtype=types.vec5, ndim=1) + m.eq_data = wp.array(mjm.eq_data, dtype=types.vec11, ndim=1) + m.site_pos = wp.array(mjm.site_pos, dtype=wp.vec3, ndim=1) + m.site_quat = wp.array(mjm.site_quat, dtype=wp.quat, ndim=1) + m.site_bodyid = wp.array(mjm.site_bodyid, dtype=wp.int32, ndim=1) + m.cam_mode = wp.array(mjm.cam_mode, dtype=wp.int32, ndim=1) + m.cam_bodyid = wp.array(mjm.cam_bodyid, dtype=wp.int32, ndim=1) + m.cam_targetbodyid = wp.array(mjm.cam_targetbodyid, dtype=wp.int32, ndim=1) + m.cam_pos = wp.array(mjm.cam_pos, dtype=wp.vec3, ndim=1) + m.cam_quat = wp.array(mjm.cam_quat, dtype=wp.quat, ndim=1) + m.cam_poscom0 = wp.array(mjm.cam_poscom0, dtype=wp.vec3, ndim=1) + m.cam_pos0 = wp.array(mjm.cam_pos0, dtype=wp.vec3, ndim=1) + m.light_mode = wp.array(mjm.light_mode, dtype=wp.int32, ndim=1) + m.light_bodyid = wp.array(mjm.light_bodyid, dtype=wp.int32, ndim=1) + m.light_targetbodyid = wp.array(mjm.light_targetbodyid, dtype=wp.int32, ndim=1) + m.light_pos = wp.array(mjm.light_pos, dtype=wp.vec3, ndim=1) + m.light_dir = wp.array(mjm.light_dir, dtype=wp.vec3, ndim=1) + m.light_poscom0 = wp.array(mjm.light_poscom0, dtype=wp.vec3, ndim=1) + m.light_pos0 = wp.array(mjm.light_pos0, dtype=wp.vec3, ndim=1) + m.dof_bodyid = wp.array(mjm.dof_bodyid, dtype=wp.int32, ndim=1) + m.dof_jntid = wp.array(mjm.dof_jntid, dtype=wp.int32, ndim=1) + m.dof_parentid = wp.array(mjm.dof_parentid, dtype=wp.int32, ndim=1) + m.dof_Madr = wp.array(mjm.dof_Madr, dtype=wp.int32, ndim=1) + m.dof_armature = wp.array(mjm.dof_armature, dtype=wp.float32, ndim=1) + m.dof_damping = wp.array(mjm.dof_damping, dtype=wp.float32, ndim=1) + m.dof_frictionloss = wp.array(mjm.dof_frictionloss, dtype=wp.float32, ndim=1) + m.dof_solimp = wp.array(mjm.dof_solimp, dtype=types.vec5, ndim=1) + m.dof_solref = wp.array(mjm.dof_solref, dtype=wp.vec2, ndim=1) + m.dof_tri_row = wp.from_numpy(dof_tri_row, dtype=wp.int32) + m.dof_tri_col = wp.from_numpy(dof_tri_col, dtype=wp.int32) + m.dof_invweight0 = wp.array(mjm.dof_invweight0, dtype=wp.float32, ndim=1) + m.actuator_trntype = wp.array(mjm.actuator_trntype, dtype=wp.int32, ndim=1) + m.actuator_trnid = wp.array(mjm.actuator_trnid, dtype=wp.int32, ndim=2) + m.actuator_ctrllimited = wp.array(mjm.actuator_ctrllimited, dtype=wp.bool, ndim=1) + m.actuator_ctrlrange = wp.array(mjm.actuator_ctrlrange, dtype=wp.vec2, ndim=1) + m.actuator_forcelimited = wp.array(mjm.actuator_forcelimited, dtype=wp.bool, ndim=1) + m.actuator_forcerange = wp.array(mjm.actuator_forcerange, dtype=wp.vec2, ndim=1) + m.actuator_gaintype = wp.array(mjm.actuator_gaintype, dtype=wp.int32, ndim=1) + m.actuator_gainprm = wp.array(mjm.actuator_gainprm, dtype=types.vec10f, ndim=1) + m.actuator_biastype = wp.array(mjm.actuator_biastype, dtype=wp.int32, ndim=1) + m.actuator_biasprm = wp.array(mjm.actuator_biasprm, dtype=types.vec10f, ndim=1) + m.actuator_gear = wp.array(mjm.actuator_gear, dtype=wp.spatial_vector, ndim=1) + m.actuator_actlimited = wp.array(mjm.actuator_actlimited, dtype=wp.bool, ndim=1) + m.actuator_actrange = wp.array(mjm.actuator_actrange, dtype=wp.vec2, ndim=1) + m.actuator_actadr = wp.array(mjm.actuator_actadr, dtype=wp.int32, ndim=1) + m.actuator_actnum = wp.array(mjm.actuator_actnum, dtype=wp.int32, ndim=1) + m.actuator_dyntype = wp.array(mjm.actuator_dyntype, dtype=wp.int32, ndim=1) + m.actuator_dynprm = wp.array(mjm.actuator_dynprm, dtype=types.vec10f, ndim=1) + m.exclude_signature = wp.array(mjm.exclude_signature, dtype=wp.int32, ndim=1) + + # pre-compute indices of equality constraints + m.eq_connect_adr = wp.array( + np.nonzero(mjm.eq_type == types.EqType.CONNECT.value)[0], dtype=wp.int32, ndim=1 + ) + m.eq_wld_adr = wp.array( + np.nonzero(mjm.eq_type == types.EqType.WELD.value)[0], dtype=wp.int32, ndim=1 + ) + m.eq_jnt_adr = wp.array( + np.nonzero(mjm.eq_type == types.EqType.JOINT.value)[0], dtype=wp.int32, ndim=1 + ) - m.tendon_site_adr = wp.array(tendon_site_adr, dtype=wp.int32, ndim=1) - m.tendon_site_pair_adr = wp.array(tendon_site_pair_adr, dtype=wp.int32, ndim=1) - m.ten_wrapadr_site = wp.array(ten_wrapadr_site, dtype=wp.int32, ndim=1) - m.ten_wrapnum_site = wp.array(ten_wrapnum_site, dtype=wp.int32, ndim=1) - m.wrap_site_adr = wp.array(wrap_site_adr, dtype=wp.int32, ndim=1) - m.wrap_site_pair_adr = wp.array(wrap_site_pair_adr, dtype=wp.int32, ndim=1) + # short-circuiting here allows us to skip a lot of code in implicit integration + m.actuator_affine_bias_gain = bool( + np.any(mjm.actuator_biastype == types.BiasType.AFFINE.value) + or np.any(mjm.actuator_gaintype == types.GainType.AFFINE.value) + ) - # sensors - m.sensor_type = wp.array(mjm.sensor_type, dtype=wp.int32, ndim=1) - m.sensor_datatype = wp.array(mjm.sensor_datatype, dtype=wp.int32, ndim=1) - m.sensor_objtype = wp.array(mjm.sensor_objtype, dtype=wp.int32, ndim=1) - m.sensor_objid = wp.array(mjm.sensor_objid, dtype=wp.int32, ndim=1) - m.sensor_reftype = wp.array(mjm.sensor_reftype, dtype=wp.int32, ndim=1) - m.sensor_refid = wp.array(mjm.sensor_refid, dtype=wp.int32, ndim=1) - m.sensor_dim = wp.array(mjm.sensor_dim, dtype=wp.int32, ndim=1) - m.sensor_adr = wp.array(mjm.sensor_adr, dtype=wp.int32, ndim=1) - m.sensor_cutoff = wp.array(mjm.sensor_cutoff, dtype=wp.float32, ndim=1) - m.sensor_pos_adr = wp.array( - np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_POS)[0], - dtype=wp.int32, - ndim=1, - ) - m.sensor_vel_adr = wp.array( - np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_VEL)[0], - dtype=wp.int32, - ndim=1, - ) - m.sensor_acc_adr = wp.array( - np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_ACC)[0], - dtype=wp.int32, - ndim=1, - ) + geompair, pairid = geom_pair(mjm) + m.nxn_geom_pair = wp.array(geompair, dtype=wp.vec2i, ndim=1) + m.nxn_pairid = wp.array(pairid, dtype=wp.int32, ndim=1) + + # predefined collision pairs + m.pair_dim = wp.array(mjm.pair_dim, dtype=wp.int32, ndim=1) + m.pair_geom1 = wp.array(mjm.pair_geom1, dtype=wp.int32, ndim=1) + m.pair_geom2 = wp.array(mjm.pair_geom2, dtype=wp.int32, ndim=1) + m.pair_solref = wp.array(mjm.pair_solref, dtype=wp.vec2, ndim=1) + m.pair_solreffriction = wp.array(mjm.pair_solreffriction, dtype=wp.vec2, ndim=1) + m.pair_solimp = wp.array(mjm.pair_solimp, dtype=types.vec5, ndim=1) + m.pair_margin = wp.array(mjm.pair_margin, dtype=wp.float32, ndim=1) + m.pair_gap = wp.array(mjm.pair_gap, dtype=wp.float32, ndim=1) + m.pair_friction = wp.array(mjm.pair_friction, dtype=types.vec5, ndim=1) + m.condim_max = np.max(mjm.geom_condim) # TODO(team): get max after filtering + + # tendon + m.tendon_adr = wp.array(mjm.tendon_adr, dtype=wp.int32, ndim=1) + m.tendon_num = wp.array(mjm.tendon_num, dtype=wp.int32, ndim=1) + m.tendon_limited = wp.array(mjm.tendon_limited, dtype=wp.int32, ndim=1) + m.tendon_limited_adr = wp.array( + np.nonzero(mjm.tendon_limited)[0], dtype=wp.int32, ndim=1 + ) + m.tendon_solref_lim = wp.array(mjm.tendon_solref_lim, dtype=wp.vec2f, ndim=1) + m.tendon_solimp_lim = wp.array(mjm.tendon_solimp_lim, dtype=types.vec5, ndim=1) + m.tendon_range = wp.array(mjm.tendon_range, dtype=wp.vec2f, ndim=1) + m.tendon_margin = wp.array(mjm.tendon_margin, dtype=wp.float32, ndim=1) + m.tendon_length0 = wp.array(mjm.tendon_length0, dtype=wp.float32, ndim=1) + m.tendon_invweight0 = wp.array(mjm.tendon_invweight0, dtype=wp.float32, ndim=1) + m.wrap_objid = wp.array(mjm.wrap_objid, dtype=wp.int32, ndim=1) + m.wrap_prm = wp.array(mjm.wrap_prm, dtype=wp.float32, ndim=1) + m.wrap_type = wp.array(mjm.wrap_type, dtype=wp.int32, ndim=1) + + # fixed tendon + tendon_jnt_adr = [] + wrap_jnt_adr = [] + for i in range(mjm.ntendon): + adr = mjm.tendon_adr[i] + if mjm.wrap_type[adr] == mujoco.mjtWrap.mjWRAP_JOINT: + tendon_num = mjm.tendon_num[i] + for j in range(tendon_num): + tendon_jnt_adr.append(i) + wrap_jnt_adr.append(adr + j) + + m.tendon_jnt_adr = wp.array(tendon_jnt_adr, dtype=wp.int32, ndim=1) + m.wrap_jnt_adr = wp.array(wrap_jnt_adr, dtype=wp.int32, ndim=1) + + # spatial tendon + tendon_site_adr = [] + tendon_site_pair_adr = [] + ten_wrapadr_site = [0] + ten_wrapnum_site = [] + for i, tendon_num in enumerate(mjm.tendon_num): + adr = mjm.tendon_adr[i] + if (mjm.wrap_type[adr : adr + tendon_num] == mujoco.mjtWrap.mjWRAP_SITE).all(): + if i < mjm.ntendon: + ten_wrapadr_site.append(ten_wrapadr_site[-1] + tendon_num) + ten_wrapnum_site.append(tendon_num) + for j in range(tendon_num): + if j < tendon_num - 1: + tendon_site_pair_adr.append(i) + tendon_site_adr.append(i) + else: + if i < mjm.ntendon: + ten_wrapadr_site.append(ten_wrapadr_site[-1]) + ten_wrapnum_site.append(0) + + tendon_site_adr = np.array(tendon_site_adr) + tendon_site_pair_adr = np.array(tendon_site_pair_adr) + wrap_site_adr = np.nonzero(mjm.wrap_type == mujoco.mjtWrap.mjWRAP_SITE)[0] + wrap_site_pair_adr = np.setdiff1d( + wrap_site_adr[np.nonzero(np.diff(wrap_site_adr) == 1)[0]], mjm.tendon_adr[1:] - 1 + ) - return m + m.tendon_site_adr = wp.array(tendon_site_adr, dtype=wp.int32, ndim=1) + m.tendon_site_pair_adr = wp.array(tendon_site_pair_adr, dtype=wp.int32, ndim=1) + m.ten_wrapadr_site = wp.array(ten_wrapadr_site, dtype=wp.int32, ndim=1) + m.ten_wrapnum_site = wp.array(ten_wrapnum_site, dtype=wp.int32, ndim=1) + m.wrap_site_adr = wp.array(wrap_site_adr, dtype=wp.int32, ndim=1) + m.wrap_site_pair_adr = wp.array(wrap_site_pair_adr, dtype=wp.int32, ndim=1) + + # sensors + m.sensor_type = wp.array(mjm.sensor_type, dtype=wp.int32, ndim=1) + m.sensor_datatype = wp.array(mjm.sensor_datatype, dtype=wp.int32, ndim=1) + m.sensor_objtype = wp.array(mjm.sensor_objtype, dtype=wp.int32, ndim=1) + m.sensor_objid = wp.array(mjm.sensor_objid, dtype=wp.int32, ndim=1) + m.sensor_reftype = wp.array(mjm.sensor_reftype, dtype=wp.int32, ndim=1) + m.sensor_refid = wp.array(mjm.sensor_refid, dtype=wp.int32, ndim=1) + m.sensor_dim = wp.array(mjm.sensor_dim, dtype=wp.int32, ndim=1) + m.sensor_adr = wp.array(mjm.sensor_adr, dtype=wp.int32, ndim=1) + m.sensor_cutoff = wp.array(mjm.sensor_cutoff, dtype=wp.float32, ndim=1) + m.sensor_pos_adr = wp.array( + np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_POS)[0], + dtype=wp.int32, + ndim=1, + ) + m.sensor_vel_adr = wp.array( + np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_VEL)[0], + dtype=wp.int32, + ndim=1, + ) + m.sensor_acc_adr = wp.array( + np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_ACC)[0], + dtype=wp.int32, + ndim=1, + ) + + return m def _constraint( @@ -675,153 +678,158 @@ def _constraint( def make_data( - mjm: mujoco.MjModel, nworld: int = 1, nconmax: int = -1, njmax: int = -1 + mjm: mujoco.MjModel, + nworld: int = 1, + nconmax: int = -1, + njmax: int = -1, + device: Optional[wp.context.Device] = None, ) -> types.Data: - d = types.Data() - d.nworld = nworld - - # TODO(team): move to Model? - if nconmax == -1: - # TODO(team): heuristic for nconmax - nconmax = nworld * 20 - d.nconmax = nconmax - if njmax == -1: - # TODO(team): heuristic for njmax - njmax = nworld * 20 * 6 - d.njmax = njmax - - d.ncon = wp.zeros(1, dtype=wp.int32) - d.ne = wp.zeros(1, dtype=wp.int32, ndim=1) - d.ne_connect = wp.zeros(1, dtype=wp.int32, ndim=1) - d.ne_weld = wp.zeros(1, dtype=wp.int32, ndim=1) - d.ne_jnt = wp.zeros(1, dtype=wp.int32, ndim=1) - d.nefc = wp.zeros(1, dtype=wp.int32, ndim=1) - d.ne = wp.zeros(1, dtype=wp.int32) - d.nf = wp.zeros(1, dtype=wp.int32) - d.nl = wp.zeros(1, dtype=wp.int32) - - d.time = wp.zeros(nworld, dtype=wp.float32, ndim=1) - - qpos0 = np.tile(mjm.qpos0, (nworld, 1)) - d.qpos = wp.array(qpos0, dtype=wp.float32, ndim=2) - d.qvel = wp.zeros((nworld, mjm.nv), dtype=wp.float32, ndim=2) - d.qacc_warmstart = wp.zeros((nworld, mjm.nv), dtype=wp.float32, ndim=2) - d.qfrc_applied = wp.zeros((nworld, mjm.nv), dtype=wp.float32, ndim=2) - d.mocap_pos = wp.zeros((nworld, mjm.nmocap), dtype=wp.vec3) - d.mocap_quat = wp.zeros((nworld, mjm.nmocap), dtype=wp.quat) - d.qacc = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.xanchor = wp.zeros((nworld, mjm.njnt), dtype=wp.vec3) - d.xaxis = wp.zeros((nworld, mjm.njnt), dtype=wp.vec3) - d.xmat = wp.zeros((nworld, mjm.nbody), dtype=wp.mat33) - d.xpos = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) - d.xquat = wp.zeros((nworld, mjm.nbody), dtype=wp.quat) - d.xipos = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) - d.ximat = wp.zeros((nworld, mjm.nbody), dtype=wp.mat33) - d.subtree_com = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) - d.geom_xpos = wp.zeros((nworld, mjm.ngeom), dtype=wp.vec3) - d.geom_xmat = wp.zeros((nworld, mjm.ngeom), dtype=wp.mat33) - d.site_xpos = wp.zeros((nworld, mjm.nsite), dtype=wp.vec3) - d.site_xmat = wp.zeros((nworld, mjm.nsite), dtype=wp.mat33) - d.cam_xpos = wp.zeros((nworld, mjm.ncam), dtype=wp.vec3) - d.cam_xmat = wp.zeros((nworld, mjm.ncam), dtype=wp.mat33) - d.light_xpos = wp.zeros((nworld, mjm.nlight), dtype=wp.vec3) - d.light_xdir = wp.zeros((nworld, mjm.nlight), dtype=wp.vec3) - d.cinert = wp.zeros((nworld, mjm.nbody), dtype=types.vec10) - d.cdof = wp.zeros((nworld, mjm.nv), dtype=wp.spatial_vector) - d.ctrl = wp.zeros((nworld, mjm.nu), dtype=wp.float32) - d.ten_velocity = wp.zeros((nworld, mjm.ntendon), dtype=wp.float32) - d.actuator_velocity = wp.zeros((nworld, mjm.nu), dtype=wp.float32) - d.actuator_force = wp.zeros((nworld, mjm.nu), dtype=wp.float32) - d.actuator_length = wp.zeros((nworld, mjm.nu), dtype=wp.float32) - d.actuator_moment = wp.zeros((nworld, mjm.nu, mjm.nv), dtype=wp.float32) - d.crb = wp.zeros((nworld, mjm.nbody), dtype=types.vec10) - if support.is_sparse(mjm): - d.qM = wp.zeros((nworld, 1, mjm.nM), dtype=wp.float32) - d.qLD = wp.zeros((nworld, 1, mjm.nM), dtype=wp.float32) - else: - d.qM = wp.zeros((nworld, mjm.nv, mjm.nv), dtype=wp.float32) - d.qLD = wp.zeros((nworld, mjm.nv, mjm.nv), dtype=wp.float32) - d.act_dot = wp.zeros((nworld, mjm.na), dtype=wp.float32) - d.act = wp.zeros((nworld, mjm.na), dtype=wp.float32) - d.qLDiagInv = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.cvel = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector) - d.cdof_dot = wp.zeros((nworld, mjm.nv), dtype=wp.spatial_vector) - d.qfrc_bias = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.contact = types.Contact() - d.contact.dist = wp.zeros((nconmax,), dtype=wp.float32) - d.contact.pos = wp.zeros((nconmax,), dtype=wp.vec3f) - d.contact.frame = wp.zeros((nconmax,), dtype=wp.mat33f) - d.contact.includemargin = wp.zeros((nconmax,), dtype=wp.float32) - d.contact.friction = wp.zeros((nconmax,), dtype=types.vec5) - d.contact.solref = wp.zeros((nconmax,), dtype=wp.vec2f) - d.contact.solreffriction = wp.zeros((nconmax,), dtype=wp.vec2f) - d.contact.solimp = wp.zeros((nconmax,), dtype=types.vec5) - d.contact.dim = wp.zeros((nconmax,), dtype=wp.int32) - d.contact.geom = wp.zeros((nconmax,), dtype=wp.vec2i) - d.contact.efc_address = wp.zeros((nconmax, np.max(mjm.geom_condim)), dtype=wp.int32) - d.contact.worldid = wp.zeros((nconmax,), dtype=wp.int32) - d.efc = _constraint(mjm, d.nworld, d.nconmax, d.njmax) - d.qfrc_passive = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.subtree_linvel = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) - d.subtree_angmom = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) - d.subtree_bodyvel = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector) - d.qfrc_spring = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qfrc_damper = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qfrc_actuator = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qfrc_smooth = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qfrc_constraint = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qacc_smooth = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.xfrc_applied = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector) - d.eq_active = wp.array(np.tile(mjm.eq_active0, (nworld, 1)), dtype=wp.bool, ndim=2) - - # internal tmp arrays - d.qfrc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qacc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qM_integration = wp.zeros_like(d.qM) - d.qLD_integration = wp.zeros_like(d.qLD) - d.qLDiagInv_integration = wp.zeros_like(d.qLDiagInv) - d.act_vel_integration = wp.zeros_like(d.ctrl) - d.qpos_t0 = wp.zeros((nworld, mjm.nq), dtype=wp.float32) - d.qvel_t0 = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.act_t0 = wp.zeros((nworld, mjm.na), dtype=wp.float32) - d.qvel_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qacc_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.act_dot_rk = wp.zeros((nworld, mjm.na), dtype=wp.float32) - - # sweep-and-prune broadphase - d.sap_projection_lower = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32) - d.sap_projection_upper = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32) - d.sap_sort_index = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32) - d.sap_range = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32) - d.sap_cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32) - segment_indices_list = [i * mjm.ngeom for i in range(nworld + 1)] - d.sap_segment_index = wp.array(segment_indices_list, dtype=int) - - # collision driver - d.collision_pair = wp.empty(nconmax, dtype=wp.vec2i, ndim=1) - d.collision_pairid = wp.empty(nconmax, dtype=wp.int32, ndim=1) - d.collision_worldid = wp.empty(nconmax, dtype=wp.int32, ndim=1) - d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1) - - # rne_postconstraint - d.cacc = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector, ndim=2) - d.cfrc_int = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector, ndim=2) - d.cfrc_ext = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector, ndim=2) - - # tendon - d.ten_length = wp.zeros((nworld, mjm.ntendon), dtype=wp.float32, ndim=2) - d.ten_J = wp.zeros((nworld, mjm.ntendon, mjm.nv), dtype=wp.float32, ndim=3) - d.ten_wrapadr = wp.zeros((nworld, mjm.ntendon), dtype=wp.int32, ndim=2) - d.ten_wrapnum = wp.zeros((nworld, mjm.ntendon), dtype=wp.int32, ndim=2) - d.wrap_obj = wp.zeros((nworld, mjm.nwrap), dtype=wp.vec2i, ndim=2) - d.wrap_xpos = wp.zeros( - (nworld, mjm.nwrap), dtype=wp.spatial_vector, ndim=2 - ) # TODO(team): vec6? - - # sensors - d.sensordata = wp.zeros((nworld, mjm.nsensordata), dtype=wp.float32) - - return d + with wp.ScopedDevice(device): + d = types.Data() + d.nworld = nworld + + # TODO(team): move to Model? + if nconmax == -1: + # TODO(team): heuristic for nconmax + nconmax = nworld * 20 + d.nconmax = nconmax + if njmax == -1: + # TODO(team): heuristic for njmax + njmax = nworld * 20 * 6 + d.njmax = njmax + + d.ncon = wp.zeros(1, dtype=wp.int32) + d.ne = wp.zeros(1, dtype=wp.int32, ndim=1) + d.ne_connect = wp.zeros(1, dtype=wp.int32, ndim=1) + d.ne_weld = wp.zeros(1, dtype=wp.int32, ndim=1) + d.ne_jnt = wp.zeros(1, dtype=wp.int32, ndim=1) + d.nefc = wp.zeros(1, dtype=wp.int32, ndim=1) + d.ne = wp.zeros(1, dtype=wp.int32) + d.nf = wp.zeros(1, dtype=wp.int32) + d.nl = wp.zeros(1, dtype=wp.int32) + + d.time = wp.zeros(nworld, dtype=wp.float32, ndim=1) + + qpos0 = np.tile(mjm.qpos0, (nworld, 1)) + d.qpos = wp.array(qpos0, dtype=wp.float32, ndim=2) + d.qvel = wp.zeros((nworld, mjm.nv), dtype=wp.float32, ndim=2) + d.qacc_warmstart = wp.zeros((nworld, mjm.nv), dtype=wp.float32, ndim=2) + d.qfrc_applied = wp.zeros((nworld, mjm.nv), dtype=wp.float32, ndim=2) + d.mocap_pos = wp.zeros((nworld, mjm.nmocap), dtype=wp.vec3) + d.mocap_quat = wp.zeros((nworld, mjm.nmocap), dtype=wp.quat) + d.qacc = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.xanchor = wp.zeros((nworld, mjm.njnt), dtype=wp.vec3) + d.xaxis = wp.zeros((nworld, mjm.njnt), dtype=wp.vec3) + d.xmat = wp.zeros((nworld, mjm.nbody), dtype=wp.mat33) + d.xpos = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) + d.xquat = wp.zeros((nworld, mjm.nbody), dtype=wp.quat) + d.xipos = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) + d.ximat = wp.zeros((nworld, mjm.nbody), dtype=wp.mat33) + d.subtree_com = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) + d.geom_xpos = wp.zeros((nworld, mjm.ngeom), dtype=wp.vec3) + d.geom_xmat = wp.zeros((nworld, mjm.ngeom), dtype=wp.mat33) + d.site_xpos = wp.zeros((nworld, mjm.nsite), dtype=wp.vec3) + d.site_xmat = wp.zeros((nworld, mjm.nsite), dtype=wp.mat33) + d.cam_xpos = wp.zeros((nworld, mjm.ncam), dtype=wp.vec3) + d.cam_xmat = wp.zeros((nworld, mjm.ncam), dtype=wp.mat33) + d.light_xpos = wp.zeros((nworld, mjm.nlight), dtype=wp.vec3) + d.light_xdir = wp.zeros((nworld, mjm.nlight), dtype=wp.vec3) + d.cinert = wp.zeros((nworld, mjm.nbody), dtype=types.vec10) + d.cdof = wp.zeros((nworld, mjm.nv), dtype=wp.spatial_vector) + d.ctrl = wp.zeros((nworld, mjm.nu), dtype=wp.float32) + d.ten_velocity = wp.zeros((nworld, mjm.ntendon), dtype=wp.float32) + d.actuator_velocity = wp.zeros((nworld, mjm.nu), dtype=wp.float32) + d.actuator_force = wp.zeros((nworld, mjm.nu), dtype=wp.float32) + d.actuator_length = wp.zeros((nworld, mjm.nu), dtype=wp.float32) + d.actuator_moment = wp.zeros((nworld, mjm.nu, mjm.nv), dtype=wp.float32) + d.crb = wp.zeros((nworld, mjm.nbody), dtype=types.vec10) + if support.is_sparse(mjm): + d.qM = wp.zeros((nworld, 1, mjm.nM), dtype=wp.float32) + d.qLD = wp.zeros((nworld, 1, mjm.nM), dtype=wp.float32) + else: + d.qM = wp.zeros((nworld, mjm.nv, mjm.nv), dtype=wp.float32) + d.qLD = wp.zeros((nworld, mjm.nv, mjm.nv), dtype=wp.float32) + d.act_dot = wp.zeros((nworld, mjm.na), dtype=wp.float32) + d.act = wp.zeros((nworld, mjm.na), dtype=wp.float32) + d.qLDiagInv = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.cvel = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector) + d.cdof_dot = wp.zeros((nworld, mjm.nv), dtype=wp.spatial_vector) + d.qfrc_bias = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.contact = types.Contact() + d.contact.dist = wp.zeros((nconmax,), dtype=wp.float32) + d.contact.pos = wp.zeros((nconmax,), dtype=wp.vec3f) + d.contact.frame = wp.zeros((nconmax,), dtype=wp.mat33f) + d.contact.includemargin = wp.zeros((nconmax,), dtype=wp.float32) + d.contact.friction = wp.zeros((nconmax,), dtype=types.vec5) + d.contact.solref = wp.zeros((nconmax,), dtype=wp.vec2f) + d.contact.solreffriction = wp.zeros((nconmax,), dtype=wp.vec2f) + d.contact.solimp = wp.zeros((nconmax,), dtype=types.vec5) + d.contact.dim = wp.zeros((nconmax,), dtype=wp.int32) + d.contact.geom = wp.zeros((nconmax,), dtype=wp.vec2i) + d.contact.efc_address = wp.zeros((nconmax, np.max(mjm.geom_condim)), dtype=wp.int32) + d.contact.worldid = wp.zeros((nconmax,), dtype=wp.int32) + d.efc = _constraint(mjm, d.nworld, d.nconmax, d.njmax) + d.qfrc_passive = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.subtree_linvel = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) + d.subtree_angmom = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) + d.subtree_bodyvel = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector) + d.qfrc_spring = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qfrc_damper = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qfrc_actuator = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qfrc_smooth = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qfrc_constraint = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qacc_smooth = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.xfrc_applied = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector) + d.eq_active = wp.array(np.tile(mjm.eq_active0, (nworld, 1)), dtype=wp.bool, ndim=2) + + # internal tmp arrays + d.qfrc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qacc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qM_integration = wp.zeros_like(d.qM) + d.qLD_integration = wp.zeros_like(d.qLD) + d.qLDiagInv_integration = wp.zeros_like(d.qLDiagInv) + d.act_vel_integration = wp.zeros_like(d.ctrl) + d.qpos_t0 = wp.zeros((nworld, mjm.nq), dtype=wp.float32) + d.qvel_t0 = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.act_t0 = wp.zeros((nworld, mjm.na), dtype=wp.float32) + d.qvel_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qacc_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.act_dot_rk = wp.zeros((nworld, mjm.na), dtype=wp.float32) + + # sweep-and-prune broadphase + d.sap_projection_lower = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32) + d.sap_projection_upper = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32) + d.sap_sort_index = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32) + d.sap_range = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32) + d.sap_cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32) + segment_indices_list = [i * mjm.ngeom for i in range(nworld + 1)] + d.sap_segment_index = wp.array(segment_indices_list, dtype=int) + + # collision driver + d.collision_pair = wp.empty(nconmax, dtype=wp.vec2i, ndim=1) + d.collision_pairid = wp.empty(nconmax, dtype=wp.int32, ndim=1) + d.collision_worldid = wp.empty(nconmax, dtype=wp.int32, ndim=1) + d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1) + + # rne_postconstraint + d.cacc = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector, ndim=2) + d.cfrc_int = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector, ndim=2) + d.cfrc_ext = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector, ndim=2) + + # tendon + d.ten_length = wp.zeros((nworld, mjm.ntendon), dtype=wp.float32, ndim=2) + d.ten_J = wp.zeros((nworld, mjm.ntendon, mjm.nv), dtype=wp.float32, ndim=3) + d.ten_wrapadr = wp.zeros((nworld, mjm.ntendon), dtype=wp.int32, ndim=2) + d.ten_wrapnum = wp.zeros((nworld, mjm.ntendon), dtype=wp.int32, ndim=2) + d.wrap_obj = wp.zeros((nworld, mjm.nwrap), dtype=wp.vec2i, ndim=2) + d.wrap_xpos = wp.zeros( + (nworld, mjm.nwrap), dtype=wp.spatial_vector, ndim=2 + ) # TODO(team): vec6? + + # sensors + d.sensordata = wp.zeros((nworld, mjm.nsensordata), dtype=wp.float32) + + return d def put_data( @@ -830,302 +838,312 @@ def put_data( nworld: Optional[int] = None, nconmax: Optional[int] = None, njmax: Optional[int] = None, + device: Optional[wp.context.Device] = None, ) -> types.Data: - d = types.Data() + with wp.ScopedDevice(device): + d = types.Data() - # TODO(team): confirm that Data is set correctly for solver with elliptic friction cones + # TODO(team): confirm that Data is set correctly for solver with elliptic friction cones - nworld = nworld or 1 - # TODO(team): better heuristic for nconmax - nconmax = nconmax or max(512, mjd.ncon * nworld) - # TODO(team): better heuristic for njmax - njmax = njmax or max(512, mjd.nefc * nworld) + nworld = nworld or 1 + # TODO(team): better heuristic for nconmax + nconmax = nconmax or max(512, mjd.ncon * nworld) + # TODO(team): better heuristic for njmax + njmax = njmax or max(512, mjd.nefc * nworld) - if nworld < 1: - raise ValueError("nworld must be >= 1") + if nworld < 1: + raise ValueError("nworld must be >= 1") - if nconmax < 1: - raise ValueError("nconmax must be >= 1") + if nconmax < 1: + raise ValueError("nconmax must be >= 1") - if njmax < 1: - raise ValueError("njmax must be >= 1") + if njmax < 1: + raise ValueError("njmax must be >= 1") - if nworld * mjd.ncon > nconmax: - raise ValueError(f"nconmax overflow (nconmax must be >= {nworld * mjd.ncon})") + if nworld * mjd.ncon > nconmax: + raise ValueError(f"nconmax overflow (nconmax must be >= {nworld * mjd.ncon})") - if nworld * mjd.nefc > njmax: - raise ValueError(f"njmax overflow (njmax must be >= {nworld * mjd.nefc})") + if nworld * mjd.nefc > njmax: + raise ValueError(f"njmax overflow (njmax must be >= {nworld * mjd.nefc})") - d.nworld = nworld - # TODO(team): move nconmax and njmax to Model? - d.nconmax = nconmax - d.njmax = njmax + d.nworld = nworld + # TODO(team): move nconmax and njmax to Model? + d.nconmax = nconmax + d.njmax = njmax - d.ncon = wp.array([mjd.ncon * nworld], dtype=wp.int32, ndim=1) - d.ne = wp.array([mjd.ne * nworld], dtype=wp.int32, ndim=1) - d.ne_connect = wp.array( - [3 * np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_CONNECT) & mjd.eq_active) * nworld], - dtype=wp.int32, - ndim=1, - ) - d.ne_weld = wp.array( - [6 * np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_WELD) & mjd.eq_active) * nworld], - dtype=wp.int32, - ndim=1, - ) - d.ne_jnt = wp.array( - [np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_JOINT) & mjd.eq_active) * nworld], - dtype=wp.int32, - ndim=1, - ) - d.nf = wp.array([mjd.nf * nworld], dtype=wp.int32, ndim=1) - d.nl = wp.array([mjd.nl * nworld], dtype=wp.int32, ndim=1) - d.nefc = wp.array([mjd.nefc * nworld], dtype=wp.int32, ndim=1) - - d.time = wp.array(mjd.time * np.ones(nworld), dtype=wp.float32, ndim=1) - - # TODO(erikfrey): would it be better to tile on the gpu? - def tile(x): - return np.tile(x, (nworld,) + (1,) * len(x.shape)) + d.ncon = wp.array([mjd.ncon * nworld], dtype=wp.int32, ndim=1) + d.ne = wp.array([mjd.ne * nworld], dtype=wp.int32, ndim=1) + d.ne_connect = wp.array( + [3 * np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_CONNECT) & mjd.eq_active) * nworld], + dtype=wp.int32, + ndim=1, + ) + d.ne_weld = wp.array( + [6 * np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_WELD) & mjd.eq_active) * nworld], + dtype=wp.int32, + ndim=1, + ) + d.ne_jnt = wp.array( + [np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_JOINT) & mjd.eq_active) * nworld], + dtype=wp.int32, + ndim=1, + ) + d.nf = wp.array([mjd.nf * nworld], dtype=wp.int32, ndim=1) + d.nl = wp.array([mjd.nl * nworld], dtype=wp.int32, ndim=1) + d.nefc = wp.array([mjd.nefc * nworld], dtype=wp.int32, ndim=1) + + d.time = wp.array(mjd.time * np.ones(nworld), dtype=wp.float32, ndim=1) + + # TODO(erikfrey): would it be better to tile on the gpu? + def tile(x): + return np.tile(x, (nworld,) + (1,) * len(x.shape)) + + if support.is_sparse(mjm): + qM = np.expand_dims(mjd.qM, axis=0) + qLD = np.expand_dims(mjd.qLD, axis=0) + efc_J = np.zeros((mjd.nefc, mjm.nv)) + mujoco.mju_sparse2dense( + efc_J, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind + ) + else: + qM = np.zeros((mjm.nv, mjm.nv)) + mujoco.mj_fullM(mjm, qM, mjd.qM) + qLD = np.linalg.cholesky(qM) + efc_J = mjd.efc_J.reshape((mjd.nefc, mjm.nv)) - if support.is_sparse(mjm): - qM = np.expand_dims(mjd.qM, axis=0) - qLD = np.expand_dims(mjd.qLD, axis=0) - efc_J = np.zeros((mjd.nefc, mjm.nv)) + # TODO(taylorhowell): sparse actuator_moment + actuator_moment = np.zeros((mjm.nu, mjm.nv)) mujoco.mju_sparse2dense( - efc_J, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind + actuator_moment, + mjd.actuator_moment, + mjd.moment_rownnz, + mjd.moment_rowadr, + mjd.moment_colind, ) - else: - qM = np.zeros((mjm.nv, mjm.nv)) - mujoco.mj_fullM(mjm, qM, mjd.qM) - qLD = np.linalg.cholesky(qM) - efc_J = mjd.efc_J.reshape((mjd.nefc, mjm.nv)) - - # TODO(taylorhowell): sparse actuator_moment - actuator_moment = np.zeros((mjm.nu, mjm.nv)) - mujoco.mju_sparse2dense( - actuator_moment, - mjd.actuator_moment, - mjd.moment_rownnz, - mjd.moment_rowadr, - mjd.moment_colind, - ) - d.qpos = wp.array(tile(mjd.qpos), dtype=wp.float32, ndim=2) - d.qvel = wp.array(tile(mjd.qvel), dtype=wp.float32, ndim=2) - d.qacc_warmstart = wp.array(tile(mjd.qacc_warmstart), dtype=wp.float32, ndim=2) - d.qfrc_applied = wp.array(tile(mjd.qfrc_applied), dtype=wp.float32, ndim=2) - d.mocap_pos = wp.array(tile(mjd.mocap_pos), dtype=wp.vec3, ndim=2) - d.mocap_quat = wp.array(tile(mjd.mocap_quat), dtype=wp.quat, ndim=2) - d.qacc = wp.array(tile(mjd.qacc), dtype=wp.float32, ndim=2) - d.xanchor = wp.array(tile(mjd.xanchor), dtype=wp.vec3, ndim=2) - d.xaxis = wp.array(tile(mjd.xaxis), dtype=wp.vec3, ndim=2) - d.xmat = wp.array(tile(mjd.xmat), dtype=wp.mat33, ndim=2) - d.xpos = wp.array(tile(mjd.xpos), dtype=wp.vec3, ndim=2) - d.xquat = wp.array(tile(mjd.xquat), dtype=wp.quat, ndim=2) - d.xipos = wp.array(tile(mjd.xipos), dtype=wp.vec3, ndim=2) - d.ximat = wp.array(tile(mjd.ximat), dtype=wp.mat33, ndim=2) - d.subtree_com = wp.array(tile(mjd.subtree_com), dtype=wp.vec3, ndim=2) - d.geom_xpos = wp.array(tile(mjd.geom_xpos), dtype=wp.vec3, ndim=2) - d.geom_xmat = wp.array(tile(mjd.geom_xmat), dtype=wp.mat33, ndim=2) - d.site_xpos = wp.array(tile(mjd.site_xpos), dtype=wp.vec3, ndim=2) - d.site_xmat = wp.array(tile(mjd.site_xmat), dtype=wp.mat33, ndim=2) - d.cam_xpos = wp.array(tile(mjd.cam_xpos), dtype=wp.vec3, ndim=2) - d.cam_xmat = wp.array(tile(mjd.cam_xmat.reshape(-1, 3, 3)), dtype=wp.mat33, ndim=2) - d.light_xpos = wp.array(tile(mjd.light_xpos), dtype=wp.vec3, ndim=2) - d.light_xdir = wp.array(tile(mjd.light_xdir), dtype=wp.vec3, ndim=2) - d.cinert = wp.array(tile(mjd.cinert), dtype=types.vec10, ndim=2) - d.cdof = wp.array(tile(mjd.cdof), dtype=wp.spatial_vector, ndim=2) - d.crb = wp.array(tile(mjd.crb), dtype=types.vec10, ndim=2) - d.qM = wp.array(tile(qM), dtype=wp.float32, ndim=3) - d.qLD = wp.array(tile(qLD), dtype=wp.float32, ndim=3) - d.qLDiagInv = wp.array(tile(mjd.qLDiagInv), dtype=wp.float32, ndim=2) - d.ctrl = wp.array(tile(mjd.ctrl), dtype=wp.float32, ndim=2) - d.ten_velocity = wp.array(tile(mjd.ten_velocity), dtype=wp.float32, ndim=2) - d.actuator_velocity = wp.array(tile(mjd.actuator_velocity), dtype=wp.float32, ndim=2) - d.actuator_force = wp.array(tile(mjd.actuator_force), dtype=wp.float32, ndim=2) - d.actuator_length = wp.array(tile(mjd.actuator_length), dtype=wp.float32, ndim=2) - d.actuator_moment = wp.array(tile(actuator_moment), dtype=wp.float32, ndim=3) - d.cvel = wp.array(tile(mjd.cvel), dtype=wp.spatial_vector, ndim=2) - d.cdof_dot = wp.array(tile(mjd.cdof_dot), dtype=wp.spatial_vector, ndim=2) - d.qfrc_bias = wp.array(tile(mjd.qfrc_bias), dtype=wp.float32, ndim=2) - d.qfrc_passive = wp.array(tile(mjd.qfrc_passive), dtype=wp.float32, ndim=2) - d.subtree_linvel = wp.array(tile(mjd.subtree_linvel), dtype=wp.vec3, ndim=2) - d.subtree_angmom = wp.array(tile(mjd.subtree_angmom), dtype=wp.vec3, ndim=2) - d.subtree_bodyvel = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector) - d.qfrc_spring = wp.array(tile(mjd.qfrc_spring), dtype=wp.float32, ndim=2) - d.qfrc_damper = wp.array(tile(mjd.qfrc_damper), dtype=wp.float32, ndim=2) - d.qfrc_actuator = wp.array(tile(mjd.qfrc_actuator), dtype=wp.float32, ndim=2) - d.qfrc_smooth = wp.array(tile(mjd.qfrc_smooth), dtype=wp.float32, ndim=2) - d.qfrc_constraint = wp.array(tile(mjd.qfrc_constraint), dtype=wp.float32, ndim=2) - d.qacc_smooth = wp.array(tile(mjd.qacc_smooth), dtype=wp.float32, ndim=2) - d.act = wp.array(tile(mjd.act), dtype=wp.float32, ndim=2) - d.act_dot = wp.array(tile(mjd.act_dot), dtype=wp.float32, ndim=2) - - nefc = mjd.nefc - efc_worldid = np.zeros(njmax, dtype=int) - - for i in range(nworld): - efc_worldid[i * nefc : (i + 1) * nefc] = i - - nefc_fill = njmax - nworld * nefc - - efc_J_fill = np.vstack( - [np.repeat(efc_J, nworld, axis=0), np.zeros((nefc_fill, mjm.nv))] - ) - efc_D_fill = np.concatenate( - [np.repeat(mjd.efc_D, nworld, axis=0), np.zeros(nefc_fill)] - ) - efc_pos_fill = np.concatenate( - [np.repeat(mjd.efc_pos, nworld, axis=0), np.zeros(nefc_fill)] - ) - efc_aref_fill = np.concatenate( - [np.repeat(mjd.efc_aref, nworld, axis=0), np.zeros(nefc_fill)] - ) - efc_frictionloss_fill = np.concatenate( - [np.repeat(mjd.efc_frictionloss, nworld, axis=0), np.zeros(nefc_fill)] - ) - efc_force_fill = np.concatenate( - [np.repeat(mjd.efc_force, nworld, axis=0), np.zeros(nefc_fill)] - ) - efc_margin_fill = np.concatenate( - [np.repeat(mjd.efc_margin, nworld, axis=0), np.zeros(nefc_fill)] - ) - efc_id_fill = np.concatenate( - [np.repeat(mjd.efc_id, nworld, axis=0), np.zeros(nefc_fill)] - ) + d.qpos = wp.array(tile(mjd.qpos), dtype=wp.float32, ndim=2) + d.qvel = wp.array(tile(mjd.qvel), dtype=wp.float32, ndim=2) + d.qacc_warmstart = wp.array(tile(mjd.qacc_warmstart), dtype=wp.float32, ndim=2) + d.qfrc_applied = wp.array(tile(mjd.qfrc_applied), dtype=wp.float32, ndim=2) + d.mocap_pos = wp.array(tile(mjd.mocap_pos), dtype=wp.vec3, ndim=2) + d.mocap_quat = wp.array(tile(mjd.mocap_quat), dtype=wp.quat, ndim=2) + d.qacc = wp.array(tile(mjd.qacc), dtype=wp.float32, ndim=2) + d.xanchor = wp.array(tile(mjd.xanchor), dtype=wp.vec3, ndim=2) + d.xaxis = wp.array(tile(mjd.xaxis), dtype=wp.vec3, ndim=2) + d.xmat = wp.array(tile(mjd.xmat), dtype=wp.mat33, ndim=2) + d.xpos = wp.array(tile(mjd.xpos), dtype=wp.vec3, ndim=2) + d.xquat = wp.array(tile(mjd.xquat), dtype=wp.quat, ndim=2) + d.xipos = wp.array(tile(mjd.xipos), dtype=wp.vec3, ndim=2) + d.ximat = wp.array(tile(mjd.ximat), dtype=wp.mat33, ndim=2) + d.subtree_com = wp.array(tile(mjd.subtree_com), dtype=wp.vec3, ndim=2) + d.geom_xpos = wp.array(tile(mjd.geom_xpos), dtype=wp.vec3, ndim=2) + d.geom_xmat = wp.array(tile(mjd.geom_xmat), dtype=wp.mat33, ndim=2) + d.site_xpos = wp.array(tile(mjd.site_xpos), dtype=wp.vec3, ndim=2) + d.site_xmat = wp.array(tile(mjd.site_xmat), dtype=wp.mat33, ndim=2) + d.cam_xpos = wp.array(tile(mjd.cam_xpos), dtype=wp.vec3, ndim=2) + d.cam_xmat = wp.array(tile(mjd.cam_xmat.reshape(-1, 3, 3)), dtype=wp.mat33, ndim=2) + d.light_xpos = wp.array(tile(mjd.light_xpos), dtype=wp.vec3, ndim=2) + d.light_xdir = wp.array(tile(mjd.light_xdir), dtype=wp.vec3, ndim=2) + d.cinert = wp.array(tile(mjd.cinert), dtype=types.vec10, ndim=2) + d.cdof = wp.array(tile(mjd.cdof), dtype=wp.spatial_vector, ndim=2) + d.crb = wp.array(tile(mjd.crb), dtype=types.vec10, ndim=2) + d.qM = wp.array(tile(qM), dtype=wp.float32, ndim=3) + d.qLD = wp.array(tile(qLD), dtype=wp.float32, ndim=3) + d.qLDiagInv = wp.array(tile(mjd.qLDiagInv), dtype=wp.float32, ndim=2) + d.ctrl = wp.array(tile(mjd.ctrl), dtype=wp.float32, ndim=2) + d.ten_velocity = wp.array(tile(mjd.ten_velocity), dtype=wp.float32, ndim=2) + d.actuator_velocity = wp.array( + tile(mjd.actuator_velocity), dtype=wp.float32, ndim=2 + ) + d.actuator_force = wp.array(tile(mjd.actuator_force), dtype=wp.float32, ndim=2) + d.actuator_length = wp.array(tile(mjd.actuator_length), dtype=wp.float32, ndim=2) + d.actuator_moment = wp.array(tile(actuator_moment), dtype=wp.float32, ndim=3) + d.cvel = wp.array(tile(mjd.cvel), dtype=wp.spatial_vector, ndim=2) + d.cdof_dot = wp.array(tile(mjd.cdof_dot), dtype=wp.spatial_vector, ndim=2) + d.qfrc_bias = wp.array(tile(mjd.qfrc_bias), dtype=wp.float32, ndim=2) + d.qfrc_passive = wp.array(tile(mjd.qfrc_passive), dtype=wp.float32, ndim=2) + d.subtree_linvel = wp.array(tile(mjd.subtree_linvel), dtype=wp.vec3, ndim=2) + d.subtree_angmom = wp.array(tile(mjd.subtree_angmom), dtype=wp.vec3, ndim=2) + d.subtree_bodyvel = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector) + d.qfrc_spring = wp.array(tile(mjd.qfrc_spring), dtype=wp.float32, ndim=2) + d.qfrc_damper = wp.array(tile(mjd.qfrc_damper), dtype=wp.float32, ndim=2) + d.qfrc_actuator = wp.array(tile(mjd.qfrc_actuator), dtype=wp.float32, ndim=2) + d.qfrc_smooth = wp.array(tile(mjd.qfrc_smooth), dtype=wp.float32, ndim=2) + d.qfrc_constraint = wp.array(tile(mjd.qfrc_constraint), dtype=wp.float32, ndim=2) + d.qacc_smooth = wp.array(tile(mjd.qacc_smooth), dtype=wp.float32, ndim=2) + d.act = wp.array(tile(mjd.act), dtype=wp.float32, ndim=2) + d.act_dot = wp.array(tile(mjd.act_dot), dtype=wp.float32, ndim=2) + + nefc = mjd.nefc + efc_worldid = np.zeros(njmax, dtype=int) + + for i in range(nworld): + efc_worldid[i * nefc : (i + 1) * nefc] = i + + nefc_fill = njmax - nworld * nefc + + efc_J_fill = np.vstack( + [np.repeat(efc_J, nworld, axis=0), np.zeros((nefc_fill, mjm.nv))] + ) + efc_D_fill = np.concatenate( + [np.repeat(mjd.efc_D, nworld, axis=0), np.zeros(nefc_fill)] + ) + efc_pos_fill = np.concatenate( + [np.repeat(mjd.efc_pos, nworld, axis=0), np.zeros(nefc_fill)] + ) + efc_aref_fill = np.concatenate( + [np.repeat(mjd.efc_aref, nworld, axis=0), np.zeros(nefc_fill)] + ) + efc_frictionloss_fill = np.concatenate( + [np.repeat(mjd.efc_frictionloss, nworld, axis=0), np.zeros(nefc_fill)] + ) + efc_force_fill = np.concatenate( + [np.repeat(mjd.efc_force, nworld, axis=0), np.zeros(nefc_fill)] + ) + efc_margin_fill = np.concatenate( + [np.repeat(mjd.efc_margin, nworld, axis=0), np.zeros(nefc_fill)] + ) + efc_id_fill = np.concatenate( + [np.repeat(mjd.efc_id, nworld, axis=0), np.zeros(nefc_fill)] + ) - ncon = mjd.ncon - condim_max = np.max(mjm.geom_condim) - con_efc_address = np.zeros((nconmax, condim_max), dtype=int) - for i in range(nworld): - for j in range(ncon): - condim = mjd.contact.dim[j] - for k in range(condim): - con_efc_address[i * ncon + j, k] = mjd.nefc * i + mjd.contact.efc_address[j] + k + ncon = mjd.ncon + condim_max = np.max(mjm.geom_condim) + con_efc_address = np.zeros((nconmax, condim_max), dtype=int) + for i in range(nworld): + for j in range(ncon): + condim = mjd.contact.dim[j] + for k in range(condim): + con_efc_address[i * ncon + j, k] = ( + mjd.nefc * i + mjd.contact.efc_address[j] + k + ) + + con_worldid = np.zeros(nconmax, dtype=int) + for i in range(nworld): + con_worldid[i * ncon : (i + 1) * ncon] = i + + ncon_fill = nconmax - nworld * ncon + + con_dist_fill = np.concatenate( + [np.repeat(mjd.contact.dist, nworld, axis=0), np.zeros(ncon_fill)] + ) + con_pos_fill = np.vstack( + [np.repeat(mjd.contact.pos, nworld, axis=0), np.zeros((ncon_fill, 3))] + ) + con_frame_fill = np.vstack( + [np.repeat(mjd.contact.frame, nworld, axis=0), np.zeros((ncon_fill, 9))] + ) + con_includemargin_fill = np.concatenate( + [np.repeat(mjd.contact.includemargin, nworld, axis=0), np.zeros(ncon_fill)] + ) + con_friction_fill = np.vstack( + [np.repeat(mjd.contact.friction, nworld, axis=0), np.zeros((ncon_fill, 5))] + ) + con_solref_fill = np.vstack( + [np.repeat(mjd.contact.solref, nworld, axis=0), np.zeros((ncon_fill, 2))] + ) + con_solreffriction_fill = np.vstack( + [np.repeat(mjd.contact.solreffriction, nworld, axis=0), np.zeros((ncon_fill, 2))] + ) + con_solimp_fill = np.vstack( + [np.repeat(mjd.contact.solimp, nworld, axis=0), np.zeros((ncon_fill, 5))] + ) + con_dim_fill = np.concatenate( + [np.repeat(mjd.contact.dim, nworld, axis=0), np.zeros(ncon_fill)] + ) + con_geom_fill = np.vstack( + [np.repeat(mjd.contact.geom, nworld, axis=0), np.zeros((ncon_fill, 2))] + ) + con_efc_address_fill = np.vstack( + [con_efc_address, np.zeros((ncon_fill, condim_max))] + ) - con_worldid = np.zeros(nconmax, dtype=int) - for i in range(nworld): - con_worldid[i * ncon : (i + 1) * ncon] = i + d.contact.dist = wp.array(con_dist_fill, dtype=wp.float32, ndim=1) + d.contact.pos = wp.array(con_pos_fill, dtype=wp.vec3f, ndim=1) + d.contact.frame = wp.array(con_frame_fill, dtype=wp.mat33f, ndim=1) + d.contact.includemargin = wp.array(con_includemargin_fill, dtype=wp.float32, ndim=1) + d.contact.friction = wp.array(con_friction_fill, dtype=types.vec5, ndim=1) + d.contact.solref = wp.array(con_solref_fill, dtype=wp.vec2f, ndim=1) + d.contact.solreffriction = wp.array(con_solreffriction_fill, dtype=wp.vec2f, ndim=1) + d.contact.solimp = wp.array(con_solimp_fill, dtype=types.vec5, ndim=1) + d.contact.dim = wp.array(con_dim_fill, dtype=wp.int32, ndim=1) + d.contact.geom = wp.array(con_geom_fill, dtype=wp.vec2i, ndim=1) + d.contact.efc_address = wp.array(con_efc_address_fill, dtype=wp.int32, ndim=2) + d.contact.worldid = wp.array(con_worldid, dtype=wp.int32, ndim=1) + + d.efc = _constraint(mjm, d.nworld, d.nconmax, d.njmax) + d.efc.J = wp.array(efc_J_fill, dtype=wp.float32, ndim=2) + d.efc.D = wp.array(efc_D_fill, dtype=wp.float32, ndim=1) + d.efc.pos = wp.array(efc_pos_fill, dtype=wp.float32, ndim=1) + d.efc.aref = wp.array(efc_aref_fill, dtype=wp.float32, ndim=1) + d.efc.frictionloss = wp.array(efc_frictionloss_fill, dtype=wp.float32, ndim=1) + d.efc.force = wp.array(efc_force_fill, dtype=wp.float32, ndim=1) + d.efc.margin = wp.array(efc_margin_fill, dtype=wp.float32, ndim=1) + d.efc.worldid = wp.from_numpy(efc_worldid, dtype=wp.int32) + d.efc.id = wp.from_numpy(efc_id_fill, dtype=wp.int32) + + d.xfrc_applied = wp.array(tile(mjd.xfrc_applied), dtype=wp.spatial_vector, ndim=2) + d.eq_active = wp.array(tile(mjm.eq_active0), dtype=wp.bool, ndim=2) + + # internal tmp arrays + d.qfrc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qacc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qM_integration = wp.zeros_like(d.qM) + d.qLD_integration = wp.zeros_like(d.qLD) + d.qLDiagInv_integration = wp.zeros_like(d.qLDiagInv) + d.act_vel_integration = wp.zeros_like(d.ctrl) + d.qpos_t0 = wp.zeros((nworld, mjm.nq), dtype=wp.float32) + d.qvel_t0 = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.act_t0 = wp.zeros((nworld, mjm.na), dtype=wp.float32) + d.qvel_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qacc_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.act_dot_rk = wp.zeros((nworld, mjm.na), dtype=wp.float32) + + # broadphase sweep and prune + d.sap_projection_lower = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32) + d.sap_projection_upper = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32) + d.sap_sort_index = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32) + d.sap_range = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32) + d.sap_cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32) + d.sap_segment_index = wp.array( + [i * mjm.ngeom for i in range(nworld + 1)], dtype=int + ) - ncon_fill = nconmax - nworld * ncon + # collision driver + d.collision_pair = wp.empty(nconmax, dtype=wp.vec2i, ndim=1) + d.collision_pairid = wp.empty(nconmax, dtype=wp.int32, ndim=1) + d.collision_worldid = wp.empty(nconmax, dtype=wp.int32, ndim=1) + d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1) - con_dist_fill = np.concatenate( - [np.repeat(mjd.contact.dist, nworld, axis=0), np.zeros(ncon_fill)] - ) - con_pos_fill = np.vstack( - [np.repeat(mjd.contact.pos, nworld, axis=0), np.zeros((ncon_fill, 3))] - ) - con_frame_fill = np.vstack( - [np.repeat(mjd.contact.frame, nworld, axis=0), np.zeros((ncon_fill, 9))] - ) - con_includemargin_fill = np.concatenate( - [np.repeat(mjd.contact.includemargin, nworld, axis=0), np.zeros(ncon_fill)] - ) - con_friction_fill = np.vstack( - [np.repeat(mjd.contact.friction, nworld, axis=0), np.zeros((ncon_fill, 5))] - ) - con_solref_fill = np.vstack( - [np.repeat(mjd.contact.solref, nworld, axis=0), np.zeros((ncon_fill, 2))] - ) - con_solreffriction_fill = np.vstack( - [np.repeat(mjd.contact.solreffriction, nworld, axis=0), np.zeros((ncon_fill, 2))] - ) - con_solimp_fill = np.vstack( - [np.repeat(mjd.contact.solimp, nworld, axis=0), np.zeros((ncon_fill, 5))] - ) - con_dim_fill = np.concatenate( - [np.repeat(mjd.contact.dim, nworld, axis=0), np.zeros(ncon_fill)] - ) - con_geom_fill = np.vstack( - [np.repeat(mjd.contact.geom, nworld, axis=0), np.zeros((ncon_fill, 2))] - ) - con_efc_address_fill = np.vstack([con_efc_address, np.zeros((ncon_fill, condim_max))]) - - d.contact.dist = wp.array(con_dist_fill, dtype=wp.float32, ndim=1) - d.contact.pos = wp.array(con_pos_fill, dtype=wp.vec3f, ndim=1) - d.contact.frame = wp.array(con_frame_fill, dtype=wp.mat33f, ndim=1) - d.contact.includemargin = wp.array(con_includemargin_fill, dtype=wp.float32, ndim=1) - d.contact.friction = wp.array(con_friction_fill, dtype=types.vec5, ndim=1) - d.contact.solref = wp.array(con_solref_fill, dtype=wp.vec2f, ndim=1) - d.contact.solreffriction = wp.array(con_solreffriction_fill, dtype=wp.vec2f, ndim=1) - d.contact.solimp = wp.array(con_solimp_fill, dtype=types.vec5, ndim=1) - d.contact.dim = wp.array(con_dim_fill, dtype=wp.int32, ndim=1) - d.contact.geom = wp.array(con_geom_fill, dtype=wp.vec2i, ndim=1) - d.contact.efc_address = wp.array(con_efc_address_fill, dtype=wp.int32, ndim=2) - d.contact.worldid = wp.array(con_worldid, dtype=wp.int32, ndim=1) - - d.efc = _constraint(mjm, d.nworld, d.nconmax, d.njmax) - d.efc.J = wp.array(efc_J_fill, dtype=wp.float32, ndim=2) - d.efc.D = wp.array(efc_D_fill, dtype=wp.float32, ndim=1) - d.efc.pos = wp.array(efc_pos_fill, dtype=wp.float32, ndim=1) - d.efc.aref = wp.array(efc_aref_fill, dtype=wp.float32, ndim=1) - d.efc.frictionloss = wp.array(efc_frictionloss_fill, dtype=wp.float32, ndim=1) - d.efc.force = wp.array(efc_force_fill, dtype=wp.float32, ndim=1) - d.efc.margin = wp.array(efc_margin_fill, dtype=wp.float32, ndim=1) - d.efc.worldid = wp.from_numpy(efc_worldid, dtype=wp.int32) - d.efc.id = wp.from_numpy(efc_id_fill, dtype=wp.int32) - - d.xfrc_applied = wp.array(tile(mjd.xfrc_applied), dtype=wp.spatial_vector, ndim=2) - d.eq_active = wp.array(tile(mjm.eq_active0), dtype=wp.bool, ndim=2) - - # internal tmp arrays - d.qfrc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qacc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qM_integration = wp.zeros_like(d.qM) - d.qLD_integration = wp.zeros_like(d.qLD) - d.qLDiagInv_integration = wp.zeros_like(d.qLDiagInv) - d.act_vel_integration = wp.zeros_like(d.ctrl) - d.qpos_t0 = wp.zeros((nworld, mjm.nq), dtype=wp.float32) - d.qvel_t0 = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.act_t0 = wp.zeros((nworld, mjm.na), dtype=wp.float32) - d.qvel_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qacc_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.act_dot_rk = wp.zeros((nworld, mjm.na), dtype=wp.float32) - - # broadphase sweep and prune - d.sap_projection_lower = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32) - d.sap_projection_upper = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32) - d.sap_sort_index = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32) - d.sap_range = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32) - d.sap_cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32) - d.sap_segment_index = wp.array([i * mjm.ngeom for i in range(nworld + 1)], dtype=int) - - # collision driver - d.collision_pair = wp.empty(nconmax, dtype=wp.vec2i, ndim=1) - d.collision_pairid = wp.empty(nconmax, dtype=wp.int32, ndim=1) - d.collision_worldid = wp.empty(nconmax, dtype=wp.int32, ndim=1) - d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1) - - # rne_postconstraint - d.cacc = wp.array(tile(mjd.cacc), dtype=wp.spatial_vector, ndim=2) - d.cfrc_int = wp.array(tile(mjd.cfrc_int), dtype=wp.spatial_vector, ndim=2) - d.cfrc_ext = wp.array(tile(mjd.cfrc_ext), dtype=wp.spatial_vector, ndim=2) + # rne_postconstraint + d.cacc = wp.array(tile(mjd.cacc), dtype=wp.spatial_vector, ndim=2) + d.cfrc_int = wp.array(tile(mjd.cfrc_int), dtype=wp.spatial_vector, ndim=2) + d.cfrc_ext = wp.array(tile(mjd.cfrc_ext), dtype=wp.spatial_vector, ndim=2) - # tendon - d.ten_length = wp.array(tile(mjd.ten_length), dtype=wp.float32, ndim=2) + # tendon + d.ten_length = wp.array(tile(mjd.ten_length), dtype=wp.float32, ndim=2) - if support.is_sparse(mjm) and mjm.ntendon: - ten_J = np.zeros((mjm.ntendon, mjm.nv)) - mujoco.mju_sparse2dense( - ten_J, mjd.ten_J, mjd.ten_J_rownnz, mjd.ten_J_rowadr, mjd.ten_J_colind - ) - else: - ten_J = mjd.ten_J.reshape((mjm.ntendon, mjm.nv)) + if support.is_sparse(mjm) and mjm.ntendon: + ten_J = np.zeros((mjm.ntendon, mjm.nv)) + mujoco.mju_sparse2dense( + ten_J, mjd.ten_J, mjd.ten_J_rownnz, mjd.ten_J_rowadr, mjd.ten_J_colind + ) + else: + ten_J = mjd.ten_J.reshape((mjm.ntendon, mjm.nv)) - d.ten_J = wp.array(tile(ten_J), dtype=wp.float32, ndim=3) + d.ten_J = wp.array(tile(ten_J), dtype=wp.float32, ndim=3) - d.ten_wrapadr = wp.array(tile(mjd.ten_wrapadr), dtype=wp.int32, ndim=2) - d.ten_wrapnum = wp.array(tile(mjd.ten_wrapnum), dtype=wp.int32, ndim=2) - d.wrap_obj = wp.array(tile(mjd.wrap_obj), dtype=wp.vec2i, ndim=2) - d.wrap_xpos = wp.array( - tile(mjd.wrap_xpos), dtype=wp.spatial_vector, ndim=2 - ) # TODO(team): vec6? + d.ten_wrapadr = wp.array(tile(mjd.ten_wrapadr), dtype=wp.int32, ndim=2) + d.ten_wrapnum = wp.array(tile(mjd.ten_wrapnum), dtype=wp.int32, ndim=2) + d.wrap_obj = wp.array(tile(mjd.wrap_obj), dtype=wp.vec2i, ndim=2) + d.wrap_xpos = wp.array( + tile(mjd.wrap_xpos), dtype=wp.spatial_vector, ndim=2 + ) # TODO(team): vec6? - # sensors - d.sensordata = wp.array(tile(mjd.sensordata), dtype=wp.float32, ndim=2) + # sensors + d.sensordata = wp.array(tile(mjd.sensordata), dtype=wp.float32, ndim=2) - return d + return d def get_data_into( diff --git a/mujoco_warp/_src/io_test.py b/mujoco_warp/_src/io_test.py index c54f8bc7..4c953c95 100644 --- a/mujoco_warp/_src/io_test.py +++ b/mujoco_warp/_src/io_test.py @@ -15,6 +15,9 @@ """Tests for io functions.""" +import inspect +import textwrap + import mujoco import numpy as np import warp as wp @@ -290,6 +293,94 @@ def test_option_physical_constants(self): with self.assertRaises(NotImplementedError): mjwarp.put_model(mjm) + def test_scoped_device_in_public_api(self): + """Checks that all public API functions use wp.ScopedDevice.""" + mjwarp_funcs = [] + for name, obj in inspect.getmembers(mjwarp): + if inspect.isfunction(obj): + # Check if the function is defined within mujoco_warp._src + try: + module_name = inspect.getmodule(obj).__name__ + if module_name.startswith("mujoco_warp._src"): + mjwarp_funcs.append(obj) + except AttributeError: + # Some objects might not have a __module__ attribute + pass + + missing_scoped_device = [] + # List of functions to exclude from the check + excluded_functions = {"geom_pair", "get_data_into", "is_sparse"} + + for func in mjwarp_funcs: + # Skip excluded functions + if func.__name__ in excluded_functions: + continue + + try: + source = inspect.getsource(func) + # Dedent the source code to handle indentation + source = textwrap.dedent(source) + # Remove decorators to check the actual function body start + lines = source.splitlines() + line_index = 0 + + # Skip decorators + while line_index < len(lines) and lines[line_index].strip().startswith("@"): + line_index += 1 + + # Skip the function definition line itself (might span multiple lines) + while line_index < len(lines) and lines[line_index].strip().startswith("def "): + line_index += 1 + # Handle potential multi-line defs until the line ends with ":" + while line_index < len(lines) and not lines[line_index - 1].strip().endswith( + ":" + ): + line_index += 1 + + in_docstring = False + first_code_line = None + while line_index < len(lines): + line = lines[line_index].strip() + + # Handle docstrings + if line.startswith('"""') or line.startswith("'''"): + quote_type = line[:3] + # Started or ended docstring + in_docstring = not in_docstring + # If it starts and ends on the same line, reset state + if in_docstring and len(line) > 3 and line.endswith(quote_type): + in_docstring = False + # If it just ended, continue to next line + elif not in_docstring: + line_index += 1 + continue + + elif not in_docstring and line and not line.startswith("#"): + # Found the first non-comment, non-docstring, non-decorator line + first_code_line = line + break + + # If in_docstring is True, or line is empty/comment, continue + line_index += 1 + + # Strip leading whitespace from the found line before checking + if first_code_line is not None: + first_code_line = first_code_line.lstrip() + + if first_code_line is None or not first_code_line.startswith( + "with wp.ScopedDevice" + ): + missing_scoped_device.append(func.__name__) + except (TypeError, OSError) as e: + # Built-in functions or dynamically generated ones might not have source + # Or file might not be found, skip these cases. + pass + + self.assertEmpty( + missing_scoped_device, + f"Functions missing 'with wp.ScopedDevice': {', '.join(missing_scoped_device)}", + ) + if __name__ == "__main__": wp.init() diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index c8cbc30b..a80b39ab 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -27,82 +27,84 @@ @event_scope def passive(m: Model, d: Data): """Adds all passive forces.""" - if m.opt.disableflags & DisableBit.PASSIVE: - d.qfrc_passive.zero_() - # TODO(team): qfrc_gravcomp - return - @kernel - def _spring(m: Model, d: Data): - worldid, jntid = wp.tid() - stiffness = m.jnt_stiffness[jntid] - dofid = m.jnt_dofadr[jntid] - - if stiffness == 0.0: + with wp.ScopedDevice(m.qpos0.device): + if m.opt.disableflags & DisableBit.PASSIVE: + d.qfrc_passive.zero_() + # TODO(team): qfrc_gravcomp return - jnt_type = m.jnt_type[jntid] - qposid = m.jnt_qposadr[jntid] + @kernel + def _spring(m: Model, d: Data): + worldid, jntid = wp.tid() + stiffness = m.jnt_stiffness[jntid] + dofid = m.jnt_dofadr[jntid] + + if stiffness == 0.0: + return + + jnt_type = m.jnt_type[jntid] + qposid = m.jnt_qposadr[jntid] - if jnt_type == wp.static(JointType.FREE.value): - dif = wp.vec3( - d.qpos[worldid, qposid + 0] - m.qpos_spring[qposid + 0], - d.qpos[worldid, qposid + 1] - m.qpos_spring[qposid + 1], - d.qpos[worldid, qposid + 2] - m.qpos_spring[qposid + 2], - ) - d.qfrc_spring[worldid, dofid + 0] = -stiffness * dif[0] - d.qfrc_spring[worldid, dofid + 1] = -stiffness * dif[1] - d.qfrc_spring[worldid, dofid + 2] = -stiffness * dif[2] - rot = wp.quat( - d.qpos[worldid, qposid + 3], - d.qpos[worldid, qposid + 4], - d.qpos[worldid, qposid + 5], - d.qpos[worldid, qposid + 6], - ) - ref = wp.quat( - m.qpos_spring[qposid + 3], - m.qpos_spring[qposid + 4], - m.qpos_spring[qposid + 5], - m.qpos_spring[qposid + 6], - ) - dif = math.quat_sub(rot, ref) - d.qfrc_spring[worldid, dofid + 3] = -stiffness * dif[0] - d.qfrc_spring[worldid, dofid + 4] = -stiffness * dif[1] - d.qfrc_spring[worldid, dofid + 5] = -stiffness * dif[2] - elif jnt_type == wp.static(JointType.BALL.value): - rot = wp.quat( - d.qpos[worldid, qposid + 0], - d.qpos[worldid, qposid + 1], - d.qpos[worldid, qposid + 2], - d.qpos[worldid, qposid + 3], - ) - ref = wp.quat( - m.qpos_spring[qposid + 0], - m.qpos_spring[qposid + 1], - m.qpos_spring[qposid + 2], - m.qpos_spring[qposid + 3], - ) - dif = math.quat_sub(rot, ref) - d.qfrc_spring[worldid, dofid + 0] = -stiffness * dif[0] - d.qfrc_spring[worldid, dofid + 1] = -stiffness * dif[1] - d.qfrc_spring[worldid, dofid + 2] = -stiffness * dif[2] - else: # mjJNT_SLIDE, mjJNT_HINGE - fdif = d.qpos[worldid, qposid] - m.qpos_spring[qposid] - d.qfrc_spring[worldid, dofid] = -stiffness * fdif + if jnt_type == wp.static(JointType.FREE.value): + dif = wp.vec3( + d.qpos[worldid, qposid + 0] - m.qpos_spring[qposid + 0], + d.qpos[worldid, qposid + 1] - m.qpos_spring[qposid + 1], + d.qpos[worldid, qposid + 2] - m.qpos_spring[qposid + 2], + ) + d.qfrc_spring[worldid, dofid + 0] = -stiffness * dif[0] + d.qfrc_spring[worldid, dofid + 1] = -stiffness * dif[1] + d.qfrc_spring[worldid, dofid + 2] = -stiffness * dif[2] + rot = wp.quat( + d.qpos[worldid, qposid + 3], + d.qpos[worldid, qposid + 4], + d.qpos[worldid, qposid + 5], + d.qpos[worldid, qposid + 6], + ) + ref = wp.quat( + m.qpos_spring[qposid + 3], + m.qpos_spring[qposid + 4], + m.qpos_spring[qposid + 5], + m.qpos_spring[qposid + 6], + ) + dif = math.quat_sub(rot, ref) + d.qfrc_spring[worldid, dofid + 3] = -stiffness * dif[0] + d.qfrc_spring[worldid, dofid + 4] = -stiffness * dif[1] + d.qfrc_spring[worldid, dofid + 5] = -stiffness * dif[2] + elif jnt_type == wp.static(JointType.BALL.value): + rot = wp.quat( + d.qpos[worldid, qposid + 0], + d.qpos[worldid, qposid + 1], + d.qpos[worldid, qposid + 2], + d.qpos[worldid, qposid + 3], + ) + ref = wp.quat( + m.qpos_spring[qposid + 0], + m.qpos_spring[qposid + 1], + m.qpos_spring[qposid + 2], + m.qpos_spring[qposid + 3], + ) + dif = math.quat_sub(rot, ref) + d.qfrc_spring[worldid, dofid + 0] = -stiffness * dif[0] + d.qfrc_spring[worldid, dofid + 1] = -stiffness * dif[1] + d.qfrc_spring[worldid, dofid + 2] = -stiffness * dif[2] + else: # mjJNT_SLIDE, mjJNT_HINGE + fdif = d.qpos[worldid, qposid] - m.qpos_spring[qposid] + d.qfrc_spring[worldid, dofid] = -stiffness * fdif - @kernel - def _damper_passive(m: Model, d: Data): - worldid, dofid = wp.tid() - damping = m.dof_damping[dofid] - qfrc_damper = -damping * d.qvel[worldid, dofid] + @kernel + def _damper_passive(m: Model, d: Data): + worldid, dofid = wp.tid() + damping = m.dof_damping[dofid] + qfrc_damper = -damping * d.qvel[worldid, dofid] - d.qfrc_damper[worldid, dofid] = qfrc_damper - d.qfrc_passive[worldid, dofid] = qfrc_damper + d.qfrc_spring[worldid, dofid] + d.qfrc_damper[worldid, dofid] = qfrc_damper + d.qfrc_passive[worldid, dofid] = qfrc_damper + d.qfrc_spring[worldid, dofid] - # TODO(team): mj_gravcomp - # TODO(team): mj_ellipsoidFluidModel - # TODO(team): mj_inertiaBoxFluidModell + # TODO(team): mj_gravcomp + # TODO(team): mj_ellipsoidFluidModel + # TODO(team): mj_inertiaBoxFluidModell - d.qfrc_spring.zero_() - wp.launch(_spring, dim=(d.nworld, m.njnt), inputs=[m, d]) - wp.launch(_damper_passive, dim=(d.nworld, m.nv), inputs=[m, d]) + d.qfrc_spring.zero_() + wp.launch(_spring, dim=(d.nworld, m.njnt), inputs=[m, d]) + wp.launch(_damper_passive, dim=(d.nworld, m.nv), inputs=[m, d]) diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py index 3ebc18e6..6b2564e5 100644 --- a/mujoco_warp/_src/sensor.py +++ b/mujoco_warp/_src/sensor.py @@ -176,71 +176,73 @@ def _clock(m: Model, d: Data, worldid: int) -> wp.float32: def sensor_pos(m: Model, d: Data): """Compute position-dependent sensor values.""" - @kernel - def _sensor_pos(m: Model, d: Data): - worldid, posid = wp.tid() - posadr = m.sensor_pos_adr[posid] - sensortype = m.sensor_type[posadr] - objid = m.sensor_objid[posadr] - adr = m.sensor_adr[posadr] - - if sensortype == int(SensorType.JOINTPOS.value): - d.sensordata[worldid, adr] = _joint_pos(m, d, worldid, objid) - elif sensortype == int(SensorType.TENDONPOS.value): - d.sensordata[worldid, adr] = _tendon_pos(m, d, worldid, objid) - elif sensortype == int(SensorType.ACTUATORPOS.value): - d.sensordata[worldid, adr] = _actuator_length(m, d, worldid, objid) - elif sensortype == int(SensorType.BALLQUAT.value): - quat = _ball_quat(m, d, worldid, objid) - d.sensordata[worldid, adr + 0] = quat[0] - d.sensordata[worldid, adr + 1] = quat[1] - d.sensordata[worldid, adr + 2] = quat[2] - d.sensordata[worldid, adr + 3] = quat[3] - elif sensortype == int(SensorType.FRAMEPOS.value): - objtype = m.sensor_objtype[posadr] - refid = m.sensor_refid[posadr] - framepos = _frame_pos(m, d, worldid, objid, objtype, refid) - d.sensordata[worldid, adr + 0] = framepos[0] - d.sensordata[worldid, adr + 1] = framepos[1] - d.sensordata[worldid, adr + 2] = framepos[2] - elif ( - sensortype == int(SensorType.FRAMEXAXIS.value) - or sensortype == int(SensorType.FRAMEYAXIS.value) - or sensortype == int(SensorType.FRAMEZAXIS.value) - ): - objtype = m.sensor_objtype[posadr] - refid = m.sensor_refid[posadr] - if sensortype == int(SensorType.FRAMEXAXIS.value): - axis = 0 - elif sensortype == int(SensorType.FRAMEYAXIS.value): - axis = 1 - elif sensortype == int(SensorType.FRAMEZAXIS.value): - axis = 2 - frameaxis = _frame_axis(m, d, worldid, objid, objtype, refid, axis) - d.sensordata[worldid, adr + 0] = frameaxis[0] - d.sensordata[worldid, adr + 1] = frameaxis[1] - d.sensordata[worldid, adr + 2] = frameaxis[2] - elif sensortype == int(SensorType.FRAMEQUAT.value): - objtype = m.sensor_objtype[posadr] - refid = m.sensor_refid[posadr] - quat = _frame_quat(m, d, worldid, objid, objtype, refid) - d.sensordata[worldid, adr + 0] = quat[0] - d.sensordata[worldid, adr + 1] = quat[1] - d.sensordata[worldid, adr + 2] = quat[2] - d.sensordata[worldid, adr + 3] = quat[3] - elif sensortype == int(SensorType.SUBTREECOM.value): - subtree_com = _subtree_com(m, d, worldid, objid) - d.sensordata[worldid, adr + 0] = subtree_com[0] - d.sensordata[worldid, adr + 1] = subtree_com[1] - d.sensordata[worldid, adr + 2] = subtree_com[2] - elif sensortype == int(SensorType.CLOCK.value): - clock = _clock(m, d, worldid) - d.sensordata[worldid, adr] = clock - - if (m.sensor_pos_adr.size == 0) or (m.opt.disableflags & DisableBit.SENSOR): - return - - wp.launch(_sensor_pos, dim=(d.nworld, m.sensor_pos_adr.size), inputs=[m, d]) + with wp.ScopedDevice(m.qpos0.device): + + @kernel + def _sensor_pos(m: Model, d: Data): + worldid, posid = wp.tid() + posadr = m.sensor_pos_adr[posid] + sensortype = m.sensor_type[posadr] + objid = m.sensor_objid[posadr] + adr = m.sensor_adr[posadr] + + if sensortype == int(SensorType.JOINTPOS.value): + d.sensordata[worldid, adr] = _joint_pos(m, d, worldid, objid) + elif sensortype == int(SensorType.TENDONPOS.value): + d.sensordata[worldid, adr] = _tendon_pos(m, d, worldid, objid) + elif sensortype == int(SensorType.ACTUATORPOS.value): + d.sensordata[worldid, adr] = _actuator_length(m, d, worldid, objid) + elif sensortype == int(SensorType.BALLQUAT.value): + quat = _ball_quat(m, d, worldid, objid) + d.sensordata[worldid, adr + 0] = quat[0] + d.sensordata[worldid, adr + 1] = quat[1] + d.sensordata[worldid, adr + 2] = quat[2] + d.sensordata[worldid, adr + 3] = quat[3] + elif sensortype == int(SensorType.FRAMEPOS.value): + objtype = m.sensor_objtype[posadr] + refid = m.sensor_refid[posadr] + framepos = _frame_pos(m, d, worldid, objid, objtype, refid) + d.sensordata[worldid, adr + 0] = framepos[0] + d.sensordata[worldid, adr + 1] = framepos[1] + d.sensordata[worldid, adr + 2] = framepos[2] + elif ( + sensortype == int(SensorType.FRAMEXAXIS.value) + or sensortype == int(SensorType.FRAMEYAXIS.value) + or sensortype == int(SensorType.FRAMEZAXIS.value) + ): + objtype = m.sensor_objtype[posadr] + refid = m.sensor_refid[posadr] + if sensortype == int(SensorType.FRAMEXAXIS.value): + axis = 0 + elif sensortype == int(SensorType.FRAMEYAXIS.value): + axis = 1 + elif sensortype == int(SensorType.FRAMEZAXIS.value): + axis = 2 + frameaxis = _frame_axis(m, d, worldid, objid, objtype, refid, axis) + d.sensordata[worldid, adr + 0] = frameaxis[0] + d.sensordata[worldid, adr + 1] = frameaxis[1] + d.sensordata[worldid, adr + 2] = frameaxis[2] + elif sensortype == int(SensorType.FRAMEQUAT.value): + objtype = m.sensor_objtype[posadr] + refid = m.sensor_refid[posadr] + quat = _frame_quat(m, d, worldid, objid, objtype, refid) + d.sensordata[worldid, adr + 0] = quat[0] + d.sensordata[worldid, adr + 1] = quat[1] + d.sensordata[worldid, adr + 2] = quat[2] + d.sensordata[worldid, adr + 3] = quat[3] + elif sensortype == int(SensorType.SUBTREECOM.value): + subtree_com = _subtree_com(m, d, worldid, objid) + d.sensordata[worldid, adr + 0] = subtree_com[0] + d.sensordata[worldid, adr + 1] = subtree_com[1] + d.sensordata[worldid, adr + 2] = subtree_com[2] + elif sensortype == int(SensorType.CLOCK.value): + clock = _clock(m, d, worldid) + d.sensordata[worldid, adr] = clock + + if (m.sensor_pos_adr.size == 0) or (m.opt.disableflags & DisableBit.SENSOR): + return + + wp.launch(_sensor_pos, dim=(d.nworld, m.sensor_pos_adr.size), inputs=[m, d]) @wp.func @@ -304,57 +306,59 @@ def _subtree_angmom(m: Model, d: Data, worldid: int, objid: int) -> wp.vec3: def sensor_vel(m: Model, d: Data): """Compute velocity-dependent sensor values.""" - @kernel - def _sensor_vel(m: Model, d: Data): - worldid, velid = wp.tid() - veladr = m.sensor_vel_adr[velid] - sensortype = m.sensor_type[veladr] - objid = m.sensor_objid[veladr] - adr = m.sensor_adr[veladr] - - if sensortype == int(SensorType.VELOCIMETER.value): - vel = _velocimeter(m, d, worldid, objid) - d.sensordata[worldid, adr + 0] = vel[0] - d.sensordata[worldid, adr + 1] = vel[1] - d.sensordata[worldid, adr + 2] = vel[2] - elif sensortype == int(SensorType.GYRO.value): - gyro = _gyro(m, d, worldid, objid) - d.sensordata[worldid, adr + 0] = gyro[0] - d.sensordata[worldid, adr + 1] = gyro[1] - d.sensordata[worldid, adr + 2] = gyro[2] - elif sensortype == int(SensorType.JOINTVEL.value): - d.sensordata[worldid, adr] = _joint_vel(m, d, worldid, objid) - elif sensortype == int(SensorType.TENDONVEL.value): - d.sensordata[worldid, adr] = _tendon_vel(m, d, worldid, objid) - elif sensortype == int(SensorType.ACTUATORVEL.value): - d.sensordata[worldid, adr] = _actuator_vel(m, d, worldid, objid) - elif sensortype == int(SensorType.BALLANGVEL.value): - angvel = _ball_ang_vel(m, d, worldid, objid) - d.sensordata[worldid, adr + 0] = angvel[0] - d.sensordata[worldid, adr + 1] = angvel[1] - d.sensordata[worldid, adr + 2] = angvel[2] - elif sensortype == int(SensorType.SUBTREELINVEL.value): - subtree_linvel = _subtree_linvel(m, d, worldid, objid) - d.sensordata[worldid, adr + 0] = subtree_linvel[0] - d.sensordata[worldid, adr + 1] = subtree_linvel[1] - d.sensordata[worldid, adr + 2] = subtree_linvel[2] - elif sensortype == int(SensorType.SUBTREEANGMOM.value): - subtree_angmom = _subtree_angmom(m, d, worldid, objid) - d.sensordata[worldid, adr + 0] = subtree_angmom[0] - d.sensordata[worldid, adr + 1] = subtree_angmom[1] - d.sensordata[worldid, adr + 2] = subtree_angmom[2] - - if (m.sensor_vel_adr.size == 0) or (m.opt.disableflags & DisableBit.SENSOR): - return - - if wp.static( - np.isin( - m.sensor_type.numpy(), [SensorType.SUBTREELINVEL, SensorType.SUBTREEANGMOM] - ).any() - ): - smooth.subtree_vel(m, d) - - wp.launch(_sensor_vel, dim=(d.nworld, m.sensor_vel_adr.size), inputs=[m, d]) + with wp.ScopedDevice(m.qpos0.device): + + @kernel + def _sensor_vel(m: Model, d: Data): + worldid, velid = wp.tid() + veladr = m.sensor_vel_adr[velid] + sensortype = m.sensor_type[veladr] + objid = m.sensor_objid[veladr] + adr = m.sensor_adr[veladr] + + if sensortype == int(SensorType.VELOCIMETER.value): + vel = _velocimeter(m, d, worldid, objid) + d.sensordata[worldid, adr + 0] = vel[0] + d.sensordata[worldid, adr + 1] = vel[1] + d.sensordata[worldid, adr + 2] = vel[2] + elif sensortype == int(SensorType.GYRO.value): + gyro = _gyro(m, d, worldid, objid) + d.sensordata[worldid, adr + 0] = gyro[0] + d.sensordata[worldid, adr + 1] = gyro[1] + d.sensordata[worldid, adr + 2] = gyro[2] + elif sensortype == int(SensorType.JOINTVEL.value): + d.sensordata[worldid, adr] = _joint_vel(m, d, worldid, objid) + elif sensortype == int(SensorType.TENDONVEL.value): + d.sensordata[worldid, adr] = _tendon_vel(m, d, worldid, objid) + elif sensortype == int(SensorType.ACTUATORVEL.value): + d.sensordata[worldid, adr] = _actuator_vel(m, d, worldid, objid) + elif sensortype == int(SensorType.BALLANGVEL.value): + angvel = _ball_ang_vel(m, d, worldid, objid) + d.sensordata[worldid, adr + 0] = angvel[0] + d.sensordata[worldid, adr + 1] = angvel[1] + d.sensordata[worldid, adr + 2] = angvel[2] + elif sensortype == int(SensorType.SUBTREELINVEL.value): + subtree_linvel = _subtree_linvel(m, d, worldid, objid) + d.sensordata[worldid, adr + 0] = subtree_linvel[0] + d.sensordata[worldid, adr + 1] = subtree_linvel[1] + d.sensordata[worldid, adr + 2] = subtree_linvel[2] + elif sensortype == int(SensorType.SUBTREEANGMOM.value): + subtree_angmom = _subtree_angmom(m, d, worldid, objid) + d.sensordata[worldid, adr + 0] = subtree_angmom[0] + d.sensordata[worldid, adr + 1] = subtree_angmom[1] + d.sensordata[worldid, adr + 2] = subtree_angmom[2] + + if (m.sensor_vel_adr.size == 0) or (m.opt.disableflags & DisableBit.SENSOR): + return + + if wp.static( + np.isin( + m.sensor_type.numpy(), [SensorType.SUBTREELINVEL, SensorType.SUBTREEANGMOM] + ).any() + ): + smooth.subtree_vel(m, d) + + wp.launch(_sensor_vel, dim=(d.nworld, m.sensor_vel_adr.size), inputs=[m, d]) @wp.func @@ -454,55 +458,57 @@ def _frameangacc(m: Model, d: Data, worldid: int, objid: int, objtype: int) -> w def sensor_acc(m: Model, d: Data): """Compute acceleration-dependent sensor values.""" - @kernel - def _sensor_acc(m: Model, d: Data): - worldid, accid = wp.tid() - accadr = m.sensor_acc_adr[accid] - sensortype = m.sensor_type[accadr] - objid = m.sensor_objid[accadr] - adr = m.sensor_adr[accadr] - - if sensortype == int(SensorType.ACCELEROMETER.value): - accelerometer = _accelerometer(m, d, worldid, objid) - d.sensordata[worldid, adr + 0] = accelerometer[0] - d.sensordata[worldid, adr + 1] = accelerometer[1] - d.sensordata[worldid, adr + 2] = accelerometer[2] - elif sensortype == int(SensorType.FORCE.value): - force = _force(m, d, worldid, objid) - d.sensordata[worldid, adr + 0] = force[0] - d.sensordata[worldid, adr + 1] = force[1] - d.sensordata[worldid, adr + 2] = force[2] - elif sensortype == int(SensorType.TORQUE.value): - torque = _torque(m, d, worldid, objid) - d.sensordata[worldid, adr + 0] = torque[0] - d.sensordata[worldid, adr + 1] = torque[1] - d.sensordata[worldid, adr + 2] = torque[2] - elif sensortype == int(SensorType.ACTUATORFRC.value): - d.sensordata[worldid, adr] = _actuator_force(m, d, worldid, objid) - elif sensortype == int(SensorType.JOINTACTFRC.value): - d.sensordata[worldid, adr] = _joint_actuator_force(m, d, worldid, objid) - elif sensortype == int(SensorType.FRAMELINACC.value): - objtype = m.sensor_objtype[accadr] - framelinacc = _framelinacc(m, d, worldid, objid, objtype) - d.sensordata[worldid, adr + 0] = framelinacc[0] - d.sensordata[worldid, adr + 1] = framelinacc[1] - d.sensordata[worldid, adr + 2] = framelinacc[2] - elif sensortype == int(SensorType.FRAMEANGACC.value): - objtype = m.sensor_objtype[accadr] - frameangacc = _frameangacc(m, d, worldid, objid, objtype) - d.sensordata[worldid, adr + 0] = frameangacc[0] - d.sensordata[worldid, adr + 1] = frameangacc[1] - d.sensordata[worldid, adr + 2] = frameangacc[2] - - if (m.sensor_acc_adr.size == 0) or (m.opt.disableflags & DisableBit.SENSOR): - return - - if wp.static( - np.isin( - m.sensor_type.numpy(), - [SensorType.ACCELEROMETER, SensorType.FORCE, SensorType.TORQUE], - ).any() - ): - smooth.rne_postconstraint(m, d) - - wp.launch(_sensor_acc, dim=(d.nworld, m.sensor_acc_adr.size), inputs=[m, d]) + with wp.ScopedDevice(m.qpos0.device): + + @kernel + def _sensor_acc(m: Model, d: Data): + worldid, accid = wp.tid() + accadr = m.sensor_acc_adr[accid] + sensortype = m.sensor_type[accadr] + objid = m.sensor_objid[accadr] + adr = m.sensor_adr[accadr] + + if sensortype == int(SensorType.ACCELEROMETER.value): + accelerometer = _accelerometer(m, d, worldid, objid) + d.sensordata[worldid, adr + 0] = accelerometer[0] + d.sensordata[worldid, adr + 1] = accelerometer[1] + d.sensordata[worldid, adr + 2] = accelerometer[2] + elif sensortype == int(SensorType.FORCE.value): + force = _force(m, d, worldid, objid) + d.sensordata[worldid, adr + 0] = force[0] + d.sensordata[worldid, adr + 1] = force[1] + d.sensordata[worldid, adr + 2] = force[2] + elif sensortype == int(SensorType.TORQUE.value): + torque = _torque(m, d, worldid, objid) + d.sensordata[worldid, adr + 0] = torque[0] + d.sensordata[worldid, adr + 1] = torque[1] + d.sensordata[worldid, adr + 2] = torque[2] + elif sensortype == int(SensorType.ACTUATORFRC.value): + d.sensordata[worldid, adr] = _actuator_force(m, d, worldid, objid) + elif sensortype == int(SensorType.JOINTACTFRC.value): + d.sensordata[worldid, adr] = _joint_actuator_force(m, d, worldid, objid) + elif sensortype == int(SensorType.FRAMELINACC.value): + objtype = m.sensor_objtype[accadr] + framelinacc = _framelinacc(m, d, worldid, objid, objtype) + d.sensordata[worldid, adr + 0] = framelinacc[0] + d.sensordata[worldid, adr + 1] = framelinacc[1] + d.sensordata[worldid, adr + 2] = framelinacc[2] + elif sensortype == int(SensorType.FRAMEANGACC.value): + objtype = m.sensor_objtype[accadr] + frameangacc = _frameangacc(m, d, worldid, objid, objtype) + d.sensordata[worldid, adr + 0] = frameangacc[0] + d.sensordata[worldid, adr + 1] = frameangacc[1] + d.sensordata[worldid, adr + 2] = frameangacc[2] + + if (m.sensor_acc_adr.size == 0) or (m.opt.disableflags & DisableBit.SENSOR): + return + + if wp.static( + np.isin( + m.sensor_type.numpy(), + [SensorType.ACCELEROMETER, SensorType.FORCE, SensorType.TORQUE], + ).any() + ): + smooth.rne_postconstraint(m, d) + + wp.launch(_sensor_acc, dim=(d.nworld, m.sensor_acc_adr.size), inputs=[m, d]) diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index bc6a0ae6..80ad6ec1 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -41,381 +41,388 @@ def kinematics(m: Model, d: Data): """Forward kinematics.""" - @kernel - def _root(m: Model, d: Data): - worldid = wp.tid() - d.xpos[worldid, 0] = wp.vec3(0.0) - d.xquat[worldid, 0] = wp.quat(1.0, 0.0, 0.0, 0.0) - d.xipos[worldid, 0] = wp.vec3(0.0) - d.xmat[worldid, 0] = wp.identity(n=3, dtype=wp.float32) - d.ximat[worldid, 0] = wp.identity(n=3, dtype=wp.float32) + with wp.ScopedDevice(m.qpos0.device): - @kernel - def _level(m: Model, d: Data, leveladr: int): - worldid, nodeid = wp.tid() - bodyid = m.body_tree[leveladr + nodeid] - jntadr = m.body_jntadr[bodyid] - jntnum = m.body_jntnum[bodyid] - qpos = d.qpos[worldid] + @kernel + def _root(m: Model, d: Data): + worldid = wp.tid() + d.xpos[worldid, 0] = wp.vec3(0.0) + d.xquat[worldid, 0] = wp.quat(1.0, 0.0, 0.0, 0.0) + d.xipos[worldid, 0] = wp.vec3(0.0) + d.xmat[worldid, 0] = wp.identity(n=3, dtype=wp.float32) + d.ximat[worldid, 0] = wp.identity(n=3, dtype=wp.float32) - if jntnum == 0: - # no joints - apply fixed translation and rotation relative to parent - pid = m.body_parentid[bodyid] - xpos = (d.xmat[worldid, pid] * m.body_pos[bodyid]) + d.xpos[worldid, pid] - xquat = math.mul_quat(d.xquat[worldid, pid], m.body_quat[bodyid]) - elif jntnum == 1 and m.jnt_type[jntadr] == wp.static(JointType.FREE.value): - # free joint - qadr = m.jnt_qposadr[jntadr] - xpos = wp.vec3(qpos[qadr], qpos[qadr + 1], qpos[qadr + 2]) - xquat = wp.quat(qpos[qadr + 3], qpos[qadr + 4], qpos[qadr + 5], qpos[qadr + 6]) - d.xanchor[worldid, jntadr] = xpos - d.xaxis[worldid, jntadr] = m.jnt_axis[jntadr] - else: - # regular or no joints - # apply fixed translation and rotation relative to parent - pid = m.body_parentid[bodyid] - xpos = (d.xmat[worldid, pid] * m.body_pos[bodyid]) + d.xpos[worldid, pid] - xquat = math.mul_quat(d.xquat[worldid, pid], m.body_quat[bodyid]) + @kernel + def _level(m: Model, d: Data, leveladr: int): + worldid, nodeid = wp.tid() + bodyid = m.body_tree[leveladr + nodeid] + jntadr = m.body_jntadr[bodyid] + jntnum = m.body_jntnum[bodyid] + qpos = d.qpos[worldid] - for _ in range(jntnum): + if jntnum == 0: + # no joints - apply fixed translation and rotation relative to parent + pid = m.body_parentid[bodyid] + xpos = (d.xmat[worldid, pid] * m.body_pos[bodyid]) + d.xpos[worldid, pid] + xquat = math.mul_quat(d.xquat[worldid, pid], m.body_quat[bodyid]) + elif jntnum == 1 and m.jnt_type[jntadr] == wp.static(JointType.FREE.value): + # free joint qadr = m.jnt_qposadr[jntadr] - jnt_type = m.jnt_type[jntadr] - jnt_axis = m.jnt_axis[jntadr] - xanchor = math.rot_vec_quat(m.jnt_pos[jntadr], xquat) + xpos - xaxis = math.rot_vec_quat(jnt_axis, xquat) - - if jnt_type == wp.static(JointType.BALL.value): - qloc = wp.quat( - qpos[qadr + 0], - qpos[qadr + 1], - qpos[qadr + 2], - qpos[qadr + 3], - ) - xquat = math.mul_quat(xquat, qloc) - # correct for off-center rotation - xpos = xanchor - math.rot_vec_quat(m.jnt_pos[jntadr], xquat) - elif jnt_type == wp.static(JointType.SLIDE.value): - xpos += xaxis * (qpos[qadr] - m.qpos0[qadr]) - elif jnt_type == wp.static(JointType.HINGE.value): - qpos0 = m.qpos0[qadr] - qloc = math.axis_angle_to_quat(jnt_axis, qpos[qadr] - qpos0) - xquat = math.mul_quat(xquat, qloc) - # correct for off-center rotation - xpos = xanchor - math.rot_vec_quat(m.jnt_pos[jntadr], xquat) - - d.xanchor[worldid, jntadr] = xanchor - d.xaxis[worldid, jntadr] = xaxis - jntadr += 1 - - d.xpos[worldid, bodyid] = xpos - xquat = wp.normalize(xquat) - d.xquat[worldid, bodyid] = xquat - d.xmat[worldid, bodyid] = math.quat_to_mat(xquat) - d.xipos[worldid, bodyid] = xpos + math.rot_vec_quat(m.body_ipos[bodyid], xquat) - d.ximat[worldid, bodyid] = math.quat_to_mat( - math.mul_quat(xquat, m.body_iquat[bodyid]) - ) + xpos = wp.vec3(qpos[qadr], qpos[qadr + 1], qpos[qadr + 2]) + xquat = wp.quat(qpos[qadr + 3], qpos[qadr + 4], qpos[qadr + 5], qpos[qadr + 6]) + d.xanchor[worldid, jntadr] = xpos + d.xaxis[worldid, jntadr] = m.jnt_axis[jntadr] + else: + # regular or no joints + # apply fixed translation and rotation relative to parent + pid = m.body_parentid[bodyid] + xpos = (d.xmat[worldid, pid] * m.body_pos[bodyid]) + d.xpos[worldid, pid] + xquat = math.mul_quat(d.xquat[worldid, pid], m.body_quat[bodyid]) + + for _ in range(jntnum): + qadr = m.jnt_qposadr[jntadr] + jnt_type = m.jnt_type[jntadr] + jnt_axis = m.jnt_axis[jntadr] + xanchor = math.rot_vec_quat(m.jnt_pos[jntadr], xquat) + xpos + xaxis = math.rot_vec_quat(jnt_axis, xquat) + + if jnt_type == wp.static(JointType.BALL.value): + qloc = wp.quat( + qpos[qadr + 0], + qpos[qadr + 1], + qpos[qadr + 2], + qpos[qadr + 3], + ) + xquat = math.mul_quat(xquat, qloc) + # correct for off-center rotation + xpos = xanchor - math.rot_vec_quat(m.jnt_pos[jntadr], xquat) + elif jnt_type == wp.static(JointType.SLIDE.value): + xpos += xaxis * (qpos[qadr] - m.qpos0[qadr]) + elif jnt_type == wp.static(JointType.HINGE.value): + qpos0 = m.qpos0[qadr] + qloc = math.axis_angle_to_quat(jnt_axis, qpos[qadr] - qpos0) + xquat = math.mul_quat(xquat, qloc) + # correct for off-center rotation + xpos = xanchor - math.rot_vec_quat(m.jnt_pos[jntadr], xquat) + + d.xanchor[worldid, jntadr] = xanchor + d.xaxis[worldid, jntadr] = xaxis + jntadr += 1 + + d.xpos[worldid, bodyid] = xpos + xquat = wp.normalize(xquat) + d.xquat[worldid, bodyid] = xquat + d.xmat[worldid, bodyid] = math.quat_to_mat(xquat) + d.xipos[worldid, bodyid] = xpos + math.rot_vec_quat(m.body_ipos[bodyid], xquat) + d.ximat[worldid, bodyid] = math.quat_to_mat( + math.mul_quat(xquat, m.body_iquat[bodyid]) + ) - @kernel - def geom_local_to_global(m: Model, d: Data): - worldid, geomid = wp.tid() - bodyid = m.geom_bodyid[geomid] - xpos = d.xpos[worldid, bodyid] - xquat = d.xquat[worldid, bodyid] - d.geom_xpos[worldid, geomid] = xpos + math.rot_vec_quat(m.geom_pos[geomid], xquat) - d.geom_xmat[worldid, geomid] = math.quat_to_mat( - math.mul_quat(xquat, m.geom_quat[geomid]) - ) + @kernel + def geom_local_to_global(m: Model, d: Data): + worldid, geomid = wp.tid() + bodyid = m.geom_bodyid[geomid] + xpos = d.xpos[worldid, bodyid] + xquat = d.xquat[worldid, bodyid] + d.geom_xpos[worldid, geomid] = xpos + math.rot_vec_quat(m.geom_pos[geomid], xquat) + d.geom_xmat[worldid, geomid] = math.quat_to_mat( + math.mul_quat(xquat, m.geom_quat[geomid]) + ) - @kernel - def site_local_to_global(m: Model, d: Data): - worldid, siteid = wp.tid() - bodyid = m.site_bodyid[siteid] - xpos = d.xpos[worldid, bodyid] - xquat = d.xquat[worldid, bodyid] - d.site_xpos[worldid, siteid] = xpos + math.rot_vec_quat(m.site_pos[siteid], xquat) - d.site_xmat[worldid, siteid] = math.quat_to_mat( - math.mul_quat(xquat, m.site_quat[siteid]) - ) + @kernel + def site_local_to_global(m: Model, d: Data): + worldid, siteid = wp.tid() + bodyid = m.site_bodyid[siteid] + xpos = d.xpos[worldid, bodyid] + xquat = d.xquat[worldid, bodyid] + d.site_xpos[worldid, siteid] = xpos + math.rot_vec_quat(m.site_pos[siteid], xquat) + d.site_xmat[worldid, siteid] = math.quat_to_mat( + math.mul_quat(xquat, m.site_quat[siteid]) + ) - wp.launch(_root, dim=(d.nworld), inputs=[m, d]) + wp.launch(_root, dim=(d.nworld), inputs=[m, d]) - body_treeadr = m.body_treeadr.numpy() - for i in range(1, len(body_treeadr)): - beg = body_treeadr[i] - end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(_level, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + body_treeadr = m.body_treeadr.numpy() + for i in range(1, len(body_treeadr)): + beg = body_treeadr[i] + end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] + wp.launch(_level, dim=(d.nworld, end - beg), inputs=[m, d, beg]) - if m.ngeom: - wp.launch(geom_local_to_global, dim=(d.nworld, m.ngeom), inputs=[m, d]) + if m.ngeom: + wp.launch(geom_local_to_global, dim=(d.nworld, m.ngeom), inputs=[m, d]) - if m.nsite: - wp.launch(site_local_to_global, dim=(d.nworld, m.nsite), inputs=[m, d]) + if m.nsite: + wp.launch(site_local_to_global, dim=(d.nworld, m.nsite), inputs=[m, d]) @event_scope def com_pos(m: Model, d: Data): """Map inertias and motion dofs to global frame centered at subtree-CoM.""" - @kernel - def subtree_com_init(m: Model, d: Data): - worldid, bodyid = wp.tid() - d.subtree_com[worldid, bodyid] = d.xipos[worldid, bodyid] * m.body_mass[bodyid] - - @kernel - def subtree_com_acc(m: Model, d: Data, leveladr: int): - worldid, nodeid = wp.tid() - bodyid = m.body_tree[leveladr + nodeid] - pid = m.body_parentid[bodyid] - wp.atomic_add(d.subtree_com, worldid, pid, d.subtree_com[worldid, bodyid]) - - @kernel - def subtree_div(m: Model, d: Data): - worldid, bodyid = wp.tid() - d.subtree_com[worldid, bodyid] /= m.subtree_mass[bodyid] + with wp.ScopedDevice(m.qpos0.device): - @kernel - def cinert(m: Model, d: Data): - worldid, bodyid = wp.tid() - mat = d.ximat[worldid, bodyid] - inert = m.body_inertia[bodyid] - mass = m.body_mass[bodyid] - dif = d.xipos[worldid, bodyid] - d.subtree_com[worldid, m.body_rootid[bodyid]] - # express inertia in com-based frame (mju_inertCom) - - res = vec10() - # res_rot = mat * diag(inert) * mat' - tmp = mat @ wp.diag(inert) @ wp.transpose(mat) - res[0] = tmp[0, 0] - res[1] = tmp[1, 1] - res[2] = tmp[2, 2] - res[3] = tmp[0, 1] - res[4] = tmp[0, 2] - res[5] = tmp[1, 2] - # res_rot -= mass * dif_cross * dif_cross - res[0] += mass * (dif[1] * dif[1] + dif[2] * dif[2]) - res[1] += mass * (dif[0] * dif[0] + dif[2] * dif[2]) - res[2] += mass * (dif[0] * dif[0] + dif[1] * dif[1]) - res[3] -= mass * dif[0] * dif[1] - res[4] -= mass * dif[0] * dif[2] - res[5] -= mass * dif[1] * dif[2] - # res_tran = mass * dif - res[6] = mass * dif[0] - res[7] = mass * dif[1] - res[8] = mass * dif[2] - # res_mass = mass - res[9] = mass - - d.cinert[worldid, bodyid] = res + @kernel + def subtree_com_init(m: Model, d: Data): + worldid, bodyid = wp.tid() + d.subtree_com[worldid, bodyid] = d.xipos[worldid, bodyid] * m.body_mass[bodyid] - @kernel - def cdof(m: Model, d: Data): - worldid, jntid = wp.tid() - bodyid = m.jnt_bodyid[jntid] - dofid = m.jnt_dofadr[jntid] - jnt_type = m.jnt_type[jntid] - xaxis = d.xaxis[worldid, jntid] - xmat = wp.transpose(d.xmat[worldid, bodyid]) - - # compute com-anchor vector - offset = d.subtree_com[worldid, m.body_rootid[bodyid]] - d.xanchor[worldid, jntid] - - res = d.cdof[worldid] - if jnt_type == wp.static(JointType.FREE.value): - res[dofid + 0] = wp.spatial_vector(0.0, 0.0, 0.0, 1.0, 0.0, 0.0) - res[dofid + 1] = wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 1.0, 0.0) - res[dofid + 2] = wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 0.0, 1.0) - # I_3 rotation in child frame (assume no subsequent rotations) - res[dofid + 3] = wp.spatial_vector(xmat[0], wp.cross(xmat[0], offset)) - res[dofid + 4] = wp.spatial_vector(xmat[1], wp.cross(xmat[1], offset)) - res[dofid + 5] = wp.spatial_vector(xmat[2], wp.cross(xmat[2], offset)) - elif jnt_type == wp.static(JointType.BALL.value): # ball - # I_3 rotation in child frame (assume no subsequent rotations) - res[dofid + 0] = wp.spatial_vector(xmat[0], wp.cross(xmat[0], offset)) - res[dofid + 1] = wp.spatial_vector(xmat[1], wp.cross(xmat[1], offset)) - res[dofid + 2] = wp.spatial_vector(xmat[2], wp.cross(xmat[2], offset)) - elif jnt_type == wp.static(JointType.SLIDE.value): - res[dofid] = wp.spatial_vector(wp.vec3(0.0), xaxis) - elif jnt_type == wp.static(JointType.HINGE.value): # hinge - res[dofid] = wp.spatial_vector(xaxis, wp.cross(xaxis, offset)) - - wp.launch(subtree_com_init, dim=(d.nworld, m.nbody), inputs=[m, d]) + @kernel + def subtree_com_acc(m: Model, d: Data, leveladr: int): + worldid, nodeid = wp.tid() + bodyid = m.body_tree[leveladr + nodeid] + pid = m.body_parentid[bodyid] + wp.atomic_add(d.subtree_com, worldid, pid, d.subtree_com[worldid, bodyid]) - body_treeadr = m.body_treeadr.numpy() + @kernel + def subtree_div(m: Model, d: Data): + worldid, bodyid = wp.tid() + d.subtree_com[worldid, bodyid] /= m.subtree_mass[bodyid] - for i in reversed(range(len(body_treeadr))): - beg = body_treeadr[i] - end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(subtree_com_acc, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + @kernel + def cinert(m: Model, d: Data): + worldid, bodyid = wp.tid() + mat = d.ximat[worldid, bodyid] + inert = m.body_inertia[bodyid] + mass = m.body_mass[bodyid] + dif = d.xipos[worldid, bodyid] - d.subtree_com[worldid, m.body_rootid[bodyid]] + # express inertia in com-based frame (mju_inertCom) + + res = vec10() + # res_rot = mat * diag(inert) * mat' + tmp = mat @ wp.diag(inert) @ wp.transpose(mat) + res[0] = tmp[0, 0] + res[1] = tmp[1, 1] + res[2] = tmp[2, 2] + res[3] = tmp[0, 1] + res[4] = tmp[0, 2] + res[5] = tmp[1, 2] + # res_rot -= mass * dif_cross * dif_cross + res[0] += mass * (dif[1] * dif[1] + dif[2] * dif[2]) + res[1] += mass * (dif[0] * dif[0] + dif[2] * dif[2]) + res[2] += mass * (dif[0] * dif[0] + dif[1] * dif[1]) + res[3] -= mass * dif[0] * dif[1] + res[4] -= mass * dif[0] * dif[2] + res[5] -= mass * dif[1] * dif[2] + # res_tran = mass * dif + res[6] = mass * dif[0] + res[7] = mass * dif[1] + res[8] = mass * dif[2] + # res_mass = mass + res[9] = mass + + d.cinert[worldid, bodyid] = res - wp.launch(subtree_div, dim=(d.nworld, m.nbody), inputs=[m, d]) - wp.launch(cinert, dim=(d.nworld, m.nbody), inputs=[m, d]) - wp.launch(cdof, dim=(d.nworld, m.njnt), inputs=[m, d]) + @kernel + def cdof(m: Model, d: Data): + worldid, jntid = wp.tid() + bodyid = m.jnt_bodyid[jntid] + dofid = m.jnt_dofadr[jntid] + jnt_type = m.jnt_type[jntid] + xaxis = d.xaxis[worldid, jntid] + xmat = wp.transpose(d.xmat[worldid, bodyid]) + + # compute com-anchor vector + offset = d.subtree_com[worldid, m.body_rootid[bodyid]] - d.xanchor[worldid, jntid] + + res = d.cdof[worldid] + if jnt_type == wp.static(JointType.FREE.value): + res[dofid + 0] = wp.spatial_vector(0.0, 0.0, 0.0, 1.0, 0.0, 0.0) + res[dofid + 1] = wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 1.0, 0.0) + res[dofid + 2] = wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 0.0, 1.0) + # I_3 rotation in child frame (assume no subsequent rotations) + res[dofid + 3] = wp.spatial_vector(xmat[0], wp.cross(xmat[0], offset)) + res[dofid + 4] = wp.spatial_vector(xmat[1], wp.cross(xmat[1], offset)) + res[dofid + 5] = wp.spatial_vector(xmat[2], wp.cross(xmat[2], offset)) + elif jnt_type == wp.static(JointType.BALL.value): # ball + # I_3 rotation in child frame (assume no subsequent rotations) + res[dofid + 0] = wp.spatial_vector(xmat[0], wp.cross(xmat[0], offset)) + res[dofid + 1] = wp.spatial_vector(xmat[1], wp.cross(xmat[1], offset)) + res[dofid + 2] = wp.spatial_vector(xmat[2], wp.cross(xmat[2], offset)) + elif jnt_type == wp.static(JointType.SLIDE.value): + res[dofid] = wp.spatial_vector(wp.vec3(0.0), xaxis) + elif jnt_type == wp.static(JointType.HINGE.value): # hinge + res[dofid] = wp.spatial_vector(xaxis, wp.cross(xaxis, offset)) + + wp.launch(subtree_com_init, dim=(d.nworld, m.nbody), inputs=[m, d]) + + body_treeadr = m.body_treeadr.numpy() + + for i in reversed(range(len(body_treeadr))): + beg = body_treeadr[i] + end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] + wp.launch(subtree_com_acc, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + + wp.launch(subtree_div, dim=(d.nworld, m.nbody), inputs=[m, d]) + wp.launch(cinert, dim=(d.nworld, m.nbody), inputs=[m, d]) + wp.launch(cdof, dim=(d.nworld, m.njnt), inputs=[m, d]) @event_scope def camlight(m: Model, d: Data): """Computes camera and light positions and orientations.""" - @kernel - def cam_local_to_global(m: Model, d: Data): - """Fixed cameras.""" - worldid, camid = wp.tid() - bodyid = m.cam_bodyid[camid] - xpos = d.xpos[worldid, bodyid] - xquat = d.xquat[worldid, bodyid] - d.cam_xpos[worldid, camid] = xpos + math.rot_vec_quat(m.cam_pos[camid], xquat) - d.cam_xmat[worldid, camid] = math.quat_to_mat( - math.mul_quat(xquat, m.cam_quat[camid]) - ) + with wp.ScopedDevice(m.qpos0.device): - @kernel - def cam_fn(m: Model, d: Data): - worldid, camid = wp.tid() - is_target_cam = (m.cam_mode[camid] == wp.static(CamLightType.TARGETBODY.value)) or ( - m.cam_mode[camid] == wp.static(CamLightType.TARGETBODYCOM.value) - ) - invalid_target = is_target_cam and (m.cam_targetbodyid[camid] < 0) - if invalid_target: - return - elif m.cam_mode[camid] == wp.static(CamLightType.TRACK.value): - body_xpos = d.xpos[worldid, m.cam_bodyid[camid]] - d.cam_xpos[worldid, camid] = body_xpos + m.cam_pos0[camid] - elif m.cam_mode[camid] == wp.static(CamLightType.TRACKCOM.value): - d.cam_xpos[worldid, camid] = ( - d.subtree_com[worldid, m.cam_bodyid[camid]] + m.cam_poscom0[camid] - ) - elif m.cam_mode[camid] == wp.static(CamLightType.TARGETBODY.value) or m.cam_mode[ - camid - ] == wp.static(CamLightType.TARGETBODYCOM.value): - pos = d.xpos[worldid, m.cam_targetbodyid[camid]] - if m.cam_mode[camid] == wp.static(CamLightType.TARGETBODYCOM.value): - pos = d.subtree_com[worldid, m.cam_targetbodyid[camid]] - # zaxis = -desired camera direction, in global frame - mat_3 = wp.normalize(d.cam_xpos[worldid, camid] - pos) - # xaxis: orthogonal to zaxis and to (0,0,1) - mat_1 = wp.normalize(wp.cross(wp.vec3(0.0, 0.0, 1.0), mat_3)) - mat_2 = wp.normalize(wp.cross(mat_3, mat_1)) - # fmt: off - d.cam_xmat[worldid, camid] = wp.mat33( - mat_1[0], mat_2[0], mat_3[0], - mat_1[1], mat_2[1], mat_3[1], - mat_1[2], mat_2[2], mat_3[2] + @kernel + def cam_local_to_global(m: Model, d: Data): + """Fixed cameras.""" + worldid, camid = wp.tid() + bodyid = m.cam_bodyid[camid] + xpos = d.xpos[worldid, bodyid] + xquat = d.xquat[worldid, bodyid] + d.cam_xpos[worldid, camid] = xpos + math.rot_vec_quat(m.cam_pos[camid], xquat) + d.cam_xmat[worldid, camid] = math.quat_to_mat( + math.mul_quat(xquat, m.cam_quat[camid]) ) - # fmt: on - @kernel - def light_local_to_global(m: Model, d: Data): - """Fixed lights.""" - worldid, lightid = wp.tid() - bodyid = m.light_bodyid[lightid] - xpos = d.xpos[worldid, bodyid] - xquat = d.xquat[worldid, bodyid] - d.light_xpos[worldid, lightid] = xpos + math.rot_vec_quat( - m.light_pos[lightid], xquat - ) - d.light_xdir[worldid, lightid] = math.rot_vec_quat(m.light_dir[lightid], xquat) + @kernel + def cam_fn(m: Model, d: Data): + worldid, camid = wp.tid() + is_target_cam = ( + m.cam_mode[camid] == wp.static(CamLightType.TARGETBODY.value) + ) or (m.cam_mode[camid] == wp.static(CamLightType.TARGETBODYCOM.value)) + invalid_target = is_target_cam and (m.cam_targetbodyid[camid] < 0) + if invalid_target: + return + elif m.cam_mode[camid] == wp.static(CamLightType.TRACK.value): + body_xpos = d.xpos[worldid, m.cam_bodyid[camid]] + d.cam_xpos[worldid, camid] = body_xpos + m.cam_pos0[camid] + elif m.cam_mode[camid] == wp.static(CamLightType.TRACKCOM.value): + d.cam_xpos[worldid, camid] = ( + d.subtree_com[worldid, m.cam_bodyid[camid]] + m.cam_poscom0[camid] + ) + elif m.cam_mode[camid] == wp.static(CamLightType.TARGETBODY.value) or m.cam_mode[ + camid + ] == wp.static(CamLightType.TARGETBODYCOM.value): + pos = d.xpos[worldid, m.cam_targetbodyid[camid]] + if m.cam_mode[camid] == wp.static(CamLightType.TARGETBODYCOM.value): + pos = d.subtree_com[worldid, m.cam_targetbodyid[camid]] + # zaxis = -desired camera direction, in global frame + mat_3 = wp.normalize(d.cam_xpos[worldid, camid] - pos) + # xaxis: orthogonal to zaxis and to (0,0,1) + mat_1 = wp.normalize(wp.cross(wp.vec3(0.0, 0.0, 1.0), mat_3)) + mat_2 = wp.normalize(wp.cross(mat_3, mat_1)) + # fmt: off + d.cam_xmat[worldid, camid] = wp.mat33( + mat_1[0], mat_2[0], mat_3[0], + mat_1[1], mat_2[1], mat_3[1], + mat_1[2], mat_2[2], mat_3[2] + ) + # fmt: on - @kernel - def light_fn(m: Model, d: Data): - worldid, lightid = wp.tid() - is_target_light = ( - m.light_mode[lightid] == wp.static(CamLightType.TARGETBODY.value) - ) or (m.light_mode[lightid] == wp.static(CamLightType.TARGETBODYCOM.value)) - invalid_target = is_target_light and (m.light_targetbodyid[lightid] < 0) - if invalid_target: - return - elif m.light_mode[lightid] == wp.static(CamLightType.TRACK.value): - body_xpos = d.xpos[worldid, m.light_bodyid[lightid]] - d.light_xpos[worldid, lightid] = body_xpos + m.light_pos0[lightid] - elif m.light_mode[lightid] == wp.static(CamLightType.TRACKCOM.value): - d.light_xpos[worldid, lightid] = ( - d.subtree_com[worldid, m.light_bodyid[lightid]] + m.light_poscom0[lightid] + @kernel + def light_local_to_global(m: Model, d: Data): + """Fixed lights.""" + worldid, lightid = wp.tid() + bodyid = m.light_bodyid[lightid] + xpos = d.xpos[worldid, bodyid] + xquat = d.xquat[worldid, bodyid] + d.light_xpos[worldid, lightid] = xpos + math.rot_vec_quat( + m.light_pos[lightid], xquat ) - elif m.light_mode[lightid] == wp.static( - CamLightType.TARGETBODY.value - ) or m.light_mode[lightid] == wp.static(CamLightType.TARGETBODYCOM.value): - pos = d.xpos[worldid, m.light_targetbodyid[lightid]] - if m.light_mode[lightid] == wp.static(CamLightType.TARGETBODYCOM.value): - pos = d.subtree_com[worldid, m.light_targetbodyid[lightid]] - d.light_xdir[worldid, lightid] = pos - d.light_xpos[worldid, lightid] - d.light_xdir[worldid, lightid] = wp.normalize(d.light_xdir[worldid, lightid]) - - if m.ncam > 0: - wp.launch(cam_local_to_global, dim=(d.nworld, m.ncam), inputs=[m, d]) - wp.launch(cam_fn, dim=(d.nworld, m.ncam), inputs=[m, d]) - if m.nlight > 0: - wp.launch(light_local_to_global, dim=(d.nworld, m.nlight), inputs=[m, d]) - wp.launch(light_fn, dim=(d.nworld, m.nlight), inputs=[m, d]) + d.light_xdir[worldid, lightid] = math.rot_vec_quat(m.light_dir[lightid], xquat) + + @kernel + def light_fn(m: Model, d: Data): + worldid, lightid = wp.tid() + is_target_light = ( + m.light_mode[lightid] == wp.static(CamLightType.TARGETBODY.value) + ) or (m.light_mode[lightid] == wp.static(CamLightType.TARGETBODYCOM.value)) + invalid_target = is_target_light and (m.light_targetbodyid[lightid] < 0) + if invalid_target: + return + elif m.light_mode[lightid] == wp.static(CamLightType.TRACK.value): + body_xpos = d.xpos[worldid, m.light_bodyid[lightid]] + d.light_xpos[worldid, lightid] = body_xpos + m.light_pos0[lightid] + elif m.light_mode[lightid] == wp.static(CamLightType.TRACKCOM.value): + d.light_xpos[worldid, lightid] = ( + d.subtree_com[worldid, m.light_bodyid[lightid]] + m.light_poscom0[lightid] + ) + elif m.light_mode[lightid] == wp.static( + CamLightType.TARGETBODY.value + ) or m.light_mode[lightid] == wp.static(CamLightType.TARGETBODYCOM.value): + pos = d.xpos[worldid, m.light_targetbodyid[lightid]] + if m.light_mode[lightid] == wp.static(CamLightType.TARGETBODYCOM.value): + pos = d.subtree_com[worldid, m.light_targetbodyid[lightid]] + d.light_xdir[worldid, lightid] = pos - d.light_xpos[worldid, lightid] + d.light_xdir[worldid, lightid] = wp.normalize(d.light_xdir[worldid, lightid]) + + if m.ncam > 0: + wp.launch(cam_local_to_global, dim=(d.nworld, m.ncam), inputs=[m, d]) + wp.launch(cam_fn, dim=(d.nworld, m.ncam), inputs=[m, d]) + if m.nlight > 0: + wp.launch(light_local_to_global, dim=(d.nworld, m.nlight), inputs=[m, d]) + wp.launch(light_fn, dim=(d.nworld, m.nlight), inputs=[m, d]) @event_scope def crb(m: Model, d: Data): """Composite rigid body inertia algorithm.""" - wp.copy(d.crb, d.cinert) + with wp.ScopedDevice(m.qpos0.device): + wp.copy(d.crb, d.cinert) - @kernel - def crb_accumulate(m: Model, d: Data, leveladr: int): - worldid, nodeid = wp.tid() - bodyid = m.body_tree[leveladr + nodeid] - pid = m.body_parentid[bodyid] - if pid == 0: - return - wp.atomic_add(d.crb, worldid, pid, d.crb[worldid, bodyid]) + @kernel + def crb_accumulate(m: Model, d: Data, leveladr: int): + worldid, nodeid = wp.tid() + bodyid = m.body_tree[leveladr + nodeid] + pid = m.body_parentid[bodyid] + if pid == 0: + return + wp.atomic_add(d.crb, worldid, pid, d.crb[worldid, bodyid]) - @kernel - def qM_sparse(m: Model, d: Data): - worldid, dofid = wp.tid() - madr_ij = m.dof_Madr[dofid] - bodyid = m.dof_bodyid[dofid] + @kernel + def qM_sparse(m: Model, d: Data): + worldid, dofid = wp.tid() + madr_ij = m.dof_Madr[dofid] + bodyid = m.dof_bodyid[dofid] - # init M(i,i) with armature inertia - d.qM[worldid, 0, madr_ij] = m.dof_armature[dofid] + # init M(i,i) with armature inertia + d.qM[worldid, 0, madr_ij] = m.dof_armature[dofid] - # precompute buf = crb_body_i * cdof_i - buf = math.inert_vec(d.crb[worldid, bodyid], d.cdof[worldid, dofid]) + # precompute buf = crb_body_i * cdof_i + buf = math.inert_vec(d.crb[worldid, bodyid], d.cdof[worldid, dofid]) - # sparse backward pass over ancestors - while dofid >= 0: - d.qM[worldid, 0, madr_ij] += wp.dot(d.cdof[worldid, dofid], buf) - madr_ij += 1 - dofid = m.dof_parentid[dofid] + # sparse backward pass over ancestors + while dofid >= 0: + d.qM[worldid, 0, madr_ij] += wp.dot(d.cdof[worldid, dofid], buf) + madr_ij += 1 + dofid = m.dof_parentid[dofid] - @kernel - def qM_dense(m: Model, d: Data): - worldid, dofid = wp.tid() - bodyid = m.dof_bodyid[dofid] + @kernel + def qM_dense(m: Model, d: Data): + worldid, dofid = wp.tid() + bodyid = m.dof_bodyid[dofid] - # init M(i,i) with armature inertia - M = m.dof_armature[dofid] + # init M(i,i) with armature inertia + M = m.dof_armature[dofid] - # precompute buf = crb_body_i * cdof_i - buf = math.inert_vec(d.crb[worldid, bodyid], d.cdof[worldid, dofid]) - M += wp.dot(d.cdof[worldid, dofid], buf) + # precompute buf = crb_body_i * cdof_i + buf = math.inert_vec(d.crb[worldid, bodyid], d.cdof[worldid, dofid]) + M += wp.dot(d.cdof[worldid, dofid], buf) - d.qM[worldid, dofid, dofid] = M + d.qM[worldid, dofid, dofid] = M - # sparse backward pass over ancestors - dofidi = dofid - dofid = m.dof_parentid[dofid] - while dofid >= 0: - qMij = wp.dot(d.cdof[worldid, dofid], buf) - d.qM[worldid, dofidi, dofid] += qMij - d.qM[worldid, dofid, dofidi] += qMij + # sparse backward pass over ancestors + dofidi = dofid dofid = m.dof_parentid[dofid] - - body_treeadr = m.body_treeadr.numpy() - for i in reversed(range(len(body_treeadr))): - beg = body_treeadr[i] - end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(crb_accumulate, dim=(d.nworld, end - beg), inputs=[m, d, beg]) - - d.qM.zero_() - if m.opt.is_sparse: - wp.launch(qM_sparse, dim=(d.nworld, m.nv), inputs=[m, d]) - else: - wp.launch(qM_dense, dim=(d.nworld, m.nv), inputs=[m, d]) + while dofid >= 0: + qMij = wp.dot(d.cdof[worldid, dofid], buf) + d.qM[worldid, dofidi, dofid] += qMij + d.qM[worldid, dofid, dofidi] += qMij + dofid = m.dof_parentid[dofid] + + body_treeadr = m.body_treeadr.numpy() + for i in reversed(range(len(body_treeadr))): + beg = body_treeadr[i] + end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] + wp.launch(crb_accumulate, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + + d.qM.zero_() + if m.opt.is_sparse: + wp.launch(qM_sparse, dim=(d.nworld, m.nv), inputs=[m, d]) + else: + wp.launch(qM_dense, dim=(d.nworld, m.nv), inputs=[m, d]) def _factor_i_sparse_legacy(m: Model, d: Data, M: array3df, L: array3df, D: array2df): @@ -543,7 +550,9 @@ def factor_i(m: Model, d: Data, M, L, D=None): @event_scope def factor_m(m: Model, d: Data): """Factorizaton of inertia-like matrix M, assumed spd.""" - factor_i(m, d, d.qM, d.qLD, d.qLDiagInv) + + with wp.ScopedDevice(m.qpos0.device): + factor_i(m, d, d.qM, d.qLD, d.qLDiagInv) def _rne_cacc_world(m: Model, d: Data): @@ -623,335 +632,346 @@ def _cfrc(m: Model, d: Data, leveladr: int): def rne(m: Model, d: Data, flg_acc: bool = False): """Computes inverse dynamics using Newton-Euler algorithm.""" - @kernel - def qfrc_bias(m: Model, d: Data): - worldid, dofid = wp.tid() - bodyid = m.dof_bodyid[dofid] - d.qfrc_bias[worldid, dofid] = wp.dot( - d.cdof[worldid, dofid], d.cfrc_int[worldid, bodyid] - ) + with wp.ScopedDevice(m.qpos0.device): - _rne_cacc_world(m, d) - _rne_cacc_forward(m, d, flg_acc=flg_acc) - _rne_cfrc(m, d) - _rne_cfrc_backward(m, d) + @kernel + def qfrc_bias(m: Model, d: Data): + worldid, dofid = wp.tid() + bodyid = m.dof_bodyid[dofid] + d.qfrc_bias[worldid, dofid] = wp.dot( + d.cdof[worldid, dofid], d.cfrc_int[worldid, bodyid] + ) - wp.launch(qfrc_bias, dim=[d.nworld, m.nv], inputs=[m, d]) + _rne_cacc_world(m, d) + _rne_cacc_forward(m, d, flg_acc=flg_acc) + _rne_cfrc(m, d) + _rne_cfrc_backward(m, d) + + wp.launch(qfrc_bias, dim=[d.nworld, m.nv], inputs=[m, d]) @event_scope def rne_postconstraint(m: Model, d: Data): """RNE with complete data: compute cacc, cfrc_ext, cfrc_int.""" - # cfrc_ext = perturb - @kernel - def _cfrc_ext(m: Model, d: Data): - worldid, bodyid = wp.tid() + with wp.ScopedDevice(m.qpos0.device): + # cfrc_ext = perturb + @kernel + def _cfrc_ext(m: Model, d: Data): + worldid, bodyid = wp.tid() - if bodyid == 0: - d.cfrc_ext[worldid, 0] = wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) - else: - xfrc_applied = d.xfrc_applied[worldid, bodyid] - subtree_com = d.subtree_com[worldid, m.body_rootid[bodyid]] - xipos = d.xipos[worldid, bodyid] - d.cfrc_ext[worldid, bodyid] = support.transform_force( - xfrc_applied, subtree_com - xipos - ) + if bodyid == 0: + d.cfrc_ext[worldid, 0] = wp.spatial_vector(0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + else: + xfrc_applied = d.xfrc_applied[worldid, bodyid] + subtree_com = d.subtree_com[worldid, m.body_rootid[bodyid]] + xipos = d.xipos[worldid, bodyid] + d.cfrc_ext[worldid, bodyid] = support.transform_force( + xfrc_applied, subtree_com - xipos + ) - wp.launch(_cfrc_ext, dim=(d.nworld, m.nbody), inputs=[m, d]) + wp.launch(_cfrc_ext, dim=(d.nworld, m.nbody), inputs=[m, d]) - @kernel - def _cfrc_ext_equality(m: Model, d: Data): - eqid = wp.tid() + @kernel + def _cfrc_ext_equality(m: Model, d: Data): + eqid = wp.tid() - ne_connect = d.ne_connect[0] - ne_weld = d.ne_weld[0] - num_connect = ne_connect // 3 + ne_connect = d.ne_connect[0] + ne_weld = d.ne_weld[0] + num_connect = ne_connect // 3 - if eqid >= num_connect + ne_weld // 6: - return + if eqid >= num_connect + ne_weld // 6: + return - is_connect = eqid < num_connect - if is_connect: - efcid = 3 * eqid - cfrc_torque = wp.vec3(0.0, 0.0, 0.0) # no torque from connect - else: - efcid = 6 * eqid - ne_connect - cfrc_torque = wp.vec3( - d.efc.force[efcid + 3], d.efc.force[efcid + 4], d.efc.force[efcid + 5] + is_connect = eqid < num_connect + if is_connect: + efcid = 3 * eqid + cfrc_torque = wp.vec3(0.0, 0.0, 0.0) # no torque from connect + else: + efcid = 6 * eqid - ne_connect + cfrc_torque = wp.vec3( + d.efc.force[efcid + 3], d.efc.force[efcid + 4], d.efc.force[efcid + 5] + ) + + cfrc_force = wp.vec3( + d.efc.force[efcid + 0], + d.efc.force[efcid + 1], + d.efc.force[efcid + 2], ) - cfrc_force = wp.vec3( - d.efc.force[efcid + 0], - d.efc.force[efcid + 1], - d.efc.force[efcid + 2], - ) - - worldid = d.efc.worldid[efcid] - id = d.efc.id[efcid] - eq_data = m.eq_data[id] - body_semantic = m.eq_objtype[id] == wp.static(ObjType.BODY.value) + worldid = d.efc.worldid[efcid] + id = d.efc.id[efcid] + eq_data = m.eq_data[id] + body_semantic = m.eq_objtype[id] == wp.static(ObjType.BODY.value) - obj1 = m.eq_obj1id[id] - obj2 = m.eq_obj2id[id] + obj1 = m.eq_obj1id[id] + obj2 = m.eq_obj2id[id] - if body_semantic: - bodyid1 = obj1 - bodyid2 = obj2 - else: - bodyid1 = m.site_bodyid[obj1] - bodyid2 = m.site_bodyid[obj2] - - # body 1 - if bodyid1: if body_semantic: - if is_connect: - offset = wp.vec3(eq_data[0], eq_data[1], eq_data[2]) - else: - offset = wp.vec3(eq_data[3], eq_data[4], eq_data[5]) + bodyid1 = obj1 + bodyid2 = obj2 else: - offset = m.site_pos[obj1] + bodyid1 = m.site_bodyid[obj1] + bodyid2 = m.site_bodyid[obj2] + + # body 1 + if bodyid1: + if body_semantic: + if is_connect: + offset = wp.vec3(eq_data[0], eq_data[1], eq_data[2]) + else: + offset = wp.vec3(eq_data[3], eq_data[4], eq_data[5]) + else: + offset = m.site_pos[obj1] - # transform point on body1: local -> global - pos = d.xmat[worldid, bodyid1] @ offset + d.xpos[worldid, bodyid1] + # transform point on body1: local -> global + pos = d.xmat[worldid, bodyid1] @ offset + d.xpos[worldid, bodyid1] - # subtree CoM-based torque_force vector - newpos = d.subtree_com[worldid, m.body_rootid[bodyid1]] + # subtree CoM-based torque_force vector + newpos = d.subtree_com[worldid, m.body_rootid[bodyid1]] - dif = newpos - pos - cfrc_com = wp.spatial_vector(cfrc_torque - wp.cross(dif, cfrc_force), cfrc_force) + dif = newpos - pos + cfrc_com = wp.spatial_vector( + cfrc_torque - wp.cross(dif, cfrc_force), cfrc_force + ) - # apply (opposite for body 1) - wp.atomic_add(d.cfrc_ext[worldid], bodyid1, cfrc_com) + # apply (opposite for body 1) + wp.atomic_add(d.cfrc_ext[worldid], bodyid1, cfrc_com) - # body 2 - if bodyid2: - if body_semantic: - if is_connect: - offset = wp.vec3(eq_data[3], eq_data[4], eq_data[5]) + # body 2 + if bodyid2: + if body_semantic: + if is_connect: + offset = wp.vec3(eq_data[3], eq_data[4], eq_data[5]) + else: + offset = wp.vec3(eq_data[0], eq_data[1], eq_data[2]) else: - offset = wp.vec3(eq_data[0], eq_data[1], eq_data[2]) - else: - offset = m.site_pos[obj2] + offset = m.site_pos[obj2] - # transform point on body2: local -> global - pos = d.xmat[worldid, bodyid2] @ offset + d.xpos[worldid, bodyid2] + # transform point on body2: local -> global + pos = d.xmat[worldid, bodyid2] @ offset + d.xpos[worldid, bodyid2] - # subtree CoM-based torque_force vector - newpos = d.subtree_com[worldid, m.body_rootid[bodyid2]] + # subtree CoM-based torque_force vector + newpos = d.subtree_com[worldid, m.body_rootid[bodyid2]] - dif = newpos - pos - cfrc_com = wp.spatial_vector(cfrc_torque - wp.cross(dif, cfrc_force), cfrc_force) + dif = newpos - pos + cfrc_com = wp.spatial_vector( + cfrc_torque - wp.cross(dif, cfrc_force), cfrc_force + ) - # apply - wp.atomic_sub(d.cfrc_ext[worldid], bodyid2, cfrc_com) + # apply + wp.atomic_sub(d.cfrc_ext[worldid], bodyid2, cfrc_com) - wp.launch(_cfrc_ext_equality, dim=(d.nworld * m.neq,), inputs=[m, d]) + wp.launch(_cfrc_ext_equality, dim=(d.nworld * m.neq,), inputs=[m, d]) - # cfrc_ext += contacts - @kernel - def _cfrc_ext_contact(m: Model, d: Data): - conid = wp.tid() + # cfrc_ext += contacts + @kernel + def _cfrc_ext_contact(m: Model, d: Data): + conid = wp.tid() - if conid >= d.ncon[0]: - return + if conid >= d.ncon[0]: + return - geom = d.contact.geom[conid] - id1 = m.geom_bodyid[geom[0]] - id2 = m.geom_bodyid[geom[1]] + geom = d.contact.geom[conid] + id1 = m.geom_bodyid[geom[0]] + id2 = m.geom_bodyid[geom[1]] - if id1 == 0 and id2 == 0: - return + if id1 == 0 and id2 == 0: + return - # contact force in world frame - force = support.contact_force(m, d, conid, to_world_frame=True) + # contact force in world frame + force = support.contact_force(m, d, conid, to_world_frame=True) - worldid = d.contact.worldid[conid] - pos = d.contact.pos[conid] + worldid = d.contact.worldid[conid] + pos = d.contact.pos[conid] - # contact force on bodies - if id1: - com1 = d.subtree_com[worldid, m.body_rootid[id1]] - wp.atomic_sub( - d.cfrc_ext[worldid], id1, support.transform_force(force, com1 - pos) - ) + # contact force on bodies + if id1: + com1 = d.subtree_com[worldid, m.body_rootid[id1]] + wp.atomic_sub( + d.cfrc_ext[worldid], id1, support.transform_force(force, com1 - pos) + ) - if id2: - com2 = d.subtree_com[worldid, m.body_rootid[id2]] - wp.atomic_add( - d.cfrc_ext[worldid], id2, support.transform_force(force, com2 - pos) - ) + if id2: + com2 = d.subtree_com[worldid, m.body_rootid[id2]] + wp.atomic_add( + d.cfrc_ext[worldid], id2, support.transform_force(force, com2 - pos) + ) - wp.launch(_cfrc_ext_contact, dim=(d.nconmax,), inputs=[m, d]) + wp.launch(_cfrc_ext_contact, dim=(d.nconmax,), inputs=[m, d]) - # forward pass over bodies: compute cacc, cfrc_int - _rne_cacc_world(m, d) - _rne_cacc_forward(m, d, flg_acc=True) + # forward pass over bodies: compute cacc, cfrc_int + _rne_cacc_world(m, d) + _rne_cacc_forward(m, d, flg_acc=True) - # cfrc_body = cinert * cacc + cvel x (cinert * cvel) - _rne_cfrc(m, d, flg_cfrc_ext=True) + # cfrc_body = cinert * cacc + cvel x (cinert * cvel) + _rne_cfrc(m, d, flg_cfrc_ext=True) - # backward pass over bodies: accumulate cfrc_int from children - _rne_cfrc_backward(m, d) + # backward pass over bodies: accumulate cfrc_int from children + _rne_cfrc_backward(m, d) @event_scope def transmission(m: Model, d: Data): """Computes actuator/transmission lengths and moments.""" - if not m.nu: - return d - @kernel - def _transmission( - m: Model, - d: Data, - # outputs - length: array2df, - moment: array3df, - ): - worldid, actid = wp.tid() - trntype = m.actuator_trntype[actid] - gear = m.actuator_gear[actid] - if trntype == wp.static(TrnType.JOINT.value) or trntype == wp.static( - TrnType.JOINTINPARENT.value + with wp.ScopedDevice(m.qpos0.device): + if not m.nu: + return d + + @kernel + def _transmission( + m: Model, + d: Data, + # outputs + length: array2df, + moment: array3df, ): - qpos = d.qpos[worldid] - jntid = m.actuator_trnid[actid, 0] - jnt_typ = m.jnt_type[jntid] - qadr = m.jnt_qposadr[jntid] - vadr = m.jnt_dofadr[jntid] - if jnt_typ == wp.static(JointType.FREE.value): - length[worldid, actid] = 0.0 - if trntype == wp.static(TrnType.JOINTINPARENT.value): - quat_neg = math.quat_inv( - wp.quat( - qpos[qadr + 3], - qpos[qadr + 4], - qpos[qadr + 5], - qpos[qadr + 6], + worldid, actid = wp.tid() + trntype = m.actuator_trntype[actid] + gear = m.actuator_gear[actid] + if trntype == wp.static(TrnType.JOINT.value) or trntype == wp.static( + TrnType.JOINTINPARENT.value + ): + qpos = d.qpos[worldid] + jntid = m.actuator_trnid[actid, 0] + jnt_typ = m.jnt_type[jntid] + qadr = m.jnt_qposadr[jntid] + vadr = m.jnt_dofadr[jntid] + if jnt_typ == wp.static(JointType.FREE.value): + length[worldid, actid] = 0.0 + if trntype == wp.static(TrnType.JOINTINPARENT.value): + quat_neg = math.quat_inv( + wp.quat( + qpos[qadr + 3], + qpos[qadr + 4], + qpos[qadr + 5], + qpos[qadr + 6], + ) ) - ) - gearaxis = math.rot_vec_quat(wp.spatial_bottom(gear), quat_neg) - moment[worldid, actid, vadr + 0] = gear[0] - moment[worldid, actid, vadr + 1] = gear[1] - moment[worldid, actid, vadr + 2] = gear[2] - moment[worldid, actid, vadr + 3] = gearaxis[0] - moment[worldid, actid, vadr + 4] = gearaxis[1] - moment[worldid, actid, vadr + 5] = gearaxis[2] + gearaxis = math.rot_vec_quat(wp.spatial_bottom(gear), quat_neg) + moment[worldid, actid, vadr + 0] = gear[0] + moment[worldid, actid, vadr + 1] = gear[1] + moment[worldid, actid, vadr + 2] = gear[2] + moment[worldid, actid, vadr + 3] = gearaxis[0] + moment[worldid, actid, vadr + 4] = gearaxis[1] + moment[worldid, actid, vadr + 5] = gearaxis[2] + else: + for i in range(6): + moment[worldid, actid, vadr + i] = gear[i] + elif jnt_typ == wp.static(JointType.BALL.value): + q = wp.quat(qpos[qadr + 0], qpos[qadr + 1], qpos[qadr + 2], qpos[qadr + 3]) + axis_angle = math.quat_to_vel(q) + gearaxis = wp.spatial_top(gear) # [:3] + if trntype == wp.static(TrnType.JOINTINPARENT.value): + quat_neg = math.quat_inv(q) + gearaxis = math.rot_vec_quat(gearaxis, quat_neg) + length[worldid, actid] = wp.dot(axis_angle, gearaxis) + for i in range(3): + moment[worldid, actid, vadr + i] = gearaxis[i] + elif jnt_typ == wp.static(JointType.SLIDE.value) or jnt_typ == wp.static( + JointType.HINGE.value + ): + length[worldid, actid] = qpos[qadr] * gear[0] + moment[worldid, actid, vadr] = gear[0] else: - for i in range(6): - moment[worldid, actid, vadr + i] = gear[i] - elif jnt_typ == wp.static(JointType.BALL.value): - q = wp.quat(qpos[qadr + 0], qpos[qadr + 1], qpos[qadr + 2], qpos[qadr + 3]) - axis_angle = math.quat_to_vel(q) - gearaxis = wp.spatial_top(gear) # [:3] - if trntype == wp.static(TrnType.JOINTINPARENT.value): - quat_neg = math.quat_inv(q) - gearaxis = math.rot_vec_quat(gearaxis, quat_neg) - length[worldid, actid] = wp.dot(axis_angle, gearaxis) - for i in range(3): - moment[worldid, actid, vadr + i] = gearaxis[i] - elif jnt_typ == wp.static(JointType.SLIDE.value) or jnt_typ == wp.static( - JointType.HINGE.value - ): - length[worldid, actid] = qpos[qadr] * gear[0] - moment[worldid, actid, vadr] = gear[0] + wp.printf("unrecognized joint type") + elif trntype == wp.static(TrnType.TENDON.value): + tenid = m.actuator_trnid[actid, 0] + + gear0 = gear[0] + length[worldid, actid] = d.ten_length[worldid, tenid] * gear0 + + # fixed + adr = m.tendon_adr[tenid] + if m.wrap_type[adr] == wp.static(WrapType.JOINT.value): + ten_num = m.tendon_num[tenid] + for i in range(ten_num): + dofadr = m.jnt_dofadr[m.wrap_objid[adr + i]] + moment[worldid, actid, dofadr] = d.ten_J[worldid, tenid, dofadr] * gear0 + # TODO(team): spatial else: - wp.printf("unrecognized joint type") - elif trntype == wp.static(TrnType.TENDON.value): - tenid = m.actuator_trnid[actid, 0] - - gear0 = gear[0] - length[worldid, actid] = d.ten_length[worldid, tenid] * gear0 - - # fixed - adr = m.tendon_adr[tenid] - if m.wrap_type[adr] == wp.static(WrapType.JOINT.value): - ten_num = m.tendon_num[tenid] - for i in range(ten_num): - dofadr = m.jnt_dofadr[m.wrap_objid[adr + i]] - moment[worldid, actid, dofadr] = d.ten_J[worldid, tenid, dofadr] * gear0 - # TODO(team): spatial - else: - # TODO(team): site, slidercrank, body - wp.printf("unhandled transmission type %d\n", trntype) - - wp.launch( - _transmission, - dim=[d.nworld, m.nu], - inputs=[m, d], - outputs=[d.actuator_length, d.actuator_moment], - ) + # TODO(team): site, slidercrank, body + wp.printf("unhandled transmission type %d\n", trntype) + + wp.launch( + _transmission, + dim=[d.nworld, m.nu], + inputs=[m, d], + outputs=[d.actuator_length, d.actuator_moment], + ) @event_scope def com_vel(m: Model, d: Data): """Computes cvel, cdof_dot.""" - @kernel - def _root(d: Data): - worldid, elementid = wp.tid() - d.cvel[worldid, 0][elementid] = 0.0 + with wp.ScopedDevice(m.qpos0.device): - @kernel - def _level(m: Model, d: Data, leveladr: int): - worldid, nodeid = wp.tid() - bodyid = m.body_tree[leveladr + nodeid] - dofid = m.body_dofadr[bodyid] - jntid = m.body_jntadr[bodyid] - jntnum = m.body_jntnum[bodyid] - pid = m.body_parentid[bodyid] + @kernel + def _root(d: Data): + worldid, elementid = wp.tid() + d.cvel[worldid, 0][elementid] = 0.0 - if jntnum == 0: - d.cvel[worldid, bodyid] = d.cvel[worldid, pid] - return + @kernel + def _level(m: Model, d: Data, leveladr: int): + worldid, nodeid = wp.tid() + bodyid = m.body_tree[leveladr + nodeid] + dofid = m.body_dofadr[bodyid] + jntid = m.body_jntadr[bodyid] + jntnum = m.body_jntnum[bodyid] + pid = m.body_parentid[bodyid] - cvel = d.cvel[worldid, pid] - qvel = d.qvel[worldid] - cdof = d.cdof[worldid] + if jntnum == 0: + d.cvel[worldid, bodyid] = d.cvel[worldid, pid] + return - for j in range(jntid, jntid + jntnum): - jnttype = m.jnt_type[j] + cvel = d.cvel[worldid, pid] + qvel = d.qvel[worldid] + cdof = d.cdof[worldid] - if jnttype == wp.static(JointType.FREE.value): - cvel += cdof[dofid + 0] * qvel[dofid + 0] - cvel += cdof[dofid + 1] * qvel[dofid + 1] - cvel += cdof[dofid + 2] * qvel[dofid + 2] + for j in range(jntid, jntid + jntnum): + jnttype = m.jnt_type[j] - d.cdof_dot[worldid, dofid + 3] = math.motion_cross(cvel, cdof[dofid + 3]) - d.cdof_dot[worldid, dofid + 4] = math.motion_cross(cvel, cdof[dofid + 4]) - d.cdof_dot[worldid, dofid + 5] = math.motion_cross(cvel, cdof[dofid + 5]) + if jnttype == wp.static(JointType.FREE.value): + cvel += cdof[dofid + 0] * qvel[dofid + 0] + cvel += cdof[dofid + 1] * qvel[dofid + 1] + cvel += cdof[dofid + 2] * qvel[dofid + 2] - cvel += cdof[dofid + 3] * qvel[dofid + 3] - cvel += cdof[dofid + 4] * qvel[dofid + 4] - cvel += cdof[dofid + 5] * qvel[dofid + 5] + d.cdof_dot[worldid, dofid + 3] = math.motion_cross(cvel, cdof[dofid + 3]) + d.cdof_dot[worldid, dofid + 4] = math.motion_cross(cvel, cdof[dofid + 4]) + d.cdof_dot[worldid, dofid + 5] = math.motion_cross(cvel, cdof[dofid + 5]) - dofid += 6 - elif jnttype == wp.static(JointType.BALL.value): - d.cdof_dot[worldid, dofid + 0] = math.motion_cross(cvel, cdof[dofid + 0]) - d.cdof_dot[worldid, dofid + 1] = math.motion_cross(cvel, cdof[dofid + 1]) - d.cdof_dot[worldid, dofid + 2] = math.motion_cross(cvel, cdof[dofid + 2]) + cvel += cdof[dofid + 3] * qvel[dofid + 3] + cvel += cdof[dofid + 4] * qvel[dofid + 4] + cvel += cdof[dofid + 5] * qvel[dofid + 5] - cvel += cdof[dofid + 0] * qvel[dofid + 0] - cvel += cdof[dofid + 1] * qvel[dofid + 1] - cvel += cdof[dofid + 2] * qvel[dofid + 2] + dofid += 6 + elif jnttype == wp.static(JointType.BALL.value): + d.cdof_dot[worldid, dofid + 0] = math.motion_cross(cvel, cdof[dofid + 0]) + d.cdof_dot[worldid, dofid + 1] = math.motion_cross(cvel, cdof[dofid + 1]) + d.cdof_dot[worldid, dofid + 2] = math.motion_cross(cvel, cdof[dofid + 2]) - dofid += 3 - else: - d.cdof_dot[worldid, dofid] = math.motion_cross(cvel, cdof[dofid]) - cvel += cdof[dofid] * qvel[dofid] + cvel += cdof[dofid + 0] * qvel[dofid + 0] + cvel += cdof[dofid + 1] * qvel[dofid + 1] + cvel += cdof[dofid + 2] * qvel[dofid + 2] - dofid += 1 + dofid += 3 + else: + d.cdof_dot[worldid, dofid] = math.motion_cross(cvel, cdof[dofid]) + cvel += cdof[dofid] * qvel[dofid] - d.cvel[worldid, bodyid] = cvel + dofid += 1 - wp.launch(_root, dim=(d.nworld, 6), inputs=[d]) + d.cvel[worldid, bodyid] = cvel - body_treeadr = m.body_treeadr.numpy() - for i in range(1, len(body_treeadr)): - beg = body_treeadr[i] - end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(_level, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + wp.launch(_root, dim=(d.nworld, 6), inputs=[d]) + + body_treeadr = m.body_treeadr.numpy() + for i in range(1, len(body_treeadr)): + beg = body_treeadr[i] + end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] + wp.launch(_level, dim=(d.nworld, end - beg), inputs=[m, d, beg]) def _solve_LD_sparse( @@ -1044,7 +1064,9 @@ def solve_LD(m: Model, d: Data, L: array3df, D: array2df, x: array2df, y: array2 @event_scope def solve_m(m: Model, d: Data, x: array2df, y: array2df): """Computes backsubstitution: x = qLD * y.""" - solve_LD(m, d, d.qLD, d.qLDiagInv, x, y) + + with wp.ScopedDevice(m.qpos0.device): + solve_LD(m, d, d.qLD, d.qLDiagInv, x, y) def _factor_solve_i_dense(m: Model, d: Data, M: array3df, x: array2df, y: array2df): @@ -1091,193 +1113,197 @@ def factor_solve_i(m, d, M, L, D, x, y): def subtree_vel(m: Model, d: Data): """Subtree linear velocity and angular momentum.""" - # bodywise quantities - @kernel - def _forward(m: Model, d: Data): - worldid, bodyid = wp.tid() + with wp.ScopedDevice(m.qpos0.device): + # bodywise quantities + @kernel + def _forward(m: Model, d: Data): + worldid, bodyid = wp.tid() - cvel = d.cvel[worldid, bodyid] - ang = wp.spatial_top(cvel) - lin = wp.spatial_bottom(cvel) - xipos = d.xipos[worldid, bodyid] - ximat = d.ximat[worldid, bodyid] - subtree_com_root = d.subtree_com[worldid, m.body_rootid[bodyid]] - - # update linear velocity - lin -= wp.cross(xipos - subtree_com_root, ang) - - d.subtree_linvel[worldid, bodyid] = m.body_mass[bodyid] * lin - dv = wp.transpose(ximat) @ ang - dv[0] *= m.body_inertia[bodyid][0] - dv[1] *= m.body_inertia[bodyid][1] - dv[2] *= m.body_inertia[bodyid][2] - d.subtree_angmom[worldid, bodyid] = ximat @ dv - d.subtree_bodyvel[worldid, bodyid] = wp.spatial_vector(ang, lin) - - wp.launch(_forward, dim=(d.nworld, m.nbody), inputs=[m, d]) - - # sum body linear momentum recursively up the kinematic tree - @kernel - def _linear_momentum(m: Model, d: Data, leveladr: int): - worldid, nodeid = wp.tid() - bodyid = m.body_tree[leveladr + nodeid] - if bodyid: - pid = m.body_parentid[bodyid] - wp.atomic_add(d.subtree_linvel[worldid], pid, d.subtree_linvel[worldid, bodyid]) - d.subtree_linvel[worldid, bodyid] /= wp.max(MJ_MINVAL, m.body_subtreemass[bodyid]) + cvel = d.cvel[worldid, bodyid] + ang = wp.spatial_top(cvel) + lin = wp.spatial_bottom(cvel) + xipos = d.xipos[worldid, bodyid] + ximat = d.ximat[worldid, bodyid] + subtree_com_root = d.subtree_com[worldid, m.body_rootid[bodyid]] - body_treeadr = m.body_treeadr.numpy() - for i in reversed(range(len(body_treeadr))): - beg = body_treeadr[i] - end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(_linear_momentum, dim=[d.nworld, end - beg], inputs=[m, d, beg]) + # update linear velocity + lin -= wp.cross(xipos - subtree_com_root, ang) - @kernel - def _angular_momentum(m: Model, d: Data, leveladr: int): - worldid, nodeid = wp.tid() - bodyid = m.body_tree[leveladr + nodeid] + d.subtree_linvel[worldid, bodyid] = m.body_mass[bodyid] * lin + dv = wp.transpose(ximat) @ ang + dv[0] *= m.body_inertia[bodyid][0] + dv[1] *= m.body_inertia[bodyid][1] + dv[2] *= m.body_inertia[bodyid][2] + d.subtree_angmom[worldid, bodyid] = ximat @ dv + d.subtree_bodyvel[worldid, bodyid] = wp.spatial_vector(ang, lin) - if bodyid == 0: - return + wp.launch(_forward, dim=(d.nworld, m.nbody), inputs=[m, d]) - pid = m.body_parentid[bodyid] + # sum body linear momentum recursively up the kinematic tree + @kernel + def _linear_momentum(m: Model, d: Data, leveladr: int): + worldid, nodeid = wp.tid() + bodyid = m.body_tree[leveladr + nodeid] + if bodyid: + pid = m.body_parentid[bodyid] + wp.atomic_add(d.subtree_linvel[worldid], pid, d.subtree_linvel[worldid, bodyid]) + d.subtree_linvel[worldid, bodyid] /= wp.max(MJ_MINVAL, m.body_subtreemass[bodyid]) + + body_treeadr = m.body_treeadr.numpy() + for i in reversed(range(len(body_treeadr))): + beg = body_treeadr[i] + end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] + wp.launch(_linear_momentum, dim=[d.nworld, end - beg], inputs=[m, d, beg]) - xipos = d.xipos[worldid, bodyid] - com = d.subtree_com[worldid, bodyid] - com_parent = d.subtree_com[worldid, pid] - vel = d.subtree_bodyvel[worldid, bodyid] - linvel = d.subtree_linvel[worldid, bodyid] - linvel_parent = d.subtree_linvel[worldid, pid] - mass = m.body_mass[bodyid] - subtreemass = m.body_subtreemass[bodyid] - - # momentum wrt body i - dx = xipos - com - dv = wp.spatial_bottom(vel) - linvel - dp = dv * mass - dL = wp.cross(dx, dp) - - # add to subtree i - d.subtree_angmom[worldid, bodyid] += dL - - # add to parent - wp.atomic_add(d.subtree_angmom[worldid], pid, d.subtree_angmom[worldid, bodyid]) - - # momentum wrt parent - dx = com - com_parent - dv = linvel - linvel_parent - dv *= subtreemass - dL = wp.cross(dx, dv) - wp.atomic_add(d.subtree_angmom[worldid], pid, dL) + @kernel + def _angular_momentum(m: Model, d: Data, leveladr: int): + worldid, nodeid = wp.tid() + bodyid = m.body_tree[leveladr + nodeid] - body_treeadr = m.body_treeadr.numpy() - for i in reversed(range(len(body_treeadr))): - beg = body_treeadr[i] - end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(_angular_momentum, dim=[d.nworld, end - beg], inputs=[m, d, beg]) + if bodyid == 0: + return + + pid = m.body_parentid[bodyid] + + xipos = d.xipos[worldid, bodyid] + com = d.subtree_com[worldid, bodyid] + com_parent = d.subtree_com[worldid, pid] + vel = d.subtree_bodyvel[worldid, bodyid] + linvel = d.subtree_linvel[worldid, bodyid] + linvel_parent = d.subtree_linvel[worldid, pid] + mass = m.body_mass[bodyid] + subtreemass = m.body_subtreemass[bodyid] + + # momentum wrt body i + dx = xipos - com + dv = wp.spatial_bottom(vel) - linvel + dp = dv * mass + dL = wp.cross(dx, dp) + + # add to subtree i + d.subtree_angmom[worldid, bodyid] += dL + + # add to parent + wp.atomic_add(d.subtree_angmom[worldid], pid, d.subtree_angmom[worldid, bodyid]) + + # momentum wrt parent + dx = com - com_parent + dv = linvel - linvel_parent + dv *= subtreemass + dL = wp.cross(dx, dv) + wp.atomic_add(d.subtree_angmom[worldid], pid, dL) + + body_treeadr = m.body_treeadr.numpy() + for i in reversed(range(len(body_treeadr))): + beg = body_treeadr[i] + end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] + wp.launch(_angular_momentum, dim=[d.nworld, end - beg], inputs=[m, d, beg]) def tendon(m: Model, d: Data): """Computes tendon lengths and moments.""" - if not m.ntendon: - return + with wp.ScopedDevice(m.qpos0.device): + if not m.ntendon: + return - d.ten_length.zero_() - d.ten_J.zero_() + d.ten_length.zero_() + d.ten_J.zero_() - # process joint tendons - if m.wrap_jnt_adr.size: + # process joint tendons + if m.wrap_jnt_adr.size: - @kernel - def _joint_tendon(m: Model, d: Data): - worldid, wrapid = wp.tid() + @kernel + def _joint_tendon(m: Model, d: Data): + worldid, wrapid = wp.tid() - tendon_jnt_adr = m.tendon_jnt_adr[wrapid] - wrap_jnt_adr = m.wrap_jnt_adr[wrapid] + tendon_jnt_adr = m.tendon_jnt_adr[wrapid] + wrap_jnt_adr = m.wrap_jnt_adr[wrapid] - wrap_objid = m.wrap_objid[wrap_jnt_adr] - prm = m.wrap_prm[wrap_jnt_adr] + wrap_objid = m.wrap_objid[wrap_jnt_adr] + prm = m.wrap_prm[wrap_jnt_adr] - # add to length - L = prm * d.qpos[worldid, m.jnt_qposadr[wrap_objid]] - # TODO(team): compare atomic_add and for loop - wp.atomic_add(d.ten_length[worldid], tendon_jnt_adr, L) + # add to length + L = prm * d.qpos[worldid, m.jnt_qposadr[wrap_objid]] + # TODO(team): compare atomic_add and for loop + wp.atomic_add(d.ten_length[worldid], tendon_jnt_adr, L) - # add to moment - d.ten_J[worldid, tendon_jnt_adr, m.jnt_dofadr[wrap_objid]] = prm + # add to moment + d.ten_J[worldid, tendon_jnt_adr, m.jnt_dofadr[wrap_objid]] = prm - wp.launch(_joint_tendon, dim=(d.nworld, m.wrap_jnt_adr.size), inputs=[m, d]) + wp.launch(_joint_tendon, dim=(d.nworld, m.wrap_jnt_adr.size), inputs=[m, d]) - # process spatial site tendons - if m.wrap_site_adr.size: - d.wrap_xpos.zero_() - d.wrap_obj.zero_() + # process spatial site tendons + if m.wrap_site_adr.size: + d.wrap_xpos.zero_() + d.wrap_obj.zero_() - N_SITE_PAIR = m.wrap_site_pair_adr.size + N_SITE_PAIR = m.wrap_site_pair_adr.size - @kernel - def _spatial_site_tendon(m: Model, d: Data): - worldid, elementid = wp.tid() - site_adr = m.wrap_site_adr[elementid] + @kernel + def _spatial_site_tendon(m: Model, d: Data): + worldid, elementid = wp.tid() + site_adr = m.wrap_site_adr[elementid] - site_xpos = d.site_xpos[worldid, m.wrap_objid[site_adr]] + site_xpos = d.site_xpos[worldid, m.wrap_objid[site_adr]] - rowid = elementid // 2 - colid = elementid % 2 - if colid == 0: - d.wrap_xpos[worldid, rowid][0] = site_xpos[0] - d.wrap_xpos[worldid, rowid][1] = site_xpos[1] - d.wrap_xpos[worldid, rowid][2] = site_xpos[2] - else: - d.wrap_xpos[worldid, rowid][3] = site_xpos[0] - d.wrap_xpos[worldid, rowid][4] = site_xpos[1] - d.wrap_xpos[worldid, rowid][5] = site_xpos[2] - - d.wrap_obj[worldid, rowid][colid] = -1 - - if elementid < N_SITE_PAIR: - # site pairs - site_pair_adr = m.wrap_site_pair_adr[elementid] - ten_adr = m.tendon_site_pair_adr[elementid] - - id0 = m.wrap_objid[site_pair_adr + 0] - id1 = m.wrap_objid[site_pair_adr + 1] - - pnt0 = d.site_xpos[worldid, id0] - pnt1 = d.site_xpos[worldid, id1] - dif = pnt1 - pnt0 - vec, length = math.normalize_with_norm(dif) - wp.atomic_add(d.ten_length[worldid], ten_adr, length) - - if length < MJ_MINVAL: - vec = wp.vec3(1.0, 0.0, 0.0) - - body0 = m.site_bodyid[id0] - body1 = m.site_bodyid[id1] - if body0 != body1: - for i in range(m.nv): - J = float(0.0) - jacp1, _ = support.jac(m, d, pnt0, body0, i, worldid) - jacp2, _ = support.jac(m, d, pnt1, body1, i, worldid) - dif = jacp2 - jacp1 - for xyz in range(3): - J += vec[xyz] * dif[xyz] - if J: - wp.atomic_add(d.ten_J[worldid, ten_adr], i, J) - - wp.launch(_spatial_site_tendon, dim=(d.nworld, m.wrap_site_adr.size), inputs=[m, d]) + rowid = elementid // 2 + colid = elementid % 2 + if colid == 0: + d.wrap_xpos[worldid, rowid][0] = site_xpos[0] + d.wrap_xpos[worldid, rowid][1] = site_xpos[1] + d.wrap_xpos[worldid, rowid][2] = site_xpos[2] + else: + d.wrap_xpos[worldid, rowid][3] = site_xpos[0] + d.wrap_xpos[worldid, rowid][4] = site_xpos[1] + d.wrap_xpos[worldid, rowid][5] = site_xpos[2] + + d.wrap_obj[worldid, rowid][colid] = -1 + + if elementid < N_SITE_PAIR: + # site pairs + site_pair_adr = m.wrap_site_pair_adr[elementid] + ten_adr = m.tendon_site_pair_adr[elementid] + + id0 = m.wrap_objid[site_pair_adr + 0] + id1 = m.wrap_objid[site_pair_adr + 1] + + pnt0 = d.site_xpos[worldid, id0] + pnt1 = d.site_xpos[worldid, id1] + dif = pnt1 - pnt0 + vec, length = math.normalize_with_norm(dif) + wp.atomic_add(d.ten_length[worldid], ten_adr, length) + + if length < MJ_MINVAL: + vec = wp.vec3(1.0, 0.0, 0.0) + + body0 = m.site_bodyid[id0] + body1 = m.site_bodyid[id1] + if body0 != body1: + for i in range(m.nv): + J = float(0.0) + jacp1, _ = support.jac(m, d, pnt0, body0, i, worldid) + jacp2, _ = support.jac(m, d, pnt1, body1, i, worldid) + dif = jacp2 - jacp1 + for xyz in range(3): + J += vec[xyz] * dif[xyz] + if J: + wp.atomic_add(d.ten_J[worldid, ten_adr], i, J) + + wp.launch( + _spatial_site_tendon, dim=(d.nworld, m.wrap_site_adr.size), inputs=[m, d] + ) - @kernel - def _spatial_tendon(m: Model, d: Data): - worldid, tenid = wp.tid() + @kernel + def _spatial_tendon(m: Model, d: Data): + worldid, tenid = wp.tid() - d.ten_wrapnum[worldid, tenid] = m.ten_wrapnum_site[tenid] - # TODO(team): geom wrap + d.ten_wrapnum[worldid, tenid] = m.ten_wrapnum_site[tenid] + # TODO(team): geom wrap - d.ten_wrapadr[worldid, tenid] = m.ten_wrapadr_site[tenid] - # TODO(team): geom wrap + d.ten_wrapadr[worldid, tenid] = m.ten_wrapadr_site[tenid] + # TODO(team): geom wrap - wp.launch(_spatial_tendon, dim=(d.nworld, m.ntendon), inputs=[m, d]) + wp.launch(_spatial_tendon, dim=(d.nworld, m.ntendon), inputs=[m, d]) - # TODO(team): geom wrap, pulleys + # TODO(team): geom wrap, pulleys diff --git a/mujoco_warp/_src/solver.py b/mujoco_warp/_src/solver.py index c0ce2d58..f8d0b2af 100644 --- a/mujoco_warp/_src/solver.py +++ b/mujoco_warp/_src/solver.py @@ -1533,129 +1533,132 @@ def _jaref(d: types.Data): @event_scope def solve(m: types.Model, d: types.Data): """Finds forces that satisfy constraints.""" - ITERATIONS = m.opt.iterations - - @kernel - def _zero_search_dot(d: types.Data): - worldid = wp.tid() - - if wp.static(m.opt.iterations) > 1: - if d.efc.done[worldid]: - return - - d.efc.search_dot[worldid] = 0.0 - - @kernel - def _search_update(d: types.Data): - worldid, dofid = wp.tid() - - if wp.static(m.opt.iterations) > 1: - if d.efc.done[worldid]: - return - - search = -1.0 * d.efc.Mgrad[worldid, dofid] - - if wp.static(m.opt.solver == types.SolverType.CG): - search += d.efc.beta[worldid] * d.efc.search[worldid, dofid] - - d.efc.search[worldid, dofid] = search - wp.atomic_add(d.efc.search_dot, worldid, search * search) - - @kernel - def _done(m: types.Model, d: types.Data, solver_niter: int): - # TODO(team): static m? - worldid = wp.tid() - - if ITERATIONS > 1: - if d.efc.done[worldid]: - return - - improvement = _rescale(m, d.efc.prev_cost[worldid] - d.efc.cost[worldid]) - gradient = _rescale(m, wp.math.sqrt(d.efc.grad_dot[worldid])) - d.efc.done[worldid] = (improvement < m.opt.tolerance) or ( - gradient < m.opt.tolerance - ) - - if m.opt.solver == types.SolverType.CG: - - @kernel - def _prev_grad_Mgrad(d: types.Data): - worldid, dofid = wp.tid() - - if wp.static(m.opt.iterations) > 1: - if d.efc.done[worldid]: - return - d.efc.prev_grad[worldid, dofid] = d.efc.grad[worldid, dofid] - d.efc.prev_Mgrad[worldid, dofid] = d.efc.Mgrad[worldid, dofid] + with wp.ScopedDevice(m.qpos0.device): + ITERATIONS = m.opt.iterations @kernel - def _zero_beta_num_den(d: types.Data): + def _zero_search_dot(d: types.Data): worldid = wp.tid() if wp.static(m.opt.iterations) > 1: if d.efc.done[worldid]: return - d.efc.beta_num[worldid] = 0.0 - d.efc.beta_den[worldid] = 0.0 + d.efc.search_dot[worldid] = 0.0 @kernel - def _beta_num_den(d: types.Data): + def _search_update(d: types.Data): worldid, dofid = wp.tid() if wp.static(m.opt.iterations) > 1: if d.efc.done[worldid]: return - prev_Mgrad = d.efc.prev_Mgrad[worldid][dofid] - wp.atomic_add( - d.efc.beta_num, - worldid, - d.efc.grad[worldid, dofid] * (d.efc.Mgrad[worldid, dofid] - prev_Mgrad), - ) - wp.atomic_add( - d.efc.beta_den, worldid, d.efc.prev_grad[worldid, dofid] * prev_Mgrad - ) + search = -1.0 * d.efc.Mgrad[worldid, dofid] + + if wp.static(m.opt.solver == types.SolverType.CG): + search += d.efc.beta[worldid] * d.efc.search[worldid, dofid] + + d.efc.search[worldid, dofid] = search + wp.atomic_add(d.efc.search_dot, worldid, search * search) @kernel - def _beta(d: types.Data): + def _done(m: types.Model, d: types.Data, solver_niter: int): + # TODO(team): static m? worldid = wp.tid() - if wp.static(m.opt.iterations) > 1: + if ITERATIONS > 1: if d.efc.done[worldid]: return - d.efc.beta[worldid] = wp.max( - 0.0, d.efc.beta_num[worldid] / wp.max(types.MJ_MINVAL, d.efc.beta_den[worldid]) + improvement = _rescale(m, d.efc.prev_cost[worldid] - d.efc.cost[worldid]) + gradient = _rescale(m, wp.math.sqrt(d.efc.grad_dot[worldid])) + d.efc.done[worldid] = (improvement < m.opt.tolerance) or ( + gradient < m.opt.tolerance ) - # warmstart - wp.copy(d.qacc, d.qacc_warmstart) + if m.opt.solver == types.SolverType.CG: - _create_context(m, d, grad=True) + @kernel + def _prev_grad_Mgrad(d: types.Data): + worldid, dofid = wp.tid() - for i in range(m.opt.iterations): - _linesearch(m, d) + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return - if m.opt.solver == types.SolverType.CG: - wp.launch(_prev_grad_Mgrad, dim=(d.nworld, m.nv), inputs=[d]) + d.efc.prev_grad[worldid, dofid] = d.efc.grad[worldid, dofid] + d.efc.prev_Mgrad[worldid, dofid] = d.efc.Mgrad[worldid, dofid] - _update_constraint(m, d) - _update_gradient(m, d) + @kernel + def _zero_beta_num_den(d: types.Data): + worldid = wp.tid() - # polak-ribiere - if m.opt.solver == types.SolverType.CG: - wp.launch(_zero_beta_num_den, dim=(d.nworld), inputs=[d]) + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + d.efc.beta_num[worldid] = 0.0 + d.efc.beta_den[worldid] = 0.0 + + @kernel + def _beta_num_den(d: types.Data): + worldid, dofid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + prev_Mgrad = d.efc.prev_Mgrad[worldid][dofid] + wp.atomic_add( + d.efc.beta_num, + worldid, + d.efc.grad[worldid, dofid] * (d.efc.Mgrad[worldid, dofid] - prev_Mgrad), + ) + wp.atomic_add( + d.efc.beta_den, worldid, d.efc.prev_grad[worldid, dofid] * prev_Mgrad + ) + + @kernel + def _beta(d: types.Data): + worldid = wp.tid() + + if wp.static(m.opt.iterations) > 1: + if d.efc.done[worldid]: + return + + d.efc.beta[worldid] = wp.max( + 0.0, + d.efc.beta_num[worldid] / wp.max(types.MJ_MINVAL, d.efc.beta_den[worldid]), + ) + + # warmstart + wp.copy(d.qacc, d.qacc_warmstart) + + _create_context(m, d, grad=True) + + for i in range(m.opt.iterations): + _linesearch(m, d) + + if m.opt.solver == types.SolverType.CG: + wp.launch(_prev_grad_Mgrad, dim=(d.nworld, m.nv), inputs=[d]) + + _update_constraint(m, d) + _update_gradient(m, d) + + # polak-ribiere + if m.opt.solver == types.SolverType.CG: + wp.launch(_zero_beta_num_den, dim=(d.nworld), inputs=[d]) - wp.launch(_beta_num_den, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_beta_num_den, dim=(d.nworld, m.nv), inputs=[d]) - wp.launch(_beta, dim=(d.nworld,), inputs=[d]) + wp.launch(_beta, dim=(d.nworld,), inputs=[d]) - wp.launch(_zero_search_dot, dim=(d.nworld), inputs=[d]) + wp.launch(_zero_search_dot, dim=(d.nworld), inputs=[d]) - wp.launch(_search_update, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_search_update, dim=(d.nworld, m.nv), inputs=[d]) - wp.launch(_done, dim=(d.nworld,), inputs=[m, d, i]) + wp.launch(_done, dim=(d.nworld,), inputs=[m, d, i]) - wp.copy(d.qacc_warmstart, d.qacc) + wp.copy(d.qacc_warmstart, d.qacc) diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py index 847bce8b..3fd72e59 100644 --- a/mujoco_warp/_src/support.py +++ b/mujoco_warp/_src/support.py @@ -46,136 +46,141 @@ def mul_m( ): """Multiply vector by inertia matrix.""" - if not m.opt.is_sparse: + with wp.ScopedDevice(m.qpos0.device): + if not m.opt.is_sparse: + + def tile_mul(adr: int, size: int, tilesize: int): + # TODO(team): speed up kernel compile time (14s on 2023 Macbook Pro) + @kernel + def mul( + m: Model, + d: Data, + leveladr: int, + res: array3df, + vec: array3df, + skip: wp.array(ndim=1, dtype=bool), + ): + worldid, nodeid = wp.tid() + + if skip[worldid]: + return + + dofid = m.qLD_tile[leveladr + nodeid] + qM_tile = wp.tile_load( + d.qM[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid) + ) + vec_tile = wp.tile_load(vec[worldid], shape=(tilesize, 1), offset=(dofid, 0)) + res_tile = wp.tile_zeros(shape=(tilesize, 1), dtype=wp.float32) + wp.tile_matmul(qM_tile, vec_tile, res_tile) + wp.tile_store(res[worldid], res_tile, offset=(dofid, 0)) + + wp.launch_tiled( + mul, + dim=(d.nworld, size), + inputs=[ + m, + d, + adr, + res.reshape(res.shape + (1,)), + vec.reshape(vec.shape + (1,)), + skip, + ], + # TODO(team): develop heuristic for block dim, or make configurable + block_dim=32, + ) + + qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() + + for i in range(len(qLD_tileadr)): + beg = qLD_tileadr[i] + end = m.qLD_tile.shape[0] if i == len(qLD_tileadr) - 1 else qLD_tileadr[i + 1] + tile_mul(beg, end - beg, int(qLD_tilesize[i])) + + else: - def tile_mul(adr: int, size: int, tilesize: int): - # TODO(team): speed up kernel compile time (14s on 2023 Macbook Pro) @kernel - def mul( + def _mul_m_sparse_diag( m: Model, d: Data, - leveladr: int, - res: array3df, - vec: array3df, + res: wp.array(ndim=2, dtype=wp.float32), + vec: wp.array(ndim=2, dtype=wp.float32), skip: wp.array(ndim=1, dtype=bool), ): - worldid, nodeid = wp.tid() + worldid, dofid = wp.tid() if skip[worldid]: return - dofid = m.qLD_tile[leveladr + nodeid] - qM_tile = wp.tile_load( - d.qM[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid) - ) - vec_tile = wp.tile_load(vec[worldid], shape=(tilesize, 1), offset=(dofid, 0)) - res_tile = wp.tile_zeros(shape=(tilesize, 1), dtype=wp.float32) - wp.tile_matmul(qM_tile, vec_tile, res_tile) - wp.tile_store(res[worldid], res_tile, offset=(dofid, 0)) - - wp.launch_tiled( - mul, - dim=(d.nworld, size), - inputs=[ - m, - d, - adr, - res.reshape(res.shape + (1,)), - vec.reshape(vec.shape + (1,)), - skip, - ], - # TODO(team): develop heuristic for block dim, or make configurable - block_dim=32, - ) - - qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() - - for i in range(len(qLD_tileadr)): - beg = qLD_tileadr[i] - end = m.qLD_tile.shape[0] if i == len(qLD_tileadr) - 1 else qLD_tileadr[i + 1] - tile_mul(beg, end - beg, int(qLD_tilesize[i])) - - else: - - @kernel - def _mul_m_sparse_diag( - m: Model, - d: Data, - res: wp.array(ndim=2, dtype=wp.float32), - vec: wp.array(ndim=2, dtype=wp.float32), - skip: wp.array(ndim=1, dtype=bool), - ): - worldid, dofid = wp.tid() - - if skip[worldid]: - return - - res[worldid, dofid] = d.qM[worldid, 0, m.dof_Madr[dofid]] * vec[worldid, dofid] + res[worldid, dofid] = d.qM[worldid, 0, m.dof_Madr[dofid]] * vec[worldid, dofid] - @kernel - def _mul_m_sparse_ij( - m: Model, - d: Data, - res: wp.array(ndim=2, dtype=wp.float32), - vec: wp.array(ndim=2, dtype=wp.float32), - skip: wp.array(ndim=1, dtype=bool), - ): - worldid, elementid = wp.tid() + @kernel + def _mul_m_sparse_ij( + m: Model, + d: Data, + res: wp.array(ndim=2, dtype=wp.float32), + vec: wp.array(ndim=2, dtype=wp.float32), + skip: wp.array(ndim=1, dtype=bool), + ): + worldid, elementid = wp.tid() - if skip[worldid]: - return + if skip[worldid]: + return - i = m.qM_mulm_i[elementid] - j = m.qM_mulm_j[elementid] - madr_ij = m.qM_madr_ij[elementid] + i = m.qM_mulm_i[elementid] + j = m.qM_mulm_j[elementid] + madr_ij = m.qM_madr_ij[elementid] - qM = d.qM[worldid, 0, madr_ij] + qM = d.qM[worldid, 0, madr_ij] - wp.atomic_add(res[worldid], i, qM * vec[worldid, j]) - wp.atomic_add(res[worldid], j, qM * vec[worldid, i]) + wp.atomic_add(res[worldid], i, qM * vec[worldid, j]) + wp.atomic_add(res[worldid], j, qM * vec[worldid, i]) - wp.launch(_mul_m_sparse_diag, dim=(d.nworld, m.nv), inputs=[m, d, res, vec, skip]) + wp.launch(_mul_m_sparse_diag, dim=(d.nworld, m.nv), inputs=[m, d, res, vec, skip]) - wp.launch( - _mul_m_sparse_ij, - dim=(d.nworld, m.qM_madr_ij.size), - inputs=[m, d, res, vec, skip], - ) + wp.launch( + _mul_m_sparse_ij, + dim=(d.nworld, m.qM_madr_ij.size), + inputs=[m, d, res, vec, skip], + ) @event_scope def xfrc_accumulate(m: Model, d: Data, qfrc: array2df): - @wp.kernel - def _accumulate(m: Model, d: Data, qfrc: array2df): - worldid, dofid = wp.tid() - cdof = d.cdof[worldid, dofid] - rotational_cdof = wp.vec3(cdof[0], cdof[1], cdof[2]) - jac = wp.spatial_vector(cdof[3], cdof[4], cdof[5], cdof[0], cdof[1], cdof[2]) - - dof_bodyid = m.dof_bodyid[dofid] - accumul = float(0.0) - - for bodyid in range(dof_bodyid, m.nbody): - # any body that is in the subtree of dof_bodyid is part of the jacobian - parentid = bodyid - while parentid != 0 and parentid != dof_bodyid: - parentid = m.body_parentid[parentid] - if parentid == 0: - continue # body is not part of the subtree - offset = d.xipos[worldid, bodyid] - d.subtree_com[worldid, m.body_rootid[bodyid]] - cross_term = wp.cross(rotational_cdof, offset) - accumul += wp.dot(jac, d.xfrc_applied[worldid, bodyid]) + wp.dot( - cross_term, - wp.vec3( - d.xfrc_applied[worldid, bodyid][0], - d.xfrc_applied[worldid, bodyid][1], - d.xfrc_applied[worldid, bodyid][2], - ), - ) + with wp.ScopedDevice(m.qpos0.device): + + @wp.kernel + def _accumulate(m: Model, d: Data, qfrc: array2df): + worldid, dofid = wp.tid() + cdof = d.cdof[worldid, dofid] + rotational_cdof = wp.vec3(cdof[0], cdof[1], cdof[2]) + jac = wp.spatial_vector(cdof[3], cdof[4], cdof[5], cdof[0], cdof[1], cdof[2]) + + dof_bodyid = m.dof_bodyid[dofid] + accumul = float(0.0) + + for bodyid in range(dof_bodyid, m.nbody): + # any body that is in the subtree of dof_bodyid is part of the jacobian + parentid = bodyid + while parentid != 0 and parentid != dof_bodyid: + parentid = m.body_parentid[parentid] + if parentid == 0: + continue # body is not part of the subtree + offset = ( + d.xipos[worldid, bodyid] - d.subtree_com[worldid, m.body_rootid[bodyid]] + ) + cross_term = wp.cross(rotational_cdof, offset) + accumul += wp.dot(jac, d.xfrc_applied[worldid, bodyid]) + wp.dot( + cross_term, + wp.vec3( + d.xfrc_applied[worldid, bodyid][0], + d.xfrc_applied[worldid, bodyid][1], + d.xfrc_applied[worldid, bodyid][2], + ), + ) - qfrc[worldid, dofid] += accumul + qfrc[worldid, dofid] += accumul - wp.launch(kernel=_accumulate, dim=(d.nworld, m.nv), inputs=[m, d, qfrc]) + wp.launch(kernel=_accumulate, dim=(d.nworld, m.nv), inputs=[m, d, qfrc]) @wp.func diff --git a/mujoco_warp/_src/test_util.py b/mujoco_warp/_src/test_util.py index 6d1fd397..c946a6c5 100644 --- a/mujoco_warp/_src/test_util.py +++ b/mujoco_warp/_src/test_util.py @@ -116,36 +116,38 @@ def benchmark( measure_alloc: bool = False, ) -> Tuple[float, float, dict, list, list]: """Benchmark a function of Model and Data.""" - jit_beg = time.perf_counter() - - fn(m, d) - - jit_end = time.perf_counter() - jit_duration = jit_end - jit_beg - wp.synchronize() - - trace = {} - ncon, nefc = [], [] - - with warp_util.EventTracer(enabled=event_trace) as tracer: - # capture the whole function as a CUDA graph - with wp.ScopedCapture() as capture: - fn(m, d) - graph = capture.graph - - run_beg = time.perf_counter() - for _ in range(nstep): - wp.capture_launch(graph) - if trace: - trace = _sum(trace, tracer.trace()) - else: - trace = tracer.trace() - if measure_alloc: - wp.synchronize() - ncon.append(d.ncon.numpy()[0]) - nefc.append(d.nefc.numpy()[0]) + + with wp.ScopedDevice(m.qpos0.device): + jit_beg = time.perf_counter() + + fn(m, d) + + jit_end = time.perf_counter() + jit_duration = jit_end - jit_beg wp.synchronize() - run_end = time.perf_counter() - run_duration = run_end - run_beg - return jit_duration, run_duration, trace, ncon, nefc + trace = {} + ncon, nefc = [], [] + + with warp_util.EventTracer(enabled=event_trace) as tracer: + # capture the whole function as a CUDA graph + with wp.ScopedCapture() as capture: + fn(m, d) + graph = capture.graph + + run_beg = time.perf_counter() + for _ in range(nstep): + wp.capture_launch(graph) + if trace: + trace = _sum(trace, tracer.trace()) + else: + trace = tracer.trace() + if measure_alloc: + wp.synchronize() + ncon.append(d.ncon.numpy()[0]) + nefc.append(d.nefc.numpy()[0]) + wp.synchronize() + run_end = time.perf_counter() + run_duration = run_end - run_beg + + return jit_duration, run_duration, trace, ncon, nefc