Skip to content

Commit 6e1d9e2

Browse files
committed
IMPORTANT: IT CLIMBS
1 parent ee47f3f commit 6e1d9e2

2 files changed

Lines changed: 203 additions & 46 deletions

File tree

go2_simulation/bullet_wrapper.py

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,7 @@ def __init__(self, node, timestep):
3737
self.load_obstacles()
3838

3939
def init_pybullet(self, timestep):
40-
cid = pybullet.connect(pybullet.SHARED_MEMORY)
41-
if cid < 0:
42-
pybullet.connect(pybullet.GUI, options="--opengl3")
43-
else:
44-
pybullet.connect(pybullet.GUI)
45-
46-
# Load robot
47-
self.robot = pybullet.loadURDF(GO2_DESCRIPTION_URDF_PATH, [0, 0, 0.4])
48-
print('URDF loaded from:', GO2_DESCRIPTION_URDF_PATH)
49-
self.localInertiaPos = pybullet.getDynamicsInfo(self.robot, -1)[3]
40+
pybullet.connect(pybullet.GUI, options="--opengl3")
5041

5142
# Load ground plane and other obstacles
5243
self.env_ids = [] # Keep track of all obstacles
@@ -56,10 +47,11 @@ def init_pybullet(self, timestep):
5647
self.env_ids.append(self.plane_id)
5748
pybullet.resetBasePositionAndOrientation(self.plane_id, [0, 0, 0], [0, 0, 0, 1])
5849

59-
self.ramp_id = pybullet.loadURDF(
60-
os.path.join(get_package_share_directory("go2_simulation"), "data/assets/obstacles.urdf")
61-
)
62-
self.env_ids.append(self.ramp_id)
50+
# Load robot
51+
GO2_DESCRIPTION_URDF_PATH = '/home/hamlet/Workspace/reinforcement-learning/inference/assets/go2/go2.urdf'
52+
self.robot = pybullet.loadURDF(GO2_DESCRIPTION_URDF_PATH, [0, 0, 0.4])
53+
print('URDF loaded from:', GO2_DESCRIPTION_URDF_PATH)
54+
self.localInertiaPos = pybullet.getDynamicsInfo(self.robot, -1)[3]
6355

6456
# Set time step
6557
pybullet.setTimeStep(timestep)
@@ -97,7 +89,7 @@ def init_pybullet(self, timestep):
9789
self.feet_idx[foot_id] = (i, link_name)
9890

9991
# Set robot initial config on the ground
100-
initial_q = [0.0, 1.00, -2.1, 0.0, 1.00, -2.1, 0, 1.00, -2.1, 0, 1.00, -2.1]
92+
initial_q = [-0.1, 0.8, -1.5, 0.1, 0.8, -1.5, -0.1, 1., -1.5, 0.1, 1., -1.5]
10193
for i, id in enumerate(self.j_idx):
10294
pybullet.resetJointState(self.robot, id, initial_q[i])
10395

@@ -148,7 +140,7 @@ def load_obstacles(self):
148140
pybullet.GEOM_BOX, halfExtents=half_extents, rgbaColor=[1, 0, 0, 1]
149141
)
150142

