Skip to content

Commit e07ca68

Browse files
authored
Bugfix/remove components crash (and tests for pose.remove_components()) (#150)
* CDL: minor doc typo fix * Undoing some changes that got mixed in * Fix remove_components crash #149 * Add test cases for pose.remove_components, update random pose object so header matches body * Another quick test case
1 parent e22d323 commit e07ca68

File tree

3 files changed

+157
-20
lines changed

3 files changed

+157
-20
lines changed

src/python/pose_format/pose.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,11 @@ def remove_components(self, components_to_remove: Union[str, List[str]], points_
215215
for component in self.header.components:
216216
if component.name not in components_to_remove:
217217
components_to_keep.append(component.name)
218-
points_dict[component.name] = []
219-
if points_to_remove is not None:
220-
for point in component.points:
221-
if point not in points_to_remove[component.name]:
222-
points_dict[component.name].append(point)
218+
if points_to_remove:
219+
points_to_remove_list = points_to_remove.get(component.name, [])
220+
points_dict[component.name] = [point for point in component.points if point not in points_to_remove_list]
221+
else:
222+
points_dict[component.name] = component.points[:]
223223

224224
return self.get_components(components_to_keep, points_dict)
225225

src/python/pose_format/utils/generic_test.py

+29-7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
from typing import List, get_args
23
import numpy as np
34
import pytest
@@ -154,7 +155,34 @@ def test_correct_wrists(fake_poses: List[Pose]):
154155
assert corrected_pose != pose
155156
assert np.array_equal(corrected_pose.body.data, pose.body.data) is False
156157

157-
158+
@pytest.mark.parametrize("fake_poses", ["holistic"], indirect=["fake_poses"])
159+
def test_remove_one_point_and_one_component(fake_poses: List[Pose]):
160+
component_to_drop = "POSE_WORLD_LANDMARKS"
161+
point_to_drop = "LEFT_KNEE"
162+
for pose in fake_poses:
163+
original_component_names = []
164+
original_points_dict = defaultdict(list)
165+
for component in pose.header.components:
166+
original_component_names.append(component.name)
167+
168+
for point in component.points:
169+
original_points_dict[component.name].append(point)
170+
171+
assert component_to_drop in original_component_names
172+
assert point_to_drop in original_points_dict["POSE_LANDMARKS"]
173+
reduced_pose = pose.remove_components(component_to_drop, {"POSE_LANDMARKS": [point_to_drop]})
174+
new_component_names, new_points_dict = [], defaultdict(list)
175+
new_component_names = []
176+
new_points_dict = defaultdict(list)
177+
for component in reduced_pose.header.components:
178+
new_component_names.append(component.name)
179+
180+
for point in component.points:
181+
new_points_dict[component.name].append(point)
182+
183+
184+
assert component_to_drop not in new_component_names
185+
assert point_to_drop not in new_points_dict["POSE_LANDMARKS"]
158186

159187

160188
@pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"])
@@ -205,9 +233,3 @@ def test_fake_pose(known_pose_format: KnownPoseFormat):
205233
assert pose.header.num_dims() == pose.body.data.shape[-1]
206234

207235
poses = [fake_pose(25) for _ in range(5)]
208-
209-
210-
211-
212-
213-

src/python/tests/pose_test.py

+123-8
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ def _create_pose_header_component(name: str, num_keypoints: int) -> PoseHeaderCo
5555
return component
5656

5757

58+
def _distribute_points_among_components(component_count: int, total_keypoint_count: int):
59+
if component_count <= 0 or total_keypoint_count < component_count + 1:
60+
raise ValueError("Total keypoints must be at least component count+1 (so that 0 can have two), and component count must be positive")
61+
62+
# Step 1: Initialize with required minimum values
63+
keypoint_counts = [2] + [1] * (component_count - 1) # Ensure first is 2, others at least 1
64+
65+
# Step 2: Distribute remaining points
66+
remaining_points = total_keypoint_count - sum(keypoint_counts)
67+
for _ in range(remaining_points):
68+
keypoint_counts[random.randint(0, component_count - 1)] += 1 # Add randomly
69+
70+
return keypoint_counts
71+
5872
def _create_pose_header(width: int, height: int, depth: int, num_components: int, num_keypoints: int) -> PoseHeader:
5973
"""
6074
Create a PoseHeader with given dimensions and components.
@@ -79,8 +93,10 @@ def _create_pose_header(width: int, height: int, depth: int, num_components: int
7993
"""
8094
dimensions = PoseHeaderDimensions(width=width, height=height, depth=depth)
8195

96+
keypoints_per_component = _distribute_points_among_components(num_components, num_keypoints)
97+
8298
components = [
83-
_create_pose_header_component(name=str(index), num_keypoints=num_keypoints) for index in range(num_components)
99+
_create_pose_header_component(name=str(index), num_keypoints=keypoints_per_component[index]) for index in range(num_components)
84100
]
85101

86102
header = PoseHeader(version=1.0, dimensions=dimensions, components=components)
@@ -134,6 +150,8 @@ def _create_random_tensorflow_data(frames_min: Optional[int] = None,
134150
return tensor, mask, confidence
135151

136152

153+
154+
137155
def _create_random_numpy_data(frames_min: Optional[int] = None,
138156
frames_max: Optional[int] = None,
139157
num_frames: Optional[int] = None,
@@ -286,7 +304,7 @@ def _get_random_pose_object_with_tf_posebody(num_keypoints: int, frames_min: int
286304
return Pose(header=header, body=body)
287305

288306

289-
def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min: int = 1, frames_max: int = 10) -> Pose:
307+
def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min: int = 1, frames_max: int = 10, num_components=3) -> Pose:
290308
"""
291309
Creates a random Pose object with Numpy pose body for testing.
292310
@@ -313,7 +331,7 @@ def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min:
313331

314332
body = NumPyPoseBody(fps=10, data=masked_array, confidence=confidence)
315333

316-
header = _create_pose_header(width=10, height=7, depth=0, num_components=3, num_keypoints=num_keypoints)
334+
header = _create_pose_header(width=10, height=7, depth=0, num_components=num_components, num_keypoints=num_keypoints)
317335

318336
return Pose(header=header, body=body)
319337

@@ -329,6 +347,96 @@ def test_pose_object_should_be_callable(self):
329347
"""
330348
assert callable(Pose)
331349

350+
def test_pose_remove_components(self):
351+
pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5)
352+
assert pose.body.data.shape[-2] == 5
353+
assert pose.body.data.shape[-1] == 2 # XY dimensions
354+
355+
self.assertEqual(len(pose.header.components), 3)
356+
self.assertEqual(sum(len(c.points) for c in pose.header.components), 5)
357+
self.assertEqual(pose.header.components[0].name, "0")
358+
self.assertEqual(pose.header.components[1].name, "1")
359+
self.assertEqual(pose.header.components[0].points[0], "0_a")
360+
self.assertIn("1_a", pose.header.components[1].points)
361+
self.assertNotIn("1_f", pose.header.components[1].points)
362+
self.assertNotIn("4", pose.header.components)
363+
364+
# test that we can remove a component
365+
component_to_remove = "0"
366+
pose_copy = pose.copy()
367+
self.assertIn(component_to_remove, [c.name for c in pose_copy.header.components])
368+
pose_copy = pose_copy.remove_components(component_to_remove)
369+
self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components])
370+
371+
372+
# Remove a point only
373+
point_to_remove = "0_a"
374+
pose_copy = pose.copy()
375+
self.assertIn(point_to_remove, pose_copy.header.components[0].points)
376+
pose_copy = pose_copy.remove_components([], {point_to_remove[0]:[point_to_remove]})
377+
self.assertNotIn(point_to_remove, pose_copy.header.components[0].points)
378+
379+
380+
# Can we remove two things at once
381+
component_to_remove = "1"
382+
point_to_remove = "2_a"
383+
component_to_remove_point_from = "2"
384+
385+
self.assertIn(component_to_remove, [c.name for c in pose_copy.header.components])
386+
self.assertIn(component_to_remove_point_from, [c.name for c in pose_copy.header.components])
387+
self.assertIn(point_to_remove, pose_copy.header.components[2].points)
388+
pose_copy = pose_copy.remove_components([component_to_remove], {component_to_remove_point_from:[point_to_remove]})
389+
self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components])
390+
self.assertIn(component_to_remove_point_from, [c.name for c in pose_copy.header.components]) # this should still be around
391+
392+
# can we remove a component and a point FROM that component without crashing
393+
component_to_remove = "0"
394+
point_to_remove = "0_a"
395+
pose_copy = pose.copy()
396+
self.assertIn(point_to_remove, pose_copy.header.components[0].points)
397+
pose_copy = pose_copy.remove_components([component_to_remove], {component_to_remove:[point_to_remove]})
398+
self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components])
399+
self.assertNotIn(point_to_remove, pose_copy.header.components[0].points)
400+
401+
402+
# can we "remove" a component that doesn't exist without crashing
403+
component_to_remove = "NOT EXISTING"
404+
pose_copy = pose.copy()
405+
initial_count = len(pose_copy.header.components)
406+
pose_copy = pose_copy.remove_components([component_to_remove])
407+
self.assertEqual(initial_count, len(pose_copy.header.components))
408+
409+
410+
411+
412+
# can we "remove" a point that doesn't exist from a component that does without crashing
413+
point_to_remove = "2_x"
414+
component_to_remove_point_from = "2"
415+
pose_copy = pose.copy()
416+
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)
417+
pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[point_to_remove]})
418+
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)
419+
420+
421+
# can we "remove" an empty list of points
422+
component_to_remove_point_from = "2"
423+
pose_copy = pose.copy()
424+
initial_component_count = len(pose_copy.header.components)
425+
initial_point_count = len(pose_copy.header.components[2].points)
426+
pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[]})
427+
self.assertEqual(initial_component_count, len(pose_copy.header.components))
428+
self.assertEqual(len(pose_copy.header.components[2].points), initial_point_count)
429+
430+
431+
# can we remove a point from a component that doesn't exist
432+
point_to_remove = "2_x"
433+
component_to_remove_point_from = "NOT EXISTING"
434+
pose_copy = pose.copy()
435+
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)
436+
pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[point_to_remove]})
437+
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)
438+
439+
332440

