@@ -55,6 +55,20 @@ def _create_pose_header_component(name: str, num_keypoints: int) -> PoseHeaderCo
55
55
return component
56
56
57
57
58
+ def _distribute_points_among_components (component_count : int , total_keypoint_count : int ):
59
+ if component_count <= 0 or total_keypoint_count < component_count + 1 :
60
+ raise ValueError ("Total keypoints must be at least component count+1 (so that 0 can have two), and component count must be positive" )
61
+
62
+ # Step 1: Initialize with required minimum values
63
+ keypoint_counts = [2 ] + [1 ] * (component_count - 1 ) # Ensure first is 2, others at least 1
64
+
65
+ # Step 2: Distribute remaining points
66
+ remaining_points = total_keypoint_count - sum (keypoint_counts )
67
+ for _ in range (remaining_points ):
68
+ keypoint_counts [random .randint (0 , component_count - 1 )] += 1 # Add randomly
69
+
70
+ return keypoint_counts
71
+
58
72
def _create_pose_header (width : int , height : int , depth : int , num_components : int , num_keypoints : int ) -> PoseHeader :
59
73
"""
60
74
Create a PoseHeader with given dimensions and components.
@@ -79,8 +93,10 @@ def _create_pose_header(width: int, height: int, depth: int, num_components: int
79
93
"""
80
94
dimensions = PoseHeaderDimensions (width = width , height = height , depth = depth )
81
95
96
+ keypoints_per_component = _distribute_points_among_components (num_components , num_keypoints )
97
+
82
98
components = [
83
- _create_pose_header_component (name = str (index ), num_keypoints = num_keypoints ) for index in range (num_components )
99
+ _create_pose_header_component (name = str (index ), num_keypoints = keypoints_per_component [ index ] ) for index in range (num_components )
84
100
]
85
101
86
102
header = PoseHeader (version = 1.0 , dimensions = dimensions , components = components )
@@ -134,6 +150,8 @@ def _create_random_tensorflow_data(frames_min: Optional[int] = None,
134
150
return tensor , mask , confidence
135
151
136
152
153
+
154
+
137
155
def _create_random_numpy_data (frames_min : Optional [int ] = None ,
138
156
frames_max : Optional [int ] = None ,
139
157
num_frames : Optional [int ] = None ,
@@ -286,7 +304,7 @@ def _get_random_pose_object_with_tf_posebody(num_keypoints: int, frames_min: int
286
304
return Pose (header = header , body = body )
287
305
288
306
289
- def _get_random_pose_object_with_numpy_posebody (num_keypoints : int , frames_min : int = 1 , frames_max : int = 10 ) -> Pose :
307
+ def _get_random_pose_object_with_numpy_posebody (num_keypoints : int , frames_min : int = 1 , frames_max : int = 10 , num_components = 3 ) -> Pose :
290
308
"""
291
309
Creates a random Pose object with Numpy pose body for testing.
292
310
@@ -313,7 +331,7 @@ def _get_random_pose_object_with_numpy_posebody(num_keypoints: int, frames_min:
313
331
314
332
body = NumPyPoseBody (fps = 10 , data = masked_array , confidence = confidence )
315
333
316
- header = _create_pose_header (width = 10 , height = 7 , depth = 0 , num_components = 3 , num_keypoints = num_keypoints )
334
+ header = _create_pose_header (width = 10 , height = 7 , depth = 0 , num_components = num_components , num_keypoints = num_keypoints )
317
335
318
336
return Pose (header = header , body = body )
319
337
@@ -329,6 +347,96 @@ def test_pose_object_should_be_callable(self):
329
347
"""
330
348
assert callable (Pose )
331
349
350
+ def test_pose_remove_components (self ):
351
+ pose = _get_random_pose_object_with_numpy_posebody (num_keypoints = 5 )
352
+ assert pose .body .data .shape [- 2 ] == 5
353
+ assert pose .body .data .shape [- 1 ] == 2 # XY dimensions
354
+
355
+ self .assertEqual (len (pose .header .components ), 3 )
356
+ self .assertEqual (sum (len (c .points ) for c in pose .header .components ), 5 )
357
+ self .assertEqual (pose .header .components [0 ].name , "0" )
358
+ self .assertEqual (pose .header .components [1 ].name , "1" )
359
+ self .assertEqual (pose .header .components [0 ].points [0 ], "0_a" )
360
+ self .assertIn ("1_a" , pose .header .components [1 ].points )
361
+ self .assertNotIn ("1_f" , pose .header .components [1 ].points )
362
+ self .assertNotIn ("4" , pose .header .components )
363
+
364
+ # test that we can remove a component
365
+ component_to_remove = "0"
366
+ pose_copy = pose .copy ()
367
+ self .assertIn (component_to_remove , [c .name for c in pose_copy .header .components ])
368
+ pose_copy = pose_copy .remove_components (component_to_remove )
369
+ self .assertNotIn (component_to_remove , [c .name for c in pose_copy .header .components ])
370
+
371
+
372
+ # Remove a point only
373
+ point_to_remove = "0_a"
374
+ pose_copy = pose .copy ()
375
+ self .assertIn (point_to_remove , pose_copy .header .components [0 ].points )
376
+ pose_copy = pose_copy .remove_components ([], {point_to_remove [0 ]:[point_to_remove ]})
377
+ self .assertNotIn (point_to_remove , pose_copy .header .components [0 ].points )
378
+
379
+
380
+ # Can we remove two things at once
381
+ component_to_remove = "1"
382
+ point_to_remove = "2_a"
383
+ component_to_remove_point_from = "2"
384
+
385
+ self .assertIn (component_to_remove , [c .name for c in pose_copy .header .components ])
386
+ self .assertIn (component_to_remove_point_from , [c .name for c in pose_copy .header .components ])
387
+ self .assertIn (point_to_remove , pose_copy .header .components [2 ].points )
388
+ pose_copy = pose_copy .remove_components ([component_to_remove ], {component_to_remove_point_from :[point_to_remove ]})
389
+ self .assertNotIn (component_to_remove , [c .name for c in pose_copy .header .components ])
390
+ self .assertIn (component_to_remove_point_from , [c .name for c in pose_copy .header .components ]) # this should still be around
391
+
392
+ # can we remove a component and a point FROM that component without crashing
393
+ component_to_remove = "0"
394
+ point_to_remove = "0_a"
395
+ pose_copy = pose .copy ()
396
+ self .assertIn (point_to_remove , pose_copy .header .components [0 ].points )
397
+ pose_copy = pose_copy .remove_components ([component_to_remove ], {component_to_remove :[point_to_remove ]})
398
+ self .assertNotIn (component_to_remove , [c .name for c in pose_copy .header .components ])
399
+ self .assertNotIn (point_to_remove , pose_copy .header .components [0 ].points )
400
+
401
+
402
+ # can we "remove" a component that doesn't exist without crashing
403
+ component_to_remove = "NOT EXISTING"
404
+ pose_copy = pose .copy ()
405
+ initial_count = len (pose_copy .header .components )
406
+ pose_copy = pose_copy .remove_components ([component_to_remove ])
407
+ self .assertEqual (initial_count , len (pose_copy .header .components ))
408
+
409
+
410
+
411
+
412
+ # can we "remove" a point that doesn't exist from a component that does without crashing
413
+ point_to_remove = "2_x"
414
+ component_to_remove_point_from = "2"
415
+ pose_copy = pose .copy ()
416
+ self .assertNotIn (point_to_remove , pose_copy .header .components [2 ].points )
417
+ pose_copy = pose_copy .remove_components ([], {component_to_remove_point_from :[point_to_remove ]})
418
+ self .assertNotIn (point_to_remove , pose_copy .header .components [2 ].points )
419
+
420
+
421
+ # can we "remove" an empty list of points
422
+ component_to_remove_point_from = "2"
423
+ pose_copy = pose .copy ()
424
+ initial_component_count = len (pose_copy .header .components )
425
+ initial_point_count = len (pose_copy .header .components [2 ].points )
426
+ pose_copy = pose_copy .remove_components ([], {component_to_remove_point_from :[]})
427
+ self .assertEqual (initial_component_count , len (pose_copy .header .components ))
428
+ self .assertEqual (len (pose_copy .header .components [2 ].points ), initial_point_count )
429
+
430
+
431
+ # can we remove a point from a component that doesn't exist
432
+ point_to_remove = "2_x"
433
+ component_to_remove_point_from = "NOT EXISTING"
434
+ pose_copy = pose .copy ()
435
+ self .assertNotIn (point_to_remove , pose_copy .header .components [2 ].points )
436
+ pose_copy = pose_copy .remove_components ([], {component_to_remove_point_from :[point_to_remove ]})
437
+ self .assertNotIn (point_to_remove , pose_copy .header .components [2 ].points )
438
+
439
+
332
440
333
441
334
442
class TestPoseTensorflowPoseBody (TestCase ):
@@ -475,7 +583,7 @@ def create_pose_and_frame_dropout_uniform(example: tf.Tensor) -> tf.Tensor:
475
583
return example
476
584
477
585
dataset .map (create_pose_and_frame_dropout_uniform )
478
-
586
+
479
587
480
588
def test_pose_tf_posebody_copy_creates_deepcopy (self ):
481
589
pose = _get_random_pose_object_with_tf_posebody (num_keypoints = 5 )
@@ -488,7 +596,7 @@ def test_pose_tf_posebody_copy_creates_deepcopy(self):
488
596
489
597
# Check that pose and pose_copy are not the same object
490
598
self .assertNotEqual (pose , pose_copy , "Copy of pose should not be 'equal' to original" )
491
-
599
+
492
600
# Ensure the data tensors are equal but independent
493
601
self .assertTrue (tf .reduce_all (pose .body .data == pose_copy .body .data ), "Copy's data should match original" )
494
602
@@ -515,6 +623,14 @@ class TestPoseNumpyPoseBody(TestCase):
515
623
Testcases for Pose objects containing NumPy PoseBody data.
516
624
"""
517
625
626
+ def test_pose_numpy_generated_with_correct_shape (self ):
627
+ pose = _get_random_pose_object_with_numpy_posebody (num_keypoints = 5 , frames_min = 3 )
628
+
629
+ # does the header match the body?
630
+ expected_keypoints_count_from_header = sum (len (c .points ) for c in pose .header .components )
631
+ self .assertEqual (expected_keypoints_count_from_header , pose .body .data .shape [- 2 ])
632
+
633
+
518
634
def test_pose_numpy_posebody_normalize_preserves_shape (self ):
519
635
"""
520
636
Tests if the normalization of Pose object with NumPy PoseBody preserves array shape.
@@ -593,17 +709,16 @@ def test_pose_torch_posebody_copy_creates_deepcopy(self):
593
709
pose = _get_random_pose_object_with_torch_posebody (num_keypoints = 5 )
594
710
self .assertIsInstance (pose .body , TorchPoseBody )
595
711
self .assertIsInstance (pose .body .data , TorchMaskedTensor )
596
-
597
712
598
713
pose_copy = pose .copy ()
599
714
self .assertIsInstance (pose_copy .body , TorchPoseBody )
600
715
self .assertIsInstance (pose_copy .body .data , TorchMaskedTensor )
601
716
602
- self .assertNotEqual (pose , pose_copy , "Copy of pose should not be 'equal' to original" )
717
+ self .assertNotEqual (pose , pose_copy , "Copy of pose should not be 'equal' to original" )
603
718
self .assertTrue (pose .body .data .tensor .equal (pose_copy .body .data .tensor ), "Copy's data should match original" )
604
719
self .assertTrue (pose .body .data .mask .equal (pose_copy .body .data .mask ), "Copy's mask should match original" )
605
720
606
- pose .body .data = TorchMaskedTensor (tensor = torch .zeros_like (pose .body .data .tensor ),
721
+ pose .body .data = TorchMaskedTensor (tensor = torch .zeros_like (pose .body .data .tensor ),
607
722
mask = torch .ones_like (pose .body .data .mask ))
608
723
609
724
0 commit comments