Skip to content

Commit d158242

Browse files
Fix njmax overflow in JTDAJ (#243)
* small fix * ruff
1 parent d78c8b2 commit d158242

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

mujoco_warp/_src/solver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2013,6 +2013,7 @@ def update_gradient_JTDAJ(
20132013
dof_tri_row: wp.array(dtype=int),
20142014
dof_tri_col: wp.array(dtype=int),
20152015
# Data in:
2016+
njmax_in: int,
20162017
nefc_in: wp.array(dtype=int),
20172018
efc_worldid_in: wp.array(dtype=int),
20182019
efc_J_in: wp.array2d(dtype=float),
@@ -2033,7 +2034,7 @@ def update_gradient_JTDAJ(
20332034
for i in range(nblocks_perblock):
20342035
efcid = efcid_temp + i * dim_x
20352036

2036-
if efcid >= nefc:
2037+
if efcid >= min(nefc, njmax_in):
20372038
return
20382039

20392040
worldid = efc_worldid_in[efcid]
@@ -2256,6 +2257,7 @@ def _update_gradient(m: types.Model, d: types.Data):
22562257
inputs=[
22572258
m.dof_tri_row,
22582259
m.dof_tri_col,
2260+
d.njmax,
22592261
d.nefc,
22602262
d.efc.worldid,
22612263
d.efc.J,

0 commit comments

Comments
 (0)