diff --git a/src/python/pose_format/pose_header.py b/src/python/pose_format/pose_header.py index 225668c..fb9f7a1 100644 --- a/src/python/pose_format/pose_header.py +++ b/src/python/pose_format/pose_header.py @@ -1,7 +1,7 @@ import hashlib import math import struct -from typing import BinaryIO, List, Tuple +from typing import BinaryIO, List, Tuple, Optional, Union from .utils.reader import BufferReader, ConstStructs @@ -21,7 +21,7 @@ class PoseNormalizationInfo: Third pose value. Defaults to None. """ - def __init__(self, p1: int, p2: int, p3: int = None): + def __init__(self, p1: int, p2: int, p3: Optional[int] = None): """Initialize a PoseNormalizationInfo instance.""" self.p1 = p1 self.p2 = p2 @@ -66,7 +66,7 @@ def __init__(self, name: str, points: List[str], limbs: List[Tuple[int, int]], c self.relative_limbs = self.get_relative_limbs() @staticmethod - def read(version: float, reader: BufferReader): + def read(version: float, reader: BufferReader) -> 'PoseHeaderComponent': """ Reads pose header dimensions from reader (BufferReader). @@ -183,7 +183,7 @@ def __init__(self, width: int, height: int, depth: int = 0, *args): self.depth = math.ceil(depth) @staticmethod - def read(version: float, reader: BufferReader): + def read(version: float, reader: BufferReader) -> 'PoseHeaderDimensions': """ Reads and returns a PoseHeaderDimensions object from a buffer reader. @@ -293,6 +293,7 @@ def __init__(self, self.components = components self.is_bbox = is_bbox + @staticmethod def read(reader: BufferReader) -> 'PoseHeader': """ @@ -376,9 +377,22 @@ def _get_point_index(self, component: str, point: str): raise ValueError("Couldn't find component") + def get_point_index(self, component: str, point: str) -> int: + """ + Returns the index of a given point within the pose. + + Args: + component (str): The name of the component containing the point. + point (str): The name of the point whose index is to be retrieved. + + Raises: + ValueError: If the specified component or point is not found. + """ + return self._get_point_index(component, point) + def normalization_info(self, p1: Tuple[str, str], p2: Tuple[str, str], p3: Tuple[str, str] = None): """ - Normalizates info for given points. + Normalization info for given points. Parameters ---------- @@ -394,9 +408,9 @@ def normalization_info(self, p1: Tuple[str, str], p2: Tuple[str, str], p3: Tuple PoseNormalizationInfo Normalization information for the points. """ - return PoseNormalizationInfo(p1=self._get_point_index(*p1), - p2=self._get_point_index(*p2), - p3=None if p3 is None else self._get_point_index(*p3)) + return PoseNormalizationInfo(p1=self.get_point_index(*p1), + p2=self.get_point_index(*p2), + p3=None if p3 is None else self.get_point_index(*p3)) def bbox(self): """ diff --git a/src/python/pose_format/utils/conftest.py b/src/python/pose_format/utils/conftest.py index 4f4ae30..3da50ae 100644 --- a/src/python/pose_format/utils/conftest.py +++ b/src/python/pose_format/utils/conftest.py @@ -20,4 +20,4 @@ def fake_poses(request) -> List[Pose]: for i, pose in enumerate(fake_poses_list): for component in pose.header.components: component.name = f"unknown_component_{i}_formerly_{component.name}" - return copy.deepcopy(fake_poses_list) + return [pose.copy() for pose in fake_poses_list] diff --git a/src/python/pose_format/utils/generic.py b/src/python/pose_format/utils/generic.py index 940d43a..653d9ae 100644 --- a/src/python/pose_format/utils/generic.py +++ b/src/python/pose_format/utils/generic.py @@ -1,13 +1,13 @@ -from pathlib import Path from typing import Tuple, Literal, List, Union import copy import numpy as np -from numpy import ma +import numpy.ma as ma from pose_format.pose import Pose from pose_format.numpy import NumPyPoseBody from pose_format.pose_header import PoseHeader, PoseHeaderDimensions, PoseHeaderComponent, PoseNormalizationInfo from pose_format.utils.normalization_3d import PoseNormalizer from pose_format.utils.openpose import OpenPose_Components +from pose_format.utils.openpose import BODY_POINTS as OPENPOSE_BODY_POINTS from pose_format.utils.openpose_135 import OpenPose_Components as OpenPose135_Components # from pose_format.utils.holistic import holistic_components @@ -62,31 +62,55 @@ def normalize_pose_size(pose: Pose, target_width: int = 512): pose.header.dimensions.height = pose.header.dimensions.width = target_width -def pose_hide_legs(pose: Pose): +def pose_hide_legs(pose: Pose, remove: bool = False) -> Pose: + """ + Hide or remove leg components from a pose. + + If `remove` is True, the leg components are removed; otherwise, they are hidden (zeroed out). + """ known_pose_format = detect_known_pose_format(pose) + if known_pose_format == "holistic": - point_names = ["KNEE", "ANKLE", "HEEL", "FOOT_INDEX"] - # pylint: disable=protected-access - points = [ - pose.header._get_point_index("POSE_LANDMARKS", side + "_" + n) - for n in point_names - for side in ["LEFT", "RIGHT"] - ] - pose.body.data[:, :, points, :] = 0 - pose.body.confidence[:, :, points] = 0 + point_names = ["KNEE", "ANKLE", "HEEL", "FOOT_INDEX", "HIP"] + sides = ["LEFT", "RIGHT"] + point_names_to_remove = [f"{side}_{name}" for side in sides for name in point_names] + points_to_remove_dict = { + "POSE_LANDMARKS": point_names_to_remove, + "POSE_WORLD_LANDMARKS": point_names_to_remove, + } + elif known_pose_format == "openpose": - point_names = ["Hip", "Knee", "Ankle", "BigToe", "SmallToe", "Heel"] - # pylint: disable=protected-access - points = [ - pose.header._get_point_index("pose_keypoints_2d", side + n) for n in point_names for side in ["L", "R"] - ] - pose.body.data[:, :, points, :] = 0 - pose.body.confidence[:, :, points] = 0 + words_to_look_for = ["Hip", "Knee", "Ankle", "BigToe", "SmallToe", "Heel"] + point_names_to_remove = [point for point in OPENPOSE_BODY_POINTS + if any(word in point for word in words_to_look_for)] + + # if any of the items in point_ + points_to_remove_dict = {"pose_keypoints_2d": point_names_to_remove} + else: raise NotImplementedError( f"Unsupported pose header schema {known_pose_format} for {pose_hide_legs.__name__}: {pose.header}" ) + if remove: + return pose.remove_components([], points_to_remove_dict) + + # Hide the points instead of removing them + point_indices = [] + for component, points in points_to_remove_dict.items(): + for point_name in points: + try: + point_index = pose.header.get_point_index(component, point_name) + point_indices.append(point_index) + except ValueError: # point not found, maybe removed earlier in other preprocessing steps + pass + + + pose.body.data[:, :, point_indices, :] = 0 + pose.body.confidence[:, :, point_indices] = 0 + + return pose + def pose_shoulders(pose_header: PoseHeader) -> Tuple[Tuple[str, str], Tuple[str, str]]: known_pose_format = detect_known_pose_format(pose_header) @@ -109,14 +133,14 @@ def hands_indexes(pose_header: PoseHeader)-> List[int]: known_pose_format = detect_known_pose_format(pose_header) if known_pose_format == "holistic": return [ - pose_header._get_point_index("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"), - pose_header._get_point_index("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"), + pose_header.get_point_index("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"), + pose_header.get_point_index("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"), ] if known_pose_format == "openpose": return [ - pose_header._get_point_index("hand_left_keypoints_2d", "M_CMC"), - pose_header._get_point_index("hand_right_keypoints_2d", "M_CMC"), + pose_header.get_point_index("hand_left_keypoints_2d", "M_CMC"), + pose_header.get_point_index("hand_right_keypoints_2d", "M_CMC"), ] raise NotImplementedError( f"Unsupported pose header schema {known_pose_format} for {hands_indexes.__name__}: {pose_header}" @@ -148,12 +172,12 @@ def hands_components(pose_header: PoseHeader)-> Tuple[Tuple[str, str], Tuple[str def normalize_component_3d(pose, component_name: str, plane: Tuple[str, str, str], line: Tuple[str, str]): hand_pose = pose.get_components([component_name]) plane_info = hand_pose.header.normalization_info( - p1=(component_name, plane[0]), - p2=(component_name, plane[1]), + p1=(component_name, plane[0]), + p2=(component_name, plane[1]), p3=(component_name, plane[2]) ) line_info = hand_pose.header.normalization_info( - p1=(component_name, line[0]), + p1=(component_name, line[0]), p2=(component_name, line[1]) ) @@ -176,10 +200,11 @@ def normalize_hands_3d(pose: Pose, left_hand=True, right_hand=True): def get_standard_components_for_known_format(known_pose_format: KnownPoseFormat) -> List[PoseHeaderComponent]: if known_pose_format == "holistic": try: + # pylint: disable=import-outside-toplevel import pose_format.utils.holistic as holistic_utils return holistic_utils.holistic_components() except ImportError as e: - raise e + raise e if known_pose_format == "openpose": return OpenPose_Components if known_pose_format == "openpose_135": @@ -191,7 +216,7 @@ def get_standard_components_for_known_format(known_pose_format: KnownPoseFormat) def fake_pose(num_frames: int, fps: int=25, components: Union[List[PoseHeaderComponent],None]=None)->Pose: if components is None: components = copy.deepcopy(OpenPose_Components) # fixes W0102, dangerous default value - + if components[0].format == "XYZC": dimensions = PoseHeaderDimensions(width=1, height=1, depth=1) elif components[0].format == "XYC": @@ -204,7 +229,6 @@ def fake_pose(num_frames: int, fps: int=25, components: Union[List[PoseHeaderCom data = np.random.randn(num_frames, 1, total_points, header.num_dims()) confidence = np.random.randn(num_frames, 1, total_points) masked_data = ma.masked_array(data) - body = NumPyPoseBody(fps=int(fps), data=masked_data, confidence=confidence) @@ -214,9 +238,9 @@ def fake_pose(num_frames: int, fps: int=25, components: Union[List[PoseHeaderCom def get_hand_wrist_index(pose: Pose, hand: str)-> int: known_pose_format = detect_known_pose_format(pose) if known_pose_format == "holistic": - return pose.header._get_point_index(f"{hand.upper()}_HAND_LANDMARKS", "WRIST") + return pose.header.get_point_index(f"{hand.upper()}_HAND_LANDMARKS", "WRIST") if known_pose_format == "openpose": - return pose.header._get_point_index(f"hand_{hand.lower()}_keypoints_2d", "BASE") + return pose.header.get_point_index(f"hand_{hand.lower()}_keypoints_2d", "BASE") raise NotImplementedError( f"Unsupported pose header schema {known_pose_format} for {get_hand_wrist_index.__name__}: {pose.header}" ) @@ -225,9 +249,9 @@ def get_hand_wrist_index(pose: Pose, hand: str)-> int: def get_body_hand_wrist_index(pose: Pose, hand: str)-> int: known_pose_format = detect_known_pose_format(pose) if known_pose_format == "holistic": - return pose.header._get_point_index("POSE_LANDMARKS", f"{hand.upper()}_WRIST") + return pose.header.get_point_index("POSE_LANDMARKS", f"{hand.upper()}_WRIST") if known_pose_format == "openpose": - return pose.header._get_point_index("pose_keypoints_2d", f"{hand.upper()[0]}Wrist") + return pose.header.get_point_index("pose_keypoints_2d", f"{hand.upper()[0]}Wrist") raise NotImplementedError( f"Unsupported pose header schema {known_pose_format} for {get_body_hand_wrist_index.__name__}: {pose.header}" ) @@ -244,7 +268,7 @@ def correct_wrist(pose: Pose, hand: str) -> Pose: body_wrist_conf = pose.body.confidence[:, :, body_wrist_index] point_coordinate_count = wrist.shape[-1] - stacked_conf = np.stack([wrist_conf] * point_coordinate_count, axis=-1) + stacked_conf = np.stack([wrist_conf] * point_coordinate_count, axis=-1) new_wrist_data = ma.where(stacked_conf == 0, body_wrist, wrist) new_wrist_conf = ma.where(wrist_conf == 0, body_wrist_conf, wrist_conf) @@ -263,7 +287,7 @@ def reduce_holistic(pose: Pose) -> Pose: known_pose_format = detect_known_pose_format(pose) if known_pose_format != "holistic": return pose - + # pylint: disable=pointless-string-statement """ # from mediapipe.python.solutions.face_mesh_connections import FACEMESH_CONTOURS # points_set = set([p for p_tup in list(FACEMESH_CONTOURS) for p in p_tup]) diff --git a/src/python/pose_format/utils/generic_test.py b/src/python/pose_format/utils/generic_test.py index 461f55e..e1b5912 100644 --- a/src/python/pose_format/utils/generic_test.py +++ b/src/python/pose_format/utils/generic_test.py @@ -60,19 +60,24 @@ def test_get_component_names(fake_poses: List[Pose], known_pose_format: KnownPos @pytest.mark.parametrize("fake_poses", list(get_args(KnownPoseFormat)), indirect=["fake_poses"]) def test_pose_hide_legs(fake_poses: List[Pose]): for pose in fake_poses: - + pose_copy = pose.copy() orig_nonzeros_count = np.count_nonzero(pose.body.data) detected_format = detect_known_pose_format(pose) if detected_format == "openpose_135": with pytest.raises(NotImplementedError, match="Unsupported pose header schema"): - pose_hide_legs(pose) - return + pose = pose_hide_legs(pose) else: - pose_hide_legs(pose) + pose = pose_hide_legs(pose) new_nonzeros_count = np.count_nonzero(pose.body.data) assert orig_nonzeros_count > new_nonzeros_count + assert len(pose_copy.header.components) == len(pose.header.components) + for c_orig, c_copy in zip(pose.header.components, pose_copy.header.components): + assert len(c_orig.points) == len(c_copy.points) + # what if we remove the legs before hiding them first? It should not crash. + pose = pose_hide_legs(pose, remove=True) + pose = pose_hide_legs(pose, remove=False) @pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"]) @@ -120,12 +125,12 @@ def test_get_hand_wrist_index(fake_poses: List[Pose]): detected_format = detect_known_pose_format(pose) for hand in ["LEFT", "RIGHT"]: if detected_format == "openpose_135": - with pytest.raises(NotImplementedError, match="Unsupported pose header schema"): - index = get_hand_wrist_index(pose, hand) + with pytest.raises(NotImplementedError, match="Unsupported pose header schema"): + _ = get_hand_wrist_index(pose, hand) else: - index = get_hand_wrist_index(pose, hand) + _ = get_hand_wrist_index(pose, hand) - # TODO: what are the expected values? + # TODO: what are the expected values? @pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"]) @@ -135,10 +140,10 @@ def test_get_body_hand_wrist_index(fake_poses: List[Pose]): detected_format = detect_known_pose_format(pose) if detected_format == "openpose_135": with pytest.raises(NotImplementedError, match="Unsupported pose header schema"): - index = get_body_hand_wrist_index(pose, hand) - # TODO: what are the expected values? - else: - index = get_body_hand_wrist_index(pose, hand) + _ = get_body_hand_wrist_index(pose, hand) + # TODO: what are the expected values? + else: + _ = get_body_hand_wrist_index(pose, hand) @@ -153,7 +158,7 @@ def test_correct_wrists(fake_poses: List[Pose]): else: corrected_pose = correct_wrists(pose) assert corrected_pose != pose - assert np.array_equal(corrected_pose.body.data, pose.body.data) is False + assert np.array_equal(corrected_pose.body.data, pose.body.data) is False @pytest.mark.parametrize("fake_poses", ["holistic"], indirect=["fake_poses"]) def test_remove_one_point_and_one_component(fake_poses: List[Pose]): @@ -182,7 +187,50 @@ def test_remove_one_point_and_one_component(fake_poses: List[Pose]): assert component_to_drop not in new_component_names - assert point_to_drop not in new_points_dict["POSE_LANDMARKS"] + assert point_to_drop not in new_points_dict["POSE_LANDMARKS"] + +@pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"]) +def test_pose_remove_legs(fake_poses: List[Pose]): + for pose in fake_poses: + known_pose_format = detect_known_pose_format(pose) + if known_pose_format == "holistic": + points_that_should_be_removed = ["LEFT_KNEE", "LEFT_HEEL", "LEFT_FOOT", "LEFT_TOE", "LEFT_FOOT_INDEX", + "RIGHT_KNEE", "RIGHT_HEEL", "RIGHT_FOOT", "RIGHT_TOE", "RIGHT_FOOT_INDEX",] + c_names = [c.name for c in pose.header.components] + assert "POSE_LANDMARKS" in c_names + pose_landmarks_index = c_names.index("POSE_LANDMARKS") + assert "LEFT_KNEE" in pose.header.components[pose_landmarks_index].points + + + pose_with_legs_removed = pose_hide_legs(pose, remove=True) + assert pose_with_legs_removed != pose + new_c_names = [c.name for c in pose_with_legs_removed.header.components] + assert "POSE_LANDMARKS" in new_c_names + + for component in pose_with_legs_removed.header.components: + point_names = [point.upper() for point in component.points] + for point_name in point_names: + for point_that_should_be_hidden in points_that_should_be_removed: + assert point_that_should_be_hidden not in point_name, f"{component.name}: {point_names}" + + elif known_pose_format == "openpose": + c_names = [c.name for c in pose.header.components] + points_that_should_be_removed = ['LHip', 'RHip', 'MidHip', + 'LKnee', 'RKnee', + 'LAnkle', 'RAnkle', + 'LBigToe', 'RBigToe', + 'LSmallToe', 'RSmallToe', + 'LHeel', 'RHeel'] + component_index = c_names.index("pose_keypoints_2d") + pose_with_legs_removed = pose_hide_legs(pose, remove=True) + + for point_name in points_that_should_be_removed: + assert point_name not in pose_with_legs_removed.header.components[component_index].points, f"{pose_with_legs_removed.header.components[component_index].name},{pose_with_legs_removed.header.components[component_index].points}" + assert point_name in pose.header.components[component_index].points + else: + with pytest.raises(NotImplementedError, match="Unsupported pose header schema"): + pose = pose_hide_legs(pose, remove=True) + @pytest.mark.parametrize("fake_poses", TEST_POSE_FORMATS, indirect=["fake_poses"]) @@ -204,7 +252,7 @@ def test_fake_pose(known_pose_format: KnownPoseFormat): for frame_count in [1, 10, 100]: for fps in [1, 15, 25, 100]: standard_components = get_standard_components_for_known_format(known_pose_format) - + pose = fake_pose(frame_count, fps=fps, components=standard_components) point_formats = [c.format for c in pose.header.components] data_dimension_expected = 0 @@ -215,7 +263,6 @@ def test_fake_pose(known_pose_format: KnownPoseFormat): assert point_format == point_formats[0] data_dimension_expected = len(point_formats[0]) - 1 - detected_format = detect_known_pose_format(pose) @@ -231,5 +278,3 @@ def test_fake_pose(known_pose_format: KnownPoseFormat): assert pose.body.data.shape == (frame_count, 1, pose.header.total_points(), data_dimension_expected) assert pose.body.data.shape[0] == frame_count assert pose.header.num_dims() == pose.body.data.shape[-1] - - poses = [fake_pose(25) for _ in range(5)] diff --git a/src/python/pose_format/utils/pose_converter.py b/src/python/pose_format/utils/pose_converter.py index 04397f1..6912a3c 100644 --- a/src/python/pose_format/utils/pose_converter.py +++ b/src/python/pose_format/utils/pose_converter.py @@ -222,8 +222,8 @@ def convert_pose(pose: Pose, pose_components: List[PoseHeaderComponent]) -> Pose for (c1, p1), (c2, p2) in mapping.items(): p2 = tuple([p2]) if isinstance(p2, str) else p2 try: - p2s = [pose.header._get_point_index(c2, p) for p in list(p2)] - p1_index = pose_header._get_point_index(c1, p1) + p2s = [pose.header.get_point_index(c2, p) for p in list(p2)] + p1_index = pose_header.get_point_index(c1, p1) data[:, :, p1_index, :dims] = pose.body.data[:, :, p2s, :dims].mean(axis=2) conf[:, :, p1_index] = pose.body.confidence[:, :, p2s].mean(axis=2) except Exception as e: diff --git a/src/python/tests/pose_test.py b/src/python/tests/pose_test.py index 1215b3f..019b0d1 100644 --- a/src/python/tests/pose_test.py +++ b/src/python/tests/pose_test.py @@ -347,6 +347,23 @@ def test_pose_object_should_be_callable(self): """ assert callable(Pose) + def test_get_index(self): + pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5) + expected_index = 0 + self.assertEqual(0, pose.header.get_point_index("0", "0_a")) + for component in pose.header.components: + for point in component.points: + self.assertEqual(expected_index, pose.header.get_point_index(component.name, point)) + expected_index +=1 + + with self.assertRaises(ValueError): + pose.header.get_point_index("component that doesn't exist", "") + + with self.assertRaises(ValueError): + pose.header.get_point_index("0", "point that doesn't exist") + + + def test_pose_remove_components(self): pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5) assert pose.body.data.shape[-2] == 5 @@ -367,6 +384,10 @@ def test_pose_remove_components(self): self.assertIn(component_to_remove, [c.name for c in pose_copy.header.components]) pose_copy = pose_copy.remove_components(component_to_remove) self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components]) + self.assertEqual(pose_copy.header.components[0].name, "1") + # quickly check to make sure other components/points weren't removed + self.assertIn("1_a", pose_copy.header.components[0].points) + self.assertEqual(pose_copy.header.components[0].points, pose.header.components[1].points) # Remove a point only @@ -375,9 +396,13 @@ def test_pose_remove_components(self): self.assertIn(point_to_remove, pose_copy.header.components[0].points) pose_copy = pose_copy.remove_components([], {point_to_remove[0]:[point_to_remove]}) self.assertNotIn(point_to_remove, pose_copy.header.components[0].points) + # quickly check to make sure other components/points weren't removed + self.assertIn("1_a", pose_copy.header.components[1].points) + self.assertEqual(pose_copy.header.components[1].points, pose.header.components[1].points) # Can we remove two things at once + pose_copy = pose.copy() component_to_remove = "1" point_to_remove = "2_a" component_to_remove_point_from = "2" @@ -388,6 +413,8 @@ def test_pose_remove_components(self): pose_copy = pose_copy.remove_components([component_to_remove], {component_to_remove_point_from:[point_to_remove]}) self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components]) self.assertIn(component_to_remove_point_from, [c.name for c in pose_copy.header.components]) # this should still be around + self.assertIn("0_a", pose_copy.header.components[0].points) + self.assertEqual(pose_copy.header.components[0].points, pose.header.components[0].points) # should be unaffected # can we remove a component and a point FROM that component without crashing component_to_remove = "0" @@ -397,6 +424,7 @@ def test_pose_remove_components(self): pose_copy = pose_copy.remove_components([component_to_remove], {component_to_remove:[point_to_remove]}) self.assertNotIn(component_to_remove, [c.name for c in pose_copy.header.components]) self.assertNotIn(point_to_remove, pose_copy.header.components[0].points) + self.assertEqual(pose_copy.header.components[0].points, pose.header.components[1].points) # should be unaffected # can we "remove" a component that doesn't exist without crashing @@ -405,6 +433,10 @@ def test_pose_remove_components(self): initial_count = len(pose_copy.header.components) pose_copy = pose_copy.remove_components([component_to_remove]) self.assertEqual(initial_count, len(pose_copy.header.components)) + for c_orig, c_copy in zip(pose.header.components, pose_copy.header.components): + self.assertNotEqual(c_copy, c_orig) # should be a new object... + self.assertEqual(c_copy.name, c_orig.name) # with the same name + self.assertEqual(c_copy.points, c_orig.points) # and the same points @@ -416,6 +448,10 @@ def test_pose_remove_components(self): self.assertNotIn(point_to_remove, pose_copy.header.components[2].points) pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[point_to_remove]}) self.assertNotIn(point_to_remove, pose_copy.header.components[2].points) + for c_orig, c_copy in zip(pose.header.components, pose_copy.header.components): + self.assertNotEqual(c_copy, c_orig) # should be a new object... + self.assertEqual(c_copy.name, c_orig.name) # with the same name + self.assertEqual(c_copy.points, c_orig.points) # and the same points # can we "remove" an empty list of points @@ -426,7 +462,10 @@ def test_pose_remove_components(self): pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[]}) self.assertEqual(initial_component_count, len(pose_copy.header.components)) self.assertEqual(len(pose_copy.header.components[2].points), initial_point_count) - + for c_orig, c_copy in zip(pose.header.components, pose_copy.header.components): + self.assertNotEqual(c_copy, c_orig) # should be a new object... + self.assertEqual(c_copy.name, c_orig.name) # with the same name + self.assertEqual(c_copy.points, c_orig.points) # and the same points # can we remove a point from a component that doesn't exist point_to_remove = "2_x" @@ -435,6 +474,10 @@ def test_pose_remove_components(self): self.assertNotIn(point_to_remove, pose_copy.header.components[2].points) pose_copy = pose_copy.remove_components([], {component_to_remove_point_from:[point_to_remove]}) self.assertNotIn(point_to_remove, pose_copy.header.components[2].points) + for c_orig, c_copy in zip(pose.header.components, pose_copy.header.components): + self.assertNotEqual(c_copy, c_orig) # should be a new object... + self.assertEqual(c_copy.name, c_orig.name) # with the same name + self.assertEqual(c_copy.points, c_orig.points) # and the same points @@ -587,7 +630,7 @@ def create_pose_and_frame_dropout_uniform(example: tf.Tensor) -> tf.Tensor: def test_pose_tf_posebody_copy_creates_deepcopy(self): pose = _get_random_pose_object_with_tf_posebody(num_keypoints=5) - self.assertIsInstance(pose.body, TensorflowPoseBody) + self.assertIsInstance(pose.body, TensorflowPoseBody) self.assertIsInstance(pose.body.data, TensorflowMaskedTensor) pose_copy = pose.copy() @@ -607,6 +650,7 @@ def test_pose_tf_posebody_copy_creates_deepcopy(self): # Create another copy and ensure it matches the first copy pose = pose_copy.copy() + self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original") self.assertTrue(tf.reduce_all(pose.body.data == pose_copy.body.data), "Copy's data should match original again") @@ -676,7 +720,6 @@ def test_pose_numpy_posebody_copy_creates_deepcopy(self): pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5, frames_min=3) pose_copy = pose.copy() - self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original") self.assertTrue(np.array_equal(pose.body.data, pose_copy.body.data), "Copy's data should match original")