1+ from collections import defaultdict
12from typing import List , get_args
23import numpy as np
34import pytest
@@ -154,7 +155,34 @@ def test_correct_wrists(fake_poses: List[Pose]):
154155 assert corrected_pose != pose
155156 assert np .array_equal (corrected_pose .body .data , pose .body .data ) is False
156157
157-
158+ @pytest .mark .parametrize ("fake_poses" , ["holistic" ], indirect = ["fake_poses" ])
159+ def test_remove_one_point_and_one_component (fake_poses : List [Pose ]):
160+ component_to_drop = "POSE_WORLD_LANDMARKS"
161+ point_to_drop = "LEFT_KNEE"
162+ for pose in fake_poses :
163+ original_component_names = []
164+ original_points_dict = defaultdict (list )
165+ for component in pose .header .components :
166+ original_component_names .append (component .name )
167+
168+ for point in component .points :
169+ original_points_dict [component .name ].append (point )
170+
171+ assert component_to_drop in original_component_names
172+ assert point_to_drop in original_points_dict ["POSE_LANDMARKS" ]
173+ reduced_pose = pose .remove_components (component_to_drop , {"POSE_LANDMARKS" : [point_to_drop ]})
174+ new_component_names , new_points_dict = [], defaultdict (list )
175+ new_component_names = []
176+ new_points_dict = defaultdict (list )
177+ for component in reduced_pose .header .components :
178+ new_component_names .append (component .name )
179+
180+ for point in component .points :
181+ new_points_dict [component .name ].append (point )
182+
183+
184+ assert component_to_drop not in new_component_names
185+ assert point_to_drop not in new_points_dict ["POSE_LANDMARKS" ]
158186
159187
160188@pytest .mark .parametrize ("fake_poses" , TEST_POSE_FORMATS , indirect = ["fake_poses" ])
@@ -205,9 +233,3 @@ def test_fake_pose(known_pose_format: KnownPoseFormat):
205233 assert pose .header .num_dims () == pose .body .data .shape [- 1 ]
206234
207235 poses = [fake_pose (25 ) for _ in range (5 )]
208-
209-
210-
211-
212-
213-
0 commit comments