Skip to content

Commit 3a34a7a

Browse files
authored
Pose remove legs (#154)
* CDL: minor doc typo fix * Undoing some changes that got mixed in * deepcopy for header and components also * pose_remove_legs utility for removing the leg points from mediapipe and openpose * PR requested changes: removed copy() functions, added get_index * combine hide_legs and remove_legs, and fix error * Remove unintended file * update hide_legs to not crash when trying to hide invalid points * Add some more tests * pose_header.get_point_indexRemove try/catch * iterate over points_to_remove_dict in pose_hide_legs * Adding a few more test updates * a few style fixes for generic utils
1 parent 7cdcdf4 commit 3a34a7a

File tree

6 files changed

+193
-67
lines changed

6 files changed

+193
-67
lines changed

src/python/pose_format/pose_header.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import hashlib
22
import math
33
import struct
4-
from typing import BinaryIO, List, Tuple
4+
from typing import BinaryIO, List, Tuple, Optional, Union
55

66
from .utils.reader import BufferReader, ConstStructs
77

@@ -21,7 +21,7 @@ class PoseNormalizationInfo:
2121
Third pose value. Defaults to None.
2222
"""
2323

24-
def __init__(self, p1: int, p2: int, p3: int = None):
24+
def __init__(self, p1: int, p2: int, p3: Optional[int] = None):
2525
"""Initialize a PoseNormalizationInfo instance."""
2626
self.p1 = p1
2727
self.p2 = p2
@@ -66,7 +66,7 @@ def __init__(self, name: str, points: List[str], limbs: List[Tuple[int, int]], c
6666
self.relative_limbs = self.get_relative_limbs()
6767

6868
@staticmethod
69-
def read(version: float, reader: BufferReader):
69+
def read(version: float, reader: BufferReader) -> 'PoseHeaderComponent':
7070
"""
7171
Reads pose header dimensions from reader (BufferReader).
7272
@@ -183,7 +183,7 @@ def __init__(self, width: int, height: int, depth: int = 0, *args):
183183
self.depth = math.ceil(depth)
184184

185185
@staticmethod
186-
def read(version: float, reader: BufferReader):
186+
def read(version: float, reader: BufferReader) -> 'PoseHeaderDimensions':
187187
"""
188188
Reads and returns a PoseHeaderDimensions object from a buffer reader.
189189
@@ -293,6 +293,7 @@ def __init__(self,
293293
self.components = components
294294
self.is_bbox = is_bbox
295295

296+
296297
@staticmethod
297298
def read(reader: BufferReader) -> 'PoseHeader':
298299
"""
@@ -376,9 +377,22 @@ def _get_point_index(self, component: str, point: str):
376377

377378
raise ValueError("Couldn't find component")
378379

380+
def get_point_index(self, component: str, point: str) -> int:
381+
"""
382+
Returns the index of a given point within the pose.
383+
384+
Args:
385+
component (str): The name of the component containing the point.
386+
point (str): The name of the point whose index is to be retrieved.
387+
388+
Raises:
389+
ValueError: If the specified component or point is not found.
390+
"""
391+
return self._get_point_index(component, point)
392+
379393
def normalization_info(self, p1: Tuple[str, str], p2: Tuple[str, str], p3: Tuple[str, str] = None):
380394
"""
381-
Normalizates info for given points.
395+
Normalization info for given points.
382396
383397
Parameters
384398
----------
@@ -394,9 +408,9 @@ def normalization_info(self, p1: Tuple[str, str], p2: Tuple[str, str], p3: Tuple
394408
PoseNormalizationInfo
395409
Normalization information for the points.
396410
"""
397-
return PoseNormalizationInfo(p1=self._get_point_index(*p1),
398-
p2=self._get_point_index(*p2),
399-
p3=None if p3 is None else self._get_point_index(*p3))
411+
return PoseNormalizationInfo(p1=self.get_point_index(*p1),
412+
p2=self.get_point_index(*p2),
413+
p3=None if p3 is None else self.get_point_index(*p3))
400414

401415
def bbox(self):
402416
"""

src/python/pose_format/utils/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ def fake_poses(request) -> List[Pose]:
2020
for i, pose in enumerate(fake_poses_list):
2121
for component in pose.header.components:
2222
component.name = f"unknown_component_{i}_formerly_{component.name}"
23-
return copy.deepcopy(fake_poses_list)
23+
return [pose.copy() for pose in fake_poses_list]

src/python/pose_format/utils/generic.py

+59-35
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from pathlib import Path
21
from typing import Tuple, Literal, List, Union
32
import copy
43
import numpy as np
5-
from numpy import ma
4+
import numpy.ma as ma
65
from pose_format.pose import Pose
76
from pose_format.numpy import NumPyPoseBody
87
from pose_format.pose_header import PoseHeader, PoseHeaderDimensions, PoseHeaderComponent, PoseNormalizationInfo
98
from pose_format.utils.normalization_3d import PoseNormalizer
109
from pose_format.utils.openpose import OpenPose_Components
10+
from pose_format.utils.openpose import BODY_POINTS as OPENPOSE_BODY_POINTS
1111
from pose_format.utils.openpose_135 import OpenPose_Components as OpenPose135_Components
1212

1313
# from pose_format.utils.holistic import holistic_components
@@ -62,31 +62,55 @@ def normalize_pose_size(pose: Pose, target_width: int = 512):
6262
pose.header.dimensions.height = pose.header.dimensions.width = target_width
6363

6464

65-
def pose_hide_legs(pose: Pose):
65+
def pose_hide_legs(pose: Pose, remove: bool = False) -> Pose:
66+
"""
67+
Hide or remove leg components from a pose.
68+
69+
If `remove` is True, the leg components are removed; otherwise, they are hidden (zeroed out).
70+
"""
6671
known_pose_format = detect_known_pose_format(pose)
72+
6773
if known_pose_format == "holistic":
68-
point_names = ["KNEE", "ANKLE", "HEEL", "FOOT_INDEX"]
69-
# pylint: disable=protected-access
70-
points = [
71-
pose.header._get_point_index("POSE_LANDMARKS", side + "_" + n)
72-
for n in point_names
73-
for side in ["LEFT", "RIGHT"]
74-
]
75-
pose.body.data[:, :, points, :] = 0
76-
pose.body.confidence[:, :, points] = 0
74+
point_names = ["KNEE", "ANKLE", "HEEL", "FOOT_INDEX", "HIP"]
75+
sides = ["LEFT", "RIGHT"]
76+
point_names_to_remove = [f"{side}_{name}" for side in sides for name in point_names]
77+
points_to_remove_dict = {
78+
"POSE_LANDMARKS": point_names_to_remove,
79+
"POSE_WORLD_LANDMARKS": point_names_to_remove,
80+
}
81+
7782
elif known_pose_format == "openpose":
78-
point_names = ["Hip", "Knee", "Ankle", "BigToe", "SmallToe", "Heel"]
79-
# pylint: disable=protected-access
80-
points = [
81-
pose.header._get_point_index("pose_keypoints_2d", side + n) for n in point_names for side in ["L", "R"]
82-
]
83-
pose.body.data[:, :, points, :] = 0
84-
pose.body.confidence[:, :, points] = 0
83+
words_to_look_for = ["Hip", "Knee", "Ankle", "BigToe", "SmallToe", "Heel"]
84+
point_names_to_remove = [point for point in OPENPOSE_BODY_POINTS
85+
if any(word in point for word in words_to_look_for)]
86+
87+
# if any of the items in point_
88+
points_to_remove_dict = {"pose_keypoints_2d": point_names_to_remove}
89+
8590
else:
8691
raise NotImplementedError(
8792
f"Unsupported pose header schema {known_pose_format} for {pose_hide_legs.__name__}: {pose.header}"
8893
)
8994

95+
if remove:
96+
return pose.remove_components([], points_to_remove_dict)
97+
98+
# Hide the points instead of removing them
99+
point_indices = []
100+
for component, points in points_to_remove_dict.items():
101+
for point_name in points:
102+
try:
103+
point_index = pose.header.get_point_index(component, point_name)
104+
point_indices.append(point_index)
105+
except ValueError: # point not found, maybe removed earlier in other preprocessing steps
106+
pass
107+
108+
109+
pose.body.data[:, :, point_indices, :] = 0
110+
pose.body.confidence[:, :, point_indices] = 0
111+
112+
return pose
113+
90114

91115
def pose_shoulders(pose_header: PoseHeader) -> Tuple[Tuple[str, str], Tuple[str, str]]:
92116
known_pose_format = detect_known_pose_format(pose_header)
@@ -109,14 +133,14 @@ def hands_indexes(pose_header: PoseHeader)-> List[int]:
109133
known_pose_format = detect_known_pose_format(pose_header)
110134
if known_pose_format == "holistic":
111135
return [
112-
pose_header._get_point_index("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"),
113-
pose_header._get_point_index("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"),
136+
pose_header.get_point_index("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"),
137+
pose_header.get_point_index("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"),
114138
]
115139

116140
if known_pose_format == "openpose":
117141
return [
118-
pose_header._get_point_index("hand_left_keypoints_2d", "M_CMC"),
119-
pose_header._get_point_index("hand_right_keypoints_2d", "M_CMC"),
142+
pose_header.get_point_index("hand_left_keypoints_2d", "M_CMC"),
143+
pose_header.get_point_index("hand_right_keypoints_2d", "M_CMC"),
120144
]
121145
raise NotImplementedError(
122146
f"Unsupported pose header schema {known_pose_format} for {hands_indexes.__name__}: {pose_header}"
@@ -148,12 +172,12 @@ def hands_components(pose_header: PoseHeader)-> Tuple[Tuple[str, str], Tuple[str
148172
def normalize_component_3d(pose, component_name: str, plane: Tuple[str, str, str], line: Tuple[str, str]):
149173
hand_pose = pose.get_components([component_name])
150174
plane_info = hand_pose.header.normalization_info(
151-
p1=(component_name, plane[0]),
152-
p2=(component_name, plane[1]),
175+
p1=(component_name, plane[0]),
176+
p2=(component_name, plane[1]),
153177
p3=(component_name, plane[2])
154178
)
155179
line_info = hand_pose.header.normalization_info(
156-
p1=(component_name, line[0]),
180+
p1=(component_name, line[0]),
157181
p2=(component_name, line[1])
158182
)
159183

@@ -176,10 +200,11 @@ def normalize_hands_3d(pose: Pose, left_hand=True, right_hand=True):
176200
def get_standard_components_for_known_format(known_pose_format: KnownPoseFormat) -> List[PoseHeaderComponent]:
177201
if known_pose_format == "holistic":
178202
try:
203+
# pylint: disable=import-outside-toplevel
179204
import pose_format.utils.holistic as holistic_utils
180205
return holistic_utils.holistic_components()
181206
except ImportError as e:
182-
raise e
207+
raise e
183208
if known_pose_format == "openpose":
184209
return OpenPose_Components
185210
if known_pose_format == "openpose_135":
@@ -191,7 +216,7 @@ def get_standard_components_for_known_format(known_pose_format: KnownPoseFormat)
191216
def fake_pose(num_frames: int, fps: int=25, components: Union[List[PoseHeaderComponent],None]=None)->Pose:
192217
if components is None:
193218
components = copy.deepcopy(OpenPose_Components) # fixes W0102, dangerous default value
194-
219+
195220
if components[0].format == "XYZC":
196221
dimensions = PoseHeaderDimensions(width=1, height=1, depth=1)
197222
elif components[0].format == "XYC":
@@ -204,7 +229,6 @@ def fake_pose(num_frames: int, fps: int=25, components: Union[List[PoseHeaderCom
204229
data = np.random.randn(num_frames, 1, total_points, header.num_dims())
205230
confidence = np.random.randn(num_frames, 1, total_points)
206231
masked_data = ma.masked_array(data)
207-
208232

209233
body = NumPyPoseBody(fps=int(fps), data=masked_data, confidence=confidence)
210234

@@ -214,9 +238,9 @@ def fake_pose(num_frames: int, fps: int=25, components: Union[List[PoseHeaderCom
214238
def get_hand_wrist_index(pose: Pose, hand: str)-> int:
215239
known_pose_format = detect_known_pose_format(pose)
216240
if known_pose_format == "holistic":
217-
return pose.header._get_point_index(f"{hand.upper()}_HAND_LANDMARKS", "WRIST")
241+
return pose.header.get_point_index(f"{hand.upper()}_HAND_LANDMARKS", "WRIST")
218242
if known_pose_format == "openpose":
219-
return pose.header._get_point_index(f"hand_{hand.lower()}_keypoints_2d", "BASE")
243+
return pose.header.get_point_index(f"hand_{hand.lower()}_keypoints_2d", "BASE")
220244
raise NotImplementedError(
221245
f"Unsupported pose header schema {known_pose_format} for {get_hand_wrist_index.__name__}: {pose.header}"
222246
)
@@ -225,9 +249,9 @@ def get_hand_wrist_index(pose: Pose, hand: str)-> int:
225249
def get_body_hand_wrist_index(pose: Pose, hand: str)-> int:
226250
known_pose_format = detect_known_pose_format(pose)
227251
if known_pose_format == "holistic":
228-
return pose.header._get_point_index("POSE_LANDMARKS", f"{hand.upper()}_WRIST")
252+
return pose.header.get_point_index("POSE_LANDMARKS", f"{hand.upper()}_WRIST")
229253
if known_pose_format == "openpose":
230-
return pose.header._get_point_index("pose_keypoints_2d", f"{hand.upper()[0]}Wrist")
254+
return pose.header.get_point_index("pose_keypoints_2d", f"{hand.upper()[0]}Wrist")
231255
raise NotImplementedError(
232256
f"Unsupported pose header schema {known_pose_format} for {get_body_hand_wrist_index.__name__}: {pose.header}"
233257
)
@@ -244,7 +268,7 @@ def correct_wrist(pose: Pose, hand: str) -> Pose:
244268
body_wrist_conf = pose.body.confidence[:, :, body_wrist_index]
245269

246270
point_coordinate_count = wrist.shape[-1]
247-
stacked_conf = np.stack([wrist_conf] * point_coordinate_count, axis=-1)
271+
stacked_conf = np.stack([wrist_conf] * point_coordinate_count, axis=-1)
248272
new_wrist_data = ma.where(stacked_conf == 0, body_wrist, wrist)
249273
new_wrist_conf = ma.where(wrist_conf == 0, body_wrist_conf, wrist_conf)
250274

@@ -263,7 +287,7 @@ def reduce_holistic(pose: Pose) -> Pose:
263287
known_pose_format = detect_known_pose_format(pose)
264288
if known_pose_format != "holistic":
265289
return pose
266-
290+
# pylint: disable=pointless-string-statement
267291
"""
268292
# from mediapipe.python.solutions.face_mesh_connections import FACEMESH_CONTOURS
269293
# points_set = set([p for p_tup in list(FACEMESH_CONTOURS) for p in p_tup])

0 commit comments

Comments
 (0)