@@ -903,24 +903,88 @@ def update_axis_properties_kernel(
903903
904904
905905@wp .kernel
906- def update_dof_properties_kernel (
906+ def update_joint_dof_properties_kernel (
907+ joint_qd_start : wp .array (dtype = wp .int32 ),
908+ joint_dof_dim : wp .array2d (dtype = wp .int32 ),
909+ joint_mjc_dof_start : wp .array (dtype = wp .int32 ),
910+ dof_to_mjc_joint : wp .array (dtype = wp .int32 ),
907911 joint_armature : wp .array (dtype = float ),
908912 joint_friction : wp .array (dtype = float ),
909- dofs_per_world : int ,
913+ joint_limit_ke : wp .array (dtype = float ),
914+ joint_limit_kd : wp .array (dtype = float ),
915+ solimplimit : wp .array (dtype = vec5 ),
916+ joints_per_world : int ,
910917 # outputs
911918 dof_armature : wp .array2d (dtype = float ),
912919 dof_frictionloss : wp .array2d (dtype = float ),
920+ jnt_solimp : wp .array2d (dtype = vec5 ),
921+ jnt_solref : wp .array2d (dtype = wp .vec2 ),
913922):
914- """Update DOF armature and friction loss values."""
923+ """Update joint DOF properties including armature, friction loss, joint impedance limits, and solref.
924+
925+ This kernel properly maps Newton DOFs to MuJoCo DOFs using joint_mjc_dof_start.
926+ For solimplimit and solref, we use dof_to_mjc_joint since jnt_solimp/jnt_solref are per-joint in MuJoCo.
927+ If solimplimit is None, jnt_solimp won't be updated (MuJoCo defaults will be preserved).
928+ """
915929 tid = wp .tid ()
916- worldid = tid // dofs_per_world
917- dof_in_world = tid % dofs_per_world
930+ worldid = tid // joints_per_world
931+ joint_in_world = tid % joints_per_world
932+
933+ lin_axis_count = joint_dof_dim [tid , 0 ]
934+ ang_axis_count = joint_dof_dim [tid , 1 ]
935+
936+ if lin_axis_count + ang_axis_count == 0 :
937+ return
938+
939+ newton_dof_start = joint_qd_start [tid ]
940+ mjc_dof_start = joint_mjc_dof_start [joint_in_world ]
941+
942+ # Get the DOF start for the template joint (world 0)
943+ # dof_to_mjc_joint is only populated for template DOFs (first world)
944+ template_joint_idx = joint_in_world
945+ template_dof_start = joint_qd_start [template_joint_idx ]
946+
947+ # update linear dofs
948+ for i in range (lin_axis_count ):
949+ newton_dof_index = newton_dof_start + i
950+ template_dof_index = template_dof_start + i
951+ mjc_dof_index = mjc_dof_start + i
952+ mjc_joint_index = dof_to_mjc_joint [template_dof_index ]
953+
954+ # Update armature and friction (per DOF)
955+ dof_armature [worldid , mjc_dof_index ] = joint_armature [newton_dof_index ]
956+ dof_frictionloss [worldid , mjc_dof_index ] = joint_friction [newton_dof_index ]
957+
958+ # Update joint limit solref using negative convention (per joint)
959+ if joint_limit_ke [newton_dof_index ] > 0.0 :
960+ jnt_solref [worldid , mjc_joint_index ] = wp .vec2 (
961+ - joint_limit_ke [newton_dof_index ], - joint_limit_kd [newton_dof_index ]
962+ )
963+
964+ # Update solimplimit (per joint)
965+ if solimplimit :
966+ jnt_solimp [worldid , mjc_joint_index ] = solimplimit [newton_dof_index ]
918967
919- # update armature
920- dof_armature [worldid , dof_in_world ] = joint_armature [tid ]
968+ # update angular dofs
969+ for i in range (ang_axis_count ):
970+ newton_dof_index = newton_dof_start + lin_axis_count + i
971+ template_dof_index = template_dof_start + lin_axis_count + i
972+ mjc_dof_index = mjc_dof_start + lin_axis_count + i
973+ mjc_joint_index = dof_to_mjc_joint [template_dof_index ]
974+
975+ # Update armature and friction (per DOF)
976+ dof_armature [worldid , mjc_dof_index ] = joint_armature [newton_dof_index ]
977+ dof_frictionloss [worldid , mjc_dof_index ] = joint_friction [newton_dof_index ]
978+
979+ # Update joint limit solref using negative convention (per joint)
980+ if joint_limit_ke [newton_dof_index ] > 0.0 :
981+ jnt_solref [worldid , mjc_joint_index ] = wp .vec2 (
982+ - joint_limit_ke [newton_dof_index ], - joint_limit_kd [newton_dof_index ]
983+ )
921984
922- # update friction loss
923- dof_frictionloss [worldid , dof_in_world ] = joint_friction [tid ]
985+ # Update solimplimit (per joint)
986+ if solimplimit :
987+ jnt_solimp [worldid , mjc_joint_index ] = solimplimit [newton_dof_index ]
924988
925989
926990@wp .kernel
@@ -932,16 +996,13 @@ def update_joint_transforms_kernel(
932996 joint_original_axis : wp .array (dtype = wp .vec3 ),
933997 joint_child : wp .array (dtype = wp .int32 ),
934998 joint_type : wp .array (dtype = wp .int32 ),
935- joint_limit_ke : wp .array (dtype = float ),
936- joint_limit_kd : wp .array (dtype = float ),
937- joint_mjc_dof_start : wp .array (dtype = wp .int32 ),
999+ dof_to_mjc_joint : wp .array (dtype = wp .int32 ),
9381000 body_mapping : wp .array (dtype = wp .int32 ),
9391001 newton_body_to_mocap_index : wp .array (dtype = wp .int32 ),
9401002 joints_per_world : int ,
9411003 # outputs
9421004 joint_pos : wp .array2d (dtype = wp .vec3 ),
9431005 joint_axis : wp .array2d (dtype = wp .vec3 ),
944- joint_solref : wp .array2d (dtype = wp .vec2 ),
9451006 body_pos : wp .array2d (dtype = wp .vec3 ),
9461007 body_quat : wp .array2d (dtype = wp .quat ),
9471008 mocap_pos : wp .array2d (dtype = wp .vec3 ),
@@ -958,14 +1019,6 @@ def update_joint_transforms_kernel(
9581019
9591020 child_xform = joint_X_c [tid ]
9601021 parent_xform = joint_X_p [tid ]
961- lin_axis_count = joint_dof_dim [tid , 0 ]
962- ang_axis_count = joint_dof_dim [tid , 1 ]
963- newton_dof_start = joint_dof_start [tid ]
964- mjc_dof_start = joint_mjc_dof_start [joint_in_world ]
965- if mjc_dof_start == - 1 :
966- # this should not happen
967- wp .printf ("Joint %i has no MuJoCo DOF start index\n " , joint_in_world )
968- return
9691022
9701023 # update body pos and quat from parent joint transform
9711024 child = joint_child [joint_in_world ] # Newton body id
@@ -984,27 +1037,31 @@ def update_joint_transforms_kernel(
9841037 body_pos [worldid , body_id ] = tf .p
9851038 body_quat [worldid , body_id ] = rotation
9861039
1040+ lin_axis_count = joint_dof_dim [tid , 0 ]
1041+ ang_axis_count = joint_dof_dim [tid , 1 ]
1042+
1043+ if lin_axis_count + ang_axis_count == 0 :
1044+ return
1045+
1046+ newton_dof_start = joint_dof_start [tid ]
1047+ template_dof_start = joint_dof_start [joint_in_world ]
1048+ mjc_joint_index = dof_to_mjc_joint [template_dof_start ]
1049+
9871050 # update linear dofs
9881051 for i in range (lin_axis_count ):
9891052 newton_dof_index = newton_dof_start + i
9901053 axis = joint_original_axis [newton_dof_index ]
991- ai = mjc_dof_start + i
1054+ ai = mjc_joint_index + i
9921055 joint_axis [worldid , ai ] = wp .quat_rotate (child_xform .q , axis )
9931056 joint_pos [worldid , ai ] = child_xform .p
994- # update joint limit solref using negative convention
995- if joint_limit_ke [newton_dof_index ] > 0 :
996- joint_solref [worldid , ai ] = wp .vec2 (- joint_limit_ke [newton_dof_index ], - joint_limit_kd [newton_dof_index ])
9971057
9981058 # update angular dofs
9991059 for i in range (ang_axis_count ):
10001060 newton_dof_index = newton_dof_start + lin_axis_count + i
10011061 axis = joint_original_axis [newton_dof_index ]
1002- ai = mjc_dof_start + lin_axis_count + i
1062+ ai = mjc_joint_index + lin_axis_count + i
10031063 joint_axis [worldid , ai ] = wp .quat_rotate (child_xform .q , axis )
10041064 joint_pos [worldid , ai ] = child_xform .p
1005- # update joint limit solref using negative convention
1006- if joint_limit_ke [newton_dof_index ] > 0 :
1007- joint_solref [worldid , ai ] = wp .vec2 (- joint_limit_ke [newton_dof_index ], - joint_limit_kd [newton_dof_index ])
10081065
10091066
10101067@wp .kernel (enable_backward = False )
0 commit comments