Skip to content

Commit 0d52c53

Browse files
committed
Add test cases for pose.remove_components, update random pose object so header matches body
1 parent a3da9a5 commit 0d52c53

File tree

1 file changed

+109
-8
lines changed

1 file changed

+109
-8
lines changed

src/python/tests/pose_test.py

+109-8
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ def _create_pose_header_component(name: str, num_keypoints: int) -> PoseHeaderCo
5555
return component
5656

5757

58+
def _distribute_points_among_components(component_count: int, total_keypoint_count: int):
59+
if component_count <= 0 or total_keypoint_count < component_count + 1:
60+
raise ValueError("Total keypoints must be at least component count+1 (so that 0 can have two), and component count must be positive")
61+
62+
# Step 1: Initialize with required minimum values
63+
keypoint_counts = [2] + [1] * (component_count - 1) # Ensure first is 2, others at least 1
64+
65+
# Step 2: Distribute remaining points
66+
remaining_points = total_keypoint_count - sum(keypoint_counts)
67+
for _ in range(remaining_points):
68+
keypoint_counts[random.randint(0, component_count - 1)] += 1 # Add randomly
69+
70+
return keypoint_counts
71+
5872
def _create_pose_header(width: int, height: int, depth: int, num_components: int, num_keypoints: int) -> PoseHeader:
5973
"""
6074
Create a PoseHeader with given dimensions and components.
@@ -79,8 +93,10 @@ def _create_pose_header(width: int, height: int, depth: int, num_components: int
7993
"""
8094
dimensions = PoseHeaderDimensions(width=width, height=height, depth=depth)
8195

96+
keypoints_per_component = _distribute_points_among_components(num_components, num_keypoints)
97+
8298
components = [
83-
_create_pose_header_component(name=str(index), num_keypoints=num_keypoints) for index in range(num_components)
99+
_create_pose_header_component(name=str(index), num_keypoints=keypoints_per_component[index]) for index in range(num_components)
84100
]
85101

86102
header = PoseHeader(version=1.0, dimensions=dimensions, components=components)
@@ -134,6 +150,8 @@ def _create_random_tensorflow_data(frames_min: Optional[int] = None,
134150
return tensor, mask, confidence
135151

136152

153+
154+
137155
def _create_random_numpy_data(frames_min: Optional[int] = None,
138156
frames_max: Optional[int] = None,
139157
num_frames: Optional[int] = None,
@@ -286,7 +304,7 @@ def _get_random_pose_object_with_tf_posebody(num_keypoints: int, frames_min: int
286304
return Pose(header=header, body=body)
287305

288306

289-
def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min: int = 1, frames_max: int = 10) -> Pose:
307+
def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min: int = 1, frames_max: int = 10, num_components=3) -> Pose:
290308
"""
291309
Creates a random Pose object with Numpy pose body for testing.
292310
@@ -313,7 +331,7 @@ def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min:
313331

314332
body = NumPyPoseBody(fps=10, data=masked_array, confidence=confidence)
315333

316-
header = _create_pose_header(width=10, height=7, depth=0, num_components=3, num_keypoints=num_keypoints)
334+
header = _create_pose_header(width=10, height=7, depth=0, num_components=num_components, num_keypoints=num_keypoints)
317335

318336
return Pose(header=header, body=body)
319337

@@ -329,6 +347,82 @@ def test_pose_object_should_be_callable(self):
329347
"""
330348
assert callable(Pose)
331349

350+
def test_pose_remove_components(self):
351+
pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5)
352+
assert pose.body.data.shape[-2] == 5
353+
assert pose.body.data.shape[-1] == 2 # XY dimensions
354+
355+
self.assertEqual(len(pose.header.components), 3)
356+
self.assertEqual(sum(len(c.points) for c in pose.header.components), 5)
357+
self.assertEqual(pose.header.components[0].name, "0")
358+
self.assertEqual(pose.header.components[1].name, "1")
359+
self.assertEqual(pose.header.components[0].points[0], "0_a")
360+
self.assertIn("1_a", pose.header.components[1].points)
361+
self.assertNotIn("1_f", pose.header.components[1].points)
362+
self.assertNotIn("4", pose.header.components)
363+
364+
# test that we can remove a component
365+
component_to_remove = "0"
366+
pose_copy = pose.copy()
367+
self.assertIn(component_to_remove, [c.name for c in pose_copy.header.components])
368+
pose_copy = pose_copy.remove_components(component_to_remove)
369+
self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components])
370+
371+
372+
# Remove a point only
373+
point_to_remove = "0_a"
374+
pose_copy = pose.copy()
375+
self.assertIn(point_to_remove, pose_copy.header.components[0].points)
376+
pose_copy = pose_copy.remove_components([], {point_to_remove[0]:[point_to_remove]})
377+
self.assertNotIn(point_to_remove, pose_copy.header.components[0].points)
378+
379+
380+
# Can we remove two things at once
381+
component_to_remove = "1"
382+
point_to_remove = "2_a"
383+
component_to_remove_point_from = "2"
384+
385+
self.assertIn(component_to_remove, [c.name for c in pose_copy.header.components])
386+
self.assertIn(component_to_remove_point_from, [c.name for c in pose_copy.header.components])
387+
self.assertIn(point_to_remove, pose_copy.header.components[2].points)
388+
pose_copy = pose_copy.remove_components([component_to_remove], {component_to_remove_point_from:[point_to_remove]})
389+
self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components])
390+
self.assertIn(component_to_remove_point_from, [c.name for c in pose_copy.header.components]) # this should still be around
391+
392+
# can we remove a component and a point FROM that component without crashing
393+
component_to_remove = "0"
394+
point_to_remove = "0_a"
395+
pose_copy = pose.copy()
396+
self.assertIn(point_to_remove, pose_copy.header.components[0].points)
397+
pose_copy = pose_copy.remove_components([component_to_remove], {component_to_remove:[point_to_remove]})
398+
self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components])
399+
self.assertNotIn(point_to_remove, pose_copy.header.components[0].points)
400+
401+
402+
# can we "remove" a component that doesn't exist without crashing
403+
component_to_remove = "NOT EXISTING"
404+
pose_copy = pose.copy()
405+
initial_count = len(pose_copy.header.components)
406+
pose_copy = pose_copy.remove_components([component_to_remove])
407+
self.assertEqual(initial_count, len(pose_copy.header.components))
408+
409+
# can we "remove" a point that doesn't exist from a component that does without crashing
410+
point_to_remove = "2_x"
411+
component_to_remove_point_from = "2"
412+
pose_copy = pose.copy()
413+
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)
414+
pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[point_to_remove]})
415+
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)
416+
417+
# can we remove a point from a component that doesn't exist
418+
point_to_remove = "2_x"
419+
component_to_remove_point_from = "NOT EXISTING"
420+
pose_copy = pose.copy()
421+
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)
422+
pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[point_to_remove]})
423+
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)
424+
425+
332426

