Skip to content

Commit 1762b1a

Browse files
committed
wip
1 parent ea43d66 commit 1762b1a

1 file changed

Lines changed: 84 additions & 82 deletions

File tree

go2_simulation/simulation_node.py

Lines changed: 84 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import rclpy
2+
import typing
23
from rclpy.node import Node
34
from sensor_msgs.msg import Image
45
from unitree_go.msg import LowState, LowCmd
@@ -42,67 +43,8 @@ def euler_from_quaternion(quat_angle):
4243

4344
return roll_x, pitch_y, yaw_z # in radians
4445

45-
46-
class Go2Simulation(Node):
46+
class Actor:
4747
def __init__(self):
48-
super().__init__("go2_simulation")
49-
simulator_name = self.declare_parameter("simulator", rclpy.Parameter.Type.STRING).value
50-
simulator_name = "pybullet" if simulator_name is None else simulator_name
51-
52-
########################### State publisher
53-
self.lowstate_publisher = self.create_publisher(LowState, "/lowstate", 10)
54-
self.odometry_publisher = self.create_publisher(Odometry, "/odometry/filtered", 10)
55-
self.tf_broadcaster = TransformBroadcaster(self)
56-
self.depth_publisher = self.create_publisher(Image, "/camera/depth", 10)
57-
self.clock_publisher = self.create_publisher(Clock, "/clock", 10)
58-
59-
# Timer to publish periodically
60-
self.high_level_period = 1.0 / 50 # seconds
61-
self.low_level_sub_step = 24
62-
self.timer = self.create_timer(self.high_level_period, self.update)
63-
64-
########################## Camera
65-
self.camera_period = 1.0 / 10 # seconds
66-
self.camera_decimation = int(self.camera_period / self.high_level_period)
67-
breakpoint()
68-
69-
########################## Cmd listener
70-
self.create_subscription(LowCmd, "/lowcmd", self.receive_cmd_cb, 10)
71-
self.last_cmd_msg = LowCmd()
72-
73-
########################## Simulator
74-
self.get_logger().info("go2_simulator::loading simulator")
75-
timestep = self.high_level_period / self.low_level_sub_step
76-
77-
self.simulator: AbstractSimulatorWrapper = None
78-
if simulator_name == "simple":
79-
from go2_simulation.simple_wrapper import SimpleWrapper
80-
81-
self.simulator = SimpleWrapper(self, timestep)
82-
elif simulator_name == "pybullet":
83-
from go2_simulation.bullet_wrapper import BulletWrapper
84-
85-
self.simulator = BulletWrapper(self, timestep)
86-
self.bridge = CvBridge()
87-
else:
88-
self.get_logger().error("Simulation tool not recognized")
89-
90-
self.simulator_name = simulator_name
91-
self.get_logger().info(f"go2_simulator::simulator {simulator_name} loaded")
92-
93-
########################## Initial state
94-
self.q_current = np.zeros(7 + 12)
95-
self.v_current = np.zeros(6 + 12)
96-
self.a_current = np.zeros(6 + 12)
97-
self.f_current = np.zeros(4)
98-
99-
self.i = 0
100-
self.sim_time = Time(seconds=0, nanoseconds=0)
101-
self.time_delta = Duration(seconds=0, nanoseconds=int(self.high_level_period * 1e9))
102-
103-
self.init_onnx()
104-
105-
def init_onnx(self):
10648
onnx_path = "./models/wall.onnx"
10749
onnx_path = "/home/hamlet/Workspace/reinforcement-learning/inference/" + onnx_path
10850
self.onnx_session = rt.InferenceSession(onnx_path)
@@ -134,37 +76,35 @@ def init_onnx(self):
13476
self.step_counter = np.zeros((1,), dtype=np.float32)
13577

13678
self.actions = np.zeros((1, 12), dtype=np.float32)
137-
138-
139-
def forward(self, camera: bool = False):
140-
if self.i == 0:
141-
return np.zeros(12)
142-
143-
robot_id = self.simulator.robot
144-
145-
if camera:
146-
im = self.simulator.get_camera_image().astype(np.float32)
79+
self.policy_step = 0
80+
81+
# LowCmd
82+
self.lowcmd = LowCmd()
83+
84+
def forward(self, lowstate: LowState, im: typing.Optional[Image] = None):
85+
if im:
86+
im = np.array(im.data).reshape(im.height,im.width)
14787
self.vobs[:] = (im / 255.) - 0.5
14888
self.update_yaw[:] = 1.0
14989
self.update_depth[:] = 1.0
15090
else:
15191
self.update_yaw[:] = 0.0
15292
self.update_depth[:] = 0.0
15393

154-
contact_states = self.low_msg.foot_force > 20
94+
contact_states = lowstate.foot_force > 20
15595

156-
quat = self.low_msg.imu_state.quaternion
96+
quat = lowstate.imu_state.quaternion
15797
roll, pitch, yaw = euler_from_quaternion(quat)
15898
imu_obs = np.array([roll, pitch])
15999

160-
q = np.array([ms.q for ms in self.low_msg.motor_state])[:12]
100+
q = np.array([ms.q for ms in lowstate.motor_state])[:12]
161101
q -= self.q0
162102

