Skip to content

Commit 28341e3

Browse files
committed
deepcopy for header and components also
1 parent 487f92a commit 28341e3

File tree

3 files changed

+32
-7
lines changed

3 files changed

+32
-7
lines changed

src/python/pose_format/pose.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def get_components(self, components: List[str], points: Union[Dict[str, List[str
278278

279279

280280
def copy(self):
281-
return self.__class__(self.header, self.body.copy())
281+
return self.__class__(self.header.copy(), self.body.copy())
282282

283283
def bbox(self):
284284
"""

src/python/pose_format/pose_header.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,15 @@ def __init__(self, name: str, points: List[str], limbs: List[Tuple[int, int]], c
6565

6666
self.relative_limbs = self.get_relative_limbs()
6767

68+
def copy(self) -> 'PoseHeaderComponent':
69+
return PoseHeaderComponent(name = self.name,
70+
points = self.points,
71+
limbs= self.limbs,
72+
colors=self.colors,
73+
point_format = self.format)
74+
6875
@staticmethod
69-
def read(version: float, reader: BufferReader):
76+
def read(version: float, reader: BufferReader) -> 'PoseHeaderComponent':
7077
"""
7178
Reads pose header dimensions from reader (BufferReader).
7279
@@ -182,8 +189,11 @@ def __init__(self, width: int, height: int, depth: int = 0, *args):
182189
self.height = math.ceil(height)
183190
self.depth = math.ceil(depth)
184191

192+
def copy(self) -> 'PoseHeaderDimensions':
193+
return self.__class__(self.width, self.height, self.depth)
194+
185195
@staticmethod
186-
def read(version: float, reader: BufferReader):
196+
def read(version: float, reader: BufferReader) -> 'PoseHeaderDimensions':
187197
"""
188198
Reads and returns a PoseHeaderDimensions object from a buffer reader.
189199
@@ -293,6 +303,13 @@ def __init__(self,
293303
self.components = components
294304
self.is_bbox = is_bbox
295305

306+
def copy(self) -> 'PoseHeader':
307+
return PoseHeader(version=self.version,
308+
dimensions=self.dimensions.copy(),
309+
components=[c.copy() for c in self.components],
310+
is_bbox=self.is_bbox
311+
)
312+
296313
@staticmethod
297314
def read(reader: BufferReader) -> 'PoseHeader':
298315
"""

src/python/tests/pose_test.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def create_pose_and_frame_dropout_uniform(example: tf.Tensor) -> tf.Tensor:
479479

480480
def test_pose_tf_posebody_copy_creates_deepcopy(self):
481481
pose = _get_random_pose_object_with_tf_posebody(num_keypoints=5)
482-
self.assertIsInstance(pose.body, TensorflowPoseBody)
482+
self.assertIsInstance(pose.body, TensorflowPoseBody)
483483
self.assertIsInstance(pose.body.data, TensorflowMaskedTensor)
484484

485485
pose_copy = pose.copy()
@@ -488,7 +488,9 @@ def test_pose_tf_posebody_copy_creates_deepcopy(self):
488488

489489
# Check that pose and pose_copy are not the same object
490490
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
491-
491+
self.assertNotEqual(pose.header, pose_copy.header, "headers should be new objects as well")
492+
self.assertNotEqual(pose.header.components, pose_copy.header.components, "components should be new objects as well")
493+
492494
# Ensure the data tensors are equal but independent
493495
self.assertTrue(tf.reduce_all(pose.body.data == pose_copy.body.data), "Copy's data should match original")
494496

@@ -499,6 +501,9 @@ def test_pose_tf_posebody_copy_creates_deepcopy(self):
499501

500502
# Create another copy and ensure it matches the first copy
501503
pose = pose_copy.copy()
504+
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
505+
self.assertNotEqual(pose.header, pose_copy.header, "headers should be new objects as well")
506+
self.assertNotEqual(pose.header.components, pose_copy.header.components, "Components should be new objects as well")
502507

503508
self.assertTrue(tf.reduce_all(pose.body.data == pose_copy.body.data), "Copy's data should match original again")
504509

@@ -560,8 +565,9 @@ def test_pose_numpy_posebody_copy_creates_deepcopy(self):
560565
pose = _get_random_pose_object_with_numpy_posebody(num_keypoints=5, frames_min=3)
561566

562567
pose_copy = pose.copy()
563-
564568
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
569+
self.assertNotEqual(pose.header, pose_copy.header, "headers should be new objects as well")
570+
self.assertNotEqual(pose.header.components, pose_copy.header.components, "components should be new objects as well")
565571

566572
self.assertTrue(np.array_equal(pose.body.data, pose_copy.body.data), "Copy's data should match original")
567573

@@ -599,7 +605,9 @@ def test_pose_torch_posebody_copy_creates_deepcopy(self):
599605
self.assertIsInstance(pose_copy.body, TorchPoseBody)
600606
self.assertIsInstance(pose_copy.body.data, TorchMaskedTensor)
601607

602-
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
608+
self.assertNotEqual(pose, pose_copy, "Copy of pose should not be 'equal' to original")
609+
self.assertNotEqual(pose.header, pose_copy.header, "headers should be new objects as well")
610+
self.assertNotEqual(pose.header.components, pose_copy.header.components, "components should be new objects as well")
603611
self.assertTrue(pose.body.data.tensor.equal(pose_copy.body.data.tensor), "Copy's data should match original")
604612
self.assertTrue(pose.body.data.mask.equal(pose_copy.body.data.mask), "Copy's mask should match original")
605613

0 commit comments

Comments
 (0)