@@ -89,40 +89,24 @@ def format_input(self, qpos_cur, qpos_goal, qpos_start, envs_idx):
8989 return qpos_cur , qpos_goal , qpos_start , envs_idx
9090
9191 def get_exclude_geom_pairs (self , qpos_goal , qpos_start , envs_idx ):
92- if self ._solver .n_envs > 0 :
93- self ._entity .set_qpos (qpos_goal , envs_idx = envs_idx )
94- else :
95- self ._entity .set_qpos (qpos_goal [0 ])
96- self ._solver ._kernel_detect_collision ()
97- scene_contact_info = self ._entity .get_contacts ()
98- geom_a_goal = scene_contact_info ["geom_a" ]
99- geom_b_goal = scene_contact_info ["geom_b" ]
100- if self ._solver .n_envs > 0 :
101- valid_mask = scene_contact_info ["valid_mask" ]
102- geom_a_goal = geom_a_goal [valid_mask ]
103- geom_b_goal = geom_b_goal [valid_mask ]
104-
105- if self ._solver .n_envs > 0 :
106- self ._entity .set_qpos (qpos_start , envs_idx = envs_idx )
107- else :
108- self ._entity .set_qpos (qpos_start [0 ])
109- self ._solver ._kernel_detect_collision ()
110- scene_contact_info = self ._entity .get_contacts ()
111- geom_a_start = scene_contact_info ["geom_a" ]
112- geom_b_start = scene_contact_info ["geom_b" ]
113- if self ._solver .n_envs > 0 :
114- valid_mask = scene_contact_info ["valid_mask" ]
115- geom_a_start = geom_a_start [valid_mask ]
116- geom_b_start = geom_b_start [valid_mask ]
92+ collision_pairs = []
93+ for qpos in [qpos_start , qpos_goal ]:
94+ if self ._solver .n_envs > 0 :
95+ self ._entity .set_qpos (qpos , envs_idx = envs_idx )
96+ else :
97+ self ._entity .set_qpos (qpos [0 ])
98+ self ._solver ._kernel_detect_collision ()
99+ scene_contact_info = self ._entity .get_contacts ()
100+ geom_a = scene_contact_info ["geom_a" ]
101+ geom_b = scene_contact_info ["geom_b" ]
102+ if self ._solver .n_envs > 0 :
103+ valid_mask = scene_contact_info ["valid_mask" ]
104+ geom_a = geom_a [valid_mask ]
105+ geom_b = geom_b [valid_mask ]
106+ collision_pairs .append (torch .stack ((geom_a , geom_b ), dim = 1 ))
117107
118- # NOTE: we will reduce the contacts in batch dim assuming internal geom collisions are the same for a robot
119108 unique_pairs = torch .unique (
120- torch .cat (
121- [
122- torch .stack ((geom_a_start , geom_b_start ), dim = 1 ),
123- torch .stack ((geom_a_goal , geom_b_goal ), dim = 1 ),
124- ]
125- ),
109+ torch .cat (collision_pairs , dim = 0 ),
126110 dim = 0 ,
127111 ) # N', 2
128112 return unique_pairs
0 commit comments