diff --git a/src/python/pose_format/pose.py b/src/python/pose_format/pose.py index 5031f64..d0480f4 100644 --- a/src/python/pose_format/pose.py +++ b/src/python/pose_format/pose.py @@ -278,7 +278,7 @@ def get_components(self, components: List[str], points: Union[Dict[str, List[str def copy(self): - return self.__class__(self.header, self.body.copy()) + return self.__class__(self.header.copy(), self.body.copy()) def bbox(self): """ diff --git a/src/python/pose_format/pose_header.py b/src/python/pose_format/pose_header.py index 225668c..dde6f60 100644 --- a/src/python/pose_format/pose_header.py +++ b/src/python/pose_format/pose_header.py @@ -65,8 +65,15 @@ def __init__(self, name: str, points: List[str], limbs: List[Tuple[int, int]], c self.relative_limbs = self.get_relative_limbs() + def copy(self) -> 'PoseHeaderComponent': + return PoseHeaderComponent(name = self.name, + points = self.points, + limbs= self.limbs, + colors=self.colors, + point_format = self.format) + @staticmethod - def read(version: float, reader: BufferReader): + def read(version: float, reader: BufferReader) -> 'PoseHeaderComponent': """ Reads pose header dimensions from reader (BufferReader). @@ -182,8 +189,11 @@ def __init__(self, width: int, height: int, depth: int = 0, *args): self.height = math.ceil(height) self.depth = math.ceil(depth) + def copy(self) -> 'PoseHeaderDimensions': + return self.__class__(self.width, self.height, self.depth) + @staticmethod - def read(version: float, reader: BufferReader): + def read(version: float, reader: BufferReader) -> 'PoseHeaderDimensions': """ Reads and returns a PoseHeaderDimensions object from a buffer reader. @@ -293,6 +303,13 @@ def __init__(self, self.components = components self.is_bbox = is_bbox + def copy(self) -> 'PoseHeader': + return PoseHeader(version=self.version, + dimensions=self.dimensions.copy(), + components=[c.copy() for c in self.components], + is_bbox=self.is_bbox + ) + @staticmethod def read(reader: BufferReader) -> 'PoseHeader': """ diff --git a/src/python/tests/pose_test.py b/src/python/tests/pose_test.py index 897a0dc..55419c1 100644 --- a/src/python/tests/pose_test.py +++ b/src/python/tests/pose_test.py @@ -479,7 +479,7 @@ def create_pose_and_frame_dropout_uniform(example: tf.Tensor) -> tf.Tensor: def test_pose_tf_posebody_copy_creates_deepcopy(self): pose = _get_random_pose_object_with_tf_posebody(num_keypoints=5) - self.assertIsInstance(pose.body, TensorflowPoseBody) + self.assertIsInstance(pose.body, TensorflowPoseBody) self.assertIsInstance(pose.body.data, TensorflowMaskedTensor) pose_copy = pose.copy() @@ -488,7 +488,9 @@ def test_pose_tf_posebody_copy_creates_deepcopy(self): # Check that pose and pose_copy are not the same object self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original") - + self.assertNotEqual(pose.header, pose_copy.header, "headers should be new objects as well") + self.assertNotEqual(pose.header.components, pose_copy.header.components, "components should be new objects as well") + # Ensure the data tensors are equal but independent self.assertTrue(tf.reduce_all(pose.body.data == pose_copy.body.data), "Copy's data should match original") @@ -499,6 +501,9 @@ def test_pose_tf_posebody_copy_creates_deepcopy(self): # Create another copy and ensure it matches the first copy pose = pose_copy.copy() + self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original") + self.assertNotEqual(pose.header, pose_copy.header, "headers should be new objects as well") + self.assertNotEqual(pose.header.components, pose_copy.header.components, "Components should be new objects as well") self.assertTrue(tf.reduce_all(pose.body.data == pose_copy.body.data), "Copy's data should match original again") @@ -560,8 +565,9 @@ def test_pose_numpy_posebody_copy_creates_deepcopy(self): pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5, frames_min=3) pose_copy = pose.copy() - self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original") + self.assertNotEqual(pose.header, pose_copy.header, "headers should be new objects as well") + self.assertNotEqual(pose.header.components, pose_copy.header.components, "components should be new objects as well") self.assertTrue(np.array_equal(pose.body.data, pose_copy.body.data), "Copy's data should match original") @@ -599,7 +605,9 @@ def test_pose_torch_posebody_copy_creates_deepcopy(self): self.assertIsInstance(pose_copy.body, TorchPoseBody) self.assertIsInstance(pose_copy.body.data, TorchMaskedTensor) - self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original") + self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original") + self.assertNotEqual(pose.header, pose_copy.header, "headers should be new objects as well") + self.assertNotEqual(pose.header.components, pose_copy.header.components, "components should be new objects as well") self.assertTrue(pose.body.data.tensor.equal(pose_copy.body.data.tensor), "Copy's data should match original") self.assertTrue(pose.body.data.mask.equal(pose_copy.body.data.mask), "Copy's mask should match original")