Skip to content

Bugfix/remove components crash (and tests for pose.remove_components()) #150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/python/pose_format/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ def remove_components(self, components_to_remove: Union[str, List[str]], points_
for component in self.header.components:
if component.name not in components_to_remove:
components_to_keep.append(component.name)
points_dict[component.name] = []
if points_to_remove is not None:
for point in component.points:
if point not in points_to_remove[component.name]:
points_dict[component.name].append(point)
if points_to_remove:
points_to_remove_list = points_to_remove.get(component.name, [])
points_dict[component.name] = [point for point in component.points if point not in points_to_remove_list]
else:
points_dict[component.name] = component.points[:]

return self.get_components(components_to_keep, points_dict)

Expand Down
36 changes: 29 additions & 7 deletions src/python/pose_format/utils/generic_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from typing import List, get_args
import numpy as np
import pytest
Expand Down Expand Up @@ -154,7 +155,34 @@ def test_correct_wrists(fake_poses: List[Pose]):
assert corrected_pose != pose
assert np.array_equal(corrected_pose.body.data, pose.body.data) is False


@pytest.mark.parametrize("fake_poses", ["holistic"], indirect=["fake_poses"])
def test_remove_one_point_and_one_component(fake_poses: List[Pose]):
component_to_drop = "POSE_WORLD_LANDMARKS"
point_to_drop = "LEFT_KNEE"
for pose in fake_poses:
original_component_names = []
original_points_dict = defaultdict(list)
for component in pose.header.components:
original_component_names.append(component.name)

for point in component.points:
original_points_dict[component.name].append(point)

assert component_to_drop in original_component_names
assert point_to_drop in original_points_dict["POSE_LANDMARKS"]
reduced_pose = pose.remove_components(component_to_drop, {"POSE_LANDMARKS": [point_to_drop]})
new_component_names, new_points_dict = [], defaultdict(list)
new_component_names = []
new_points_dict = defaultdict(list)
for component in reduced_pose.header.components:
new_component_names.append(component.name)

for point in component.points:
new_points_dict[component.name].append(point)


assert component_to_drop not in new_component_names
assert point_to_drop not in new_points_dict["POSE_LANDMARKS"]


@pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"])
Expand Down Expand Up @@ -205,9 +233,3 @@ def test_fake_pose(known_pose_format: KnownPoseFormat):
assert pose.header.num_dims() == pose.body.data.shape[-1]

poses = [fake_pose(25) for _ in range(5)]






131 changes: 123 additions & 8 deletions src/python/tests/pose_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ def _create_pose_header_component(name: str, num_keypoints: int) -> PoseHeaderCo
return component


def _distribute_points_among_components(component_count: int, total_keypoint_count: int):
if component_count <= 0 or total_keypoint_count < component_count + 1:
raise ValueError("Total keypoints must be at least component count+1 (so that 0 can have two), and component count must be positive")

# Step 1: Initialize with required minimum values
keypoint_counts = [2] + [1] * (component_count - 1) # Ensure first is 2, others at least 1

# Step 2: Distribute remaining points
remaining_points = total_keypoint_count - sum(keypoint_counts)
for _ in range(remaining_points):
keypoint_counts[random.randint(0, component_count - 1)] += 1 # Add randomly

return keypoint_counts

def _create_pose_header(width: int, height: int, depth: int, num_components: int, num_keypoints: int) -> PoseHeader:
"""
Create a PoseHeader with given dimensions and components.
Expand All @@ -79,8 +93,10 @@ def _create_pose_header(width: int, height: int, depth: int, num_components: int
"""
dimensions = PoseHeaderDimensions(width=width, height=height, depth=depth)

keypoints_per_component = _distribute_points_among_components(num_components, num_keypoints)

components = [
_create_pose_header_component(name=str(index), num_keypoints=num_keypoints) for index in range(num_components)
_create_pose_header_component(name=str(index), num_keypoints=keypoints_per_component[index]) for index in range(num_components)
]

header = PoseHeader(version=1.0, dimensions=dimensions, components=components)
Expand Down Expand Up @@ -134,6 +150,8 @@ def _create_random_tensorflow_data(frames_min: Optional[int] = None,
return tensor, mask, confidence




def _create_random_numpy_data(frames_min: Optional[int] = None,
frames_max: Optional[int] = None,
num_frames: Optional[int] = None,
Expand Down Expand Up @@ -286,7 +304,7 @@ def _get_random_pose_object_with_tf_posebody(num_keypoints: int, frames_min: int
return Pose(header=header, body=body)


def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min: int = 1, frames_max: int = 10) -> Pose:
def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min: int = 1, frames_max: int = 10, num_components=3) -> Pose:
"""
Creates a random Pose object with Numpy pose body for testing.

Expand All @@ -313,7 +331,7 @@ def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min:

body = NumPyPoseBody(fps=10, data=masked_array, confidence=confidence)

header = _create_pose_header(width=10, height=7, depth=0, num_components=3, num_keypoints=num_keypoints)
header = _create_pose_header(width=10, height=7, depth=0, num_components=num_components, num_keypoints=num_keypoints)

return Pose(header=header, body=body)

Expand All @@ -329,6 +347,96 @@ def test_pose_object_should_be_callable(self):
"""
assert callable(Pose)

def test_pose_remove_components(self):
pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5)
assert pose.body.data.shape[-2] == 5
assert pose.body.data.shape[-1] == 2 # XY dimensions

