Skip to content

Device Management in Multi-GPU systems #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1 change: 1 addition & 0 deletions mujoco_warp/_src/collision_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,4 +565,5 @@ def box_box_narrowphase(
dim=num_threads,
inputs=[m, d, num_threads],
block_dim=BOX_BOX_BLOCK_DIM,
device=m.device,
)
2 changes: 1 addition & 1 deletion mujoco_warp/_src/collision_convex.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,4 +777,4 @@ def gjk_narrowphase(m: Model, d: Data):
)

for collision_kernel in _collision_kernels.values():
wp.launch(collision_kernel, dim=d.nconmax, inputs=[m, d])
wp.launch(collision_kernel, dim=d.nconmax, inputs=[m, d], device=m.device)
22 changes: 16 additions & 6 deletions mujoco_warp/_src/collision_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def sap_broadphase(m: Model, d: Data):
kernel=broadphase_project_spheres_onto_sweep_direction_kernel,
dim=(d.nworld, m.ngeom),
inputs=[m, d, direction],
device=m.device,
)

tile_sort_available = False
Expand All @@ -368,7 +369,11 @@ def sap_broadphase(m: Model, d: Data):
if tile_sort_available:
segmented_sort_kernel = create_segmented_sort_kernel(m.ngeom)
wp.launch_tiled(
kernel=segmented_sort_kernel, dim=(d.nworld), inputs=[m, d], block_dim=128
kernel=segmented_sort_kernel,
dim=(d.nworld),
inputs=[m, d],
block_dim=128,
device=m.device,
)
print("tile sort available")
elif segmented_sort_available:
Expand All @@ -385,12 +390,10 @@ def sap_broadphase(m: Model, d: Data):

# Create temporary arrays for sorting
temp_box_projections_lower = wp.zeros(
m.ngeom * 2,
dtype=d.sap_projection_lower.dtype,
m.ngeom * 2, dtype=d.sap_projection_lower.dtype, device=m.device
)
temp_box_sorting_indexer = wp.zeros(
m.ngeom * 2,
dtype=d.sap_sort_index.dtype,
m.ngeom * 2, dtype=d.sap_sort_index.dtype, device=m.device
)

# Copy data to temporary arrays
Expand Down Expand Up @@ -434,12 +437,14 @@ def sap_broadphase(m: Model, d: Data):
kernel=reorder_bounding_spheres_kernel,
dim=(d.nworld, m.ngeom),
inputs=[m, d],
device=m.device,
)

wp.launch(
kernel=sap_broadphase_prepare_kernel,
dim=(d.nworld, m.ngeom),
inputs=[m, d],
device=m.device,
)

# The scan (scan = cumulative sum, either inclusive or exclusive depending on the last argument) is used for load balancing among the threads
Expand All @@ -452,6 +457,7 @@ def sap_broadphase(m: Model, d: Data):
kernel=sap_broadphase_kernel,
dim=num_sweep_threads,
inputs=[m, d, num_sweep_threads, filter_parent],
device=m.device,
)

return d
Expand Down Expand Up @@ -510,7 +516,10 @@ def _nxn_broadphase(m: Model, d: Data):
_add_geom_pair(m, d, geom1, geom2, worldid)

wp.launch(
_nxn_broadphase, dim=(d.nworld, m.ngeom * (m.ngeom - 1) // 2), inputs=[m, d]
_nxn_broadphase,
dim=(d.nworld, m.ngeom * (m.ngeom - 1) // 2),
inputs=[m, d],
device=m.device,
)


Expand All @@ -519,6 +528,7 @@ def get_contact_solver_params(m: Model, d: Data):
get_contact_solver_params_kernel,
dim=[d.nconmax],
inputs=[m, d],
device=m.device,
)

# TODO(team): do we need condim sorting, deepest penetrating contact here?
Expand Down
2 changes: 1 addition & 1 deletion mujoco_warp/_src/collision_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,4 +384,4 @@ def _primitive_narrowphase(
def primitive_narrowphase(m: Model, d: Data):
# we need to figure out how to keep the overhead of this small - not launching anything
# for pair types without collisions, as well as updating the launch dimensions.
wp.launch(_primitive_narrowphase, dim=d.nconmax, inputs=[m, d])
wp.launch(_primitive_narrowphase, dim=d.nconmax, inputs=[m, d], device=m.device)
9 changes: 8 additions & 1 deletion mujoco_warp/_src/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def make_constraint(m: types.Model, d: types.Data):
_efc_limit_slide_hinge,
dim=(d.nworld, m.jnt_limited_slide_hinge_adr.size),
inputs=[m, d, refsafe],
device=m.device,
)

# contact
Expand All @@ -359,9 +360,15 @@ def make_constraint(m: types.Model, d: types.Data):
_efc_contact_pyramidal,
dim=(d.nconmax, 2 * (m.condim_max - 1)),
inputs=[m, d, refsafe],
device=m.device,
)
elif m.opt.cone == types.ConeType.ELLIPTIC.value:
wp.launch(_efc_contact_elliptic, dim=(d.nconmax, 3), inputs=[m, d, refsafe])
wp.launch(
_efc_contact_elliptic,
dim=(d.nconmax, 3),
inputs=[m, d, refsafe],
device=m.device,
)

