Skip to content

Commit a3da9a5

Browse files
committed
Fix remove_components crash #149
1 parent 487f92a commit a3da9a5

File tree

2 files changed

+34
-12
lines changed

2 files changed

+34
-12
lines changed

src/python/pose_format/pose.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,11 @@ def remove_components(self, components_to_remove: Union[str, List[str]], points_
215215
for component in self.header.components:
216216
if component.name not in components_to_remove:
217217
components_to_keep.append(component.name)
218-
points_dict[component.name] = []
219-
if points_to_remove is not None:
220-
for point in component.points:
221-
if point not in points_to_remove[component.name]:
222-
points_dict[component.name].append(point)
218+
if points_to_remove:
219+
points_to_remove_list = points_to_remove.get(component.name, [])
220+
points_dict[component.name] = [point for point in component.points if point not in points_to_remove_list]
221+
else:
222+
points_dict[component.name] = component.points[:]
223223

224224
return self.get_components(components_to_keep, points_dict)
225225

src/python/pose_format/utils/generic_test.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
from typing import List, get_args
23
import numpy as np
34
import pytest
@@ -154,7 +155,34 @@ def test_correct_wrists(fake_poses: List[Pose]):
154155
assert corrected_pose != pose
155156
assert np.array_equal(corrected_pose.body.data, pose.body.data) is False
156157

157-
158+
@pytest.mark.parametrize("fake_poses", ["holistic"], indirect=["fake_poses"])
159+
def test_remove_one_point_and_one_component(fake_poses: List[Pose]):
160+
component_to_drop = "POSE_WORLD_LANDMARKS"
161+
point_to_drop = "LEFT_KNEE"
162+
for pose in fake_poses:
163+
original_component_names = []
164+
original_points_dict = defaultdict(list)
165+
for component in pose.header.components:
166+
original_component_names.append(component.name)
167+
168+
for point in component.points:
169+
original_points_dict[component.name].append(point)
170+
171+
assert component_to_drop in original_component_names
172+
assert point_to_drop in original_points_dict["POSE_LANDMARKS"]
173+
reduced_pose = pose.remove_components(component_to_drop, {"POSE_LANDMARKS": [point_to_drop]})
174+
new_component_names, new_points_dict = [], defaultdict(list)
175+
new_component_names = []
176+
new_points_dict = defaultdict(list)
177+
for component in reduced_pose.header.components:
178+
new_component_names.append(component.name)
179+
180+
for point in component.points:
181+
new_points_dict[component.name].append(point)
182+
183+
184+
assert component_to_drop not in new_component_names
185+
assert point_to_drop not in new_points_dict["POSE_LANDMARKS"]
158186

159187

160188
@pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"])
@@ -205,9 +233,3 @@ def test_fake_pose(known_pose_format: KnownPoseFormat):
205233
assert pose.header.num_dims() == pose.body.data.shape[-1]
206234

207235
poses = [fake_pose(25) for _ in range(5)]
208-
209-
210-
211-
212-
213-

0 commit comments

Comments
 (0)