self.assertEqual(len(pose.header.components), 3)
self.assertEqual(sum(len(c.points) for c in pose.header.components), 5)
self.assertEqual(pose.header.components[0].name, "0")
self.assertEqual(pose.header.components[1].name, "1")
self.assertEqual(pose.header.components[0].points[0], "0_a")
self.assertIn("1_a", pose.header.components[1].points)
self.assertNotIn("1_f", pose.header.components[1].points)
self.assertNotIn("4", pose.header.components)

# test that we can remove a component
component_to_remove = "0"
pose_copy = pose.copy()
self.assertIn(component_to_remove, [c.name for c in pose_copy.header.components])
pose_copy = pose_copy.remove_components(component_to_remove)
self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components])


# Remove a point only
point_to_remove = "0_a"
pose_copy = pose.copy()
self.assertIn(point_to_remove, pose_copy.header.components[0].points)
pose_copy = pose_copy.remove_components([], {point_to_remove[0]:[point_to_remove]})
self.assertNotIn(point_to_remove, pose_copy.header.components[0].points)


# Can we remove two things at once
component_to_remove = "1"
point_to_remove = "2_a"
component_to_remove_point_from = "2"

self.assertIn(component_to_remove, [c.name for c in pose_copy.header.components])
self.assertIn(component_to_remove_point_from, [c.name for c in pose_copy.header.components])
self.assertIn(point_to_remove, pose_copy.header.components[2].points)
pose_copy = pose_copy.remove_components([component_to_remove], {component_to_remove_point_from:[point_to_remove]})
self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components])
self.assertIn(component_to_remove_point_from, [c.name for c in pose_copy.header.components]) # this should still be around

# can we remove a component and a point FROM that component without crashing
component_to_remove = "0"
point_to_remove = "0_a"
pose_copy = pose.copy()
self.assertIn(point_to_remove, pose_copy.header.components[0].points)
pose_copy = pose_copy.remove_components([component_to_remove], {component_to_remove:[point_to_remove]})
self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components])
self.assertNotIn(point_to_remove, pose_copy.header.components[0].points)


# can we "remove" a component that doesn't exist without crashing
component_to_remove = "NOT EXISTING"
pose_copy = pose.copy()
initial_count = len(pose_copy.header.components)
pose_copy = pose_copy.remove_components([component_to_remove])
self.assertEqual(initial_count, len(pose_copy.header.components))




# can we "remove" a point that doesn't exist from a component that does without crashing
point_to_remove = "2_x"
component_to_remove_point_from = "2"
pose_copy = pose.copy()
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)
pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[point_to_remove]})
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)


# can we "remove" an empty list of points
component_to_remove_point_from = "2"
pose_copy = pose.copy()
initial_component_count = len(pose_copy.header.components)
initial_point_count = len(pose_copy.header.components[2].points)
pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[]})
self.assertEqual(initial_component_count, len(pose_copy.header.components))
self.assertEqual(len(pose_copy.header.components[2].points), initial_point_count)


# can we remove a point from a component that doesn't exist
point_to_remove = "2_x"
component_to_remove_point_from = "NOT EXISTING"
pose_copy = pose.copy()
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)
pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[point_to_remove]})
self.assertNotIn(point_to_remove, pose_copy.header.components[2].points)




class TestPoseTensorflowPoseBody(TestCase):
Expand Down Expand Up @@ -475,7 +583,7 @@ def create_pose_and_frame_dropout_uniform(example: tf.Tensor) -> tf.Tensor:
return example

dataset.map(create_pose_and_frame_dropout_uniform)


def test_pose_tf_posebody_copy_creates_deepcopy(self):
pose = _get_random_pose_object_with_tf_posebody(num_keypoints=5)
Expand All @@ -488,7 +596,7 @@ def test_pose_tf_posebody_copy_creates_deepcopy(self):

# Check that pose and pose_copy are not the same object
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")

# Ensure the data tensors are equal but independent
self.assertTrue(tf.reduce_all(pose.body.data == pose_copy.body.data), "Copy's data should match original")

Expand All @@ -515,6 +623,14 @@ class TestPoseNumpyPoseBody(TestCase):
Testcases for Pose objects containing NumPy PoseBody data.
"""

def test_pose_numpy_generated_with_correct_shape(self):
pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5, frames_min=3)

# does the header match the body?
expected_keypoints_count_from_header = sum(len(c.points) for c in pose.header.components)
self.assertEqual(expected_keypoints_count_from_header, pose.body.data.shape[-2])


def test_pose_numpy_posebody_normalize_preserves_shape(self):
"""
Tests if the normalization of Pose object with NumPy PoseBody preserves array shape.
Expand Down Expand Up @@ -593,17 +709,16 @@ def test_pose_torch_posebody_copy_creates_deepcopy(self):
pose = _get_random_pose_object_with_torch_posebody(num_keypoints=5)
self.assertIsInstance(pose.body, TorchPoseBody)
self.assertIsInstance(pose.body.data, TorchMaskedTensor)


pose_copy = pose.copy()
self.assertIsInstance(pose_copy.body, TorchPoseBody)
self.assertIsInstance(pose_copy.body.data, TorchMaskedTensor)

self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
self.assertTrue(pose.body.data.tensor.equal(pose_copy.body.data.tensor), "Copy's data should match original")
self.assertTrue(pose.body.data.mask.equal(pose_copy.body.data.mask), "Copy's mask should match original")

pose.body.data = TorchMaskedTensor(tensor=torch.zeros_like(pose.body.data.tensor),
pose.body.data = TorchMaskedTensor(tensor=torch.zeros_like(pose.body.data.tensor),
mask=torch.ones_like(pose.body.data.mask))


Expand Down