333427

334428
class TestPoseTensorflowPoseBody(TestCase):
@@ -475,7 +569,7 @@ def create_pose_and_frame_dropout_uniform(example: tf.Tensor) -> tf.Tensor:
475569
return example
476570

477571
dataset.map(create_pose_and_frame_dropout_uniform)
478-
572+
479573

480574
def test_pose_tf_posebody_copy_creates_deepcopy(self):
481575
pose = _get_random_pose_object_with_tf_posebody(num_keypoints=5)
@@ -488,7 +582,7 @@ def test_pose_tf_posebody_copy_creates_deepcopy(self):
488582

489583
# Check that pose and pose_copy are not the same object
490584
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
491-
585+
492586
# Ensure the data tensors are equal but independent
493587
self.assertTrue(tf.reduce_all(pose.body.data == pose_copy.body.data), "Copy's data should match original")
494588

@@ -515,6 +609,14 @@ class TestPoseNumpyPoseBody(TestCase):
515609
Testcases for Pose objects containing NumPy PoseBody data.
516610
"""
517611

612+
def test_pose_numpy_generated_with_correct_shape(self):
613+
pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5, frames_min=3)
614+
615+
# does the header match the body?
616+
expected_keypoints_count_from_header = sum(len(c.points) for c in pose.header.components)
617+
self.assertEqual(expected_keypoints_count_from_header, pose.body.data.shape[-2])
618+
619+
518620
def test_pose_numpy_posebody_normalize_preserves_shape(self):
519621
"""
520622
Tests if the normalization of Pose object with NumPy PoseBody preserves array shape.
@@ -593,17 +695,16 @@ def test_pose_torch_posebody_copy_creates_deepcopy(self):
593695
pose = _get_random_pose_object_with_torch_posebody(num_keypoints=5)
594696
self.assertIsInstance(pose.body, TorchPoseBody)
595697
self.assertIsInstance(pose.body.data, TorchMaskedTensor)
596-
597698

598699
pose_copy = pose.copy()
599700
self.assertIsInstance(pose_copy.body, TorchPoseBody)
600701
self.assertIsInstance(pose_copy.body.data, TorchMaskedTensor)
601702

602-
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
703+
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
603704
self.assertTrue(pose.body.data.tensor.equal(pose_copy.body.data.tensor), "Copy's data should match original")
604705
self.assertTrue(pose.body.data.mask.equal(pose_copy.body.data.mask), "Copy's mask should match original")
605706

606-
pose.body.data = TorchMaskedTensor(tensor=torch.zeros_like(pose.body.data.tensor),
707+
pose.body.data = TorchMaskedTensor(tensor=torch.zeros_like(pose.body.data.tensor),
607708
mask=torch.ones_like(pose.body.data.mask))
608709

609710

0 commit comments

Comments
 (0)