diff --git a/src/python/pose_format/pose.py b/src/python/pose_format/pose.py index 5031f64..084c5fd 100644 --- a/src/python/pose_format/pose.py +++ b/src/python/pose_format/pose.py @@ -215,11 +215,11 @@ def remove_components(self, components_to_remove: Union[str, List[str]], points_ for component in self.header.components: if component.name not in components_to_remove: components_to_keep.append(component.name) - points_dict[component.name] = [] - if points_to_remove is not None: - for point in component.points: - if point not in points_to_remove[component.name]: - points_dict[component.name].append(point) + if points_to_remove: + points_to_remove_list = points_to_remove.get(component.name, []) + points_dict[component.name] = [point for point in component.points if point not in points_to_remove_list] + else: + points_dict[component.name] = component.points[:] return self.get_components(components_to_keep, points_dict) diff --git a/src/python/pose_format/utils/generic_test.py b/src/python/pose_format/utils/generic_test.py index 4ef8070..461f55e 100644 --- a/src/python/pose_format/utils/generic_test.py +++ b/src/python/pose_format/utils/generic_test.py @@ -1,3 +1,4 @@ +from collections import defaultdict from typing import List, get_args import numpy as np import pytest @@ -154,7 +155,34 @@ def test_correct_wrists(fake_poses: List[Pose]): assert corrected_pose != pose assert np.array_equal(corrected_pose.body.data, pose.body.data) is False - +@pytest.mark.parametrize("fake_poses", ["holistic"], indirect=["fake_poses"]) +def test_remove_one_point_and_one_component(fake_poses: List[Pose]): + component_to_drop = "POSE_WORLD_LANDMARKS" + point_to_drop = "LEFT_KNEE" + for pose in fake_poses: + original_component_names = [] + original_points_dict = defaultdict(list) + for component in pose.header.components: + original_component_names.append(component.name) + + for point in component.points: + original_points_dict[component.name].append(point) + + assert component_to_drop in original_component_names + assert point_to_drop in original_points_dict["POSE_LANDMARKS"] + reduced_pose = pose.remove_components(component_to_drop, {"POSE_LANDMARKS": [point_to_drop]}) + new_component_names, new_points_dict = [], defaultdict(list) + new_component_names = [] + new_points_dict = defaultdict(list) + for component in reduced_pose.header.components: + new_component_names.append(component.name) + + for point in component.points: + new_points_dict[component.name].append(point) + + + assert component_to_drop not in new_component_names + assert point_to_drop not in new_points_dict["POSE_LANDMARKS"] @pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"]) @@ -205,9 +233,3 @@ def test_fake_pose(known_pose_format: KnownPoseFormat): assert pose.header.num_dims() == pose.body.data.shape[-1] poses = [fake_pose(25) for _ in range(5)] - - - - - - \ No newline at end of file diff --git a/src/python/tests/pose_test.py b/src/python/tests/pose_test.py index 897a0dc..1215b3f 100644 --- a/src/python/tests/pose_test.py +++ b/src/python/tests/pose_test.py @@ -55,6 +55,20 @@ def _create_pose_header_component(name: str, num_keypoints: int) -> PoseHeaderCo return component +def _distribute_points_among_components(component_count: int, total_keypoint_count: int): + if component_count <= 0 or total_keypoint_count < component_count + 1: + raise ValueError("Total keypoints must be at least component count+1 (so that 0 can have two), and component count must be positive") + + # Step 1: Initialize with required minimum values + keypoint_counts = [2] + [1] * (component_count - 1) # Ensure first is 2, others at least 1 + + # Step 2: Distribute remaining points + remaining_points = total_keypoint_count - sum(keypoint_counts) + for _ in range(remaining_points): + keypoint_counts[random.randint(0, component_count - 1)] += 1 # Add randomly + + return keypoint_counts + def _create_pose_header(width: int, height: int, depth: int, num_components: int, num_keypoints: int) -> PoseHeader: """ Create a PoseHeader with given dimensions and components. @@ -79,8 +93,10 @@ def _create_pose_header(width: int, height: int, depth: int, num_components: int """ dimensions = PoseHeaderDimensions(width=width, height=height, depth=depth) + keypoints_per_component = _distribute_points_among_components(num_components, num_keypoints) + components = [ - _create_pose_header_component(name=str(index), num_keypoints=num_keypoints) for index in range(num_components) + _create_pose_header_component(name=str(index), num_keypoints=keypoints_per_component[index]) for index in range(num_components) ] header = PoseHeader(version=1.0, dimensions=dimensions, components=components) @@ -134,6 +150,8 @@ def _create_random_tensorflow_data(frames_min: Optional[int] = None, return tensor, mask, confidence + + def _create_random_numpy_data(frames_min: Optional[int] = None, frames_max: Optional[int] = None, num_frames: Optional[int] = None, @@ -286,7 +304,7 @@ def _get_random_pose_object_with_tf_posebody(num_keypoints: int, frames_min: int return Pose(header=header, body=body) -def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min: int = 1, frames_max: int = 10) -> Pose: +def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min: int = 1, frames_max: int = 10, num_components=3) -> Pose: """ Creates a random Pose object with Numpy pose body for testing. @@ -313,7 +331,7 @@ def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min: body = NumPyPoseBody(fps=10, data=masked_array, confidence=confidence) - header = _create_pose_header(width=10, height=7, depth=0, num_components=3, num_keypoints=num_keypoints) + header = _create_pose_header(width=10, height=7, depth=0, num_components=num_components, num_keypoints=num_keypoints) return Pose(header=header, body=body) @@ -329,6 +347,96 @@ def test_pose_object_should_be_callable(self): """ assert callable(Pose) + def test_pose_remove_components(self): + pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5) + assert pose.body.data.shape[-2] == 5 + assert pose.body.data.shape[-1] == 2 # XY dimensions + + self.assertEqual(len(pose.header.components), 3) + self.assertEqual(sum(len(c.points) for c in pose.header.components), 5) + self.assertEqual(pose.header.components[0].name, "0") + self.assertEqual(pose.header.components[1].name, "1") + self.assertEqual(pose.header.components[0].points[0], "0_a") + self.assertIn("1_a", pose.header.components[1].points) + self.assertNotIn("1_f", pose.header.components[1].points) + self.assertNotIn("4", pose.header.components) + + # test that we can remove a component + component_to_remove = "0" + pose_copy = pose.copy() + self.assertIn(component_to_remove, [c.name for c in pose_copy.header.components]) + pose_copy = pose_copy.remove_components(component_to_remove) + self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components]) + + + # Remove a point only + point_to_remove = "0_a" + pose_copy = pose.copy() + self.assertIn(point_to_remove, pose_copy.header.components[0].points) + pose_copy = pose_copy.remove_components([], {point_to_remove[0]:[point_to_remove]}) + self.assertNotIn(point_to_remove, pose_copy.header.components[0].points) + + + # Can we remove two things at once + component_to_remove = "1" + point_to_remove = "2_a" + component_to_remove_point_from = "2" + + self.assertIn(component_to_remove, [c.name for c in pose_copy.header.components]) + self.assertIn(component_to_remove_point_from, [c.name for c in pose_copy.header.components]) + self.assertIn(point_to_remove, pose_copy.header.components[2].points) + pose_copy = pose_copy.remove_components([component_to_remove], {component_to_remove_point_from:[point_to_remove]}) + self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components]) + self.assertIn(component_to_remove_point_from, [c.name for c in pose_copy.header.components]) # this should still be around + + # can we remove a component and a point FROM that component without crashing + component_to_remove = "0" + point_to_remove = "0_a" + pose_copy = pose.copy() + self.assertIn(point_to_remove, pose_copy.header.components[0].points) + pose_copy = pose_copy.remove_components([component_to_remove], {component_to_remove:[point_to_remove]}) + self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components]) + self.assertNotIn(point_to_remove, pose_copy.header.components[0].points) + + + # can we "remove" a component that doesn't exist without crashing + component_to_remove = "NOT EXISTING" + pose_copy = pose.copy() + initial_count = len(pose_copy.header.components) + pose_copy = pose_copy.remove_components([component_to_remove]) + self.assertEqual(initial_count, len(pose_copy.header.components)) + + + + + # can we "remove" a point that doesn't exist from a component that does without crashing + point_to_remove = "2_x" + component_to_remove_point_from = "2" + pose_copy = pose.copy() + self.assertNotIn(point_to_remove, pose_copy.header.components[2].points) + pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[point_to_remove]}) + self.assertNotIn(point_to_remove, pose_copy.header.components[2].points) + + + # can we "remove" an empty list of points + component_to_remove_point_from = "2" + pose_copy = pose.copy() + initial_component_count = len(pose_copy.header.components) + initial_point_count = len(pose_copy.header.components[2].points) + pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[]}) + self.assertEqual(initial_component_count, len(pose_copy.header.components)) + self.assertEqual(len(pose_copy.header.components[2].points), initial_point_count) + + + # can we remove a point from a component that doesn't exist + point_to_remove = "2_x" + component_to_remove_point_from = "NOT EXISTING" + pose_copy = pose.copy() + self.assertNotIn(point_to_remove, pose_copy.header.components[2].points) + pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[point_to_remove]}) + self.assertNotIn(point_to_remove, pose_copy.header.components[2].points) + + class TestPoseTensorflowPoseBody(TestCase): @@ -475,7 +583,7 @@ def create_pose_and_frame_dropout_uniform(example: tf.Tensor) -> tf.Tensor: return example dataset.map(create_pose_and_frame_dropout_uniform) - + def test_pose_tf_posebody_copy_creates_deepcopy(self): pose = _get_random_pose_object_with_tf_posebody(num_keypoints=5) @@ -488,7 +596,7 @@ def test_pose_tf_posebody_copy_creates_deepcopy(self): # Check that pose and pose_copy are not the same object self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original") - + # Ensure the data tensors are equal but independent self.assertTrue(tf.reduce_all(pose.body.data == pose_copy.body.data), "Copy's data should match original") @@ -515,6 +623,14 @@ class TestPoseNumpyPoseBody(TestCase): Testcases for Pose objects containing NumPy PoseBody data. """ + def test_pose_numpy_generated_with_correct_shape(self): + pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5, frames_min=3) + + # does the header match the body? + expected_keypoints_count_from_header = sum(len(c.points) for c in pose.header.components) + self.assertEqual(expected_keypoints_count_from_header, pose.body.data.shape[-2]) + + def test_pose_numpy_posebody_normalize_preserves_shape(self): """ Tests if the normalization of Pose object with NumPy PoseBody preserves array shape. @@ -593,17 +709,16 @@ def test_pose_torch_posebody_copy_creates_deepcopy(self): pose = _get_random_pose_object_with_torch_posebody(num_keypoints=5) self.assertIsInstance(pose.body, TorchPoseBody) self.assertIsInstance(pose.body.data, TorchMaskedTensor) - pose_copy = pose.copy() self.assertIsInstance(pose_copy.body, TorchPoseBody) self.assertIsInstance(pose_copy.body.data, TorchMaskedTensor) - self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original") + self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original") self.assertTrue(pose.body.data.tensor.equal(pose_copy.body.data.tensor), "Copy's data should match original") self.assertTrue(pose.body.data.mask.equal(pose_copy.body.data.mask), "Copy's mask should match original") - pose.body.data = TorchMaskedTensor(tensor=torch.zeros_like(pose.body.data.tensor), + pose.body.data = TorchMaskedTensor(tensor=torch.zeros_like(pose.body.data.tensor), mask=torch.ones_like(pose.body.data.mask))