151-
num_boxes = 8
143+
num_boxes = 2
152144
for i in range(num_boxes):
153145
box_id = pybullet.createMultiBody(
154146
baseMass=0,
@@ -178,6 +170,30 @@ def get_joint_id(self, joint_name):
178170
return i
179171
return None # Joint name not found
180172

173+
174+
def get_feet_contact_states(self):
175+
f_current = np.zeros(4)
176+
for i, foot_name in enumerate(self.foot_link_names):
177+
for collision_id in self.env_ids:
178+
foot_link_id = self.feet_idx[i][0]
179+
180+
# Get contact points between foot and ground
181+
contact_points = pybullet.getContactPoints(
182+
bodyA=self.robot,
183+
bodyB=collision_id,
184+
linkIndexA=foot_link_id
185+
)
186+
187+
# Check if there are any contacts
188+
is_in_contact = len(contact_points) > 0
189+
190+
if is_in_contact:
191+
f_current[i] = 39.4 # roughly 1/4 of the robot mass (0th order approx)
192+
break # No need to check other obstacles for this foot
193+
194+
return f_current
195+
196+
181197
def step(self, tau_cmd):
182198
# Set actuation
183199
pybullet.setJointMotorControlArray(
@@ -212,27 +228,9 @@ def step(self, tau_cmd):
212228
q_current = np.concatenate((np.array(linear_pose), np.array(angular_pose), joint_position))
213229
v_current = np.concatenate((np.array(linear_vel), np.array(angular_vel), joint_velocity))
214230
a_current = ((v_current - self.v_last) / self.dt) if self.v_last is not None else np.zeros(6 + 12)
215-
f_current = np.zeros(4)
216-
217231
self.v_last = v_current
218232

219-
for i, foot_name in enumerate(self.foot_link_names):
220-
for collision_id in self.env_ids:
221-
foot_link_id = self.feet_idx[i][0]
222-
223-
# Get contact points between foot and ground
224-
contact_points = pybullet.getContactPoints(
225-
bodyA=self.robot,
226-
bodyB=collision_id,
227-
linkIndexA=foot_link_id
228-
)
229-
230-
# Check if there are any contacts
231-
is_in_contact = len(contact_points) > 0
232-
233-
if is_in_contact:
234-
f_current[i] = 39.4 # roughly 1/4 of the robot mass (0th order approx)
235-
break # No need to check other obstacles for this foot
233+
f_current = self.get_feet_contact_states()
236234

237235
return q_current, v_current, a_current, f_current
238236

go2_simulation/simulation_node.py

Lines changed: 170 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from unitree_go.msg import LowState, LowCmd
55
from nav_msgs.msg import Odometry
66
import numpy as np
7+
import pybullet as pb
78
from scipy.spatial.transform import Rotation as R
89

910
from tf2_ros import TransformBroadcaster
@@ -14,6 +15,34 @@
1415
from rclpy.time import Time
1516
from rclpy.duration import Duration
1617

18+
import onnxruntime as rt
19+
from collections import deque
20+
21+
def euler_from_quaternion(quat_angle):
22+
"""
23+
NOTE: This was copied from extreme-parkour repo
24+
25+
Convert a quaternion into euler angles (roll, pitch, yaw)
26+
roll is rotation around x in radians (counterclockwise)
27+
pitch is rotation around y in radians (counterclockwise)
28+
yaw is rotation around z in radians (counterclockwise)
29+
"""
30+
x, y, z, w = quat_angle
31+
t0 = +2.0 * (w * x + y * z)
32+
t1 = +1.0 - 2.0 * (x * x + y * y)
33+
roll_x = np.arctan2(t0, t1)
34+
35+
t2 = +2.0 * (w * y - z * x)
36+
t2 = np.clip(t2, -1, 1)
37+
pitch_y = np.arcsin(t2)
38+
39+
t3 = +2.0 * (w * z + x * y)
40+
t4 = +1.0 - 2.0 * (y * y + z * z)
41+
yaw_z = np.arctan2(t3, t4)
42+
43+
return roll_x, pitch_y, yaw_z # in radians
44+
45+
1746
class Go2Simulation(Node):
1847
def __init__(self):
1948
super().__init__("go2_simulation")
@@ -28,13 +57,14 @@ def __init__(self):
2857
self.clock_publisher = self.create_publisher(Clock, "/clock", 10)
2958

3059
# Timer to publish periodically
31-
self.high_level_period = 1.0 / 500 # seconds
32-
self.low_level_sub_step = 4
60+
self.high_level_period = 1.0 / 50 # seconds
61+
self.low_level_sub_step = 24
3362
self.timer = self.create_timer(self.high_level_period, self.update)
3463

3564
########################## Camera
3665
self.camera_period = 1.0 / 10 # seconds
3766
self.camera_decimation = int(self.camera_period / self.high_level_period)
67+
breakpoint()
3868

3969
########################## Cmd listener
4070
self.create_subscription(LowCmd, "/lowcmd", self.receive_cmd_cb, 10)
@@ -70,13 +100,144 @@ def __init__(self):
70100
self.sim_time = Time(seconds=0, nanoseconds=0)
71101
self.time_delta = Duration(seconds=0, nanoseconds=int(self.high_level_period * 1e9))
72102

103+
self.init_onnx()
104+
105+
def init_onnx(self):
106+
onnx_path = "./models/wall.onnx"
107+
onnx_path = "/home/hamlet/Workspace/reinforcement-learning/inference/" + onnx_path
108+
self.onnx_session = rt.InferenceSession(onnx_path)
109+
110+
self.w_T_b = np.eye(4)
111+
self.joint_pos = np.zeros(12)
112+
self.joint_vel = np.zeros(12)
113+
self.joint_pos_policy = np.zeros(12)
114+
self.joint_vel_policy = np.zeros(12)
115+
116+
self.q0 = np.array([-0.1, 0.8, -1.5, 0.1, 0.8, -1.5, -0.1, 1., -1.5, 0.1, 1., -1.5])
117+
self.q_des = self.q0.copy()
118+
119+
# First two elements are 0, third is the forward speed
120+
forward_speed = 0.37
121+
self.vel_cmd = np.array([0., 0., forward_speed])
122+
self.env_class = np.array([1, 0])
123+
124+
self.action_buffer = deque(maxlen=2)
125+
self.depth_buffer = deque(maxlen=2)
126+
127+
self.depth_latent = np.zeros((1, 32), dtype=np.float32)
128+
self.vobs = np.zeros((1, 58, 87), dtype=np.float32)
129+
self.yaws = np.zeros((1, 2), dtype=np.float32)
130+
self.obs = np.zeros((1, 53), dtype=np.float32)
131+
self.obs_history = np.zeros((1, 10, 53), dtype=np.float32)
132+
self.rnn_hidden_in = np.zeros((1, 1, 512), dtype=np.float32)
133+
self.update_depth = np.zeros((1,1), dtype=np.float32)
134+
self.update_yaw = np.ones((1,1), dtype=np.float32)
135+
self.step_counter = np.zeros((1,), dtype=np.float32)
136+
137+
self.actions = np.zeros((1, 12), dtype=np.float32)
138+
139+
140+
def forward(self, camera: bool = False):
141+
if self.i == 0:
142+
return np.zeros(12)
143+
144+
robot_id = self.simulator.robot
145+
146+
if camera:
147+
im = self.simulator.get_camera_image().astype(np.float32)
148+
self.vobs[:] = (im / 255.) - 0.5
149+
self.update_yaw[:] = 1.0
150+
self.update_depth[:] = 1.0
151+
else:
152+
self.update_yaw[:] = 0.0
153+
self.update_depth[:] = 0.0
154+
155+
w_P_b, w_Q_b = pb.getBasePositionAndOrientation(robot_id)
156+
157+
w_P_b = np.array(w_P_b, dtype=np.float32)
158+
w_R_b = np.array(pb.getMatrixFromQuaternion(w_Q_b), dtype=np.float32).reshape(
159+
3, 3
160+
)
161+
162+
self.w_T_b[:3, :3] = w_R_b
163+
self.w_T_b[:3, 3] = w_P_b
164+
165+
_, ang_vel_w = pb.getBaseVelocity(robot_id)
166+
ang_vel_b = w_R_b.T @ np.array(ang_vel_w)
167+
contact_states = self.low_msg.foot_force > 20
168+
169+
roll, pitch, yaw = euler_from_quaternion(w_Q_b)
170+
imu_obs = np.array([roll, pitch])
171+
172+
q = np.array([ms.q for ms in self.low_msg.motor_state])[:12] - self.q0
173+
174+
self.joint_vel[:] = (q - self.joint_pos) * 50.
175+
self.joint_pos[:] = q
176+
177+
obs_data = [
178+
1 * ang_vel_b * 0.25, # 3
179+
1 * imu_obs, # 2
180+
[0.0],
181+
1 * self.yaws.squeeze(),
182+
1 * self.vel_cmd, # 3
183+
1 * self.env_class, # 2
184+
1 * self.joint_pos,
185+
1 * (self.joint_vel * 0.05),
186+
1 * (self.actions.squeeze()),
187+
1 * (contact_states - 0.5)
188+
]
189+
190+
clip = lambda a: np.clip(a, -100.0, 100.0)
191+
self.obs[:] = (
192+
np.concatenate(obs_data).reshape(1, 53).astype(np.float32)
193+
)
194+
self.obs[:] = clip(self.obs)
195+
self.step_counter[:] = self.i - 1
196+
197+
# Policy module
198+
inputs = {
199+
"depth": clip(self.vobs),
200+
"depth_latent_in": self.depth_latent,
201+
"yaw_in": clip(self.yaws),
202+
"obs_proprio": clip(self.obs),
203+
"obs_history_in": clip(self.obs_history),
204+
"update_depth": self.update_depth,
205+
"update_yaw": self.update_yaw,
206+
"hidden_states_in": self.rnn_hidden_in,
207+
"step_counter": self.step_counter
208+
}
209+
210+
nn_actions, depth_latent, yaws, obs_history, _ = self.onnx_session.run(
211+
['actions', 'depth_latent_out', 'yaw_out', 'obs_history_out', 'hidden_states_out'], inputs
212+
)
213+
self.actions[:] = nn_actions.astype(np.float32)
214+
self.depth_latent[:] = depth_latent
215+
self.yaws[:] = yaws
216+
self.obs_history[:] = obs_history
217+
# self.rnn_hidden_in[:] = hidden_states_out
218+
219+
return self.q0 + (np.clip(self.actions.squeeze(), -4.8, 4.8) * .25)
220+
221+
73222
def update(self):
74223
## Control robot
75-
q_des = np.array([self.last_cmd_msg.motor_cmd[i].q for i in range(12)])
76-
v_des = np.array([self.last_cmd_msg.motor_cmd[i].dq for i in range(12)])
77-
tau_des = np.array([self.last_cmd_msg.motor_cmd[i].tau for i in range(12)])
78-
kp_des = np.array([self.last_cmd_msg.motor_cmd[i].kp for i in range(12)])
79-
kd_des = np.array([self.last_cmd_msg.motor_cmd[i].kd for i in range(12)])
224+
if False:
225+
q_des = np.array([self.last_cmd_msg.motor_cmd[i].q for i in range(12)])
226+
v_des = np.array([self.last_cmd_msg.motor_cmd[i].dq for i in range(12)])
227+
tau_des = np.array([self.last_cmd_msg.motor_cmd[i].tau for i in range(12)])
228+
kp_des = np.array([self.last_cmd_msg.motor_cmd[i].kp for i in range(12)])
229+
kd_des = np.array([self.last_cmd_msg.motor_cmd[i].kd for i in range(12)])
230+
else:
231+
# Camera update
232+
if self.i % self.camera_decimation == 0:
233+
q_des = self.forward(camera=True)
234+
else:
235+
q_des = self.forward(camera=False)
236+
237+
v_des = np.zeros(12)
238+
tau_des = np.zeros(12)
239+
kp_des = 40 * np.ones(12)
240+
kd_des = 1 * np.ones(12)
80241

81242
for _ in range(self.low_level_sub_step):
82243
# Iterate to simulate motor internal controller
@@ -112,6 +273,7 @@ def update(self):
112273
low_msg.foot_force = (14.2 * np.ones(4) + 0.562 * self.f_current).astype(np.int32).tolist()
113274

114275
# Format IMU
276+
# bullet quat
115277
quat_xyzw = self.q_current[3:7].tolist()
116278
l_angular_vel = self.v_current[3:6] # In local frame
117279
l_linear_acc = self.a_current[0:3] # In local frame
@@ -133,6 +295,7 @@ def update(self):
133295

134296
# Publish message
135297
self.lowstate_publisher.publish(low_msg)
298+
self.low_msg = low_msg
136299

137300
## Send robot pose
138301
# Odometry / state estimation
@@ -167,10 +330,6 @@ def update(self):
167330
transform_msg.transform.rotation.w = self.q_current[6]
168331
self.tf_broadcaster.sendTransform(transform_msg)
169332

170-
# Camera update
171-
if self.i % self.camera_decimation == 0:
172-
self.camera_update()
173-
174333
# Check that the simulator is on time
175334
if self.timer.time_until_next_call() < 0 and self.i % self.camera_decimation != 0:
176335
ratio = 1.0 - self.timer.time_until_next_call() * 1e-9 / self.high_level_period

0 commit comments

Comments
 (0)