11import rclpy
2+ import typing
23from rclpy .node import Node
34from sensor_msgs .msg import Image
45from unitree_go .msg import LowState , LowCmd
@@ -42,67 +43,8 @@ def euler_from_quaternion(quat_angle):
4243
4344 return roll_x , pitch_y , yaw_z # in radians
4445
45-
46- class Go2Simulation (Node ):
46+ class Actor :
4747 def __init__ (self ):
48- super ().__init__ ("go2_simulation" )
49- simulator_name = self .declare_parameter ("simulator" , rclpy .Parameter .Type .STRING ).value
50- simulator_name = "pybullet" if simulator_name is None else simulator_name
51-
52- ########################### State publisher
53- self .lowstate_publisher = self .create_publisher (LowState , "/lowstate" , 10 )
54- self .odometry_publisher = self .create_publisher (Odometry , "/odometry/filtered" , 10 )
55- self .tf_broadcaster = TransformBroadcaster (self )
56- self .depth_publisher = self .create_publisher (Image , "/camera/depth" , 10 )
57- self .clock_publisher = self .create_publisher (Clock , "/clock" , 10 )
58-
59- # Timer to publish periodically
60- self .high_level_period = 1.0 / 50 # seconds
61- self .low_level_sub_step = 24
62- self .timer = self .create_timer (self .high_level_period , self .update )
63-
64- ########################## Camera
65- self .camera_period = 1.0 / 10 # seconds
66- self .camera_decimation = int (self .camera_period / self .high_level_period )
67- breakpoint ()
68-
69- ########################## Cmd listener
70- self .create_subscription (LowCmd , "/lowcmd" , self .receive_cmd_cb , 10 )
71- self .last_cmd_msg = LowCmd ()
72-
73- ########################## Simulator
74- self .get_logger ().info ("go2_simulator::loading simulator" )
75- timestep = self .high_level_period / self .low_level_sub_step
76-
77- self .simulator : AbstractSimulatorWrapper = None
78- if simulator_name == "simple" :
79- from go2_simulation .simple_wrapper import SimpleWrapper
80-
81- self .simulator = SimpleWrapper (self , timestep )
82- elif simulator_name == "pybullet" :
83- from go2_simulation .bullet_wrapper import BulletWrapper
84-
85- self .simulator = BulletWrapper (self , timestep )
86- self .bridge = CvBridge ()
87- else :
88- self .get_logger ().error ("Simulation tool not recognized" )
89-
90- self .simulator_name = simulator_name
91- self .get_logger ().info (f"go2_simulator::simulator { simulator_name } loaded" )
92-
93- ########################## Initial state
94- self .q_current = np .zeros (7 + 12 )
95- self .v_current = np .zeros (6 + 12 )
96- self .a_current = np .zeros (6 + 12 )
97- self .f_current = np .zeros (4 )
98-
99- self .i = 0
100- self .sim_time = Time (seconds = 0 , nanoseconds = 0 )
101- self .time_delta = Duration (seconds = 0 , nanoseconds = int (self .high_level_period * 1e9 ))
102-
103- self .init_onnx ()
104-
105- def init_onnx (self ):
10648 onnx_path = "./models/wall.onnx"
10749 onnx_path = "/home/hamlet/Workspace/reinforcement-learning/inference/" + onnx_path
10850 self .onnx_session = rt .InferenceSession (onnx_path )
@@ -134,37 +76,35 @@ def init_onnx(self):
13476 self .step_counter = np .zeros ((1 ,), dtype = np .float32 )
13577
13678 self .actions = np .zeros ((1 , 12 ), dtype = np .float32 )
137-
138-
139- def forward (self , camera : bool = False ):
140- if self .i == 0 :
141- return np .zeros (12 )
142-
143- robot_id = self .simulator .robot
144-
145- if camera :
146- im = self .simulator .get_camera_image ().astype (np .float32 )
79+ self .policy_step = 0
80+
81+ # LowCmd
82+ self .lowcmd = LowCmd ()
83+
84+ def forward (self , lowstate : LowState , im : typing .Optional [Image ] = None ):
85+ if im :
86+ im = np .array (im .data ).reshape (im .height ,im .width )
14787 self .vobs [:] = (im / 255. ) - 0.5
14888 self .update_yaw [:] = 1.0
14989 self .update_depth [:] = 1.0
15090 else :
15191 self .update_yaw [:] = 0.0
15292 self .update_depth [:] = 0.0
15393
154- contact_states = self . low_msg .foot_force > 20
94+ contact_states = lowstate .foot_force > 20
15595
156- quat = self . low_msg .imu_state .quaternion
96+ quat = lowstate .imu_state .quaternion
15797 roll , pitch , yaw = euler_from_quaternion (quat )
15898 imu_obs = np .array ([roll , pitch ])
15999
160- q = np .array ([ms .q for ms in self . low_msg .motor_state ])[:12 ]
100+ q = np .array ([ms .q for ms in lowstate .motor_state ])[:12 ]
161101 q -= self .q0
162102
163103 self .joint_vel [:] = (q - self .joint_pos ) * 50.
164104 self .joint_pos [:] = q
165105
166106 obs_data = [
167- 1 * self . low_msg .imu_state .gyroscope * 0.25 , # 3
107+ 1 * lowstate .imu_state .gyroscope * 0.25 , # 3
168108 1 * imu_obs , # 2
169109 [0.0 ],
170110 1 * self .yaws .squeeze (),
@@ -181,7 +121,8 @@ def forward(self, camera: bool = False):
181121 np .concatenate (obs_data ).reshape (1 , 53 ).astype (np .float32 )
182122 )
183123 self .obs [:] = clip (self .obs )
184- self .step_counter [:] = self .i - 1
124+ self .step_counter [:] = self .policy_step
125+ self .policy_step += 1
185126
186127 # Policy module
187128 inputs = {
@@ -208,6 +149,65 @@ def forward(self, camera: bool = False):
208149 return self .q0 + (np .clip (self .actions .squeeze (), - 4.8 , 4.8 ) * .25 )
209150
210151
152+ class Go2Simulation (Node ):
153+ def __init__ (self ):
154+ super ().__init__ ("go2_simulation" )
155+ simulator_name = self .declare_parameter ("simulator" , rclpy .Parameter .Type .STRING ).value
156+ simulator_name = "pybullet" if simulator_name is None else simulator_name
157+
158+ ########################### State publisher
159+ self .lowstate_publisher = self .create_publisher (LowState , "/lowstate" , 10 )
160+ self .odometry_publisher = self .create_publisher (Odometry , "/odometry/filtered" , 10 )
161+ self .tf_broadcaster = TransformBroadcaster (self )
162+ self .depth_publisher = self .create_publisher (Image , "/camera/depth" , 10 )
163+ self .clock_publisher = self .create_publisher (Clock , "/clock" , 10 )
164+
165+ # Timer to publish periodically
166+ self .high_level_period = 1.0 / 50 # seconds
167+ self .low_level_sub_step = 24
168+ self .timer = self .create_timer (self .high_level_period , self .update )
169+
170+ ########################## Camera
171+ self .camera_period = 1.0 / 10 # seconds
172+ self .camera_decimation = int (self .camera_period / self .high_level_period )
173+
174+ ########################## Cmd listener
175+ self .create_subscription (LowCmd , "/lowcmd" , self .receive_cmd_cb , 10 )
176+ self .last_cmd_msg = LowCmd ()
177+
178+ ########################## Simulator
179+ self .get_logger ().info ("go2_simulator::loading simulator" )
180+ timestep = self .high_level_period / self .low_level_sub_step
181+
182+ self .simulator : AbstractSimulatorWrapper = None
183+ if simulator_name == "simple" :
184+ from go2_simulation .simple_wrapper import SimpleWrapper
185+
186+ self .simulator = SimpleWrapper (self , timestep )
187+ elif simulator_name == "pybullet" :
188+ from go2_simulation .bullet_wrapper import BulletWrapper
189+
190+ self .simulator = BulletWrapper (self , timestep )
191+ self .bridge = CvBridge ()
192+ else :
193+ self .get_logger ().error ("Simulation tool not recognized" )
194+
195+ self .simulator_name = simulator_name
196+ self .get_logger ().info (f"go2_simulator::simulator { simulator_name } loaded" )
197+
198+ ########################## Initial state
199+ self .q_current = np .zeros (7 + 12 )
200+ self .v_current = np .zeros (6 + 12 )
201+ self .a_current = np .zeros (6 + 12 )
202+ self .f_current = np .zeros (4 )
203+
204+ self .i = 0
205+ self .sim_time = Time (seconds = 0 , nanoseconds = 0 )
206+ self .time_delta = Duration (seconds = 0 , nanoseconds = int (self .high_level_period * 1e9 ))
207+
208+ self .actor = Actor ()
209+
210+
211211 def update (self ):
212212 ## Control robot
213213 if False :
@@ -218,10 +218,13 @@ def update(self):
218218 kd_des = np .array ([self .last_cmd_msg .motor_cmd [i ].kd for i in range (12 )])
219219 else :
220220 # Camera update
221- if self .i % self .camera_decimation == 0 :
222- q_des = self .forward (camera = True )
221+ if self .i == 0 :
222+ q_des = np .zeros (12 )
223+ elif self .i % self .camera_decimation == 0 :
224+ im = self .camera_update ()
225+ q_des = self .actor .forward (self .low_msg , im = im )
223226 else :
224- q_des = self .forward (camera = False )
227+ q_des = self .actor . forward (self . low_msg , im = None )
225228
226229 v_des = np .zeros (12 )
227230 tau_des = np .zeros (12 )
@@ -330,12 +333,11 @@ def update(self):
330333 def camera_update (self ):
331334 if self .simulator_name == "pybullet" :
332335 im = self .simulator .get_camera_image ()
333- else :
334- self .get_logger ().warn (f"Camera not implemented for this simulator: { self .simulator_name } " )
335-
336- if im is not None :
337336 img_msg = self .bridge .cv2_to_imgmsg (im , encoding = "mono8" )
338337 self .depth_publisher .publish (img_msg )
338+ return img_msg
339+ else :
340+ self .get_logger ().warn (f"Camera not implemented for this simulator: { self .simulator_name } " )
339341
340342 def receive_cmd_cb (self , msg ):
341343 self .last_cmd_msg = msg
0 commit comments