# TODO(team): condim=4
# TODO(team): condim=6
63 changes: 47 additions & 16 deletions mujoco_warp/_src/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,26 @@ def integrate_joint_positions(m: Model, d: Data, qvel_in: array2df):

# skip if no stateful actuators.
if m.na:
wp.launch(next_activation, dim=(d.nworld, m.nu), inputs=[m, d, act_dot])
wp.launch(
next_activation, dim=(d.nworld, m.nu), inputs=[m, d, act_dot], device=m.device
)

wp.launch(advance_velocities, dim=(d.nworld, m.nv), inputs=[m, d, qacc])
wp.launch(
advance_velocities, dim=(d.nworld, m.nv), inputs=[m, d, qacc], device=m.device
)

# advance positions with qvel if given, d.qvel otherwise (semi-implicit)
if qvel is not None:
qvel_in = qvel
else:
qvel_in = d.qvel

wp.launch(integrate_joint_positions, dim=(d.nworld, m.njnt), inputs=[m, d, qvel_in])
wp.launch(
integrate_joint_positions,
dim=(d.nworld, m.njnt),
inputs=[m, d, qvel_in],
device=m.device,
)

d.time = d.time + m.opt.timestep

Expand All @@ -204,8 +213,13 @@ def add_damping_sum_qfrc_kernel_sparse(m: Model, d: Data):
d.qfrc_smooth[worldid, tid] + d.qfrc_constraint[worldid, tid]
)

