forked from open-edge-platform/edge-ai-suites
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path0006-add-ros2-node-and-use-fixed-cube-pose.patch
More file actions
106 lines (96 loc) · 4.38 KB
/
0006-add-ros2-node-and-use-fixed-cube-pose.patch
File metadata and controls
106 lines (96 loc) · 4.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
From 1317f50458fbf572b4ad63320a93bca6a2d58149 Mon Sep 17 00:00:00 2001
From: AnnikaWU <xian.wu@intel.com>
Date: Thu, 13 Nov 2025 04:41:52 -0500
Subject: [PATCH] add ros2 node and use fixed cube pose
---
imitate_episodes.py | 45 ++++++++++++++++++++++++++++++++++++++++++---
1 file changed, 42 insertions(+), 3 deletions(-)
diff --git a/imitate_episodes.py b/imitate_episodes.py
index a695204..b32578d 100644
--- a/imitate_episodes.py
+++ b/imitate_episodes.py
@@ -23,6 +23,19 @@ from sim_env import BOX_POSE
import IPython
e = IPython.embed
+# import ros2
+import rclpy
+from rclpy.node import Node
+from std_msgs.msg import Float32MultiArray
+
+# create ros2 node
+class ActAlohaNode(Node):
+ def __init__(self):
+ super().__init__('act_aloha')
+ self.pub = self.create_publisher(Float32MultiArray, 'act_aloha_target_qpos', 10)
+ self.pub_left_gripper = self.create_publisher(Float32MultiArray, '/left_arm/act_target_gripper_qpos', 10)
+ self.pub_right_gripper = self.create_publisher(Float32MultiArray, '/right_arm/act_target_gripper_qpos', 10)
+
def main(args):
set_seed(1)
# command line parameters
@@ -217,7 +230,9 @@ def eval_bc(config, ckpt_name, save_episode=True):
from sim_env import make_sim_env
env = make_sim_env(task_name)
env_max_reward = env.task.max_reward
-
+ # initialize ros node in sim
+ rclpy.init()
+ ros_node = ActAlohaNode()
query_frequency = policy_config['num_queries']
if temporal_agg:
query_frequency = 1
@@ -225,14 +240,15 @@ def eval_bc(config, ckpt_name, save_episode=True):
max_timesteps = int(max_timesteps * 1) # may increase for real-world tasks
- num_rollouts = 10
+ num_rollouts = 1
episode_returns = []
highest_rewards = []
for rollout_id in range(num_rollouts):
rollout_id += 0
### set task
if 'sim_transfer_cube' in task_name:
- BOX_POSE[0] = sample_box_pose() # used in sim reset
+ # BOX_POSE[0] = sample_box_pose() # used in sim reset
+ BOX_POSE[0] = [0.2, 0.5, 0.05, 1, 0, 0, 0] # use fixed box pose and it needs to be consistent with the mojoco model file.
elif 'sim_insertion' in task_name:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
@@ -332,10 +348,28 @@ def eval_bc(config, ckpt_name, save_episode=True):
qpos_list.append(qpos_numpy)
target_qpos_list.append(target_qpos)
rewards.append(ts.reward)
+
+ # for ros topic publish
+ ros_time_target_qpos = target_qpos.tolist()
+ time_now = ros_node.get_clock().now().nanoseconds / 1e9
+ ros_time_target_qpos.append(time_now)
+
+ if not real_robot:
+ # ros2 publish messages
+ ros_target_qpos_msg = Float32MultiArray(data=ros_time_target_qpos)
+ ros_node.pub.publish(ros_target_qpos_msg)
+ ros_target_left_gripper_msg = Float32MultiArray(data=[ros_time_target_qpos[6], ros_time_target_qpos[6]])
+ ros_node.pub_left_gripper.publish(ros_target_left_gripper_msg)
+ ros_target_right_gripper_msg = Float32MultiArray(data=[ros_time_target_qpos[13], ros_time_target_qpos[13]])
+ ros_node.pub_right_gripper.publish(ros_target_right_gripper_msg)
+
if print_time:
print(f'screen render:{latencies[0]:.9f}s, process image:{(latencies[1]):.9f}s, model inference:{latencies[2]:.9f}, query policy:{(latencies[3]):.9f}s, post process:{(latencies[4]):.9f}, env:{(latencies[5]):.9f}')
latencies_all.append(latencies)
+ # ros2 node spin
+ rclpy.spin_once(ros_node, timeout_sec=0.001)
+
print(f'Avg fps {max_timesteps / (time.time() - time0)}')
if print_time:
latencies_all = np.array(latencies_all)
@@ -356,6 +390,11 @@ def eval_bc(config, ckpt_name, save_episode=True):
if save_episode:
save_videos(image_list, DT, video_path=os.path.join(ckpt_dir, f'video{rollout_id}.mp4'))
+
+ # shut down ros2 node
+ if not real_robot:
+ ros_node.destroy_node()
+ rclpy.shutdown()
success_rate = np.mean(np.array(highest_rewards) == env_max_reward)
avg_return = np.mean(episode_returns)
--
2.34.1