Skip to content

Commit 82d5049

Browse files
ruff
Signed-off-by: Alain Denzler <adenzler@nvidia.com>
2 parents 0404542 + 8e9cd80 commit 82d5049

File tree

9 files changed

+877
-148
lines changed

9 files changed

+877
-148
lines changed

newton/_src/sim/builder.py

Lines changed: 91 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -735,11 +735,15 @@ def _process_joint_custom_attributes(
735735
736736
Joint attributes are processed based on their declared frequency:
737737
- JOINT frequency: Single value per joint
738-
- JOINT_DOF frequency: List of values with length equal to joint DOF count
739-
- JOINT_COORD frequency: List of values with length equal to joint coordinate count
738+
- JOINT_DOF frequency: List or dict of values for each DOF
739+
- JOINT_COORD frequency: List or dict of values for each coordinate
740740
741-
For DOF and COORD attributes, values must always be provided as lists with length
742-
matching the joint's DOF or coordinate count.
741+
For DOF and COORD attributes, values can be:
742+
- A list with length matching the joint's DOF/coordinate count (all DOFs get values)
743+
- A dict mapping DOF/coord indices to values (only specified indices get values, rest use defaults)
744+
- For single-DOF joints with JOINT_DOF frequency: a single Warp vector/matrix value
745+
746+
When using dict format, unspecified indices will be filled with the attribute's default value during finalization.
743747
744748
Args:
745749
joint_index: Index of the joint
@@ -764,7 +768,7 @@ def _process_joint_custom_attributes(
764768
)
765769

766770
elif custom_attr.frequency == ModelAttributeFrequency.JOINT_DOF:
767-
# List of values, one per DOF
771+
# Values per DOF - can be list or dict
768772
dof_start = self.joint_qd_start[joint_index]
769773
if joint_index + 1 < len(self.joint_qd_start):
770774
dof_end = self.joint_qd_start[joint_index + 1]
@@ -773,27 +777,54 @@ def _process_joint_custom_attributes(
773777

774778
dof_count = dof_end - dof_start
775779

776-
if not isinstance(value, (list, tuple)):
777-
raise TypeError(
778-
f"JOINT_DOF attribute '{attr_key}' must be a list with length equal to joint DOF count ({dof_count})"
779-
)
780-
781-
if len(value) != dof_count:
782-
raise ValueError(
783-
f"JOINT_DOF attribute '{attr_key}' has {len(value)} values but joint has {dof_count} DOFs"
784-
)
780+
# Check if value is a dict (mapping DOF index to value)
781+
if isinstance(value, dict):
782+
# Dict format: only specified DOF indices have values, rest use defaults
783+
for dof_offset, dof_value in value.items():
784+
if not isinstance(dof_offset, int):
785+
raise TypeError(
786+
f"JOINT_DOF attribute '{attr_key}' dict keys must be integers (DOF indices), got {type(dof_offset)}"
787+
)
788+
if dof_offset < 0 or dof_offset >= dof_count:
789+
raise ValueError(
790+
f"JOINT_DOF attribute '{attr_key}' has invalid DOF index {dof_offset} (joint has {dof_count} DOFs)"
791+
)
792+
single_attr = {attr_key: dof_value}
793+
self._process_custom_attributes(
794+
entity_index=dof_start + dof_offset,
795+
custom_attrs=single_attr,
796+
expected_frequency=ModelAttributeFrequency.JOINT_DOF,
797+
)
798+
else:
799+
# List format or single value for single-DOF joints
800+
value_sanitized = value
801+
if not isinstance(value_sanitized, (list, tuple)):
802+
# Check if it's a Warp vector/matrix type
803+
if wp.types.type_is_vector(type(value_sanitized)) or wp.types.type_is_matrix(
804+
type(value_sanitized)
805+
):
806+
value_sanitized = [value_sanitized]
807+
else:
808+
raise TypeError(
809+
f"JOINT_DOF attribute '{attr_key}' must be a list with length equal to joint DOF count ({dof_count}), "
810+
f"a dict mapping DOF indices to values, or a single Warp vector/matrix value for single-DOF joints"
811+
)
785812

786-
# Apply each value to its corresponding DOF
787-
for i, dof_value in enumerate(value):
788-
single_attr = {attr_key: dof_value}
789-
self._process_custom_attributes(
790-
entity_index=dof_start + i,
791-
custom_attrs=single_attr,
792-
expected_frequency=ModelAttributeFrequency.JOINT_DOF,
793-
)
813+
actual = len(value_sanitized)
814+
if actual != dof_count:
815+
raise ValueError(f"JOINT_DOF '{attr_key}': got {actual}, expected {dof_count}")
816+
817+
# Apply each value to its corresponding DOF
818+
for i, dof_value in enumerate(value_sanitized):
819+
single_attr = {attr_key: dof_value}
820+
self._process_custom_attributes(
821+
entity_index=dof_start + i,
822+
custom_attrs=single_attr,
823+
expected_frequency=ModelAttributeFrequency.JOINT_DOF,
824+
)
794825

795826
elif custom_attr.frequency == ModelAttributeFrequency.JOINT_COORD:
796-
# List of values, one per coordinate
827+
# Values per coordinate - can be list or dict
797828
coord_start = self.joint_q_start[joint_index]
798829
if joint_index + 1 < len(self.joint_q_start):
799830
coord_end = self.joint_q_start[joint_index + 1]
@@ -802,24 +833,45 @@ def _process_joint_custom_attributes(
802833

803834
coord_count = coord_end - coord_start
804835

805-
if not isinstance(value, (list, tuple)):
806-
raise TypeError(
807-
f"JOINT_COORD attribute '{attr_key}' must be a list with length equal to joint coordinate count ({coord_count})"
808-
)
836+
# Check if value is a dict (mapping coord index to value)
837+
if isinstance(value, dict):
838+
# Dict format: only specified coord indices have values, rest use defaults
839+
for coord_offset, coord_value in value.items():
840+
if not isinstance(coord_offset, int):
841+
raise TypeError(
842+
f"JOINT_COORD attribute '{attr_key}' dict keys must be integers (coord indices), got {type(coord_offset)}"
843+
)
844+
if coord_offset < 0 or coord_offset >= coord_count:
845+
raise ValueError(
846+
f"JOINT_COORD attribute '{attr_key}' has invalid coord index {coord_offset} (joint has {coord_count} coordinates)"
847+
)
848+
single_attr = {attr_key: coord_value}
849+
self._process_custom_attributes(
850+
entity_index=coord_start + coord_offset,
851+
custom_attrs=single_attr,
852+
expected_frequency=ModelAttributeFrequency.JOINT_COORD,
853+
)
854+
else:
855+
# List format
856+
if not isinstance(value, (list, tuple)):
857+
raise TypeError(
858+
f"JOINT_COORD attribute '{attr_key}' must be a list with length equal to joint coordinate count ({coord_count}) "
859+
f"or a dict mapping coordinate indices to values"
860+
)
809861

810-
if len(value) != coord_count:
811-
raise ValueError(
812-
f"JOINT_COORD attribute '{attr_key}' has {len(value)} values but joint has {coord_count} coordinates"
813-
)
862+
if len(value) != coord_count:
863+
raise ValueError(
864+
f"JOINT_COORD attribute '{attr_key}' has {len(value)} values but joint has {coord_count} coordinates"
865+
)
814866

815-
# Apply each value to its corresponding coordinate
816-
for i, coord_value in enumerate(value):
817-
single_attr = {attr_key: coord_value}
818-
self._process_custom_attributes(
819-
entity_index=coord_start + i,
820-
custom_attrs=single_attr,
821-
expected_frequency=ModelAttributeFrequency.JOINT_COORD,
822-
)
867+
# Apply each value to its corresponding coordinate
868+
for i, coord_value in enumerate(value):
869+
single_attr = {attr_key: coord_value}
870+
self._process_custom_attributes(
871+
entity_index=coord_start + i,
872+
custom_attrs=single_attr,
873+
expected_frequency=ModelAttributeFrequency.JOINT_COORD,
874+
)
823875

824876
else:
825877
raise ValueError(

newton/_src/solvers/mujoco/kernels.py

Lines changed: 86 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -903,24 +903,88 @@ def update_axis_properties_kernel(
903903

904904

905905
@wp.kernel
906-
def update_dof_properties_kernel(
906+
def update_joint_dof_properties_kernel(
907+
joint_qd_start: wp.array(dtype=wp.int32),
908+
joint_dof_dim: wp.array2d(dtype=wp.int32),
909+
joint_mjc_dof_start: wp.array(dtype=wp.int32),
910+
dof_to_mjc_joint: wp.array(dtype=wp.int32),
907911
joint_armature: wp.array(dtype=float),
908912
joint_friction: wp.array(dtype=float),
909-
dofs_per_world: int,
913+
joint_limit_ke: wp.array(dtype=float),
914+
joint_limit_kd: wp.array(dtype=float),
915+
solimplimit: wp.array(dtype=vec5),
916+
joints_per_world: int,
910917
# outputs
911918
dof_armature: wp.array2d(dtype=float),
912919
dof_frictionloss: wp.array2d(dtype=float),
920+
jnt_solimp: wp.array2d(dtype=vec5),
921+
jnt_solref: wp.array2d(dtype=wp.vec2),
913922
):
914-
"""Update DOF armature and friction loss values."""
923+
"""Update joint DOF properties including armature, friction loss, joint impedance limits, and solref.
924+
925+
This kernel properly maps Newton DOFs to MuJoCo DOFs using joint_mjc_dof_start.
926+
For solimplimit and solref, we use dof_to_mjc_joint since jnt_solimp/jnt_solref are per-joint in MuJoCo.
927+
If solimplimit is None, jnt_solimp won't be updated (MuJoCo defaults will be preserved).
928+
"""
915929
tid = wp.tid()
916-
worldid = tid // dofs_per_world
917-
dof_in_world = tid % dofs_per_world
930+
worldid = tid // joints_per_world
931+
joint_in_world = tid % joints_per_world
932+
933+
lin_axis_count = joint_dof_dim[tid, 0]
934+
ang_axis_count = joint_dof_dim[tid, 1]
935+
936+
if lin_axis_count + ang_axis_count == 0:
937+
return
938+
939+
newton_dof_start = joint_qd_start[tid]
940+
mjc_dof_start = joint_mjc_dof_start[joint_in_world]
941+
942+
# Get the DOF start for the template joint (world 0)
943+
# dof_to_mjc_joint is only populated for template DOFs (first world)
944+
template_joint_idx = joint_in_world
945+
template_dof_start = joint_qd_start[template_joint_idx]
946+
947+
# update linear dofs
948+
for i in range(lin_axis_count):
949+
newton_dof_index = newton_dof_start + i
950+
template_dof_index = template_dof_start + i
951+
mjc_dof_index = mjc_dof_start + i
952+
mjc_joint_index = dof_to_mjc_joint[template_dof_index]
953+
954+
# Update armature and friction (per DOF)
955+
dof_armature[worldid, mjc_dof_index] = joint_armature[newton_dof_index]
956+
dof_frictionloss[worldid, mjc_dof_index] = joint_friction[newton_dof_index]
957+
958+
# Update joint limit solref using negative convention (per joint)
959+
if joint_limit_ke[newton_dof_index] > 0.0:
960+
jnt_solref[worldid, mjc_joint_index] = wp.vec2(
961+
-joint_limit_ke[newton_dof_index], -joint_limit_kd[newton_dof_index]
962+
)
963+
964+
# Update solimplimit (per joint)
965+
if solimplimit:
966+
jnt_solimp[worldid, mjc_joint_index] = solimplimit[newton_dof_index]
918967

919-
# update armature
920-
dof_armature[worldid, dof_in_world] = joint_armature[tid]
968+
# update angular dofs
969+
for i in range(ang_axis_count):
970+
newton_dof_index = newton_dof_start + lin_axis_count + i
971+
template_dof_index = template_dof_start + lin_axis_count + i
972+
mjc_dof_index = mjc_dof_start + lin_axis_count + i
973+
mjc_joint_index = dof_to_mjc_joint[template_dof_index]
974+
975+
# Update armature and friction (per DOF)
976+
dof_armature[worldid, mjc_dof_index] = joint_armature[newton_dof_index]
977+
dof_frictionloss[worldid, mjc_dof_index] = joint_friction[newton_dof_index]
978+
979+
# Update joint limit solref using negative convention (per joint)
980+
if joint_limit_ke[newton_dof_index] > 0.0:
981+
jnt_solref[worldid, mjc_joint_index] = wp.vec2(
982+
-joint_limit_ke[newton_dof_index], -joint_limit_kd[newton_dof_index]
983+
)
921984

922-
# update friction loss
923-
dof_frictionloss[worldid, dof_in_world] = joint_friction[tid]
985+
# Update solimplimit (per joint)
986+
if solimplimit:
987+
jnt_solimp[worldid, mjc_joint_index] = solimplimit[newton_dof_index]
924988

925989

926990
@wp.kernel
@@ -932,16 +996,13 @@ def update_joint_transforms_kernel(
932996
joint_original_axis: wp.array(dtype=wp.vec3),
933997
joint_child: wp.array(dtype=wp.int32),
934998
joint_type: wp.array(dtype=wp.int32),
935-
joint_limit_ke: wp.array(dtype=float),
936-
joint_limit_kd: wp.array(dtype=float),
937-
joint_mjc_dof_start: wp.array(dtype=wp.int32),
999+
dof_to_mjc_joint: wp.array(dtype=wp.int32),
9381000
body_mapping: wp.array(dtype=wp.int32),
9391001
newton_body_to_mocap_index: wp.array(dtype=wp.int32),
9401002
joints_per_world: int,
9411003
# outputs
9421004
joint_pos: wp.array2d(dtype=wp.vec3),
9431005
joint_axis: wp.array2d(dtype=wp.vec3),
944-
joint_solref: wp.array2d(dtype=wp.vec2),
9451006
body_pos: wp.array2d(dtype=wp.vec3),
9461007
body_quat: wp.array2d(dtype=wp.quat),
9471008
mocap_pos: wp.array2d(dtype=wp.vec3),
@@ -958,14 +1019,6 @@ def update_joint_transforms_kernel(
9581019

9591020
child_xform = joint_X_c[tid]
9601021
parent_xform = joint_X_p[tid]
961-
lin_axis_count = joint_dof_dim[tid, 0]
962-
ang_axis_count = joint_dof_dim[tid, 1]
963-
newton_dof_start = joint_dof_start[tid]
964-
mjc_dof_start = joint_mjc_dof_start[joint_in_world]
965-
if mjc_dof_start == -1:
966-
# this should not happen
967-
wp.printf("Joint %i has no MuJoCo DOF start index\n", joint_in_world)
968-
return
9691022

9701023
# update body pos and quat from parent joint transform
9711024
child = joint_child[joint_in_world] # Newton body id
@@ -984,27 +1037,31 @@ def update_joint_transforms_kernel(
9841037
body_pos[worldid, body_id] = tf.p
9851038
body_quat[worldid, body_id] = rotation
9861039

1040+
lin_axis_count = joint_dof_dim[tid, 0]
1041+
ang_axis_count = joint_dof_dim[tid, 1]
1042+
1043+
if lin_axis_count + ang_axis_count == 0:
1044+
return
1045+
1046+
newton_dof_start = joint_dof_start[tid]
1047+
template_dof_start = joint_dof_start[joint_in_world]
1048+
mjc_joint_index = dof_to_mjc_joint[template_dof_start]
1049+
9871050
# update linear dofs
9881051
for i in range(lin_axis_count):
9891052
newton_dof_index = newton_dof_start + i
9901053
axis = joint_original_axis[newton_dof_index]
991-
ai = mjc_dof_start + i
1054+
ai = mjc_joint_index + i
9921055
joint_axis[worldid, ai] = wp.quat_rotate(child_xform.q, axis)
9931056
joint_pos[worldid, ai] = child_xform.p
994-
# update joint limit solref using negative convention
995-
if joint_limit_ke[newton_dof_index] > 0:
996-
joint_solref[worldid, ai] = wp.vec2(-joint_limit_ke[newton_dof_index], -joint_limit_kd[newton_dof_index])
9971057

9981058
# update angular dofs
9991059
for i in range(ang_axis_count):
10001060
newton_dof_index = newton_dof_start + lin_axis_count + i
10011061
axis = joint_original_axis[newton_dof_index]
1002-
ai = mjc_dof_start + lin_axis_count + i
1062+
ai = mjc_joint_index + lin_axis_count + i
10031063
joint_axis[worldid, ai] = wp.quat_rotate(child_xform.q, axis)
10041064
joint_pos[worldid, ai] = child_xform.p
1005-
# update joint limit solref using negative convention
1006-
if joint_limit_ke[newton_dof_index] > 0:
1007-
joint_solref[worldid, ai] = wp.vec2(-joint_limit_ke[newton_dof_index], -joint_limit_kd[newton_dof_index])
10081065

10091066

10101067
@wp.kernel(enable_backward=False)

0 commit comments

Comments
 (0)