kernel_copy(d.qM_integration, d.qM)
wp.launch(add_damping_sum_qfrc_kernel_sparse, dim=(d.nworld, m.nv), inputs=[m, d])
kernel_copy(d.qM_integration, d.qM, m.device)
wp.launch(
add_damping_sum_qfrc_kernel_sparse,
dim=(d.nworld, m.nv),
inputs=[m, d],
device=m.device,
)
smooth.factor_solve_i(
m,
d,
Expand Down Expand Up @@ -245,7 +259,11 @@ def eulerdamp(
wp.tile_store(d.qacc_integration[worldid], qacc_tile, offset=(dofid))

wp.launch_tiled(
eulerdamp, dim=(d.nworld, size), inputs=[m, d, m.dof_damping, adr], block_dim=32
eulerdamp,
dim=(d.nworld, size),
inputs=[m, d, m.dof_damping, adr],
block_dim=32,
device=m.device,
)

qLD_tileadr, qLD_tilesize = m.qLD_tileadr.numpy(), m.qLD_tilesize.numpy()
Expand Down Expand Up @@ -290,8 +308,8 @@ def _act_dot(d: Data, b: float):
worldId, tid = wp.tid()
d.act_dot_rk[worldId, tid] += b * d.act_dot[worldId, tid]

wp.launch(_qvel_acc, dim=(d.nworld, m.nv), inputs=[d, b])
wp.launch(_act_dot, dim=(d.nworld, m.na), inputs=[d, b])
wp.launch(_qvel_acc, dim=(d.nworld, m.nv), inputs=[d, b], device=m.device)
wp.launch(_act_dot, dim=(d.nworld, m.na), inputs=[d, b], device=m.device)

def perturb_state(m: Model, d: Data, a: float):
@kernel
Expand All @@ -312,9 +330,9 @@ def _qvel(m: Model, d: Data):
dqacc = a * d.qacc[worldId, tid]
d.qvel[worldId, tid] = d.qvel_t0[worldId, tid] + dqacc * m.opt.timestep

wp.launch(_qpos, dim=(d.nworld, m.njnt), inputs=[m, d])
wp.launch(_act, dim=(d.nworld, m.na), inputs=[m, d])
wp.launch(_qvel, dim=(d.nworld, m.nv), inputs=[m, d])
wp.launch(_qpos, dim=(d.nworld, m.njnt), inputs=[m, d], device=m.device)
wp.launch(_act, dim=(d.nworld, m.na), inputs=[m, d], device=m.device)
wp.launch(_qvel, dim=(d.nworld, m.nv), inputs=[m, d], device=m.device)

rk_accumulate(d, B[0])
for i in range(3):
Expand Down Expand Up @@ -457,6 +475,7 @@ def qderiv_actuator_fused_kernel(
dim=(d.nworld, size),
inputs=[m, d, damping, adr],
block_dim=block_dim,
device=m.device,
)

qderiv_tilesize_nv = m.actuator_moment_tilesize_nv.numpy()
Expand All @@ -479,6 +498,7 @@ def qderiv_actuator_fused_kernel(
actuator_bias_gain_vel,
dim=(d.nworld, m.nu),
inputs=[m, d],
device=m.device,
)

qderiv_actuator_damping_fused(m, d, m.dof_damping)
Expand Down Expand Up @@ -522,7 +542,9 @@ def _actuator_velocity(d: Data):
qvel = d.qvel[worldid]
wp.atomic_add(d.actuator_velocity[worldid], actid, moment[dofid] * qvel[dofid])

wp.launch(_actuator_velocity, dim=(d.nworld, m.nu, m.nv), inputs=[d])
wp.launch(
_actuator_velocity, dim=(d.nworld, m.nu, m.nv), inputs=[d], device=m.device
)
else:

def actuator_velocity(
Expand Down Expand Up @@ -561,6 +583,7 @@ def _actuator_velocity(
d.qvel.reshape(d.qvel.shape + (1,)),
],
block_dim=32,
device=m.device,
)

actuator_moment_tilesize_nu = m.actuator_moment_tilesize_nu.numpy()
Expand Down Expand Up @@ -655,7 +678,13 @@ def _qfrc(m: Model, moment: array3df, force: array2df, qfrc: array2df):
s = wp.clamp(s, r[0], r[1])
qfrc[worldid, vid] = s

wp.launch(_force, dim=[d.nworld, m.nu], inputs=[m, d], outputs=[d.actuator_force])
wp.launch(
_force,
dim=[d.nworld, m.nu],
inputs=[m, d],
outputs=[d.actuator_force],
device=m.device,
)

if m.opt.is_sparse:
# TODO(team): sparse version
Expand All @@ -665,6 +694,7 @@ def _qfrc(m: Model, moment: array3df, force: array2df, qfrc: array2df):
dim=(d.nworld, m.nv),
inputs=[m, d.actuator_moment, d.actuator_force],
outputs=[d.qfrc_actuator],
device=m.device,
)

else:
Expand Down Expand Up @@ -706,6 +736,7 @@ def qfrc_actuator_kernel(
d.actuator_force.reshape(d.actuator_force.shape + (1,)),
],
block_dim=32,
device=m.device,
)

qderiv_tilesize_nu = m.actuator_moment_tilesize_nu.numpy()
Expand All @@ -722,7 +753,7 @@ def qfrc_actuator_kernel(
beg, end - beg, int(qderiv_tilesize_nu[i]), int(qderiv_tilesize_nv[i])
)

wp.launch(_qfrc_limited, dim=(d.nworld, m.nv), inputs=[m, d])
wp.launch(_qfrc_limited, dim=(d.nworld, m.nv), inputs=[m, d], device=m.device)

# TODO actuator-level gravity compensation, skip if added as passive force

Expand All @@ -741,7 +772,7 @@ def _qfrc_smooth(d: Data):
+ d.qfrc_applied[worldid, dofid]
)

wp.launch(_qfrc_smooth, dim=(d.nworld, m.nv), inputs=[d])
wp.launch(_qfrc_smooth, dim=(d.nworld, m.nv), inputs=[d], device=m.device)
xfrc_accumulate(m, d, d.qfrc_smooth)

smooth.solve_m(m, d, d.qacc_smooth, d.qfrc_smooth)
Expand All @@ -760,7 +791,7 @@ def forward(m: Model, d: Data):
sensor.sensor_acc(m, d)

if d.njmax == 0:
kernel_copy(d.qacc, d.qacc_smooth)
kernel_copy(d.qacc, d.qacc_smooth, m.device)
else:
solver.solve(m, d)

Expand Down
Loading
Loading