44from unitree_go .msg import LowState , LowCmd
55from nav_msgs .msg import Odometry
66import numpy as np
7+ import pybullet as pb
78from scipy .spatial .transform import Rotation as R
89
910from tf2_ros import TransformBroadcaster
1415from rclpy .time import Time
1516from rclpy .duration import Duration
1617
18+ import onnxruntime as rt
19+ from collections import deque
20+
21+ def euler_from_quaternion (quat_angle ):
22+ """
23+ NOTE: This was copied from extreme-parkour repo
24+
25+ Convert a quaternion into euler angles (roll, pitch, yaw)
26+ roll is rotation around x in radians (counterclockwise)
27+ pitch is rotation around y in radians (counterclockwise)
28+ yaw is rotation around z in radians (counterclockwise)
29+ """
30+ x , y , z , w = quat_angle
31+ t0 = + 2.0 * (w * x + y * z )
32+ t1 = + 1.0 - 2.0 * (x * x + y * y )
33+ roll_x = np .arctan2 (t0 , t1 )
34+
35+ t2 = + 2.0 * (w * y - z * x )
36+ t2 = np .clip (t2 , - 1 , 1 )
37+ pitch_y = np .arcsin (t2 )
38+
39+ t3 = + 2.0 * (w * z + x * y )
40+ t4 = + 1.0 - 2.0 * (y * y + z * z )
41+ yaw_z = np .arctan2 (t3 , t4 )
42+
43+ return roll_x , pitch_y , yaw_z # in radians
44+
45+
1746class Go2Simulation (Node ):
1847 def __init__ (self ):
1948 super ().__init__ ("go2_simulation" )
@@ -28,13 +57,14 @@ def __init__(self):
2857 self .clock_publisher = self .create_publisher (Clock , "/clock" , 10 )
2958
3059 # Timer to publish periodically
31- self .high_level_period = 1.0 / 500 # seconds
32- self .low_level_sub_step = 4
60+ self .high_level_period = 1.0 / 50 # seconds
61+ self .low_level_sub_step = 24
3362 self .timer = self .create_timer (self .high_level_period , self .update )
3463
3564 ########################## Camera
3665 self .camera_period = 1.0 / 10 # seconds
3766 self .camera_decimation = int (self .camera_period / self .high_level_period )
67+ breakpoint ()
3868
3969 ########################## Cmd listener
4070 self .create_subscription (LowCmd , "/lowcmd" , self .receive_cmd_cb , 10 )
@@ -70,13 +100,144 @@ def __init__(self):
70100 self .sim_time = Time (seconds = 0 , nanoseconds = 0 )
71101 self .time_delta = Duration (seconds = 0 , nanoseconds = int (self .high_level_period * 1e9 ))
72102
103+ self .init_onnx ()
104+
105+ def init_onnx (self ):
106+ onnx_path = "./models/wall.onnx"
107+ onnx_path = "/home/hamlet/Workspace/reinforcement-learning/inference/" + onnx_path
108+ self .onnx_session = rt .InferenceSession (onnx_path )
109+
110+ self .w_T_b = np .eye (4 )
111+ self .joint_pos = np .zeros (12 )
112+ self .joint_vel = np .zeros (12 )
113+ self .joint_pos_policy = np .zeros (12 )
114+ self .joint_vel_policy = np .zeros (12 )
115+
116+ self .q0 = np .array ([- 0.1 , 0.8 , - 1.5 , 0.1 , 0.8 , - 1.5 , - 0.1 , 1. , - 1.5 , 0.1 , 1. , - 1.5 ])
117+ self .q_des = self .q0 .copy ()
118+
119+ # First two elements are 0, third is the forward speed
120+ forward_speed = 0.37
121+ self .vel_cmd = np .array ([0. , 0. , forward_speed ])
122+ self .env_class = np .array ([1 , 0 ])
123+
124+ self .action_buffer = deque (maxlen = 2 )
125+ self .depth_buffer = deque (maxlen = 2 )
126+
127+ self .depth_latent = np .zeros ((1 , 32 ), dtype = np .float32 )
128+ self .vobs = np .zeros ((1 , 58 , 87 ), dtype = np .float32 )
129+ self .yaws = np .zeros ((1 , 2 ), dtype = np .float32 )
130+ self .obs = np .zeros ((1 , 53 ), dtype = np .float32 )
131+ self .obs_history = np .zeros ((1 , 10 , 53 ), dtype = np .float32 )
132+ self .rnn_hidden_in = np .zeros ((1 , 1 , 512 ), dtype = np .float32 )
133+ self .update_depth = np .zeros ((1 ,1 ), dtype = np .float32 )
134+ self .update_yaw = np .ones ((1 ,1 ), dtype = np .float32 )
135+ self .step_counter = np .zeros ((1 ,), dtype = np .float32 )
136+
137+ self .actions = np .zeros ((1 , 12 ), dtype = np .float32 )
138+
139+
140+ def forward (self , camera : bool = False ):
141+ if self .i == 0 :
142+ return np .zeros (12 )
143+
144+ robot_id = self .simulator .robot
145+
146+ if camera :
147+ im = self .simulator .get_camera_image ().astype (np .float32 )
148+ self .vobs [:] = (im / 255. ) - 0.5
149+ self .update_yaw [:] = 1.0
150+ self .update_depth [:] = 1.0
151+ else :
152+ self .update_yaw [:] = 0.0
153+ self .update_depth [:] = 0.0
154+
155+ w_P_b , w_Q_b = pb .getBasePositionAndOrientation (robot_id )
156+
157+ w_P_b = np .array (w_P_b , dtype = np .float32 )
158+ w_R_b = np .array (pb .getMatrixFromQuaternion (w_Q_b ), dtype = np .float32 ).reshape (
159+ 3 , 3
160+ )
161+
162+ self .w_T_b [:3 , :3 ] = w_R_b
163+ self .w_T_b [:3 , 3 ] = w_P_b
164+
165+ _ , ang_vel_w = pb .getBaseVelocity (robot_id )
166+ ang_vel_b = w_R_b .T @ np .array (ang_vel_w )
167+ contact_states = self .low_msg .foot_force > 20
168+
169+ roll , pitch , yaw = euler_from_quaternion (w_Q_b )
170+ imu_obs = np .array ([roll , pitch ])
171+
172+ q = np .array ([ms .q for ms in self .low_msg .motor_state ])[:12 ] - self .q0
173+
174+ self .joint_vel [:] = (q - self .joint_pos ) * 50.
175+ self .joint_pos [:] = q
176+
177+ obs_data = [
178+ 1 * ang_vel_b * 0.25 , # 3
179+ 1 * imu_obs , # 2
180+ [0.0 ],
181+ 1 * self .yaws .squeeze (),
182+ 1 * self .vel_cmd , # 3
183+ 1 * self .env_class , # 2
184+ 1 * self .joint_pos ,
185+ 1 * (self .joint_vel * 0.05 ),
186+ 1 * (self .actions .squeeze ()),
187+ 1 * (contact_states - 0.5 )
188+ ]
189+
190+ clip = lambda a : np .clip (a , - 100.0 , 100.0 )
191+ self .obs [:] = (
192+ np .concatenate (obs_data ).reshape (1 , 53 ).astype (np .float32 )
193+ )
194+ self .obs [:] = clip (self .obs )
195+ self .step_counter [:] = self .i - 1
196+
197+ # Policy module
198+ inputs = {
199+ "depth" : clip (self .vobs ),
200+ "depth_latent_in" : self .depth_latent ,
201+ "yaw_in" : clip (self .yaws ),
202+ "obs_proprio" : clip (self .obs ),
203+ "obs_history_in" : clip (self .obs_history ),
204+ "update_depth" : self .update_depth ,
205+ "update_yaw" : self .update_yaw ,
206+ "hidden_states_in" : self .rnn_hidden_in ,
207+ "step_counter" : self .step_counter
208+ }
209+
210+ nn_actions , depth_latent , yaws , obs_history , _ = self .onnx_session .run (
211+ ['actions' , 'depth_latent_out' , 'yaw_out' , 'obs_history_out' , 'hidden_states_out' ], inputs
212+ )
213+ self .actions [:] = nn_actions .astype (np .float32 )
214+ self .depth_latent [:] = depth_latent
215+ self .yaws [:] = yaws
216+ self .obs_history [:] = obs_history
217+ # self.rnn_hidden_in[:] = hidden_states_out
218+
219+ return self .q0 + (np .clip (self .actions .squeeze (), - 4.8 , 4.8 ) * .25 )
220+
221+
73222 def update (self ):
74223 ## Control robot
75- q_des = np .array ([self .last_cmd_msg .motor_cmd [i ].q for i in range (12 )])
76- v_des = np .array ([self .last_cmd_msg .motor_cmd [i ].dq for i in range (12 )])
77- tau_des = np .array ([self .last_cmd_msg .motor_cmd [i ].tau for i in range (12 )])
78- kp_des = np .array ([self .last_cmd_msg .motor_cmd [i ].kp for i in range (12 )])
79- kd_des = np .array ([self .last_cmd_msg .motor_cmd [i ].kd for i in range (12 )])
224+ if False :
225+ q_des = np .array ([self .last_cmd_msg .motor_cmd [i ].q for i in range (12 )])
226+ v_des = np .array ([self .last_cmd_msg .motor_cmd [i ].dq for i in range (12 )])
227+ tau_des = np .array ([self .last_cmd_msg .motor_cmd [i ].tau for i in range (12 )])
228+ kp_des = np .array ([self .last_cmd_msg .motor_cmd [i ].kp for i in range (12 )])
229+ kd_des = np .array ([self .last_cmd_msg .motor_cmd [i ].kd for i in range (12 )])
230+ else :
231+ # Camera update
232+ if self .i % self .camera_decimation == 0 :
233+ q_des = self .forward (camera = True )
234+ else :
235+ q_des = self .forward (camera = False )
236+
237+ v_des = np .zeros (12 )
238+ tau_des = np .zeros (12 )
239+ kp_des = 40 * np .ones (12 )
240+ kd_des = 1 * np .ones (12 )
80241
81242 for _ in range (self .low_level_sub_step ):
82243 # Iterate to simulate motor internal controller
@@ -112,6 +273,7 @@ def update(self):
112273 low_msg .foot_force = (14.2 * np .ones (4 ) + 0.562 * self .f_current ).astype (np .int32 ).tolist ()
113274
114275 # Format IMU
276+ # bullet quat
115277 quat_xyzw = self .q_current [3 :7 ].tolist ()
116278 l_angular_vel = self .v_current [3 :6 ] # In local frame
117279 l_linear_acc = self .a_current [0 :3 ] # In local frame
@@ -133,6 +295,7 @@ def update(self):
133295
134296 # Publish message
135297 self .lowstate_publisher .publish (low_msg )
298+ self .low_msg = low_msg
136299
137300 ## Send robot pose
138301 # Odometry / state estimation
@@ -167,10 +330,6 @@ def update(self):
167330 transform_msg .transform .rotation .w = self .q_current [6 ]
168331 self .tf_broadcaster .sendTransform (transform_msg )
169332
170- # Camera update
171- if self .i % self .camera_decimation == 0 :
172- self .camera_update ()
173-
174333 # Check that the simulator is on time
175334 if self .timer .time_until_next_call () < 0 and self .i % self .camera_decimation != 0 :
176335 ratio = 1.0 - self .timer .time_until_next_call () * 1e-9 / self .high_level_period
0 commit comments