163103
self.joint_vel[:] = (q - self.joint_pos) * 50.
164104
self.joint_pos[:] = q
165105

166106
obs_data = [
167-
1 * self.low_msg.imu_state.gyroscope * 0.25, # 3
107+
1 * lowstate.imu_state.gyroscope * 0.25, # 3
168108
1 * imu_obs, # 2
169109
[0.0],
170110
1 * self.yaws.squeeze(),
@@ -181,7 +121,8 @@ def forward(self, camera: bool = False):
181121
np.concatenate(obs_data).reshape(1, 53).astype(np.float32)
182122
)
183123
self.obs[:] = clip(self.obs)
184-
self.step_counter[:] = self.i - 1
124+
self.step_counter[:] = self.policy_step
125+
self.policy_step += 1
185126

186127
# Policy module
187128
inputs = {
@@ -208,6 +149,65 @@ def forward(self, camera: bool = False):
208149
return self.q0 + (np.clip(self.actions.squeeze(), -4.8, 4.8) * .25)
209150

210151

152+
class Go2Simulation(Node):
153+
def __init__(self):
154+
super().__init__("go2_simulation")
155+
simulator_name = self.declare_parameter("simulator", rclpy.Parameter.Type.STRING).value
156+
simulator_name = "pybullet" if simulator_name is None else simulator_name
157+
158+
########################### State publisher
159+
self.lowstate_publisher = self.create_publisher(LowState, "/lowstate", 10)
160+
self.odometry_publisher = self.create_publisher(Odometry, "/odometry/filtered", 10)
161+
self.tf_broadcaster = TransformBroadcaster(self)
162+
self.depth_publisher = self.create_publisher(Image, "/camera/depth", 10)
163+
self.clock_publisher = self.create_publisher(Clock, "/clock", 10)
164+
165+
# Timer to publish periodically
166+
self.high_level_period = 1.0 / 50 # seconds
167+
self.low_level_sub_step = 24
168+
self.timer = self.create_timer(self.high_level_period, self.update)
169+
170+
########################## Camera
171+
self.camera_period = 1.0 / 10 # seconds
172+
self.camera_decimation = int(self.camera_period / self.high_level_period)
173+
174+
########################## Cmd listener
175+
self.create_subscription(LowCmd, "/lowcmd", self.receive_cmd_cb, 10)
176+
self.last_cmd_msg = LowCmd()
177+
178+
########################## Simulator
179+
self.get_logger().info("go2_simulator::loading simulator")
180+
timestep = self.high_level_period / self.low_level_sub_step
181+
182+
self.simulator: AbstractSimulatorWrapper = None
183+
if simulator_name == "simple":
184+
from go2_simulation.simple_wrapper import SimpleWrapper
185+
186+
self.simulator = SimpleWrapper(self, timestep)
187+
elif simulator_name == "pybullet":
188+
from go2_simulation.bullet_wrapper import BulletWrapper
189+
190+
self.simulator = BulletWrapper(self, timestep)
191+
self.bridge = CvBridge()
192+
else:
193+
self.get_logger().error("Simulation tool not recognized")
194+
195+
self.simulator_name = simulator_name
196+
self.get_logger().info(f"go2_simulator::simulator {simulator_name} loaded")
197+
198+
########################## Initial state
199+
self.q_current = np.zeros(7 + 12)
200+
self.v_current = np.zeros(6 + 12)
201+
self.a_current = np.zeros(6 + 12)
202+
self.f_current = np.zeros(4)
203+
204+
self.i = 0
205+
self.sim_time = Time(seconds=0, nanoseconds=0)
206+
self.time_delta = Duration(seconds=0, nanoseconds=int(self.high_level_period * 1e9))
207+
208+
self.actor = Actor()
209+
210+
211211
def update(self):
212212
## Control robot
213213
if False:
@@ -218,10 +218,13 @@ def update(self):
218218
kd_des = np.array([self.last_cmd_msg.motor_cmd[i].kd for i in range(12)])
219219
else:
220220
# Camera update
221-
if self.i % self.camera_decimation == 0:
222-
q_des = self.forward(camera=True)
221+
if self.i == 0:
222+
q_des = np.zeros(12)
223+
elif self.i % self.camera_decimation == 0:
224+
im = self.camera_update()
225+
q_des = self.actor.forward(self.low_msg, im=im)
223226
else:
224-
q_des = self.forward(camera=False)
227+
q_des = self.actor.forward(self.low_msg, im=None)
225228

226229
v_des = np.zeros(12)
227230
tau_des = np.zeros(12)
@@ -330,12 +333,11 @@ def update(self):
330333
def camera_update(self):
331334
if self.simulator_name == "pybullet":
332335
im = self.simulator.get_camera_image()
333-
else:
334-
self.get_logger().warn(f"Camera not implemented for this simulator: {self.simulator_name}")
335-
336-
if im is not None:
337336
img_msg = self.bridge.cv2_to_imgmsg(im, encoding="mono8")
338337
self.depth_publisher.publish(img_msg)
338+
return img_msg
339+
else:
340+
self.get_logger().warn(f"Camera not implemented for this simulator: {self.simulator_name}")
339341

340342
def receive_cmd_cb(self, msg):
341343
self.last_cmd_msg = msg

0 commit comments

Comments
 (0)