333441

334442
class TestPoseTensorflowPoseBody(TestCase):
@@ -475,7 +583,7 @@ def create_pose_and_frame_dropout_uniform(example: tf.Tensor) -> tf.Tensor:
475583
return example
476584

477585
dataset.map(create_pose_and_frame_dropout_uniform)
478-
586+
479587

480588
def test_pose_tf_posebody_copy_creates_deepcopy(self):
481589
pose = _get_random_pose_object_with_tf_posebody(num_keypoints=5)
@@ -488,7 +596,7 @@ def test_pose_tf_posebody_copy_creates_deepcopy(self):
488596

489597
# Check that pose and pose_copy are not the same object
490598
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
491-
599+
492600
# Ensure the data tensors are equal but independent
493601
self.assertTrue(tf.reduce_all(pose.body.data == pose_copy.body.data), "Copy's data should match original")
494602

@@ -515,6 +623,14 @@ class TestPoseNumpyPoseBody(TestCase):
515623
Testcases for Pose objects containing NumPy PoseBody data.
516624
"""
517625

626+
def test_pose_numpy_generated_with_correct_shape(self):
627+
pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5, frames_min=3)
628+
629+
# does the header match the body?
630+
expected_keypoints_count_from_header = sum(len(c.points) for c in pose.header.components)
631+
self.assertEqual(expected_keypoints_count_from_header, pose.body.data.shape[-2])
632+
633+
518634
def test_pose_numpy_posebody_normalize_preserves_shape(self):
519635
"""
520636
Tests if the normalization of Pose object with NumPy PoseBody preserves array shape.
@@ -593,17 +709,16 @@ def test_pose_torch_posebody_copy_creates_deepcopy(self):
593709
pose = _get_random_pose_object_with_torch_posebody(num_keypoints=5)
594710
self.assertIsInstance(pose.body, TorchPoseBody)
595711
self.assertIsInstance(pose.body.data, TorchMaskedTensor)
596-
597712

598713
pose_copy = pose.copy()
599714
self.assertIsInstance(pose_copy.body, TorchPoseBody)
600715
self.assertIsInstance(pose_copy.body.data, TorchMaskedTensor)
601716

602-
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
717+
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
603718
self.assertTrue(pose.body.data.tensor.equal(pose_copy.body.data.tensor), "Copy's data should match original")
604719
self.assertTrue(pose.body.data.mask.equal(pose_copy.body.data.mask), "Copy's mask should match original")
605720

606-
pose.body.data = TorchMaskedTensor(tensor=torch.zeros_like(pose.body.data.tensor),
721+
pose.body.data = TorchMaskedTensor(tensor=torch.zeros_like(pose.body.data.tensor),
607722
mask=torch.ones_like(pose.body.data.mask))
608723

609724

0 commit comments

Comments
 (0)