diff --git a/mujoco_warp/_src/collision_box.py b/mujoco_warp/_src/collision_box.py index a7ba6270..bff02cf1 100644 --- a/mujoco_warp/_src/collision_box.py +++ b/mujoco_warp/_src/collision_box.py @@ -565,4 +565,5 @@ def box_box_narrowphase( dim=num_threads, inputs=[m, d, num_threads], block_dim=BOX_BOX_BLOCK_DIM, + device=m.device, ) diff --git a/mujoco_warp/_src/collision_convex.py b/mujoco_warp/_src/collision_convex.py index 1e6df84c..beda82f0 100644 --- a/mujoco_warp/_src/collision_convex.py +++ b/mujoco_warp/_src/collision_convex.py @@ -777,4 +777,4 @@ def gjk_narrowphase(m: Model, d: Data): ) for collision_kernel in _collision_kernels.values(): - wp.launch(collision_kernel, dim=d.nconmax, inputs=[m, d]) + wp.launch(collision_kernel, dim=d.nconmax, inputs=[m, d], device=m.device) diff --git a/mujoco_warp/_src/collision_driver.py b/mujoco_warp/_src/collision_driver.py index 0b64bada..8c50b61f 100644 --- a/mujoco_warp/_src/collision_driver.py +++ b/mujoco_warp/_src/collision_driver.py @@ -360,6 +360,7 @@ def sap_broadphase(m: Model, d: Data): kernel=broadphase_project_spheres_onto_sweep_direction_kernel, dim=(d.nworld, m.ngeom), inputs=[m, d, direction], + device=m.device, ) tile_sort_available = False @@ -368,7 +369,11 @@ def sap_broadphase(m: Model, d: Data): if tile_sort_available: segmented_sort_kernel = create_segmented_sort_kernel(m.ngeom) wp.launch_tiled( - kernel=segmented_sort_kernel, dim=(d.nworld), inputs=[m, d], block_dim=128 + kernel=segmented_sort_kernel, + dim=(d.nworld), + inputs=[m, d], + block_dim=128, + device=m.device, ) print("tile sort available") elif segmented_sort_available: @@ -385,12 +390,10 @@ def sap_broadphase(m: Model, d: Data): # Create temporary arrays for sorting temp_box_projections_lower = wp.zeros( - m.ngeom * 2, - dtype=d.sap_projection_lower.dtype, + m.ngeom * 2, dtype=d.sap_projection_lower.dtype, device=m.device ) temp_box_sorting_indexer = wp.zeros( - m.ngeom * 2, - dtype=d.sap_sort_index.dtype, + m.ngeom * 2, dtype=d.sap_sort_index.dtype, device=m.device ) # Copy data to temporary arrays @@ -434,12 +437,14 @@ def sap_broadphase(m: Model, d: Data): kernel=reorder_bounding_spheres_kernel, dim=(d.nworld, m.ngeom), inputs=[m, d], + device=m.device, ) wp.launch( kernel=sap_broadphase_prepare_kernel, dim=(d.nworld, m.ngeom), inputs=[m, d], + device=m.device, ) # The scan (scan = cumulative sum, either inclusive or exclusive depending on the last argument) is used for load balancing among the threads @@ -452,6 +457,7 @@ def sap_broadphase(m: Model, d: Data): kernel=sap_broadphase_kernel, dim=num_sweep_threads, inputs=[m, d, num_sweep_threads, filter_parent], + device=m.device, ) return d @@ -510,7 +516,10 @@ def _nxn_broadphase(m: Model, d: Data): _add_geom_pair(m, d, geom1, geom2, worldid) wp.launch( - _nxn_broadphase, dim=(d.nworld, m.ngeom * (m.ngeom - 1) // 2), inputs=[m, d] + _nxn_broadphase, + dim=(d.nworld, m.ngeom * (m.ngeom - 1) // 2), + inputs=[m, d], + device=m.device, ) @@ -519,6 +528,7 @@ def get_contact_solver_params(m: Model, d: Data): get_contact_solver_params_kernel, dim=[d.nconmax], inputs=[m, d], + device=m.device, ) # TODO(team): do we need condim sorting, deepest penetrating contact here? diff --git a/mujoco_warp/_src/collision_primitive.py b/mujoco_warp/_src/collision_primitive.py index 6a997b34..779a4a5f 100644 --- a/mujoco_warp/_src/collision_primitive.py +++ b/mujoco_warp/_src/collision_primitive.py @@ -384,4 +384,4 @@ def _primitive_narrowphase( def primitive_narrowphase(m: Model, d: Data): # we need to figure out how to keep the overhead of this small - not launching anything # for pair types without collisions, as well as updating the launch dimensions. - wp.launch(_primitive_narrowphase, dim=d.nconmax, inputs=[m, d]) + wp.launch(_primitive_narrowphase, dim=d.nconmax, inputs=[m, d], device=m.device) diff --git a/mujoco_warp/_src/constraint.py b/mujoco_warp/_src/constraint.py index 206d8d50..8a93a5ac 100644 --- a/mujoco_warp/_src/constraint.py +++ b/mujoco_warp/_src/constraint.py @@ -350,6 +350,7 @@ def make_constraint(m: types.Model, d: types.Data): _efc_limit_slide_hinge, dim=(d.nworld, m.jnt_limited_slide_hinge_adr.size), inputs=[m, d, refsafe], + device=m.device, ) # contact @@ -359,9 +360,15 @@ def make_constraint(m: types.Model, d: types.Data): _efc_contact_pyramidal, dim=(d.nconmax, 2 * (m.condim_max - 1)), inputs=[m, d, refsafe], + device=m.device, ) elif m.opt.cone == types.ConeType.ELLIPTIC.value: - wp.launch(_efc_contact_elliptic, dim=(d.nconmax, 3), inputs=[m, d, refsafe]) + wp.launch( + _efc_contact_elliptic, + dim=(d.nconmax, 3), + inputs=[m, d, refsafe], + device=m.device, + ) # TODO(team): condim=4 # TODO(team): condim=6 diff --git a/mujoco_warp/_src/forward.py b/mujoco_warp/_src/forward.py index 03d4b288..45ecb600 100644 --- a/mujoco_warp/_src/forward.py +++ b/mujoco_warp/_src/forward.py @@ -171,9 +171,13 @@ def integrate_joint_positions(m: Model, d: Data, qvel_in: array2df): # skip if no stateful actuators. if m.na: - wp.launch(next_activation, dim=(d.nworld, m.nu), inputs=[m, d, act_dot]) + wp.launch( + next_activation, dim=(d.nworld, m.nu), inputs=[m, d, act_dot], device=m.device + ) - wp.launch(advance_velocities, dim=(d.nworld, m.nv), inputs=[m, d, qacc]) + wp.launch( + advance_velocities, dim=(d.nworld, m.nv), inputs=[m, d, qacc], device=m.device + ) # advance positions with qvel if given, d.qvel otherwise (semi-implicit) if qvel is not None: @@ -181,7 +185,12 @@ def integrate_joint_positions(m: Model, d: Data, qvel_in: array2df): else: qvel_in = d.qvel - wp.launch(integrate_joint_positions, dim=(d.nworld, m.njnt), inputs=[m, d, qvel_in]) + wp.launch( + integrate_joint_positions, + dim=(d.nworld, m.njnt), + inputs=[m, d, qvel_in], + device=m.device, + ) d.time = d.time + m.opt.timestep @@ -204,8 +213,13 @@ def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data): d.qfrc_smooth[worldid, tid] + d.qfrc_constraint[worldid, tid] ) - kernel_copy(d.qM_integration, d.qM) - wp.launch(add_damping_sum_qfrc_kernel_sparse, dim=(d.nworld, m.nv), inputs=[m, d]) + kernel_copy(d.qM_integration, d.qM, m.device) + wp.launch( + add_damping_sum_qfrc_kernel_sparse, + dim=(d.nworld, m.nv), + inputs=[m, d], + device=m.device, + ) smooth.factor_solve_i( m, d, @@ -245,7 +259,11 @@ def eulerdamp( wp.tile_store(d.qacc_integration[worldid], qacc_tile, offset=(dofid)) wp.launch_tiled( - eulerdamp, dim=(d.nworld, size), inputs=[m, d, m.dof_damping, adr], block_dim=32 + eulerdamp, + dim=(d.nworld, size), + inputs=[m, d, m.dof_damping, adr], + block_dim=32, + device=m.device, ) qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() @@ -290,8 +308,8 @@ def _act_dot(d: Data, b: float): worldId, tid = wp.tid() d.act_dot_rk[worldId, tid] += b * d.act_dot[worldId, tid] - wp.launch(_qvel_acc, dim=(d.nworld, m.nv), inputs=[d, b]) - wp.launch(_act_dot, dim=(d.nworld, m.na), inputs=[d, b]) + wp.launch(_qvel_acc, dim=(d.nworld, m.nv), inputs=[d, b], device=m.device) + wp.launch(_act_dot, dim=(d.nworld, m.na), inputs=[d, b], device=m.device) def perturb_state(m: Model, d: Data, a: float): @kernel @@ -312,9 +330,9 @@ def _qvel(m: Model, d: Data): dqacc = a * d.qacc[worldId, tid] d.qvel[worldId, tid] = d.qvel_t0[worldId, tid] + dqacc * m.opt.timestep - wp.launch(_qpos, dim=(d.nworld, m.njnt), inputs=[m, d]) - wp.launch(_act, dim=(d.nworld, m.na), inputs=[m, d]) - wp.launch(_qvel, dim=(d.nworld, m.nv), inputs=[m, d]) + wp.launch(_qpos, dim=(d.nworld, m.njnt), inputs=[m, d], device=m.device) + wp.launch(_act, dim=(d.nworld, m.na), inputs=[m, d], device=m.device) + wp.launch(_qvel, dim=(d.nworld, m.nv), inputs=[m, d], device=m.device) rk_accumulate(d, B[0]) for i in range(3): @@ -457,6 +475,7 @@ def qderiv_actuator_fused_kernel( dim=(d.nworld, size), inputs=[m, d, damping, adr], block_dim=block_dim, + device=m.device, ) qderiv_tilesize_nv = m.actuator_moment_tilesize_nv.numpy() @@ -479,6 +498,7 @@ def qderiv_actuator_fused_kernel( actuator_bias_gain_vel, dim=(d.nworld, m.nu), inputs=[m, d], + device=m.device, ) qderiv_actuator_damping_fused(m, d, m.dof_damping) @@ -522,7 +542,9 @@ def _actuator_velocity(d: Data): qvel = d.qvel[worldid] wp.atomic_add(d.actuator_velocity[worldid], actid, moment[dofid] * qvel[dofid]) - wp.launch(_actuator_velocity, dim=(d.nworld, m.nu, m.nv), inputs=[d]) + wp.launch( + _actuator_velocity, dim=(d.nworld, m.nu, m.nv), inputs=[d], device=m.device + ) else: def actuator_velocity( @@ -561,6 +583,7 @@ def _actuator_velocity( d.qvel.reshape(d.qvel.shape + (1,)), ], block_dim=32, + device=m.device, ) actuator_moment_tilesize_nu = m.actuator_moment_tilesize_nu.numpy() @@ -655,7 +678,13 @@ def _qfrc(m: Model, moment: array3df, force: array2df, qfrc: array2df): s = wp.clamp(s, r[0], r[1]) qfrc[worldid, vid] = s - wp.launch(_force, dim=[d.nworld, m.nu], inputs=[m, d], outputs=[d.actuator_force]) + wp.launch( + _force, + dim=[d.nworld, m.nu], + inputs=[m, d], + outputs=[d.actuator_force], + device=m.device, + ) if m.opt.is_sparse: # TODO(team): sparse version @@ -665,6 +694,7 @@ def _qfrc(m: Model, moment: array3df, force: array2df, qfrc: array2df): dim=(d.nworld, m.nv), inputs=[m, d.actuator_moment, d.actuator_force], outputs=[d.qfrc_actuator], + device=m.device, ) else: @@ -706,6 +736,7 @@ def qfrc_actuator_kernel( d.actuator_force.reshape(d.actuator_force.shape + (1,)), ], block_dim=32, + device=m.device, ) qderiv_tilesize_nu = m.actuator_moment_tilesize_nu.numpy() @@ -722,7 +753,7 @@ def qfrc_actuator_kernel( beg, end - beg, int(qderiv_tilesize_nu[i]), int(qderiv_tilesize_nv[i]) ) - wp.launch(_qfrc_limited, dim=(d.nworld, m.nv), inputs=[m, d]) + wp.launch(_qfrc_limited, dim=(d.nworld, m.nv), inputs=[m, d], device=m.device) # TODO actuator-level gravity compensation, skip if added as passive force @@ -741,7 +772,7 @@ def _qfrc_smooth(d: Data): + d.qfrc_applied[worldid, dofid] ) - wp.launch(_qfrc_smooth, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_qfrc_smooth, dim=(d.nworld, m.nv), inputs=[d], device=m.device) xfrc_accumulate(m, d, d.qfrc_smooth) smooth.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth) @@ -760,7 +791,7 @@ def forward(m: Model, d: Data): sensor.sensor_acc(m, d) if d.njmax == 0: - kernel_copy(d.qacc, d.qacc_smooth) + kernel_copy(d.qacc, d.qacc_smooth, m.device) else: solver.solve(m, d) diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index 315e3c5b..35f0fa66 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -24,7 +24,9 @@ from . import types -def put_model(mjm: mujoco.MjModel) -> types.Model: +def put_model( + mjm: mujoco.MjModel, device: Optional[wp.context.Device] = None +) -> types.Model: # check supported features for field, field_types, field_str in ( (mjm.actuator_trntype, types.TrnType, "Actuator transmission type"), @@ -72,398 +74,403 @@ def put_model(mjm: mujoco.MjModel) -> types.Model: m = types.Model() - m.nq = mjm.nq - m.nv = mjm.nv - m.na = mjm.na - m.nu = mjm.nu - m.nbody = mjm.nbody - m.njnt = mjm.njnt - m.ngeom = mjm.ngeom - m.nsite = mjm.nsite - m.ncam = mjm.ncam - m.nlight = mjm.nlight - m.nmocap = mjm.nmocap - m.nM = mjm.nM - m.ntendon = mjm.ntendon - m.nwrap = mjm.nwrap - m.nsensor = mjm.nsensor - m.nsensordata = mjm.nsensordata - m.nlsp = mjm.opt.ls_iterations # TODO(team): how to set nlsp? - m.nexclude = mjm.nexclude - m.opt.timestep = mjm.opt.timestep - m.opt.tolerance = mjm.opt.tolerance - m.opt.ls_tolerance = mjm.opt.ls_tolerance - m.opt.gravity = wp.vec3(mjm.opt.gravity) - m.opt.cone = mjm.opt.cone - m.opt.solver = mjm.opt.solver - m.opt.iterations = mjm.opt.iterations - m.opt.ls_iterations = mjm.opt.ls_iterations - m.opt.integrator = mjm.opt.integrator - m.opt.disableflags = mjm.opt.disableflags - m.opt.impratio = wp.float32(mjm.opt.impratio) - m.opt.is_sparse = support.is_sparse(mjm) - m.opt.ls_parallel = False - # TODO(team) Figure out good default parameters - m.opt.gjk_iteration_count = wp.int32(1) # warp only - m.opt.epa_iteration_count = wp.int32(12) # warp only - m.opt.epa_exact_neg_distance = wp.bool(False) # warp only - m.opt.depth_extension = wp.float32(0.1) # warp only - m.stat.meaninertia = mjm.stat.meaninertia - - m.qpos0 = wp.array(mjm.qpos0, dtype=wp.float32, ndim=1) - m.qpos_spring = wp.array(mjm.qpos_spring, dtype=wp.float32, ndim=1) - - # dof lower triangle row and column indices - dof_tri_row, dof_tri_col = np.tril_indices(mjm.nv) - - # indices for sparse qM full_m - is_, js = [], [] - for i in range(mjm.nv): - j = i - while j > -1: - is_.append(i) - js.append(j) - j = mjm.dof_parentid[j] - qM_fullm_i = is_ - qM_fullm_j = js - - # indices for sparse qM mul_m - is_, js, madr_ijs = [], [], [] - for i in range(mjm.nv): - madr_ij, j = mjm.dof_Madr[i], i - - while True: - madr_ij, j = madr_ij + 1, mjm.dof_parentid[j] - if j == -1: - break - is_, js, madr_ijs = is_ + [i], js + [j], madr_ijs + [madr_ij] - - qM_mulm_i, qM_mulm_j, qM_madr_ij = ( - np.array(x, dtype=np.int32) for x in (is_, js, madr_ijs) - ) - - jnt_limited_slide_hinge_adr = np.nonzero( - mjm.jnt_limited - & ( - (mjm.jnt_type == mujoco.mjtJoint.mjJNT_SLIDE) - | (mjm.jnt_type == mujoco.mjtJoint.mjJNT_HINGE) - ) - )[0] - - # body_tree is BFS ordering of body ids - # body_treeadr contains starting index of each body tree level - bodies, body_depth = {}, np.zeros(mjm.nbody, dtype=int) - 1 - for i in range(mjm.nbody): - body_depth[i] = body_depth[mjm.body_parentid[i]] + 1 - bodies.setdefault(body_depth[i], []).append(i) - body_tree = np.concatenate([bodies[i] for i in range(len(bodies))]) - tree_off = [0] + [len(bodies[i]) for i in range(len(bodies))] - body_treeadr = np.cumsum(tree_off)[:-1] - - m.body_tree = wp.array(body_tree, dtype=wp.int32, ndim=1) - m.body_treeadr = wp.array(body_treeadr, dtype=wp.int32, ndim=1, device="cpu") - - qLD_update_tree = np.empty(shape=(0, 3), dtype=int) - qLD_update_treeadr = np.empty(shape=(0,), dtype=int) - qLD_tile = np.empty(shape=(0,), dtype=int) - qLD_tileadr = np.empty(shape=(0,), dtype=int) - qLD_tilesize = np.empty(shape=(0,), dtype=int) - - if support.is_sparse(mjm): - # qLD_update_tree has dof tree ordering of qLD updates for sparse factor m - # qLD_update_treeadr contains starting index of each dof tree level - mjd = mujoco.MjData(mjm) - if version.parse(mujoco.__version__) > version.parse("3.2.7"): - m.M_rownnz = wp.array(mjd.M_rownnz, dtype=wp.int32, ndim=1) - m.M_rowadr = wp.array(mjd.M_rowadr, dtype=wp.int32, ndim=1) - m.M_colind = wp.array(mjd.M_colind, dtype=wp.int32, ndim=1) - m.mapM2M = wp.array(mjd.mapM2M, dtype=wp.int32, ndim=1) - qLD_updates, dof_depth = {}, np.zeros(mjm.nv, dtype=int) - 1 - - rownnz = mjd.M_rownnz - rowadr = mjd.M_rowadr - - for k in range(mjm.nv): - dof_depth[k] = dof_depth[mjm.dof_parentid[k]] + 1 - i = mjm.dof_parentid[k] - diag_k = rowadr[k] + rownnz[k] - 1 - Madr_ki = diag_k - 1 - while i > -1: - qLD_updates.setdefault(dof_depth[i], []).append((i, k, Madr_ki)) - i = mjm.dof_parentid[i] - Madr_ki -= 1 - - qLD_update_tree = np.concatenate( - [qLD_updates[i] for i in range(len(qLD_updates))] - ) - tree_off = [0] + [len(qLD_updates[i]) for i in range(len(qLD_updates))] - qLD_update_treeadr = np.cumsum(tree_off)[:-1] - else: - qLD_updates, dof_depth = {}, np.zeros(mjm.nv, dtype=int) - 1 - for k in range(mjm.nv): - dof_depth[k] = dof_depth[mjm.dof_parentid[k]] + 1 - i = mjm.dof_parentid[k] - Madr_ki = mjm.dof_Madr[k] + 1 - while i > -1: - qLD_updates.setdefault(dof_depth[i], []).append((i, k, Madr_ki)) - i = mjm.dof_parentid[i] - Madr_ki += 1 - - # qLD_treeadr contains starting indicies of each level of sparse updates - qLD_update_tree = np.concatenate( - [qLD_updates[i] for i in range(len(qLD_updates))] - ) - tree_off = [0] + [len(qLD_updates[i]) for i in range(len(qLD_updates))] - qLD_update_treeadr = np.cumsum(tree_off)[:-1] + with wp.ScopedDevice(device): + m.nq = mjm.nq + m.nv = mjm.nv + m.na = mjm.na + m.nu = mjm.nu + m.nbody = mjm.nbody + m.njnt = mjm.njnt + m.ngeom = mjm.ngeom + m.nsite = mjm.nsite + m.ncam = mjm.ncam + m.nlight = mjm.nlight + m.nmocap = mjm.nmocap + m.nM = mjm.nM + m.ntendon = mjm.ntendon + m.nwrap = mjm.nwrap + m.nsensor = mjm.nsensor + m.nsensordata = mjm.nsensordata + m.nlsp = mjm.opt.ls_iterations # TODO(team): how to set nlsp? + m.nexclude = mjm.nexclude + m.opt.timestep = mjm.opt.timestep + m.opt.tolerance = mjm.opt.tolerance + m.opt.ls_tolerance = mjm.opt.ls_tolerance + m.opt.gravity = wp.vec3(mjm.opt.gravity) + m.opt.cone = mjm.opt.cone + m.opt.solver = mjm.opt.solver + m.opt.iterations = mjm.opt.iterations + m.opt.ls_iterations = mjm.opt.ls_iterations + m.opt.integrator = mjm.opt.integrator + m.opt.disableflags = mjm.opt.disableflags + m.opt.impratio = wp.float32(mjm.opt.impratio) + m.opt.is_sparse = support.is_sparse(mjm) + m.opt.ls_parallel = False + # TODO(team) Figure out good default parameters + m.opt.gjk_iteration_count = wp.int32(1) # warp only + m.opt.epa_iteration_count = wp.int32(12) # warp only + m.opt.epa_exact_neg_distance = wp.bool(False) # warp only + m.opt.depth_extension = wp.float32(0.1) # warp only + m.stat.meaninertia = mjm.stat.meaninertia + + m.qpos0 = wp.array(mjm.qpos0, dtype=wp.float32, ndim=1) + m.qpos_spring = wp.array(mjm.qpos_spring, dtype=wp.float32, ndim=1) + + # dof lower triangle row and column indices + dof_tri_row, dof_tri_col = np.tril_indices(mjm.nv) + + # indices for sparse qM full_m + is_, js = [], [] + for i in range(mjm.nv): + j = i + while j > -1: + is_.append(i) + js.append(j) + j = mjm.dof_parentid[j] + qM_fullm_i = is_ + qM_fullm_j = js - else: - # qLD_tile has the dof id of each tile in qLD for dense factor m - # qLD_tileadr contains starting index in qLD_tile of each tile group - # qLD_tilesize has the square tile size of each tile group - tile_corners = [i for i in range(mjm.nv) if mjm.dof_parentid[i] == -1] - tiles = {} - for i in range(len(tile_corners)): - tile_beg = tile_corners[i] - tile_end = mjm.nv if i == len(tile_corners) - 1 else tile_corners[i + 1] - tiles.setdefault(tile_end - tile_beg, []).append(tile_beg) - qLD_tile = np.concatenate([tiles[sz] for sz in sorted(tiles.keys())]) - tile_off = [0] + [len(tiles[sz]) for sz in sorted(tiles.keys())] - qLD_tileadr = np.cumsum(tile_off)[:-1] - qLD_tilesize = np.array(sorted(tiles.keys())) - - # tiles for actuator_moment - needs nu + nv tile size and offset - actuator_moment_offset_nv = np.empty(shape=(0,), dtype=int) - actuator_moment_offset_nu = np.empty(shape=(0,), dtype=int) - actuator_moment_tileadr = np.empty(shape=(0,), dtype=int) - actuator_moment_tilesize_nv = np.empty(shape=(0,), dtype=int) - actuator_moment_tilesize_nu = np.empty(shape=(0,), dtype=int) - - if not support.is_sparse(mjm): - # how many actuators for each tree - tile_corners = [i for i in range(mjm.nv) if mjm.dof_parentid[i] == -1] - tree_id = mjm.dof_treeid[tile_corners] - num_trees = int(np.max(tree_id)) - tree = mjm.body_treeid[mjm.jnt_bodyid[mjm.actuator_trnid[:, 0]]] - counts, ids = np.histogram(tree, bins=np.arange(0, num_trees + 2)) - acts_per_tree = dict(zip([int(i) for i in ids], [int(i) for i in counts])) - - tiles = {} - act_beg = 0 - for i in range(len(tile_corners)): - tile_beg = tile_corners[i] - tile_end = mjm.nv if i == len(tile_corners) - 1 else tile_corners[i + 1] - tree = int(tree_id[i]) - act_num = acts_per_tree[tree] - tiles.setdefault((tile_end - tile_beg, act_num), []).append((tile_beg, act_beg)) - act_beg += act_num - - sorted_keys = sorted(tiles.keys()) - actuator_moment_offset_nv = [ - t[0] for key in sorted_keys for t in tiles.get(key, []) - ] - actuator_moment_offset_nu = [ - t[1] for key in sorted_keys for t in tiles.get(key, []) - ] - tile_off = [0] + [len(tiles[sz]) for sz in sorted(tiles.keys())] - actuator_moment_tileadr = np.cumsum(tile_off)[:-1] # offset - actuator_moment_tilesize_nv = np.array( - [a[0] for a in sorted_keys] - ) # for this level - actuator_moment_tilesize_nu = np.array( - [int(a[1]) for a in sorted_keys] - ) # for this level - - m.qM_fullm_i = wp.array(qM_fullm_i, dtype=wp.int32, ndim=1) - m.qM_fullm_j = wp.array(qM_fullm_j, dtype=wp.int32, ndim=1) - m.qM_mulm_i = wp.array(qM_mulm_i, dtype=wp.int32, ndim=1) - m.qM_mulm_j = wp.array(qM_mulm_j, dtype=wp.int32, ndim=1) - m.qM_madr_ij = wp.array(qM_madr_ij, dtype=wp.int32, ndim=1) - m.qLD_update_tree = wp.array(qLD_update_tree, dtype=wp.vec3i, ndim=1) - m.qLD_update_treeadr = wp.array( - qLD_update_treeadr, dtype=wp.int32, ndim=1, device="cpu" - ) - m.qLD_tile = wp.array(qLD_tile, dtype=wp.int32, ndim=1) - m.qLD_tileadr = wp.array(qLD_tileadr, dtype=wp.int32, ndim=1, device="cpu") - m.qLD_tilesize = wp.array(qLD_tilesize, dtype=wp.int32, ndim=1, device="cpu") - m.actuator_moment_offset_nv = wp.array( - actuator_moment_offset_nv, dtype=wp.int32, ndim=1 - ) - m.actuator_moment_offset_nu = wp.array( - actuator_moment_offset_nu, dtype=wp.int32, ndim=1 - ) - m.actuator_moment_tileadr = wp.array( - actuator_moment_tileadr, dtype=wp.int32, ndim=1, device="cpu" - ) - m.actuator_moment_tilesize_nv = wp.array( - actuator_moment_tilesize_nv, dtype=wp.int32, ndim=1, device="cpu" - ) - m.actuator_moment_tilesize_nu = wp.array( - actuator_moment_tilesize_nu, dtype=wp.int32, ndim=1, device="cpu" - ) - m.alpha_candidate = wp.array(np.linspace(0.0, 1.0, m.nlsp), dtype=wp.float32) - m.body_dofadr = wp.array(mjm.body_dofadr, dtype=wp.int32, ndim=1) - m.body_dofnum = wp.array(mjm.body_dofnum, dtype=wp.int32, ndim=1) - m.body_jntadr = wp.array(mjm.body_jntadr, dtype=wp.int32, ndim=1) - m.body_jntnum = wp.array(mjm.body_jntnum, dtype=wp.int32, ndim=1) - m.body_parentid = wp.array(mjm.body_parentid, dtype=wp.int32, ndim=1) - m.body_mocapid = wp.array(mjm.body_mocapid, dtype=wp.int32, ndim=1) - m.body_weldid = wp.array(mjm.body_weldid, dtype=wp.int32, ndim=1) - m.body_pos = wp.array(mjm.body_pos, dtype=wp.vec3, ndim=1) - m.body_quat = wp.array(mjm.body_quat, dtype=wp.quat, ndim=1) - m.body_ipos = wp.array(mjm.body_ipos, dtype=wp.vec3, ndim=1) - m.body_iquat = wp.array(mjm.body_iquat, dtype=wp.quat, ndim=1) - m.body_rootid = wp.array(mjm.body_rootid, dtype=wp.int32, ndim=1) - m.body_inertia = wp.array(mjm.body_inertia, dtype=wp.vec3, ndim=1) - m.body_mass = wp.array(mjm.body_mass, dtype=wp.float32, ndim=1) - - subtree_mass = np.copy(mjm.body_mass) - # TODO(team): should this be [mjm.nbody - 1, 0) ? - for i in range(mjm.nbody - 1, -1, -1): - subtree_mass[mjm.body_parentid[i]] += subtree_mass[i] - - m.subtree_mass = wp.array(subtree_mass, dtype=wp.float32, ndim=1) - m.body_invweight0 = wp.array(mjm.body_invweight0, dtype=wp.float32, ndim=2) - m.body_geomnum = wp.array(mjm.body_geomnum, dtype=wp.int32, ndim=1) - m.body_geomadr = wp.array(mjm.body_geomadr, dtype=wp.int32, ndim=1) - m.body_contype = wp.array(mjm.body_contype, dtype=wp.int32, ndim=1) - m.body_conaffinity = wp.array(mjm.body_conaffinity, dtype=wp.int32, ndim=1) - m.jnt_bodyid = wp.array(mjm.jnt_bodyid, dtype=wp.int32, ndim=1) - m.jnt_limited = wp.array(mjm.jnt_limited, dtype=wp.int32, ndim=1) - m.jnt_limited_slide_hinge_adr = wp.array( - jnt_limited_slide_hinge_adr, dtype=wp.int32, ndim=1 - ) - m.jnt_type = wp.array(mjm.jnt_type, dtype=wp.int32, ndim=1) - m.jnt_solref = wp.array(mjm.jnt_solref, dtype=wp.vec2f, ndim=1) - m.jnt_solimp = wp.array(mjm.jnt_solimp, dtype=types.vec5, ndim=1) - m.jnt_qposadr = wp.array(mjm.jnt_qposadr, dtype=wp.int32, ndim=1) - m.jnt_dofadr = wp.array(mjm.jnt_dofadr, dtype=wp.int32, ndim=1) - m.jnt_axis = wp.array(mjm.jnt_axis, dtype=wp.vec3, ndim=1) - m.jnt_pos = wp.array(mjm.jnt_pos, dtype=wp.vec3, ndim=1) - m.jnt_range = wp.array(mjm.jnt_range, dtype=wp.float32, ndim=2) - m.jnt_margin = wp.array(mjm.jnt_margin, dtype=wp.float32, ndim=1) - m.jnt_stiffness = wp.array(mjm.jnt_stiffness, dtype=wp.float32, ndim=1) - m.jnt_actfrclimited = wp.array(mjm.jnt_actfrclimited, dtype=wp.bool, ndim=1) - m.jnt_actfrcrange = wp.array(mjm.jnt_actfrcrange, dtype=wp.vec2, ndim=1) - m.geom_type = wp.array(mjm.geom_type, dtype=wp.int32, ndim=1) - m.geom_bodyid = wp.array(mjm.geom_bodyid, dtype=wp.int32, ndim=1) - m.geom_conaffinity = wp.array(mjm.geom_conaffinity, dtype=wp.int32, ndim=1) - m.geom_contype = wp.array(mjm.geom_contype, dtype=wp.int32, ndim=1) - m.geom_condim = wp.array(mjm.geom_condim, dtype=wp.int32, ndim=1) - m.geom_pos = wp.array(mjm.geom_pos, dtype=wp.vec3, ndim=1) - m.geom_quat = wp.array(mjm.geom_quat, dtype=wp.quat, ndim=1) - m.geom_size = wp.array(mjm.geom_size, dtype=wp.vec3, ndim=1) - m.geom_priority = wp.array(mjm.geom_priority, dtype=wp.int32, ndim=1) - m.geom_solmix = wp.array(mjm.geom_solmix, dtype=wp.float32, ndim=1) - m.geom_solref = wp.array(mjm.geom_solref, dtype=wp.vec2, ndim=1) - m.geom_solimp = wp.array(mjm.geom_solimp, dtype=types.vec5, ndim=1) - m.geom_friction = wp.array(mjm.geom_friction, dtype=wp.vec3, ndim=1) - m.geom_margin = wp.array(mjm.geom_margin, dtype=wp.float32, ndim=1) - m.geom_gap = wp.array(mjm.geom_gap, dtype=wp.float32, ndim=1) - m.geom_aabb = wp.array(mjm.geom_aabb, dtype=wp.vec3, ndim=3) - m.geom_rbound = wp.array(mjm.geom_rbound, dtype=wp.float32, ndim=1) - m.geom_dataid = wp.array(mjm.geom_dataid, dtype=wp.int32, ndim=1) - m.mesh_vertadr = wp.array(mjm.mesh_vertadr, dtype=wp.int32, ndim=1) - m.mesh_vertnum = wp.array(mjm.mesh_vertnum, dtype=wp.int32, ndim=1) - m.mesh_vert = wp.array(mjm.mesh_vert, dtype=wp.vec3, ndim=1) - m.site_pos = wp.array(mjm.site_pos, dtype=wp.vec3, ndim=1) - m.site_quat = wp.array(mjm.site_quat, dtype=wp.quat, ndim=1) - m.site_bodyid = wp.array(mjm.site_bodyid, dtype=wp.int32, ndim=1) - m.cam_mode = wp.array(mjm.cam_mode, dtype=wp.int32, ndim=1) - m.cam_bodyid = wp.array(mjm.cam_bodyid, dtype=wp.int32, ndim=1) - m.cam_targetbodyid = wp.array(mjm.cam_targetbodyid, dtype=wp.int32, ndim=1) - m.cam_pos = wp.array(mjm.cam_pos, dtype=wp.vec3, ndim=1) - m.cam_quat = wp.array(mjm.cam_quat, dtype=wp.quat, ndim=1) - m.cam_poscom0 = wp.array(mjm.cam_poscom0, dtype=wp.vec3, ndim=1) - m.cam_pos0 = wp.array(mjm.cam_pos0, dtype=wp.vec3, ndim=1) - m.cam_mat0 = wp.array(mjm.cam_mat0.reshape(-1, 3, 3), dtype=wp.mat33, ndim=1) - m.light_mode = wp.array(mjm.light_mode, dtype=wp.int32, ndim=1) - m.light_bodyid = wp.array(mjm.light_bodyid, dtype=wp.int32, ndim=1) - m.light_targetbodyid = wp.array(mjm.light_targetbodyid, dtype=wp.int32, ndim=1) - m.light_pos = wp.array(mjm.light_pos, dtype=wp.vec3, ndim=1) - m.light_dir = wp.array(mjm.light_dir, dtype=wp.vec3, ndim=1) - m.light_poscom0 = wp.array(mjm.light_poscom0, dtype=wp.vec3, ndim=1) - m.light_pos0 = wp.array(mjm.light_pos0, dtype=wp.vec3, ndim=1) - m.light_dir0 = wp.array(mjm.light_dir0, dtype=wp.vec3, ndim=1) - m.dof_bodyid = wp.array(mjm.dof_bodyid, dtype=wp.int32, ndim=1) - m.dof_jntid = wp.array(mjm.dof_jntid, dtype=wp.int32, ndim=1) - m.dof_parentid = wp.array(mjm.dof_parentid, dtype=wp.int32, ndim=1) - m.dof_Madr = wp.array(mjm.dof_Madr, dtype=wp.int32, ndim=1) - m.dof_armature = wp.array(mjm.dof_armature, dtype=wp.float32, ndim=1) - m.dof_damping = wp.array(mjm.dof_damping, dtype=wp.float32, ndim=1) - m.dof_tri_row = wp.from_numpy(dof_tri_row, dtype=wp.int32) - m.dof_tri_col = wp.from_numpy(dof_tri_col, dtype=wp.int32) - m.dof_invweight0 = wp.array(mjm.dof_invweight0, dtype=wp.float32, ndim=1) - m.actuator_trntype = wp.array(mjm.actuator_trntype, dtype=wp.int32, ndim=1) - m.actuator_trnid = wp.array(mjm.actuator_trnid, dtype=wp.int32, ndim=2) - m.actuator_ctrllimited = wp.array(mjm.actuator_ctrllimited, dtype=wp.bool, ndim=1) - m.actuator_ctrlrange = wp.array(mjm.actuator_ctrlrange, dtype=wp.vec2, ndim=1) - m.actuator_forcelimited = wp.array(mjm.actuator_forcelimited, dtype=wp.bool, ndim=1) - m.actuator_forcerange = wp.array(mjm.actuator_forcerange, dtype=wp.vec2, ndim=1) - m.actuator_gaintype = wp.array(mjm.actuator_gaintype, dtype=wp.int32, ndim=1) - m.actuator_gainprm = wp.array(mjm.actuator_gainprm, dtype=wp.float32, ndim=2) - m.actuator_biastype = wp.array(mjm.actuator_biastype, dtype=wp.int32, ndim=1) - m.actuator_biasprm = wp.array(mjm.actuator_biasprm, dtype=wp.float32, ndim=2) - m.actuator_gear = wp.array(mjm.actuator_gear, dtype=wp.spatial_vector, ndim=1) - m.actuator_actlimited = wp.array(mjm.actuator_actlimited, dtype=wp.bool, ndim=1) - m.actuator_actrange = wp.array(mjm.actuator_actrange, dtype=wp.vec2, ndim=1) - m.actuator_actadr = wp.array(mjm.actuator_actadr, dtype=wp.int32, ndim=1) - m.actuator_dyntype = wp.array(mjm.actuator_dyntype, dtype=wp.int32, ndim=1) - m.actuator_dynprm = wp.array(mjm.actuator_dynprm, dtype=types.vec10f, ndim=1) - m.exclude_signature = wp.array(mjm.exclude_signature, dtype=wp.int32, ndim=1) - - # short-circuiting here allows us to skip a lot of code in implicit integration - m.actuator_affine_bias_gain = bool( - np.any(mjm.actuator_biastype == types.BiasType.AFFINE.value) - or np.any(mjm.actuator_gaintype == types.GainType.AFFINE.value) - ) + # indices for sparse qM mul_m + is_, js, madr_ijs = [], [], [] + for i in range(mjm.nv): + madr_ij, j = mjm.dof_Madr[i], i - m.condim_max = np.max(mjm.geom_condim) # TODO(team): get max after filtering + while True: + madr_ij, j = madr_ij + 1, mjm.dof_parentid[j] + if j == -1: + break + is_, js, madr_ijs = is_ + [i], js + [j], madr_ijs + [madr_ij] - # tendon - m.tendon_adr = wp.array(mjm.tendon_adr, dtype=wp.int32, ndim=1) - m.tendon_num = wp.array(mjm.tendon_num, dtype=wp.int32, ndim=1) - m.wrap_objid = wp.array(mjm.wrap_objid, dtype=wp.int32, ndim=1) - m.wrap_prm = wp.array(mjm.wrap_prm, dtype=wp.float32, ndim=1) - m.wrap_type = wp.array(mjm.wrap_type, dtype=wp.int32, ndim=1) + qM_mulm_i, qM_mulm_j, qM_madr_ij = ( + np.array(x, dtype=np.int32) for x in (is_, js, madr_ijs) + ) - tendon_jnt_adr = [] - wrap_jnt_adr = [] - for i in range(mjm.ntendon): - adr = mjm.tendon_adr[i] - if mjm.wrap_type[adr] == mujoco.mjtWrap.mjWRAP_JOINT: - tendon_num = mjm.tendon_num[i] - for j in range(tendon_num): - tendon_jnt_adr.append(i) - wrap_jnt_adr.append(adr + j) + jnt_limited_slide_hinge_adr = np.nonzero( + mjm.jnt_limited + & ( + (mjm.jnt_type == mujoco.mjtJoint.mjJNT_SLIDE) + | (mjm.jnt_type == mujoco.mjtJoint.mjJNT_HINGE) + ) + )[0] + + # body_tree is BFS ordering of body ids + # body_treeadr contains starting index of each body tree level + bodies, body_depth = {}, np.zeros(mjm.nbody, dtype=int) - 1 + for i in range(mjm.nbody): + body_depth[i] = body_depth[mjm.body_parentid[i]] + 1 + bodies.setdefault(body_depth[i], []).append(i) + body_tree = np.concatenate([bodies[i] for i in range(len(bodies))]) + tree_off = [0] + [len(bodies[i]) for i in range(len(bodies))] + body_treeadr = np.cumsum(tree_off)[:-1] + + m.body_tree = wp.array(body_tree, dtype=wp.int32, ndim=1) + m.body_treeadr = wp.array(body_treeadr, dtype=wp.int32, ndim=1, device="cpu") + + qLD_update_tree = np.empty(shape=(0, 3), dtype=int) + qLD_update_treeadr = np.empty(shape=(0,), dtype=int) + qLD_tile = np.empty(shape=(0,), dtype=int) + qLD_tileadr = np.empty(shape=(0,), dtype=int) + qLD_tilesize = np.empty(shape=(0,), dtype=int) + + if support.is_sparse(mjm): + # qLD_update_tree has dof tree ordering of qLD updates for sparse factor m + # qLD_update_treeadr contains starting index of each dof tree level + mjd = mujoco.MjData(mjm) + if version.parse(mujoco.__version__) > version.parse("3.2.7"): + m.M_rownnz = wp.array(mjd.M_rownnz, dtype=wp.int32, ndim=1) + m.M_rowadr = wp.array(mjd.M_rowadr, dtype=wp.int32, ndim=1) + m.M_colind = wp.array(mjd.M_colind, dtype=wp.int32, ndim=1) + m.mapM2M = wp.array(mjd.mapM2M, dtype=wp.int32, ndim=1) + qLD_updates, dof_depth = {}, np.zeros(mjm.nv, dtype=int) - 1 + + rownnz = mjd.M_rownnz + rowadr = mjd.M_rowadr + + for k in range(mjm.nv): + dof_depth[k] = dof_depth[mjm.dof_parentid[k]] + 1 + i = mjm.dof_parentid[k] + diag_k = rowadr[k] + rownnz[k] - 1 + Madr_ki = diag_k - 1 + while i > -1: + qLD_updates.setdefault(dof_depth[i], []).append((i, k, Madr_ki)) + i = mjm.dof_parentid[i] + Madr_ki -= 1 + + qLD_update_tree = np.concatenate( + [qLD_updates[i] for i in range(len(qLD_updates))] + ) + tree_off = [0] + [len(qLD_updates[i]) for i in range(len(qLD_updates))] + qLD_update_treeadr = np.cumsum(tree_off)[:-1] + else: + qLD_updates, dof_depth = {}, np.zeros(mjm.nv, dtype=int) - 1 + for k in range(mjm.nv): + dof_depth[k] = dof_depth[mjm.dof_parentid[k]] + 1 + i = mjm.dof_parentid[k] + Madr_ki = mjm.dof_Madr[k] + 1 + while i > -1: + qLD_updates.setdefault(dof_depth[i], []).append((i, k, Madr_ki)) + i = mjm.dof_parentid[i] + Madr_ki += 1 + + # qLD_treeadr contains starting indicies of each level of sparse updates + qLD_update_tree = np.concatenate( + [qLD_updates[i] for i in range(len(qLD_updates))] + ) + tree_off = [0] + [len(qLD_updates[i]) for i in range(len(qLD_updates))] + qLD_update_treeadr = np.cumsum(tree_off)[:-1] - m.tendon_jnt_adr = wp.array(tendon_jnt_adr, dtype=wp.int32, ndim=1) - m.wrap_jnt_adr = wp.array(wrap_jnt_adr, dtype=wp.int32, ndim=1) + else: + # qLD_tile has the dof id of each tile in qLD for dense factor m + # qLD_tileadr contains starting index in qLD_tile of each tile group + # qLD_tilesize has the square tile size of each tile group + tile_corners = [i for i in range(mjm.nv) if mjm.dof_parentid[i] == -1] + tiles = {} + for i in range(len(tile_corners)): + tile_beg = tile_corners[i] + tile_end = mjm.nv if i == len(tile_corners) - 1 else tile_corners[i + 1] + tiles.setdefault(tile_end - tile_beg, []).append(tile_beg) + qLD_tile = np.concatenate([tiles[sz] for sz in sorted(tiles.keys())]) + tile_off = [0] + [len(tiles[sz]) for sz in sorted(tiles.keys())] + qLD_tileadr = np.cumsum(tile_off)[:-1] + qLD_tilesize = np.array(sorted(tiles.keys())) + + # tiles for actuator_moment - needs nu + nv tile size and offset + actuator_moment_offset_nv = np.empty(shape=(0,), dtype=int) + actuator_moment_offset_nu = np.empty(shape=(0,), dtype=int) + actuator_moment_tileadr = np.empty(shape=(0,), dtype=int) + actuator_moment_tilesize_nv = np.empty(shape=(0,), dtype=int) + actuator_moment_tilesize_nu = np.empty(shape=(0,), dtype=int) + + if not support.is_sparse(mjm): + # how many actuators for each tree + tile_corners = [i for i in range(mjm.nv) if mjm.dof_parentid[i] == -1] + tree_id = mjm.dof_treeid[tile_corners] + num_trees = int(np.max(tree_id)) + tree = mjm.body_treeid[mjm.jnt_bodyid[mjm.actuator_trnid[:, 0]]] + counts, ids = np.histogram(tree, bins=np.arange(0, num_trees + 2)) + acts_per_tree = dict(zip([int(i) for i in ids], [int(i) for i in counts])) + + tiles = {} + act_beg = 0 + for i in range(len(tile_corners)): + tile_beg = tile_corners[i] + tile_end = mjm.nv if i == len(tile_corners) - 1 else tile_corners[i + 1] + tree = int(tree_id[i]) + act_num = acts_per_tree[tree] + tiles.setdefault((tile_end - tile_beg, act_num), []).append((tile_beg, act_beg)) + act_beg += act_num + + sorted_keys = sorted(tiles.keys()) + actuator_moment_offset_nv = [ + t[0] for key in sorted_keys for t in tiles.get(key, []) + ] + actuator_moment_offset_nu = [ + t[1] for key in sorted_keys for t in tiles.get(key, []) + ] + tile_off = [0] + [len(tiles[sz]) for sz in sorted(tiles.keys())] + actuator_moment_tileadr = np.cumsum(tile_off)[:-1] # offset + actuator_moment_tilesize_nv = np.array( + [a[0] for a in sorted_keys] + ) # for this level + actuator_moment_tilesize_nu = np.array( + [int(a[1]) for a in sorted_keys] + ) # for this level + + m.qM_fullm_i = wp.array(qM_fullm_i, dtype=wp.int32, ndim=1) + m.qM_fullm_j = wp.array(qM_fullm_j, dtype=wp.int32, ndim=1) + m.qM_mulm_i = wp.array(qM_mulm_i, dtype=wp.int32, ndim=1) + m.qM_mulm_j = wp.array(qM_mulm_j, dtype=wp.int32, ndim=1) + m.qM_madr_ij = wp.array(qM_madr_ij, dtype=wp.int32, ndim=1) + m.qLD_update_tree = wp.array(qLD_update_tree, dtype=wp.vec3i, ndim=1) + m.qLD_update_treeadr = wp.array( + qLD_update_treeadr, dtype=wp.int32, ndim=1, device="cpu" + ) + m.qLD_tile = wp.array(qLD_tile, dtype=wp.int32, ndim=1) + m.qLD_tileadr = wp.array(qLD_tileadr, dtype=wp.int32, ndim=1, device="cpu") + m.qLD_tilesize = wp.array(qLD_tilesize, dtype=wp.int32, ndim=1, device="cpu") + m.actuator_moment_offset_nv = wp.array( + actuator_moment_offset_nv, dtype=wp.int32, ndim=1 + ) + m.actuator_moment_offset_nu = wp.array( + actuator_moment_offset_nu, dtype=wp.int32, ndim=1 + ) + m.actuator_moment_tileadr = wp.array( + actuator_moment_tileadr, dtype=wp.int32, ndim=1, device="cpu" + ) + m.actuator_moment_tilesize_nv = wp.array( + actuator_moment_tilesize_nv, dtype=wp.int32, ndim=1, device="cpu" + ) + m.actuator_moment_tilesize_nu = wp.array( + actuator_moment_tilesize_nu, dtype=wp.int32, ndim=1, device="cpu" + ) + m.alpha_candidate = wp.array(np.linspace(0.0, 1.0, m.nlsp), dtype=wp.float32) + m.body_dofadr = wp.array(mjm.body_dofadr, dtype=wp.int32, ndim=1) + m.body_dofnum = wp.array(mjm.body_dofnum, dtype=wp.int32, ndim=1) + m.body_jntadr = wp.array(mjm.body_jntadr, dtype=wp.int32, ndim=1) + m.body_jntnum = wp.array(mjm.body_jntnum, dtype=wp.int32, ndim=1) + m.body_parentid = wp.array(mjm.body_parentid, dtype=wp.int32, ndim=1) + m.body_mocapid = wp.array(mjm.body_mocapid, dtype=wp.int32, ndim=1) + m.body_weldid = wp.array(mjm.body_weldid, dtype=wp.int32, ndim=1) + m.body_pos = wp.array(mjm.body_pos, dtype=wp.vec3, ndim=1) + m.body_quat = wp.array(mjm.body_quat, dtype=wp.quat, ndim=1) + m.body_ipos = wp.array(mjm.body_ipos, dtype=wp.vec3, ndim=1) + m.body_iquat = wp.array(mjm.body_iquat, dtype=wp.quat, ndim=1) + m.body_rootid = wp.array(mjm.body_rootid, dtype=wp.int32, ndim=1) + m.body_inertia = wp.array(mjm.body_inertia, dtype=wp.vec3, ndim=1) + m.body_mass = wp.array(mjm.body_mass, dtype=wp.float32, ndim=1) + + subtree_mass = np.copy(mjm.body_mass) + # TODO(team): should this be [mjm.nbody - 1, 0) ? + for i in range(mjm.nbody - 1, -1, -1): + subtree_mass[mjm.body_parentid[i]] += subtree_mass[i] + + m.subtree_mass = wp.array(subtree_mass, dtype=wp.float32, ndim=1) + m.body_invweight0 = wp.array(mjm.body_invweight0, dtype=wp.float32, ndim=2) + m.body_geomnum = wp.array(mjm.body_geomnum, dtype=wp.int32, ndim=1) + m.body_geomadr = wp.array(mjm.body_geomadr, dtype=wp.int32, ndim=1) + m.body_contype = wp.array(mjm.body_contype, dtype=wp.int32, ndim=1) + m.body_conaffinity = wp.array(mjm.body_conaffinity, dtype=wp.int32, ndim=1) + m.jnt_bodyid = wp.array(mjm.jnt_bodyid, dtype=wp.int32, ndim=1) + m.jnt_limited = wp.array(mjm.jnt_limited, dtype=wp.int32, ndim=1) + m.jnt_limited_slide_hinge_adr = wp.array( + jnt_limited_slide_hinge_adr, dtype=wp.int32, ndim=1 + ) + m.jnt_type = wp.array(mjm.jnt_type, dtype=wp.int32, ndim=1) + m.jnt_solref = wp.array(mjm.jnt_solref, dtype=wp.vec2f, ndim=1) + m.jnt_solimp = wp.array(mjm.jnt_solimp, dtype=types.vec5, ndim=1) + m.jnt_qposadr = wp.array(mjm.jnt_qposadr, dtype=wp.int32, ndim=1) + m.jnt_dofadr = wp.array(mjm.jnt_dofadr, dtype=wp.int32, ndim=1) + m.jnt_axis = wp.array(mjm.jnt_axis, dtype=wp.vec3, ndim=1) + m.jnt_pos = wp.array(mjm.jnt_pos, dtype=wp.vec3, ndim=1) + m.jnt_range = wp.array(mjm.jnt_range, dtype=wp.float32, ndim=2) + m.jnt_margin = wp.array(mjm.jnt_margin, dtype=wp.float32, ndim=1) + m.jnt_stiffness = wp.array(mjm.jnt_stiffness, dtype=wp.float32, ndim=1) + m.jnt_actfrclimited = wp.array(mjm.jnt_actfrclimited, dtype=wp.bool, ndim=1) + m.jnt_actfrcrange = wp.array(mjm.jnt_actfrcrange, dtype=wp.vec2, ndim=1) + m.geom_type = wp.array(mjm.geom_type, dtype=wp.int32, ndim=1) + m.geom_bodyid = wp.array(mjm.geom_bodyid, dtype=wp.int32, ndim=1) + m.geom_conaffinity = wp.array(mjm.geom_conaffinity, dtype=wp.int32, ndim=1) + m.geom_contype = wp.array(mjm.geom_contype, dtype=wp.int32, ndim=1) + m.geom_condim = wp.array(mjm.geom_condim, dtype=wp.int32, ndim=1) + m.geom_pos = wp.array(mjm.geom_pos, dtype=wp.vec3, ndim=1) + m.geom_quat = wp.array(mjm.geom_quat, dtype=wp.quat, ndim=1) + m.geom_size = wp.array(mjm.geom_size, dtype=wp.vec3, ndim=1) + m.geom_priority = wp.array(mjm.geom_priority, dtype=wp.int32, ndim=1) + m.geom_solmix = wp.array(mjm.geom_solmix, dtype=wp.float32, ndim=1) + m.geom_solref = wp.array(mjm.geom_solref, dtype=wp.vec2, ndim=1) + m.geom_solimp = wp.array(mjm.geom_solimp, dtype=types.vec5, ndim=1) + m.geom_friction = wp.array(mjm.geom_friction, dtype=wp.vec3, ndim=1) + m.geom_margin = wp.array(mjm.geom_margin, dtype=wp.float32, ndim=1) + m.geom_gap = wp.array(mjm.geom_gap, dtype=wp.float32, ndim=1) + m.geom_aabb = wp.array(mjm.geom_aabb, dtype=wp.vec3, ndim=3) + m.geom_rbound = wp.array(mjm.geom_rbound, dtype=wp.float32, ndim=1) + m.geom_dataid = wp.array(mjm.geom_dataid, dtype=wp.int32, ndim=1) + m.mesh_vertadr = wp.array(mjm.mesh_vertadr, dtype=wp.int32, ndim=1) + m.mesh_vertnum = wp.array(mjm.mesh_vertnum, dtype=wp.int32, ndim=1) + m.mesh_vert = wp.array(mjm.mesh_vert, dtype=wp.vec3, ndim=1) + m.site_pos = wp.array(mjm.site_pos, dtype=wp.vec3, ndim=1) + m.site_quat = wp.array(mjm.site_quat, dtype=wp.quat, ndim=1) + m.site_bodyid = wp.array(mjm.site_bodyid, dtype=wp.int32, ndim=1) + m.cam_mode = wp.array(mjm.cam_mode, dtype=wp.int32, ndim=1) + m.cam_bodyid = wp.array(mjm.cam_bodyid, dtype=wp.int32, ndim=1) + m.cam_targetbodyid = wp.array(mjm.cam_targetbodyid, dtype=wp.int32, ndim=1) + m.cam_pos = wp.array(mjm.cam_pos, dtype=wp.vec3, ndim=1) + m.cam_quat = wp.array(mjm.cam_quat, dtype=wp.quat, ndim=1) + m.cam_poscom0 = wp.array(mjm.cam_poscom0, dtype=wp.vec3, ndim=1) + m.cam_pos0 = wp.array(mjm.cam_pos0, dtype=wp.vec3, ndim=1) + m.cam_mat0 = wp.array(mjm.cam_mat0.reshape(-1, 3, 3), dtype=wp.mat33, ndim=1) + m.light_mode = wp.array(mjm.light_mode, dtype=wp.int32, ndim=1) + m.light_bodyid = wp.array(mjm.light_bodyid, dtype=wp.int32, ndim=1) + m.light_targetbodyid = wp.array(mjm.light_targetbodyid, dtype=wp.int32, ndim=1) + m.light_pos = wp.array(mjm.light_pos, dtype=wp.vec3, ndim=1) + m.light_dir = wp.array(mjm.light_dir, dtype=wp.vec3, ndim=1) + m.light_poscom0 = wp.array(mjm.light_poscom0, dtype=wp.vec3, ndim=1) + m.light_pos0 = wp.array(mjm.light_pos0, dtype=wp.vec3, ndim=1) + m.light_dir0 = wp.array(mjm.light_dir0, dtype=wp.vec3, ndim=1) + m.dof_bodyid = wp.array(mjm.dof_bodyid, dtype=wp.int32, ndim=1) + m.dof_jntid = wp.array(mjm.dof_jntid, dtype=wp.int32, ndim=1) + m.dof_parentid = wp.array(mjm.dof_parentid, dtype=wp.int32, ndim=1) + m.dof_Madr = wp.array(mjm.dof_Madr, dtype=wp.int32, ndim=1) + m.dof_armature = wp.array(mjm.dof_armature, dtype=wp.float32, ndim=1) + m.dof_damping = wp.array(mjm.dof_damping, dtype=wp.float32, ndim=1) + m.dof_tri_row = wp.from_numpy(dof_tri_row, dtype=wp.int32) + m.dof_tri_col = wp.from_numpy(dof_tri_col, dtype=wp.int32) + m.dof_invweight0 = wp.array(mjm.dof_invweight0, dtype=wp.float32, ndim=1) + m.actuator_trntype = wp.array(mjm.actuator_trntype, dtype=wp.int32, ndim=1) + m.actuator_trnid = wp.array(mjm.actuator_trnid, dtype=wp.int32, ndim=2) + m.actuator_ctrllimited = wp.array(mjm.actuator_ctrllimited, dtype=wp.bool, ndim=1) + m.actuator_ctrlrange = wp.array(mjm.actuator_ctrlrange, dtype=wp.vec2, ndim=1) + m.actuator_forcelimited = wp.array(mjm.actuator_forcelimited, dtype=wp.bool, ndim=1) + m.actuator_forcerange = wp.array(mjm.actuator_forcerange, dtype=wp.vec2, ndim=1) + m.actuator_gaintype = wp.array(mjm.actuator_gaintype, dtype=wp.int32, ndim=1) + m.actuator_gainprm = wp.array(mjm.actuator_gainprm, dtype=wp.float32, ndim=2) + m.actuator_biastype = wp.array(mjm.actuator_biastype, dtype=wp.int32, ndim=1) + m.actuator_biasprm = wp.array(mjm.actuator_biasprm, dtype=wp.float32, ndim=2) + m.actuator_gear = wp.array(mjm.actuator_gear, dtype=wp.spatial_vector, ndim=1) + m.actuator_actlimited = wp.array(mjm.actuator_actlimited, dtype=wp.bool, ndim=1) + m.actuator_actrange = wp.array(mjm.actuator_actrange, dtype=wp.vec2, ndim=1) + m.actuator_actadr = wp.array(mjm.actuator_actadr, dtype=wp.int32, ndim=1) + m.actuator_dyntype = wp.array(mjm.actuator_dyntype, dtype=wp.int32, ndim=1) + m.actuator_dynprm = wp.array(mjm.actuator_dynprm, dtype=types.vec10f, ndim=1) + m.exclude_signature = wp.array(mjm.exclude_signature, dtype=wp.int32, ndim=1) + + # short-circuiting here allows us to skip a lot of code in implicit integration + m.actuator_affine_bias_gain = bool( + np.any(mjm.actuator_biastype == types.BiasType.AFFINE.value) + or np.any(mjm.actuator_gaintype == types.GainType.AFFINE.value) + ) - # sensors - m.sensor_type = wp.array(mjm.sensor_type, dtype=wp.int32, ndim=1) - m.sensor_datatype = wp.array(mjm.sensor_datatype, dtype=wp.int32, ndim=1) - m.sensor_objtype = wp.array(mjm.sensor_objtype, dtype=wp.int32, ndim=1) - m.sensor_objid = wp.array(mjm.sensor_objid, dtype=wp.int32, ndim=1) - m.sensor_reftype = wp.array(mjm.sensor_reftype, dtype=wp.int32, ndim=1) - m.sensor_refid = wp.array(mjm.sensor_refid, dtype=wp.int32, ndim=1) - m.sensor_dim = wp.array(mjm.sensor_dim, dtype=wp.int32, ndim=1) - m.sensor_adr = wp.array(mjm.sensor_adr, dtype=wp.int32, ndim=1) - m.sensor_cutoff = wp.array(mjm.sensor_cutoff, dtype=wp.float32, ndim=1) - m.sensor_pos_adr = wp.array( - np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_POS)[0], - dtype=wp.int32, - ndim=1, - ) - m.sensor_vel_adr = wp.array( - np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_VEL)[0], - dtype=wp.int32, - ndim=1, - ) - m.sensor_acc_adr = wp.array( - np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_ACC)[0], - dtype=wp.int32, - ndim=1, - ) + # tendon + m.tendon_adr = wp.array(mjm.tendon_adr, dtype=wp.int32, ndim=1) + m.tendon_num = wp.array(mjm.tendon_num, dtype=wp.int32, ndim=1) + m.wrap_objid = wp.array(mjm.wrap_objid, dtype=wp.int32, ndim=1) + m.wrap_prm = wp.array(mjm.wrap_prm, dtype=wp.float32, ndim=1) + m.wrap_type = wp.array(mjm.wrap_type, dtype=wp.int32, ndim=1) + + tendon_jnt_adr = [] + wrap_jnt_adr = [] + for i in range(mjm.ntendon): + adr = mjm.tendon_adr[i] + if mjm.wrap_type[adr] == mujoco.mjtWrap.mjWRAP_JOINT: + tendon_num = mjm.tendon_num[i] + for j in range(tendon_num): + tendon_jnt_adr.append(i) + wrap_jnt_adr.append(adr + j) + + m.condim_max = np.max(mjm.geom_condim) # TODO(team): get max after filtering + + # tendon + m.tendon_adr = wp.array(mjm.tendon_adr, dtype=wp.int32, ndim=1) + m.tendon_num = wp.array(mjm.tendon_num, dtype=wp.int32, ndim=1) + m.wrap_objid = wp.array(mjm.wrap_objid, dtype=wp.int32, ndim=1) + m.wrap_prm = wp.array(mjm.wrap_prm, dtype=wp.float32, ndim=1) + m.wrap_type = wp.array(mjm.wrap_type, dtype=wp.int32, ndim=1) + + # sensors + m.sensor_type = wp.array(mjm.sensor_type, dtype=wp.int32, ndim=1) + m.sensor_datatype = wp.array(mjm.sensor_datatype, dtype=wp.int32, ndim=1) + m.sensor_objtype = wp.array(mjm.sensor_objtype, dtype=wp.int32, ndim=1) + m.sensor_objid = wp.array(mjm.sensor_objid, dtype=wp.int32, ndim=1) + m.sensor_reftype = wp.array(mjm.sensor_reftype, dtype=wp.int32, ndim=1) + m.sensor_refid = wp.array(mjm.sensor_refid, dtype=wp.int32, ndim=1) + m.sensor_dim = wp.array(mjm.sensor_dim, dtype=wp.int32, ndim=1) + m.sensor_adr = wp.array(mjm.sensor_adr, dtype=wp.int32, ndim=1) + m.sensor_cutoff = wp.array(mjm.sensor_cutoff, dtype=wp.float32, ndim=1) + m.sensor_pos_adr = wp.array( + np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_POS)[0], + dtype=wp.int32, + ndim=1, + ) + m.sensor_vel_adr = wp.array( + np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_VEL)[0], + dtype=wp.int32, + ndim=1, + ) + m.sensor_acc_adr = wp.array( + np.nonzero(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_ACC)[0], + dtype=wp.int32, + ndim=1, + ) return m @@ -527,7 +534,11 @@ def _constraint(mjm: mujoco.MjModel, nworld: int, njmax: int) -> types.Constrain def make_data( - mjm: mujoco.MjModel, nworld: int = 1, nconmax: int = -1, njmax: int = -1 + mjm: mujoco.MjModel, + nworld: int = 1, + nconmax: int = -1, + njmax: int = -1, + device: Optional[wp.context.Device] = None, ) -> types.Data: d = types.Data() d.nworld = nworld @@ -542,118 +553,119 @@ def make_data( njmax = 512 d.njmax = njmax - d.ncon = wp.zeros(1, dtype=wp.int32) - d.nefc = wp.zeros(1, dtype=wp.int32, ndim=1) - d.nl = 0 - d.time = 0.0 - - qpos0 = np.tile(mjm.qpos0, (nworld, 1)) - d.qpos = wp.array(qpos0, dtype=wp.float32, ndim=2) - d.qvel = wp.zeros((nworld, mjm.nv), dtype=wp.float32, ndim=2) - d.qacc_warmstart = wp.zeros((nworld, mjm.nv), dtype=wp.float32, ndim=2) - d.qfrc_applied = wp.zeros((nworld, mjm.nv), dtype=wp.float32, ndim=2) - d.mocap_pos = wp.zeros((nworld, mjm.nmocap), dtype=wp.vec3) - d.mocap_quat = wp.zeros((nworld, mjm.nmocap), dtype=wp.quat) - d.qacc = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.xanchor = wp.zeros((nworld, mjm.njnt), dtype=wp.vec3) - d.xaxis = wp.zeros((nworld, mjm.njnt), dtype=wp.vec3) - d.xmat = wp.zeros((nworld, mjm.nbody), dtype=wp.mat33) - d.xpos = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) - d.xquat = wp.zeros((nworld, mjm.nbody), dtype=wp.quat) - d.xipos = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) - d.ximat = wp.zeros((nworld, mjm.nbody), dtype=wp.mat33) - d.subtree_com = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) - d.geom_xpos = wp.zeros((nworld, mjm.ngeom), dtype=wp.vec3) - d.geom_xmat = wp.zeros((nworld, mjm.ngeom), dtype=wp.mat33) - d.site_xpos = wp.zeros((nworld, mjm.nsite), dtype=wp.vec3) - d.site_xmat = wp.zeros((nworld, mjm.nsite), dtype=wp.mat33) - d.cam_xpos = wp.zeros((nworld, mjm.ncam), dtype=wp.vec3) - d.cam_xmat = wp.zeros((nworld, mjm.ncam), dtype=wp.mat33) - d.light_xpos = wp.zeros((nworld, mjm.nlight), dtype=wp.vec3) - d.light_xdir = wp.zeros((nworld, mjm.nlight), dtype=wp.vec3) - d.cinert = wp.zeros((nworld, mjm.nbody), dtype=types.vec10) - d.cdof = wp.zeros((nworld, mjm.nv), dtype=wp.spatial_vector) - d.ctrl = wp.zeros((nworld, mjm.nu), dtype=wp.float32) - d.actuator_velocity = wp.zeros((nworld, mjm.nu), dtype=wp.float32) - d.actuator_force = wp.zeros((nworld, mjm.nu), dtype=wp.float32) - d.actuator_length = wp.zeros((nworld, mjm.nu), dtype=wp.float32) - d.actuator_moment = wp.zeros((nworld, mjm.nu, mjm.nv), dtype=wp.float32) - d.crb = wp.zeros((nworld, mjm.nbody), dtype=types.vec10) - if support.is_sparse(mjm): - d.qM = wp.zeros((nworld, 1, mjm.nM), dtype=wp.float32) - d.qLD = wp.zeros((nworld, 1, mjm.nM), dtype=wp.float32) - else: - d.qM = wp.zeros((nworld, mjm.nv, mjm.nv), dtype=wp.float32) - d.qLD = wp.zeros((nworld, mjm.nv, mjm.nv), dtype=wp.float32) - d.act_dot = wp.zeros((nworld, mjm.na), dtype=wp.float32) - d.act = wp.zeros((nworld, mjm.na), dtype=wp.float32) - d.qLDiagInv = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.cvel = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector) - d.cdof_dot = wp.zeros((nworld, mjm.nv), dtype=wp.spatial_vector) - d.qfrc_bias = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.contact = types.Contact() - d.contact.dist = wp.zeros((nconmax,), dtype=wp.float32) - d.contact.pos = wp.zeros((nconmax,), dtype=wp.vec3f) - d.contact.frame = wp.zeros((nconmax,), dtype=wp.mat33f) - d.contact.includemargin = wp.zeros((nconmax,), dtype=wp.float32) - d.contact.friction = wp.zeros((nconmax,), dtype=types.vec5) - d.contact.solref = wp.zeros((nconmax,), dtype=wp.vec2f) - d.contact.solreffriction = wp.zeros((nconmax,), dtype=wp.vec2f) - d.contact.solimp = wp.zeros((nconmax,), dtype=types.vec5) - d.contact.dim = wp.zeros((nconmax,), dtype=wp.int32) - d.contact.geom = wp.zeros((nconmax,), dtype=wp.vec2i) - d.contact.efc_address = wp.zeros((nconmax, np.max(mjm.geom_condim)), dtype=wp.int32) - d.contact.worldid = wp.zeros((nconmax,), dtype=wp.int32) - d.efc = _constraint(mjm, d.nworld, d.njmax) - d.qfrc_passive = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qfrc_spring = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qfrc_damper = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qfrc_actuator = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qfrc_smooth = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qfrc_constraint = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qacc_smooth = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.xfrc_applied = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector) - - # internal tmp arrays - d.qfrc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qacc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qM_integration = wp.zeros_like(d.qM) - d.qLD_integration = wp.zeros_like(d.qLD) - d.qLDiagInv_integration = wp.zeros_like(d.qLDiagInv) - d.act_vel_integration = wp.zeros_like(d.ctrl) - d.qpos_t0 = wp.zeros((nworld, mjm.nq), dtype=wp.float32) - d.qvel_t0 = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.act_t0 = wp.zeros((nworld, mjm.na), dtype=wp.float32) - d.qvel_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qacc_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.act_dot_rk = wp.zeros((nworld, mjm.na), dtype=wp.float32) - - # sweep-and-prune broadphase - d.sap_geom_sort = wp.zeros((nworld, mjm.ngeom), dtype=wp.vec4) - d.sap_projection_lower = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32) - d.sap_projection_upper = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32) - d.sap_sort_index = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32) - d.sap_range = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32) - d.sap_cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32) - segment_indices_list = [i * mjm.ngeom for i in range(nworld + 1)] - d.sap_segment_index = wp.array(segment_indices_list, dtype=int) - - # collision driver - d.collision_pair = wp.empty(nconmax, dtype=wp.vec2i, ndim=1) - d.collision_worldid = wp.empty(nconmax, dtype=wp.int32, ndim=1) - d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1) - - # rne_postconstraint - d.cacc = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector, ndim=2) - d.cfrc_int = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector, ndim=2) - d.cfrc_ext = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector, ndim=2) - - # tendon - d.ten_length = wp.zeros((nworld, mjm.ntendon), dtype=wp.float32, ndim=2) - d.ten_J = wp.zeros((nworld, mjm.ntendon, mjm.nv), dtype=wp.float32, ndim=3) - - # sensors - d.sensordata = wp.zeros((nworld, mjm.nsensordata), dtype=wp.float32) + with wp.ScopedDevice(device): + d.ncon = wp.zeros(1, dtype=wp.int32) + d.nefc = wp.zeros(1, dtype=wp.int32, ndim=1) + d.nl = 0 + d.time = 0.0 + + qpos0 = np.tile(mjm.qpos0, (nworld, 1)) + d.qpos = wp.array(qpos0, dtype=wp.float32, ndim=2) + d.qvel = wp.zeros((nworld, mjm.nv), dtype=wp.float32, ndim=2) + d.qacc_warmstart = wp.zeros((nworld, mjm.nv), dtype=wp.float32, ndim=2) + d.qfrc_applied = wp.zeros((nworld, mjm.nv), dtype=wp.float32, ndim=2) + d.mocap_pos = wp.zeros((nworld, mjm.nmocap), dtype=wp.vec3) + d.mocap_quat = wp.zeros((nworld, mjm.nmocap), dtype=wp.quat) + d.qacc = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.xanchor = wp.zeros((nworld, mjm.njnt), dtype=wp.vec3) + d.xaxis = wp.zeros((nworld, mjm.njnt), dtype=wp.vec3) + d.xmat = wp.zeros((nworld, mjm.nbody), dtype=wp.mat33) + d.xpos = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) + d.xquat = wp.zeros((nworld, mjm.nbody), dtype=wp.quat) + d.xipos = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) + d.ximat = wp.zeros((nworld, mjm.nbody), dtype=wp.mat33) + d.subtree_com = wp.zeros((nworld, mjm.nbody), dtype=wp.vec3) + d.geom_xpos = wp.zeros((nworld, mjm.ngeom), dtype=wp.vec3) + d.geom_xmat = wp.zeros((nworld, mjm.ngeom), dtype=wp.mat33) + d.site_xpos = wp.zeros((nworld, mjm.nsite), dtype=wp.vec3) + d.site_xmat = wp.zeros((nworld, mjm.nsite), dtype=wp.mat33) + d.cam_xpos = wp.zeros((nworld, mjm.ncam), dtype=wp.vec3) + d.cam_xmat = wp.zeros((nworld, mjm.ncam), dtype=wp.mat33) + d.light_xpos = wp.zeros((nworld, mjm.nlight), dtype=wp.vec3) + d.light_xdir = wp.zeros((nworld, mjm.nlight), dtype=wp.vec3) + d.cinert = wp.zeros((nworld, mjm.nbody), dtype=types.vec10) + d.cdof = wp.zeros((nworld, mjm.nv), dtype=wp.spatial_vector) + d.ctrl = wp.zeros((nworld, mjm.nu), dtype=wp.float32) + d.actuator_velocity = wp.zeros((nworld, mjm.nu), dtype=wp.float32) + d.actuator_force = wp.zeros((nworld, mjm.nu), dtype=wp.float32) + d.actuator_length = wp.zeros((nworld, mjm.nu), dtype=wp.float32) + d.actuator_moment = wp.zeros((nworld, mjm.nu, mjm.nv), dtype=wp.float32) + d.crb = wp.zeros((nworld, mjm.nbody), dtype=types.vec10) + if support.is_sparse(mjm): + d.qM = wp.zeros((nworld, 1, mjm.nM), dtype=wp.float32) + d.qLD = wp.zeros((nworld, 1, mjm.nM), dtype=wp.float32) + else: + d.qM = wp.zeros((nworld, mjm.nv, mjm.nv), dtype=wp.float32) + d.qLD = wp.zeros((nworld, mjm.nv, mjm.nv), dtype=wp.float32) + d.act_dot = wp.zeros((nworld, mjm.na), dtype=wp.float32) + d.act = wp.zeros((nworld, mjm.na), dtype=wp.float32) + d.qLDiagInv = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.cvel = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector) + d.cdof_dot = wp.zeros((nworld, mjm.nv), dtype=wp.spatial_vector) + d.qfrc_bias = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.contact = types.Contact() + d.contact.dist = wp.zeros((nconmax,), dtype=wp.float32) + d.contact.pos = wp.zeros((nconmax,), dtype=wp.vec3f) + d.contact.frame = wp.zeros((nconmax,), dtype=wp.mat33f) + d.contact.includemargin = wp.zeros((nconmax,), dtype=wp.float32) + d.contact.friction = wp.zeros((nconmax,), dtype=types.vec5) + d.contact.solref = wp.zeros((nconmax,), dtype=wp.vec2f) + d.contact.solreffriction = wp.zeros((nconmax,), dtype=wp.vec2f) + d.contact.solimp = wp.zeros((nconmax,), dtype=types.vec5) + d.contact.dim = wp.zeros((nconmax,), dtype=wp.int32) + d.contact.geom = wp.zeros((nconmax,), dtype=wp.vec2i) + d.contact.efc_address = wp.zeros((nconmax, np.max(mjm.geom_condim)), dtype=wp.int32) + d.contact.worldid = wp.zeros((nconmax,), dtype=wp.int32) + d.efc = _constraint(mjm, d.nworld, d.njmax) + d.qfrc_passive = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qfrc_spring = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qfrc_damper = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qfrc_actuator = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qfrc_smooth = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qfrc_constraint = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qacc_smooth = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.xfrc_applied = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector) + + # internal tmp arrays + d.qfrc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qacc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qM_integration = wp.zeros_like(d.qM) + d.qLD_integration = wp.zeros_like(d.qLD) + d.qLDiagInv_integration = wp.zeros_like(d.qLDiagInv) + d.act_vel_integration = wp.zeros_like(d.ctrl) + d.qpos_t0 = wp.zeros((nworld, mjm.nq), dtype=wp.float32) + d.qvel_t0 = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.act_t0 = wp.zeros((nworld, mjm.na), dtype=wp.float32) + d.qvel_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qacc_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.act_dot_rk = wp.zeros((nworld, mjm.na), dtype=wp.float32) + + # sweep-and-prune broadphase + d.sap_geom_sort = wp.zeros((nworld, mjm.ngeom), dtype=wp.vec4) + d.sap_projection_lower = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32) + d.sap_projection_upper = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32) + d.sap_sort_index = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32) + d.sap_range = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32) + d.sap_cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32) + segment_indices_list = [i * mjm.ngeom for i in range(nworld + 1)] + d.sap_segment_index = wp.array(segment_indices_list, dtype=int) + + # collision driver + d.collision_pair = wp.empty(nconmax, dtype=wp.vec2i, ndim=1) + d.collision_worldid = wp.empty(nconmax, dtype=wp.int32, ndim=1) + d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1) + + # rne_postconstraint + d.cacc = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector, ndim=2) + d.cfrc_int = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector, ndim=2) + d.cfrc_ext = wp.zeros((nworld, mjm.nbody), dtype=wp.spatial_vector, ndim=2) + + # tendon + d.ten_length = wp.zeros((nworld, mjm.ntendon), dtype=wp.float32, ndim=2) + d.ten_J = wp.zeros((nworld, mjm.ntendon, mjm.nv), dtype=wp.float32, ndim=3) + + # sensors + d.sensordata = wp.zeros((nworld, mjm.nsensordata), dtype=wp.float32) return d @@ -664,6 +676,7 @@ def put_data( nworld: Optional[int] = None, nconmax: Optional[int] = None, njmax: Optional[int] = None, + device: Optional[wp.context.Device] = None, ) -> types.Data: d = types.Data() @@ -688,237 +701,254 @@ def put_data( if nworld * mjd.nefc > njmax: raise ValueError(f"njmax overflow (njmax must be >= {nworld * mjd.nefc})") - d.nworld = nworld - # TODO(team): move nconmax and njmax to Model? - d.nconmax = nconmax - d.njmax = njmax - - d.ncon = wp.array([mjd.ncon * nworld], dtype=wp.int32, ndim=1) - d.nl = mjd.nl - d.nefc = wp.array([mjd.nefc * nworld], dtype=wp.int32, ndim=1) - d.time = mjd.time - - # TODO(erikfrey): would it be better to tile on the gpu? - def tile(x): - return np.tile(x, (nworld,) + (1,) * len(x.shape)) + with wp.ScopedDevice(device): + d.nworld = nworld + # TODO(team): move nconmax and njmax to Model? + d.nconmax = nconmax + d.njmax = njmax + + d.ncon = wp.array([mjd.ncon * nworld], dtype=wp.int32, ndim=1) + d.nl = mjd.nl + d.nefc = wp.array([mjd.nefc * nworld], dtype=wp.int32, ndim=1) + d.time = mjd.time + + # TODO(erikfrey): would it be better to tile on the gpu? + def tile(x): + return np.tile(x, (nworld,) + (1,) * len(x.shape)) + + if support.is_sparse(mjm): + qM = np.expand_dims(mjd.qM, axis=0) + qLD = np.expand_dims(mjd.qLD, axis=0) + efc_J = np.zeros((mjd.nefc, mjm.nv)) + mujoco.mju_sparse2dense( + efc_J, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind + ) + else: + qM = np.zeros((mjm.nv, mjm.nv)) + mujoco.mj_fullM(mjm, qM, mjd.qM) + qLD = np.linalg.cholesky(qM) + efc_J = mjd.efc_J.reshape((mjd.nefc, mjm.nv)) - if support.is_sparse(mjm): - qM = np.expand_dims(mjd.qM, axis=0) - qLD = np.expand_dims(mjd.qLD, axis=0) - efc_J = np.zeros((mjd.nefc, mjm.nv)) + # TODO(taylorhowell): sparse actuator_moment + actuator_moment = np.zeros((mjm.nu, mjm.nv)) mujoco.mju_sparse2dense( - efc_J, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind + actuator_moment, + mjd.actuator_moment, + mjd.moment_rownnz, + mjd.moment_rowadr, + mjd.moment_colind, ) - else: - qM = np.zeros((mjm.nv, mjm.nv)) - mujoco.mj_fullM(mjm, qM, mjd.qM) - qLD = np.linalg.cholesky(qM) - efc_J = mjd.efc_J.reshape((mjd.nefc, mjm.nv)) - - # TODO(taylorhowell): sparse actuator_moment - actuator_moment = np.zeros((mjm.nu, mjm.nv)) - mujoco.mju_sparse2dense( - actuator_moment, - mjd.actuator_moment, - mjd.moment_rownnz, - mjd.moment_rowadr, - mjd.moment_colind, - ) - - d.qpos = wp.array(tile(mjd.qpos), dtype=wp.float32, ndim=2) - d.qvel = wp.array(tile(mjd.qvel), dtype=wp.float32, ndim=2) - d.qacc_warmstart = wp.array(tile(mjd.qacc_warmstart), dtype=wp.float32, ndim=2) - d.qfrc_applied = wp.array(tile(mjd.qfrc_applied), dtype=wp.float32, ndim=2) - d.mocap_pos = wp.array(tile(mjd.mocap_pos), dtype=wp.vec3, ndim=2) - d.mocap_quat = wp.array(tile(mjd.mocap_quat), dtype=wp.quat, ndim=2) - d.qacc = wp.array(tile(mjd.qacc), dtype=wp.float32, ndim=2) - d.xanchor = wp.array(tile(mjd.xanchor), dtype=wp.vec3, ndim=2) - d.xaxis = wp.array(tile(mjd.xaxis), dtype=wp.vec3, ndim=2) - d.xmat = wp.array(tile(mjd.xmat), dtype=wp.mat33, ndim=2) - d.xpos = wp.array(tile(mjd.xpos), dtype=wp.vec3, ndim=2) - d.xquat = wp.array(tile(mjd.xquat), dtype=wp.quat, ndim=2) - d.xipos = wp.array(tile(mjd.xipos), dtype=wp.vec3, ndim=2) - d.ximat = wp.array(tile(mjd.ximat), dtype=wp.mat33, ndim=2) - d.subtree_com = wp.array(tile(mjd.subtree_com), dtype=wp.vec3, ndim=2) - d.geom_xpos = wp.array(tile(mjd.geom_xpos), dtype=wp.vec3, ndim=2) - d.geom_xmat = wp.array(tile(mjd.geom_xmat), dtype=wp.mat33, ndim=2) - d.site_xpos = wp.array(tile(mjd.site_xpos), dtype=wp.vec3, ndim=2) - d.site_xmat = wp.array(tile(mjd.site_xmat), dtype=wp.mat33, ndim=2) - d.cam_xpos = wp.array(tile(mjd.cam_xpos), dtype=wp.vec3, ndim=2) - d.cam_xmat = wp.array(tile(mjd.cam_xmat.reshape(-1, 3, 3)), dtype=wp.mat33, ndim=2) - d.light_xpos = wp.array(tile(mjd.light_xpos), dtype=wp.vec3, ndim=2) - d.light_xdir = wp.array(tile(mjd.light_xdir), dtype=wp.vec3, ndim=2) - d.cinert = wp.array(tile(mjd.cinert), dtype=types.vec10, ndim=2) - d.cdof = wp.array(tile(mjd.cdof), dtype=wp.spatial_vector, ndim=2) - d.crb = wp.array(tile(mjd.crb), dtype=types.vec10, ndim=2) - d.qM = wp.array(tile(qM), dtype=wp.float32, ndim=3) - d.qLD = wp.array(tile(qLD), dtype=wp.float32, ndim=3) - d.qLDiagInv = wp.array(tile(mjd.qLDiagInv), dtype=wp.float32, ndim=2) - d.ctrl = wp.array(tile(mjd.ctrl), dtype=wp.float32, ndim=2) - d.actuator_velocity = wp.array(tile(mjd.actuator_velocity), dtype=wp.float32, ndim=2) - d.actuator_force = wp.array(tile(mjd.actuator_force), dtype=wp.float32, ndim=2) - d.actuator_length = wp.array(tile(mjd.actuator_length), dtype=wp.float32, ndim=2) - d.actuator_moment = wp.array(tile(actuator_moment), dtype=wp.float32, ndim=3) - d.cvel = wp.array(tile(mjd.cvel), dtype=wp.spatial_vector, ndim=2) - d.cdof_dot = wp.array(tile(mjd.cdof_dot), dtype=wp.spatial_vector, ndim=2) - d.qfrc_bias = wp.array(tile(mjd.qfrc_bias), dtype=wp.float32, ndim=2) - d.qfrc_passive = wp.array(tile(mjd.qfrc_passive), dtype=wp.float32, ndim=2) - d.qfrc_spring = wp.array(tile(mjd.qfrc_spring), dtype=wp.float32, ndim=2) - d.qfrc_damper = wp.array(tile(mjd.qfrc_damper), dtype=wp.float32, ndim=2) - d.qfrc_actuator = wp.array(tile(mjd.qfrc_actuator), dtype=wp.float32, ndim=2) - d.qfrc_smooth = wp.array(tile(mjd.qfrc_smooth), dtype=wp.float32, ndim=2) - d.qfrc_constraint = wp.array(tile(mjd.qfrc_constraint), dtype=wp.float32, ndim=2) - d.qacc_smooth = wp.array(tile(mjd.qacc_smooth), dtype=wp.float32, ndim=2) - d.act = wp.array(tile(mjd.act), dtype=wp.float32, ndim=2) - d.act_dot = wp.array(tile(mjd.act_dot), dtype=wp.float32, ndim=2) - - nefc = mjd.nefc - efc_worldid = np.zeros(njmax, dtype=int) - - for i in range(nworld): - efc_worldid[i * nefc : (i + 1) * nefc] = i - - nefc_fill = njmax - nworld * nefc - - efc_J_fill = np.vstack( - [np.repeat(efc_J, nworld, axis=0), np.zeros((nefc_fill, mjm.nv))] - ) - efc_D_fill = np.concatenate( - [np.repeat(mjd.efc_D, nworld, axis=0), np.zeros(nefc_fill)] - ) - efc_pos_fill = np.concatenate( - [np.repeat(mjd.efc_pos, nworld, axis=0), np.zeros(nefc_fill)] - ) - efc_aref_fill = np.concatenate( - [np.repeat(mjd.efc_aref, nworld, axis=0), np.zeros(nefc_fill)] - ) - efc_force_fill = np.concatenate( - [np.repeat(mjd.efc_force, nworld, axis=0), np.zeros(nefc_fill)] - ) - efc_margin_fill = np.concatenate( - [np.repeat(mjd.efc_margin, nworld, axis=0), np.zeros(nefc_fill)] - ) - ncon = mjd.ncon - condim_max = np.max(mjm.geom_condim) - con_efc_address = np.zeros((nconmax, condim_max), dtype=int) - for i in range(nworld): - for j in range(ncon): - condim = mjd.contact.dim[j] - for k in range(condim): - con_efc_address[i * ncon + j, k] = mjd.nefc * i + mjd.contact.efc_address[j] + k - - con_worldid = np.zeros(nconmax, dtype=int) - for i in range(nworld): - con_worldid[i * ncon : (i + 1) * ncon] = i + # TODO(taylorhowell): sparse actuator_moment + actuator_moment = np.zeros((mjm.nu, mjm.nv)) + mujoco.mju_sparse2dense( + actuator_moment, + mjd.actuator_moment, + mjd.moment_rownnz, + mjd.moment_rowadr, + mjd.moment_colind, + ) - ncon_fill = nconmax - nworld * ncon + d.qpos = wp.array(tile(mjd.qpos), dtype=wp.float32, ndim=2) + d.qvel = wp.array(tile(mjd.qvel), dtype=wp.float32, ndim=2) + d.qacc_warmstart = wp.array(tile(mjd.qacc_warmstart), dtype=wp.float32, ndim=2) + d.qfrc_applied = wp.array(tile(mjd.qfrc_applied), dtype=wp.float32, ndim=2) + d.mocap_pos = wp.array(tile(mjd.mocap_pos), dtype=wp.vec3, ndim=2) + d.mocap_quat = wp.array(tile(mjd.mocap_quat), dtype=wp.quat, ndim=2) + d.qacc = wp.array(tile(mjd.qacc), dtype=wp.float32, ndim=2) + d.xanchor = wp.array(tile(mjd.xanchor), dtype=wp.vec3, ndim=2) + d.xaxis = wp.array(tile(mjd.xaxis), dtype=wp.vec3, ndim=2) + d.xmat = wp.array(tile(mjd.xmat), dtype=wp.mat33, ndim=2) + d.xpos = wp.array(tile(mjd.xpos), dtype=wp.vec3, ndim=2) + d.xquat = wp.array(tile(mjd.xquat), dtype=wp.quat, ndim=2) + d.xipos = wp.array(tile(mjd.xipos), dtype=wp.vec3, ndim=2) + d.ximat = wp.array(tile(mjd.ximat), dtype=wp.mat33, ndim=2) + d.subtree_com = wp.array(tile(mjd.subtree_com), dtype=wp.vec3, ndim=2) + d.geom_xpos = wp.array(tile(mjd.geom_xpos), dtype=wp.vec3, ndim=2) + d.geom_xmat = wp.array(tile(mjd.geom_xmat), dtype=wp.mat33, ndim=2) + d.site_xpos = wp.array(tile(mjd.site_xpos), dtype=wp.vec3, ndim=2) + d.site_xmat = wp.array(tile(mjd.site_xmat), dtype=wp.mat33, ndim=2) + d.cam_xpos = wp.array(tile(mjd.cam_xpos), dtype=wp.vec3, ndim=2) + d.cam_xmat = wp.array(tile(mjd.cam_xmat.reshape(-1, 3, 3)), dtype=wp.mat33, ndim=2) + d.light_xpos = wp.array(tile(mjd.light_xpos), dtype=wp.vec3, ndim=2) + d.light_xdir = wp.array(tile(mjd.light_xdir), dtype=wp.vec3, ndim=2) + d.cinert = wp.array(tile(mjd.cinert), dtype=types.vec10, ndim=2) + d.cdof = wp.array(tile(mjd.cdof), dtype=wp.spatial_vector, ndim=2) + d.crb = wp.array(tile(mjd.crb), dtype=types.vec10, ndim=2) + d.qM = wp.array(tile(qM), dtype=wp.float32, ndim=3) + d.qLD = wp.array(tile(qLD), dtype=wp.float32, ndim=3) + d.qLDiagInv = wp.array(tile(mjd.qLDiagInv), dtype=wp.float32, ndim=2) + d.ctrl = wp.array(tile(mjd.ctrl), dtype=wp.float32, ndim=2) + d.actuator_velocity = wp.array( + tile(mjd.actuator_velocity), dtype=wp.float32, ndim=2 + ) + d.actuator_force = wp.array(tile(mjd.actuator_force), dtype=wp.float32, ndim=2) + d.actuator_length = wp.array(tile(mjd.actuator_length), dtype=wp.float32, ndim=2) + d.actuator_moment = wp.array(tile(actuator_moment), dtype=wp.float32, ndim=3) + d.cvel = wp.array(tile(mjd.cvel), dtype=wp.spatial_vector, ndim=2) + d.cdof_dot = wp.array(tile(mjd.cdof_dot), dtype=wp.spatial_vector, ndim=2) + d.qfrc_bias = wp.array(tile(mjd.qfrc_bias), dtype=wp.float32, ndim=2) + d.qfrc_passive = wp.array(tile(mjd.qfrc_passive), dtype=wp.float32, ndim=2) + d.qfrc_spring = wp.array(tile(mjd.qfrc_spring), dtype=wp.float32, ndim=2) + d.qfrc_damper = wp.array(tile(mjd.qfrc_damper), dtype=wp.float32, ndim=2) + d.qfrc_actuator = wp.array(tile(mjd.qfrc_actuator), dtype=wp.float32, ndim=2) + d.qfrc_smooth = wp.array(tile(mjd.qfrc_smooth), dtype=wp.float32, ndim=2) + d.qfrc_constraint = wp.array(tile(mjd.qfrc_constraint), dtype=wp.float32, ndim=2) + d.qacc_smooth = wp.array(tile(mjd.qacc_smooth), dtype=wp.float32, ndim=2) + d.act = wp.array(tile(mjd.act), dtype=wp.float32, ndim=2) + d.act_dot = wp.array(tile(mjd.act_dot), dtype=wp.float32, ndim=2) + + nefc = mjd.nefc + efc_worldid = np.zeros(njmax, dtype=int) + + for i in range(nworld): + efc_worldid[i * nefc : (i + 1) * nefc] = i + + nefc_fill = njmax - nworld * nefc + + efc_J_fill = np.vstack( + [np.repeat(efc_J, nworld, axis=0), np.zeros((nefc_fill, mjm.nv))] + ) + efc_D_fill = np.concatenate( + [np.repeat(mjd.efc_D, nworld, axis=0), np.zeros(nefc_fill)] + ) + efc_pos_fill = np.concatenate( + [np.repeat(mjd.efc_pos, nworld, axis=0), np.zeros(nefc_fill)] + ) + efc_aref_fill = np.concatenate( + [np.repeat(mjd.efc_aref, nworld, axis=0), np.zeros(nefc_fill)] + ) + efc_force_fill = np.concatenate( + [np.repeat(mjd.efc_force, nworld, axis=0), np.zeros(nefc_fill)] + ) + efc_margin_fill = np.concatenate( + [np.repeat(mjd.efc_margin, nworld, axis=0), np.zeros(nefc_fill)] + ) - con_dist_fill = np.concatenate( - [np.repeat(mjd.contact.dist, nworld, axis=0), np.zeros(ncon_fill)] - ) - con_pos_fill = np.vstack( - [np.repeat(mjd.contact.pos, nworld, axis=0), np.zeros((ncon_fill, 3))] - ) - con_frame_fill = np.vstack( - [np.repeat(mjd.contact.frame, nworld, axis=0), np.zeros((ncon_fill, 9))] - ) - con_includemargin_fill = np.concatenate( - [np.repeat(mjd.contact.includemargin, nworld, axis=0), np.zeros(ncon_fill)] - ) - con_friction_fill = np.vstack( - [np.repeat(mjd.contact.friction, nworld, axis=0), np.zeros((ncon_fill, 5))] - ) - con_solref_fill = np.vstack( - [np.repeat(mjd.contact.solref, nworld, axis=0), np.zeros((ncon_fill, 2))] - ) - con_solreffriction_fill = np.vstack( - [np.repeat(mjd.contact.solreffriction, nworld, axis=0), np.zeros((ncon_fill, 2))] - ) - con_solimp_fill = np.vstack( - [np.repeat(mjd.contact.solimp, nworld, axis=0), np.zeros((ncon_fill, 5))] - ) - con_dim_fill = np.concatenate( - [np.repeat(mjd.contact.dim, nworld, axis=0), np.zeros(ncon_fill)] - ) - con_geom_fill = np.vstack( - [np.repeat(mjd.contact.geom, nworld, axis=0), np.zeros((ncon_fill, 2))] - ) - con_efc_address_fill = np.vstack([con_efc_address, np.zeros((ncon_fill, condim_max))]) - - d.contact.dist = wp.array(con_dist_fill, dtype=wp.float32, ndim=1) - d.contact.pos = wp.array(con_pos_fill, dtype=wp.vec3f, ndim=1) - d.contact.frame = wp.array(con_frame_fill, dtype=wp.mat33f, ndim=1) - d.contact.includemargin = wp.array(con_includemargin_fill, dtype=wp.float32, ndim=1) - d.contact.friction = wp.array(con_friction_fill, dtype=types.vec5, ndim=1) - d.contact.solref = wp.array(con_solref_fill, dtype=wp.vec2f, ndim=1) - d.contact.solreffriction = wp.array(con_solreffriction_fill, dtype=wp.vec2f, ndim=1) - d.contact.solimp = wp.array(con_solimp_fill, dtype=types.vec5, ndim=1) - d.contact.dim = wp.array(con_dim_fill, dtype=wp.int32, ndim=1) - d.contact.geom = wp.array(con_geom_fill, dtype=wp.vec2i, ndim=1) - d.contact.efc_address = wp.array(con_efc_address_fill, dtype=wp.int32, ndim=2) - d.contact.worldid = wp.array(con_worldid, dtype=wp.int32, ndim=1) - - d.efc = _constraint(mjm, d.nworld, d.njmax) - d.efc.J = wp.array(efc_J_fill, dtype=wp.float32, ndim=2) - d.efc.D = wp.array(efc_D_fill, dtype=wp.float32, ndim=1) - d.efc.pos = wp.array(efc_pos_fill, dtype=wp.float32, ndim=1) - d.efc.aref = wp.array(efc_aref_fill, dtype=wp.float32, ndim=1) - d.efc.force = wp.array(efc_force_fill, dtype=wp.float32, ndim=1) - d.efc.margin = wp.array(efc_margin_fill, dtype=wp.float32, ndim=1) - d.efc.worldid = wp.from_numpy(efc_worldid, dtype=wp.int32) - - d.xfrc_applied = wp.array(tile(mjd.xfrc_applied), dtype=wp.spatial_vector, ndim=2) - - # internal tmp arrays - d.qfrc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qacc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qM_integration = wp.zeros_like(d.qM) - d.qLD_integration = wp.zeros_like(d.qLD) - d.qLDiagInv_integration = wp.zeros_like(d.qLDiagInv) - d.act_vel_integration = wp.zeros_like(d.ctrl) - d.qpos_t0 = wp.zeros((nworld, mjm.nq), dtype=wp.float32) - d.qvel_t0 = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.act_t0 = wp.zeros((nworld, mjm.na), dtype=wp.float32) - d.qvel_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.qacc_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) - d.act_dot_rk = wp.zeros((nworld, mjm.na), dtype=wp.float32) - - # broadphase sweep and prune - d.sap_geom_sort = wp.zeros((nworld, mjm.ngeom), dtype=wp.vec4) - d.sap_projection_lower = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32) - d.sap_projection_upper = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32) - d.sap_sort_index = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32) - d.sap_range = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32) - d.sap_cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32) - segment_indices_list = [i * mjm.ngeom for i in range(nworld + 1)] - d.sap_segment_index = wp.array(segment_indices_list, dtype=int) - - # collision driver - d.collision_pair = wp.empty(nconmax, dtype=wp.vec2i, ndim=1) - d.collision_worldid = wp.empty(nconmax, dtype=wp.int32, ndim=1) - d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1) - - # rne_postconstraint - d.cacc = wp.array(tile(mjd.cacc), dtype=wp.spatial_vector, ndim=2) - d.cfrc_int = wp.array(tile(mjd.cfrc_int), dtype=wp.spatial_vector, ndim=2) - d.cfrc_ext = wp.array(tile(mjd.cfrc_ext), dtype=wp.spatial_vector, ndim=2) - - # tendon - d.ten_length = wp.array(tile(mjd.ten_length), dtype=wp.float32, ndim=2) - - if support.is_sparse(mjm) and mjm.ntendon: - ten_J = np.zeros((mjm.ntendon, mjm.nv)) - mujoco.mju_sparse2dense( - ten_J, mjd.ten_J, mjd.ten_J_rownnz, mjd.ten_J_rowadr, mjd.ten_J_colind + ncon = mjd.ncon + condim_max = np.max(mjm.geom_condim) + con_efc_address = np.zeros((nconmax, condim_max), dtype=int) + for i in range(nworld): + for j in range(ncon): + condim = mjd.contact.dim[j] + for k in range(condim): + con_efc_address[i * ncon + j, k] = ( + mjd.nefc * i + mjd.contact.efc_address[j] + k + ) + + con_worldid = np.zeros(nconmax, dtype=int) + for i in range(nworld): + con_worldid[i * ncon : (i + 1) * ncon] = i + + ncon_fill = nconmax - nworld * ncon + + con_dist_fill = np.concatenate( + [np.repeat(mjd.contact.dist, nworld, axis=0), np.zeros(ncon_fill)] ) - else: - ten_J = mjd.ten_J.reshape((mjm.ntendon, mjm.nv)) + con_pos_fill = np.vstack( + [np.repeat(mjd.contact.pos, nworld, axis=0), np.zeros((ncon_fill, 3))] + ) + con_frame_fill = np.vstack( + [np.repeat(mjd.contact.frame, nworld, axis=0), np.zeros((ncon_fill, 9))] + ) + con_includemargin_fill = np.concatenate( + [np.repeat(mjd.contact.includemargin, nworld, axis=0), np.zeros(ncon_fill)] + ) + con_friction_fill = np.vstack( + [np.repeat(mjd.contact.friction, nworld, axis=0), np.zeros((ncon_fill, 5))] + ) + con_solref_fill = np.vstack( + [np.repeat(mjd.contact.solref, nworld, axis=0), np.zeros((ncon_fill, 2))] + ) + con_solreffriction_fill = np.vstack( + [np.repeat(mjd.contact.solreffriction, nworld, axis=0), np.zeros((ncon_fill, 2))] + ) + con_solimp_fill = np.vstack( + [np.repeat(mjd.contact.solimp, nworld, axis=0), np.zeros((ncon_fill, 5))] + ) + con_dim_fill = np.concatenate( + [np.repeat(mjd.contact.dim, nworld, axis=0), np.zeros(ncon_fill)] + ) + con_geom_fill = np.vstack( + [np.repeat(mjd.contact.geom, nworld, axis=0), np.zeros((ncon_fill, 2))] + ) + con_efc_address_fill = np.vstack( + [con_efc_address, np.zeros((ncon_fill, condim_max))] + ) + + d.contact.dist = wp.array(con_dist_fill, dtype=wp.float32, ndim=1) + d.contact.pos = wp.array(con_pos_fill, dtype=wp.vec3f, ndim=1) + d.contact.frame = wp.array(con_frame_fill, dtype=wp.mat33f, ndim=1) + d.contact.includemargin = wp.array(con_includemargin_fill, dtype=wp.float32, ndim=1) + d.contact.friction = wp.array(con_friction_fill, dtype=types.vec5, ndim=1) + d.contact.solref = wp.array(con_solref_fill, dtype=wp.vec2f, ndim=1) + d.contact.solreffriction = wp.array(con_solreffriction_fill, dtype=wp.vec2f, ndim=1) + d.contact.solimp = wp.array(con_solimp_fill, dtype=types.vec5, ndim=1) + d.contact.dim = wp.array(con_dim_fill, dtype=wp.int32, ndim=1) + d.contact.geom = wp.array(con_geom_fill, dtype=wp.vec2i, ndim=1) + d.contact.efc_address = wp.array(con_efc_address_fill, dtype=wp.int32, ndim=2) + d.contact.worldid = wp.array(con_worldid, dtype=wp.int32, ndim=1) + + d.efc = _constraint(mjm, d.nworld, d.njmax) + d.efc.J = wp.array(efc_J_fill, dtype=wp.float32, ndim=2) + d.efc.D = wp.array(efc_D_fill, dtype=wp.float32, ndim=1) + d.efc.pos = wp.array(efc_pos_fill, dtype=wp.float32, ndim=1) + d.efc.aref = wp.array(efc_aref_fill, dtype=wp.float32, ndim=1) + d.efc.force = wp.array(efc_force_fill, dtype=wp.float32, ndim=1) + d.efc.margin = wp.array(efc_margin_fill, dtype=wp.float32, ndim=1) + d.efc.worldid = wp.from_numpy(efc_worldid, dtype=wp.int32) + + d.xfrc_applied = wp.array(tile(mjd.xfrc_applied), dtype=wp.spatial_vector, ndim=2) + + # internal tmp arrays + d.qfrc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qacc_integration = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qM_integration = wp.zeros_like(d.qM) + d.qLD_integration = wp.zeros_like(d.qLD) + d.qLDiagInv_integration = wp.zeros_like(d.qLDiagInv) + d.act_vel_integration = wp.zeros_like(d.ctrl) + d.qpos_t0 = wp.zeros((nworld, mjm.nq), dtype=wp.float32) + d.qvel_t0 = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.act_t0 = wp.zeros((nworld, mjm.na), dtype=wp.float32) + d.qvel_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.qacc_rk = wp.zeros((nworld, mjm.nv), dtype=wp.float32) + d.act_dot_rk = wp.zeros((nworld, mjm.na), dtype=wp.float32) + + # broadphase sweep and prune + d.sap_geom_sort = wp.zeros((nworld, mjm.ngeom), dtype=wp.vec4) + d.sap_projection_lower = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.float32) + d.sap_projection_upper = wp.zeros((nworld, mjm.ngeom), dtype=wp.float32) + d.sap_sort_index = wp.zeros((2 * nworld, mjm.ngeom), dtype=wp.int32) + d.sap_range = wp.zeros((nworld, mjm.ngeom), dtype=wp.int32) + d.sap_cumulative_sum = wp.zeros(nworld * mjm.ngeom, dtype=wp.int32) + segment_indices_list = [i * mjm.ngeom for i in range(nworld + 1)] + d.sap_segment_index = wp.array(segment_indices_list, dtype=int) + + # collision driver + d.collision_pair = wp.empty(nconmax, dtype=wp.vec2i, ndim=1) + d.collision_worldid = wp.empty(nconmax, dtype=wp.int32, ndim=1) + d.ncollision = wp.zeros(1, dtype=wp.int32, ndim=1) + + # rne_postconstraint + d.cacc = wp.array(tile(mjd.cacc), dtype=wp.spatial_vector, ndim=2) + d.cfrc_int = wp.array(tile(mjd.cfrc_int), dtype=wp.spatial_vector, ndim=2) + d.cfrc_ext = wp.array(tile(mjd.cfrc_ext), dtype=wp.spatial_vector, ndim=2) + + # tendon + d.ten_length = wp.array(tile(mjd.ten_length), dtype=wp.float32, ndim=2) + + if support.is_sparse(mjm) and mjm.ntendon: + ten_J = np.zeros((mjm.ntendon, mjm.nv)) + mujoco.mju_sparse2dense( + ten_J, mjd.ten_J, mjd.ten_J_rownnz, mjd.ten_J_rowadr, mjd.ten_J_colind + ) + else: + ten_J = mjd.ten_J.reshape((mjm.ntendon, mjm.nv)) - d.ten_J = wp.array(tile(ten_J), dtype=wp.float32, ndim=3) + d.ten_J = wp.array(tile(ten_J), dtype=wp.float32, ndim=3) - # sensors - d.sensordata = wp.array(tile(mjd.sensordata), dtype=wp.float32, ndim=2) + # sensors + d.sensordata = wp.array(tile(mjd.sensordata), dtype=wp.float32, ndim=2) return d diff --git a/mujoco_warp/_src/passive.py b/mujoco_warp/_src/passive.py index c8cbc30b..af182dca 100644 --- a/mujoco_warp/_src/passive.py +++ b/mujoco_warp/_src/passive.py @@ -104,5 +104,5 @@ def _damper_passive(m: Model, d: Data): # TODO(team): mj_inertiaBoxFluidModell d.qfrc_spring.zero_() - wp.launch(_spring, dim=(d.nworld, m.njnt), inputs=[m, d]) - wp.launch(_damper_passive, dim=(d.nworld, m.nv), inputs=[m, d]) + wp.launch(_spring, dim=(d.nworld, m.njnt), inputs=[m, d], device=m.device) + wp.launch(_damper_passive, dim=(d.nworld, m.nv), inputs=[m, d], device=m.device) diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py index 8d496bb2..4f998583 100644 --- a/mujoco_warp/_src/sensor.py +++ b/mujoco_warp/_src/sensor.py @@ -46,7 +46,9 @@ def _sensor_pos(m: Model, d: Data): if (m.sensor_pos_adr.size == 0) or (m.opt.disableflags & DisableBit.SENSOR): return - wp.launch(_sensor_pos, dim=(d.nworld, m.sensor_pos_adr.size), inputs=[m, d]) + wp.launch( + _sensor_pos, dim=(d.nworld, m.sensor_pos_adr.size), inputs=[m, d], device=m.device + ) @wp.func @@ -72,7 +74,9 @@ def _sensor_vel(m: Model, d: Data): if (m.sensor_vel_adr.size == 0) or (m.opt.disableflags & DisableBit.SENSOR): return - wp.launch(_sensor_vel, dim=(d.nworld, m.sensor_vel_adr.size), inputs=[m, d]) + wp.launch( + _sensor_vel, dim=(d.nworld, m.sensor_vel_adr.size), inputs=[m, d], device=m.device + ) @wp.func @@ -98,4 +102,6 @@ def _sensor_acc(m: Model, d: Data): if (m.sensor_acc_adr.size == 0) or (m.opt.disableflags & DisableBit.SENSOR): return - wp.launch(_sensor_acc, dim=(d.nworld, m.sensor_acc_adr.size), inputs=[m, d]) + wp.launch( + _sensor_acc, dim=(d.nworld, m.sensor_acc_adr.size), inputs=[m, d], device=m.device + ) diff --git a/mujoco_warp/_src/smooth.py b/mujoco_warp/_src/smooth.py index e5251807..b0951870 100644 --- a/mujoco_warp/_src/smooth.py +++ b/mujoco_warp/_src/smooth.py @@ -134,19 +134,23 @@ def site_local_to_global(m: Model, d: Data): math.mul_quat(xquat, m.site_quat[siteid]) ) - wp.launch(_root, dim=(d.nworld), inputs=[m, d]) + wp.launch(_root, dim=(d.nworld), inputs=[m, d], device=m.device) body_treeadr = m.body_treeadr.numpy() for i in range(1, len(body_treeadr)): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(_level, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + wp.launch(_level, dim=(d.nworld, end - beg), inputs=[m, d, beg], device=m.device) if m.ngeom: - wp.launch(geom_local_to_global, dim=(d.nworld, m.ngeom), inputs=[m, d]) + wp.launch( + geom_local_to_global, dim=(d.nworld, m.ngeom), inputs=[m, d], device=m.device + ) if m.nsite: - wp.launch(site_local_to_global, dim=(d.nworld, m.nsite), inputs=[m, d]) + wp.launch( + site_local_to_global, dim=(d.nworld, m.nsite), inputs=[m, d], device=m.device + ) @event_scope @@ -235,18 +239,20 @@ def cdof(m: Model, d: Data): elif jnt_type == wp.static(JointType.HINGE.value): # hinge res[dofid] = wp.spatial_vector(xaxis, wp.cross(xaxis, offset)) - wp.launch(subtree_com_init, dim=(d.nworld, m.nbody), inputs=[m, d]) + wp.launch(subtree_com_init, dim=(d.nworld, m.nbody), inputs=[m, d], device=m.device) body_treeadr = m.body_treeadr.numpy() for i in reversed(range(len(body_treeadr))): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(subtree_com_acc, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + wp.launch( + subtree_com_acc, dim=(d.nworld, end - beg), inputs=[m, d, beg], device=m.device + ) - wp.launch(subtree_div, dim=(d.nworld, m.nbody), inputs=[m, d]) - wp.launch(cinert, dim=(d.nworld, m.nbody), inputs=[m, d]) - wp.launch(cdof, dim=(d.nworld, m.njnt), inputs=[m, d]) + wp.launch(subtree_div, dim=(d.nworld, m.nbody), inputs=[m, d], device=m.device) + wp.launch(cinert, dim=(d.nworld, m.nbody), inputs=[m, d], device=m.device) + wp.launch(cdof, dim=(d.nworld, m.njnt), inputs=[m, d], device=m.device) @event_scope @@ -338,18 +344,22 @@ def light_fn(m: Model, d: Data): d.light_xdir[worldid, lightid] = wp.normalize(d.light_xdir[worldid, lightid]) if m.ncam > 0: - wp.launch(cam_local_to_global, dim=(d.nworld, m.ncam), inputs=[m, d]) - wp.launch(cam_fn, dim=(d.nworld, m.ncam), inputs=[m, d]) + wp.launch( + cam_local_to_global, dim=(d.nworld, m.ncam), inputs=[m, d], device=m.device + ) + wp.launch(cam_fn, dim=(d.nworld, m.ncam), inputs=[m, d], device=m.device) if m.nlight > 0: - wp.launch(light_local_to_global, dim=(d.nworld, m.nlight), inputs=[m, d]) - wp.launch(light_fn, dim=(d.nworld, m.nlight), inputs=[m, d]) + wp.launch( + light_local_to_global, dim=(d.nworld, m.nlight), inputs=[m, d], device=m.device + ) + wp.launch(light_fn, dim=(d.nworld, m.nlight), inputs=[m, d], device=m.device) @event_scope def crb(m: Model, d: Data): """Composite rigid body inertia algorithm.""" - kernel_copy(d.crb, d.cinert) + kernel_copy(d.crb, d.cinert, m.device) @kernel def crb_accumulate(m: Model, d: Data, leveladr: int): @@ -405,13 +415,15 @@ def qM_dense(m: Model, d: Data): for i in reversed(range(len(body_treeadr))): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(crb_accumulate, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + wp.launch( + crb_accumulate, dim=(d.nworld, end - beg), inputs=[m, d, beg], device=m.device + ) d.qM.zero_() if m.opt.is_sparse: - wp.launch(qM_sparse, dim=(d.nworld, m.nv), inputs=[m, d]) + wp.launch(qM_sparse, dim=(d.nworld, m.nv), inputs=[m, d], device=m.device) else: - wp.launch(qM_dense, dim=(d.nworld, m.nv), inputs=[m, d]) + wp.launch(qM_dense, dim=(d.nworld, m.nv), inputs=[m, d], device=m.device) def _factor_i_sparse_legacy(m: Model, d: Data, M: array3df, L: array3df, D: array2df): @@ -436,7 +448,7 @@ def qLDiag_div(m: Model, L: array3df, D: array2df): worldid, dofid = wp.tid() D[worldid, dofid] = 1.0 / L[worldid, 0, m.dof_Madr[dofid]] - kernel_copy(L, M) + kernel_copy(L, M, m.device) qLD_update_treeadr = m.qLD_update_treeadr.numpy() @@ -445,9 +457,9 @@ def qLDiag_div(m: Model, L: array3df, D: array2df): beg, end = qLD_update_treeadr[i], m.qLD_update_tree.shape[0] else: beg, end = qLD_update_treeadr[i], qLD_update_treeadr[i + 1] - wp.launch(qLD_acc, dim=(d.nworld, end - beg), inputs=[m, beg, L]) + wp.launch(qLD_acc, dim=(d.nworld, end - beg), inputs=[m, beg, L], device=m.device) - wp.launch(qLDiag_div, dim=(d.nworld, m.nv), inputs=[m, L, D]) + wp.launch(qLDiag_div, dim=(d.nworld, m.nv), inputs=[m, L, D], device=m.device) def _factor_i_sparse(m: Model, d: Data, M: array3df, L: array3df, D: array2df): @@ -483,7 +495,7 @@ def copy_CSR(L: array3df, M: array3df, mapM2M: wp.array(dtype=wp.int32, ndim=1)) worldid, ind = wp.tid() L[worldid, 0, ind] = M[worldid, 0, mapM2M[ind]] - wp.launch(copy_CSR, dim=(d.nworld, m.nM), inputs=[L, M, m.mapM2M]) + wp.launch(copy_CSR, dim=(d.nworld, m.nM), inputs=[L, M, m.mapM2M], device=m.device) qLD_update_treeadr = m.qLD_update_treeadr.numpy() @@ -492,9 +504,9 @@ def copy_CSR(L: array3df, M: array3df, mapM2M: wp.array(dtype=wp.int32, ndim=1)) beg, end = qLD_update_treeadr[i], m.qLD_update_tree.shape[0] else: beg, end = qLD_update_treeadr[i], qLD_update_treeadr[i + 1] - wp.launch(qLD_acc, dim=(d.nworld, end - beg), inputs=[m, beg, L]) + wp.launch(qLD_acc, dim=(d.nworld, end - beg), inputs=[m, beg, L], device=m.device) - wp.launch(qLDiag_div, dim=(d.nworld, m.nv), inputs=[m, L, D]) + wp.launch(qLDiag_div, dim=(d.nworld, m.nv), inputs=[m, L, D], device=m.device) def _factor_i_dense(m: Model, d: Data, M: wp.array, L: wp.array): @@ -512,10 +524,14 @@ def cholesky(m: Model, leveladr: int, M: array3df, L: array3df): M[worldid], shape=(tilesize, tilesize), offset=(dofid, dofid) ) L_tile = wp.tile_cholesky(M_tile) - wp.tile_store(L[worldid], L_tile, offset=(dofid, dofid)) + wp.tile_store(L[worldid], L_tile, offset=(dofid, dofid), device=m.device) wp.launch_tiled( - cholesky, dim=(d.nworld, size), inputs=[m, adr, M, L], block_dim=block_dim + cholesky, + dim=(d.nworld, size), + inputs=[m, adr, M, L], + block_dim=block_dim, + device=m.device, ) qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() @@ -551,7 +567,7 @@ def _cacc_world(m: Model, d: Data): if m.opt.disableflags & DisableBit.GRAVITY: d.cacc.zero_() else: - wp.launch(_cacc_world, dim=[d.nworld], inputs=[m, d]) + wp.launch(_cacc_world, dim=[d.nworld], inputs=[m, d], device=m.device) def _rne_cacc_forward(m: Model, d: Data, flg_acc: bool = False): @@ -577,7 +593,7 @@ def _cacc( for i in range(len(body_treeadr)): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(_cacc, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + wp.launch(_cacc, dim=(d.nworld, end - beg), inputs=[m, d, beg], device=m.device) def _rne_cfrc(m: Model, d: Data, flg_cfrc_ext: bool = False): @@ -596,7 +612,7 @@ def _cfrc(d: Data): d.cfrc_int[worldid, bodyid] = frc - wp.launch(_cfrc, dim=[d.nworld, m.nbody - 1], inputs=[d]) + wp.launch(_cfrc, dim=[d.nworld, m.nbody - 1], inputs=[d], device=m.device) def _rne_cfrc_backward(m: Model, d: Data): @@ -612,7 +628,7 @@ def _cfrc(m: Model, d: Data, leveladr: int): for i in reversed(range(len(body_treeadr))): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(_cfrc, dim=[d.nworld, end - beg], inputs=[m, d, beg]) + wp.launch(_cfrc, dim=[d.nworld, end - beg], inputs=[m, d, beg], device=m.device) @event_scope @@ -632,7 +648,7 @@ def qfrc_bias(m: Model, d: Data): _rne_cfrc(m, d) _rne_cfrc_backward(m, d) - wp.launch(qfrc_bias, dim=[d.nworld, m.nv], inputs=[m, d]) + wp.launch(qfrc_bias, dim=[d.nworld, m.nv], inputs=[m, d], device=m.device) @event_scope @@ -654,7 +670,7 @@ def _cfrc_ext(m: Model, d: Data): xfrc_applied, subtree_com - xipos ) - wp.launch(_cfrc_ext, dim=(d.nworld, m.nbody), inputs=[m, d]) + wp.launch(_cfrc_ext, dim=(d.nworld, m.nbody), inputs=[m, d], device=m.device) # TODO(team): cfrc_ext += contacts # TODO(team): cfrc_ext += equality @@ -737,6 +753,7 @@ def _transmission( dim=[d.nworld, m.nu], inputs=[m, d], outputs=[d.actuator_length, d.actuator_moment], + device=m.device, ) @@ -801,13 +818,13 @@ def _level(m: Model, d: Data, leveladr: int): d.cvel[worldid, bodyid] = cvel - wp.launch(_root, dim=(d.nworld, 6), inputs=[d]) + wp.launch(_root, dim=(d.nworld, 6), inputs=[d], device=m.device) body_treeadr = m.body_treeadr.numpy() for i in range(1, len(body_treeadr)): beg = body_treeadr[i] end = m.nbody if i == len(body_treeadr) - 1 else body_treeadr[i + 1] - wp.launch(_level, dim=(d.nworld, end - beg), inputs=[m, d, beg]) + wp.launch(_level, dim=(d.nworld, end - beg), inputs=[m, d, beg], device=m.device) def _solve_LD_sparse( @@ -834,7 +851,7 @@ def x_acc_down(m: Model, L: array3df, x: array2df, leveladr: int): i, k, Madr_ki = update[0], update[1], update[2] wp.atomic_sub(x[worldid], k, L[worldid, 0, Madr_ki] * x[worldid, i]) - kernel_copy(x, y) + kernel_copy(x, y, m.device) qLD_update_treeadr = m.qLD_update_treeadr.numpy() @@ -843,16 +860,20 @@ def x_acc_down(m: Model, L: array3df, x: array2df, leveladr: int): beg, end = qLD_update_treeadr[i], m.qLD_update_tree.shape[0] else: beg, end = qLD_update_treeadr[i], qLD_update_treeadr[i + 1] - wp.launch(x_acc_up, dim=(d.nworld, end - beg), inputs=[m, L, x, beg]) + wp.launch( + x_acc_up, dim=(d.nworld, end - beg), inputs=[m, L, x, beg], device=m.device + ) - wp.launch(qLDiag_mul, dim=(d.nworld, m.nv), inputs=[D, x]) + wp.launch(qLDiag_mul, dim=(d.nworld, m.nv), inputs=[D, x], device=m.device) for i in range(len(qLD_update_treeadr)): if i == len(qLD_update_treeadr) - 1: beg, end = qLD_update_treeadr[i], m.qLD_update_tree.shape[0] else: beg, end = qLD_update_treeadr[i], qLD_update_treeadr[i + 1] - wp.launch(x_acc_down, dim=(d.nworld, end - beg), inputs=[m, L, x, beg]) + wp.launch( + x_acc_down, dim=(d.nworld, end - beg), inputs=[m, L, x, beg], device=m.device + ) def _solve_LD_dense(m: Model, d: Data, L: array3df, x: array2df, y: array2df): @@ -874,7 +895,11 @@ def cho_solve(m: Model, L: array3df, x: array2df, y: array2df, leveladr: int): wp.tile_store(x[worldid], x_slice, offset=(dofid,)) wp.launch_tiled( - cho_solve, dim=(d.nworld, size), inputs=[m, L, x, y, adr], block_dim=block_dim + cho_solve, + dim=(d.nworld, size), + inputs=[m, L, x, y, adr], + block_dim=block_dim, + device=m.device, ) qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() @@ -919,7 +944,11 @@ def cholesky(m: Model, leveladr: int, M: array3df, x: array2df, y: array2df): wp.tile_store(x[worldid], x_slice, offset=(dofid,)) wp.launch_tiled( - cholesky, dim=(d.nworld, size), inputs=[m, adr, M, x, y], block_dim=block_dim + cholesky, + dim=(d.nworld, size), + inputs=[m, adr, M, x, y], + block_dim=block_dim, + device=m.device, ) qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() @@ -968,6 +997,8 @@ def _joint_tendon(m: Model, d: Data): # add to moment d.ten_J[worldid, tendon_jnt_adr, m.jnt_dofadr[wrap_jnt_adr]] = prm - wp.launch(_joint_tendon, dim=(d.nworld, m.wrap_jnt_adr.size), inputs=[m, d]) + wp.launch( + _joint_tendon, dim=(d.nworld, m.wrap_jnt_adr.size), inputs=[m, d], device=m.device + ) # TODO(team): spatial diff --git a/mujoco_warp/_src/solver.py b/mujoco_warp/_src/solver.py index dafee604..d960ff66 100644 --- a/mujoco_warp/_src/solver.py +++ b/mujoco_warp/_src/solver.py @@ -54,22 +54,22 @@ def _search(d: types.Data): d.efc.search[worldid, dofid] = search wp.atomic_add(d.efc.search_dot, worldid, search * search) - wp.launch(_init_context, dim=(d.nworld), inputs=[d]) + wp.launch(_init_context, dim=(d.nworld), inputs=[d], device=m.device) # jaref = d.efc_J @ d.qacc - d.efc_aref - d.efc.Jaref.zero_() + d.efc.Jaref.zero_(device=m.device) - wp.launch(_jaref, dim=(d.njmax, m.nv), inputs=[m, d]) + wp.launch(_jaref, dim=(d.njmax, m.nv), inputs=[m, d], device=m.device) # Ma = qM @ qacc - support.mul_m(m, d, d.efc.Ma, d.qacc, d.efc.done) + support.mul_m(m, d, d.efc.Ma, d.qacc, d.efc.done, device=m.device) - _update_constraint(m, d) + _update_constraint(m, d, device=m.device) if grad: _update_gradient(m, d) # search = -Mgrad - wp.launch(_search, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_search, dim=(d.nworld, m.nv), inputs=[d], device=m.device) def _update_constraint(m: types.Model, d: types.Data): @@ -159,21 +159,21 @@ def _gauss(d: types.Data): wp.atomic_add(d.efc.gauss, worldid, gauss_cost) wp.atomic_add(d.efc.cost, worldid, gauss_cost) - wp.launch(_init_cost, dim=(d.nworld), inputs=[d]) + wp.launch(_init_cost, dim=(d.nworld), inputs=[d], device=m.device) - wp.launch(_efc_kernel, dim=(d.njmax,), inputs=[d]) + wp.launch(_efc_kernel, dim=(d.njmax,), inputs=[d], device=m.device) # qfrc_constraint = efc_J.T @ efc_force - wp.launch(_zero_qfrc_constraint, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_zero_qfrc_constraint, dim=(d.nworld, m.nv), inputs=[d], device=m.device) - wp.launch(_qfrc_constraint, dim=(m.nv, d.njmax), inputs=[d]) + wp.launch(_qfrc_constraint, dim=(m.nv, d.njmax), inputs=[d], device=m.device) # gauss = 0.5 * (Ma - qfrc_smooth).T @ (qacc - qacc_smooth) - wp.launch(_gauss, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_gauss, dim=(d.nworld, m.nv), inputs=[d], device=m.device) -def _update_gradient(m: types.Model, d: types.Data): +def _update_gradient(m: types.Model, d: types.Data, device=None): TILE = m.nv ITERATIONS = m.opt.iterations @@ -315,30 +315,46 @@ def _cholesky(d: types.Data): wp.tile_store(d.efc.Mgrad[worldid], output_tile) # grad = Ma - qfrc_smooth - qfrc_constraint - wp.launch(_zero_grad_dot, dim=(d.nworld), inputs=[d]) + wp.launch(_zero_grad_dot, dim=(d.nworld), inputs=[d], device=m.device) - wp.launch(_grad, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_grad, dim=(d.nworld, m.nv), inputs=[d], device=m.device) if m.opt.solver == types.SolverType.CG: - smooth.solve_m(m, d, d.efc.Mgrad, d.efc.grad) + smooth.solve_m(m, d, d.efc.Mgrad, d.efc.grad, device=m.device) elif m.opt.solver == types.SolverType.NEWTON: # h = qM + (efc_J.T * efc_D * active) @ efc_J if m.opt.is_sparse: - wp.launch(_zero_h_lower, dim=(d.nworld, m.dof_tri_row.size), inputs=[m, d]) + wp.launch( + _zero_h_lower, + dim=(d.nworld, m.dof_tri_row.size), + inputs=[m, d], + device=m.device, + ) wp.launch( - _set_h_qM_lower_sparse, dim=(d.nworld, m.qM_fullm_i.size), inputs=[m, d] + _set_h_qM_lower_sparse, + dim=(d.nworld, m.qM_fullm_i.size), + inputs=[m, d], + device=m.device, ) else: - wp.launch(_copy_lower_triangle, dim=(d.nworld, m.dof_tri_row.size), inputs=[m, d]) + wp.launch( + _copy_lower_triangle, + dim=(d.nworld, m.dof_tri_row.size), + inputs=[m, d], + device=m.device, + ) wp.launch( _JTDAJ, dim=(dim_x, m.dof_tri_row.size), inputs=[m, d], + device=m.device, ) - wp.launch_tiled(_cholesky, dim=(d.nworld,), inputs=[d], block_dim=32) + wp.launch_tiled( + _cholesky, dim=(d.nworld,), inputs=[d], block_dim=32, device=m.device + ) @wp.func @@ -645,11 +661,11 @@ def _swap( alpha = wp.where(improved and not plo_better, phi_alpha, alpha) d.efc.alpha[worldid] = alpha - wp.launch(_gtol, dim=(d.nworld,), inputs=[m, d]) + wp.launch(_gtol, dim=(d.nworld,), inputs=[m, d], device=m.device) # linesearch points done = d.efc.ls_done - done.zero_() + done.zero_(device=m.device) p0 = d.efc.p0 lo = d.efc.lo lo_alpha = d.efc.lo_alpha @@ -664,17 +680,24 @@ def _swap( # initialize interval - wp.launch(_init_p0_gauss, dim=(d.nworld,), inputs=[p0, d]) + wp.launch(_init_p0_gauss, dim=(d.nworld,), inputs=[p0, d], device=m.device) - wp.launch(_init_p0, dim=(d.njmax,), inputs=[p0, d]) + wp.launch(_init_p0, dim=(d.njmax,), inputs=[p0, d], device=m.device) - wp.launch(_init_lo_gauss, dim=(d.nworld,), inputs=[p0, lo, lo_alpha, d]) + wp.launch( + _init_lo_gauss, dim=(d.nworld,), inputs=[p0, lo, lo_alpha, d], device=m.device + ) - wp.launch(_init_lo, dim=(d.njmax,), inputs=[lo, lo_alpha, d]) + wp.launch(_init_lo, dim=(d.njmax,), inputs=[lo, lo_alpha, d], device=m.device) # set the lo/hi interval bounds - wp.launch(_init_bounds, dim=(d.nworld,), inputs=[p0, lo, lo_alpha, hi, hi_alpha, d]) + wp.launch( + _init_bounds, + dim=(d.nworld,), + inputs=[p0, lo, lo_alpha, hi, hi_alpha, d], + device=m.device, + ) for _ in range(m.opt.ls_iterations): # note: we always launch ls_iterations kernels, but the kernels may early exit if done is true @@ -682,15 +705,15 @@ def _swap( # of extra launches inputs = [done, lo, lo_alpha, hi, hi_alpha, lo_next, lo_next_alpha, hi_next] inputs += [hi_next_alpha, mid, mid_alpha, d] - wp.launch(_next_alpha_gauss, dim=(d.nworld,), inputs=inputs) + wp.launch(_next_alpha_gauss, dim=(d.nworld,), inputs=inputs, device=m.device) inputs = [done, lo_next, lo_next_alpha, hi_next, hi_next_alpha, mid, mid_alpha] inputs += [d] - wp.launch(_next_quad, dim=(d.njmax,), inputs=inputs) + wp.launch(_next_quad, dim=(d.njmax,), inputs=inputs, device=m.device) inputs = [done, p0, lo, lo_alpha, hi, hi_alpha, lo_next, lo_next_alpha, hi_next] inputs += [hi_next_alpha, mid, mid_alpha, d] - wp.launch(_swap, dim=(d.nworld,), inputs=inputs) + wp.launch(_swap, dim=(d.nworld,), inputs=inputs, device=m.device) def _linesearch_parallel(m: types.Model, d: types.Data): @@ -757,10 +780,12 @@ def _best_alpha(d: types.Data): bestid = wp.argmin(d.efc.cost_candidate[worldid]) d.efc.alpha[worldid] = m.alpha_candidate[bestid] - wp.launch(_quad_total, dim=(d.nworld, m.nlsp), inputs=[m, d]) - wp.launch(_quad_total_candidate, dim=(d.njmax, m.nlsp), inputs=[m, d]) - wp.launch(_cost_alpha, dim=(d.nworld, m.nlsp), inputs=[m, d]) - wp.launch(_best_alpha, dim=(d.nworld), inputs=[d]) + wp.launch(_quad_total, dim=(d.nworld, m.nlsp), inputs=[m, d], device=m.device) + wp.launch( + _quad_total_candidate, dim=(d.njmax, m.nlsp), inputs=[m, d], device=m.device + ) + wp.launch(_cost_alpha, dim=(d.nworld, m.nlsp), inputs=[m, d], device=m.device) + wp.launch(_best_alpha, dim=(d.nworld), inputs=[d], device=m.device) @event_scope @@ -873,32 +898,32 @@ def _jaref(d: types.Data): d.efc.Jaref[efcid] += d.efc.alpha[worldid] * d.efc.jv[efcid] # mv = qM @ search - support.mul_m(m, d, d.efc.mv, d.efc.search, d.efc.done) + support.mul_m(m, d, d.efc.mv, d.efc.search, d.efc.done, device=m.device) # jv = efc_J @ search # TODO(team): is there a better way of doing batched matmuls with dynamic array sizes? - wp.launch(_zero_jv, dim=(d.njmax), inputs=[d]) + wp.launch(_zero_jv, dim=(d.njmax), inputs=[d], device=m.device) - wp.launch(_jv, dim=(d.njmax, m.nv), inputs=[d]) + wp.launch(_jv, dim=(d.njmax, m.nv), inputs=[d], device=m.device) # prepare quadratics # quad_gauss = [gauss, search.T @ Ma - search.T @ qfrc_smooth, 0.5 * search.T @ mv] - wp.launch(_zero_quad_gauss, dim=(d.nworld), inputs=[d]) + wp.launch(_zero_quad_gauss, dim=(d.nworld), inputs=[d], device=m.device) - wp.launch(_init_quad_gauss, dim=(d.nworld, m.nv), inputs=[m, d]) + wp.launch(_init_quad_gauss, dim=(d.nworld, m.nv), inputs=[m, d], device=m.device) # quad = [0.5 * Jaref * Jaref * efc_D, jv * Jaref * efc_D, 0.5 * jv * jv * efc_D] - wp.launch(_init_quad, dim=(d.njmax), inputs=[d]) + wp.launch(_init_quad, dim=(d.njmax), inputs=[d], device=m.device) if m.opt.ls_parallel: _linesearch_parallel(m, d) else: _linesearch_iterative(m, d) - wp.launch(_qacc_ma, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_qacc_ma, dim=(d.nworld, m.nv), inputs=[d], device=m.device) - wp.launch(_jaref, dim=(d.njmax,), inputs=[d]) + wp.launch(_jaref, dim=(d.njmax,), inputs=[d], device=m.device) @event_scope @@ -1002,7 +1027,7 @@ def _beta(d: types.Data): ) # warmstart - kernel_copy(d.qacc, d.qacc_warmstart) + kernel_copy(d.qacc, d.qacc_warmstart, m.device) _create_context(m, d, grad=True) @@ -1010,23 +1035,23 @@ def _beta(d: types.Data): _linesearch(m, d) if m.opt.solver == types.SolverType.CG: - wp.launch(_prev_grad_Mgrad, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_prev_grad_Mgrad, dim=(d.nworld, m.nv), inputs=[d], device=m.device) - _update_constraint(m, d) - _update_gradient(m, d) + _update_constraint(m, d, device=m.device) + _update_gradient(m, d, device=m.device) # polak-ribiere if m.opt.solver == types.SolverType.CG: - wp.launch(_zero_beta_num_den, dim=(d.nworld), inputs=[d]) + wp.launch(_zero_beta_num_den, dim=(d.nworld), inputs=[d], device=m.device) - wp.launch(_beta_num_den, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_beta_num_den, dim=(d.nworld, m.nv), inputs=[d], device=m.device) - wp.launch(_beta, dim=(d.nworld,), inputs=[d]) + wp.launch(_beta, dim=(d.nworld,), inputs=[d], device=m.device) - wp.launch(_zero_search_dot, dim=(d.nworld), inputs=[d]) + wp.launch(_zero_search_dot, dim=(d.nworld), inputs=[d], device=m.device) - wp.launch(_search_update, dim=(d.nworld, m.nv), inputs=[d]) + wp.launch(_search_update, dim=(d.nworld, m.nv), inputs=[d], device=m.device) - wp.launch(_done, dim=(d.nworld,), inputs=[m, d, i]) + wp.launch(_done, dim=(d.nworld,), inputs=[m, d, i], device=m.device) - kernel_copy(d.qacc_warmstart, d.qacc) + kernel_copy(d.qacc_warmstart, d.qacc, m.device) diff --git a/mujoco_warp/_src/support.py b/mujoco_warp/_src/support.py index 34878b5a..efe41303 100644 --- a/mujoco_warp/_src/support.py +++ b/mujoco_warp/_src/support.py @@ -82,6 +82,7 @@ def mul( ], # TODO(team): develop heuristic for block dim, or make configurable block_dim=32, + device=m.device, ) qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy() @@ -130,10 +131,18 @@ def _mul_m_sparse_ij( wp.atomic_add(res[worldid], i, qM * vec[worldid, j]) wp.atomic_add(res[worldid], j, qM * vec[worldid, i]) - wp.launch(_mul_m_sparse_diag, dim=(d.nworld, m.nv), inputs=[m, d, res, vec, skip]) + wp.launch( + _mul_m_sparse_diag, + dim=(d.nworld, m.nv), + inputs=[m, d, res, vec, skip], + device=m.device, + ) wp.launch( - _mul_m_sparse_ij, dim=(d.nworld, m.qM_madr_ij.size), inputs=[m, d, res, vec, skip] + _mul_m_sparse_ij, + dim=(d.nworld, m.qM_madr_ij.size), + inputs=[m, d, res, vec, skip], + device=m.device, ) @@ -169,7 +178,9 @@ def _accumulate(m: Model, d: Data, qfrc: array2df): qfrc[worldid, dofid] += accumul - wp.launch(kernel=_accumulate, dim=(d.nworld, m.nv), inputs=[m, d, qfrc]) + wp.launch( + kernel=_accumulate, dim=(d.nworld, m.nv), inputs=[m, d, qfrc], device=m.device + ) @wp.func diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 4b8d39a0..b43bd454 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -612,6 +612,7 @@ class Model: nlsp: int # warp only opt: Option stat: Statistic + device: wp.context.Device # warp only qpos0: wp.array(dtype=wp.float32, ndim=1) qpos_spring: wp.array(dtype=wp.float32, ndim=1) body_tree: wp.array(dtype=wp.int32, ndim=1) # warp only @@ -854,8 +855,6 @@ class Data: nworld: number of worlds () nconmax: maximum number of contacts () njmax: maximum number of constraints () - rne_cacc: arrays used for smooth.rne (nworld, nbody, 6) - rne_cfrc: arrays used for smooth.rne (nworld, nbody, 6) qpos_t0: temporary array for rk4 (nworld, nq) qvel_t0: temporary array for rk4 (nworld, nv) act_t0: temporary array for rk4 (nworld, na) diff --git a/mujoco_warp/_src/warp_util.py b/mujoco_warp/_src/warp_util.py index 5b2d437c..568d82ce 100644 --- a/mujoco_warp/_src/warp_util.py +++ b/mujoco_warp/_src/warp_util.py @@ -215,7 +215,9 @@ def _copy_2dspatialf( # TODO(team): remove kernel_copy once wp.copy is supported in cuda subgraphs -def kernel_copy(dest: wp.array, src: wp.array): +def kernel_copy( + dest: wp.array, src: wp.array, device: Optional[wp.context.Device] = None +): if src.shape != dest.shape: raise ValueError("only same shape copying allowed") @@ -243,4 +245,4 @@ def kernel_copy(dest: wp.array, src: wp.array): else: raise NotImplementedError("copy not supported for these array types") - wp.launch(kernel=kernel, dim=src.shape, inputs=[dest, src]) + wp.launch(kernel=kernel, dim=src.shape, inputs=[dest, src], device=device)