Skip to content

Commit aecafa5

Browse files
committed
remove copy_pose in favor of pose.copy() from sign-language-processing/pose#148
1 parent 457c004 commit aecafa5

File tree

3 files changed

+60
-15
lines changed

3 files changed

+60
-15
lines changed
+52-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,55 @@
1+
from typing import Iterable, List, cast, Union
12
from pose_format import Pose
3+
from pose_evaluation.metrics.pose_processors import PoseProcessor
24

3-
from pose_evaluation.metrics.base import BaseMetric
5+
from pose_evaluation.metrics.base import BaseMetric, Signature
46

5-
PoseMetric = BaseMetric[Pose]
7+
8+
class PoseMetricSignature(Signature):
9+
10+
def __init__(self, args: dict):
11+
super().__init__(args)
12+
13+
self._abbreviated.update({"pose_preprocessers": "pre"})
14+
15+
pose_preprocessors = args.get("pose_preprocessers", None)
16+
prep_string = ""
17+
if pose_preprocessors is not None:
18+
prep_string = (
19+
"{" + "|".join([f"{prep}" for prep in pose_preprocessors]) + "}"
20+
)
21+
22+
self.signature_info.update(
23+
{"pose_preprocessers": prep_string if pose_preprocessors else None}
24+
)
25+
26+
27+
class PoseMetric(BaseMetric[Pose]):
28+
_SIGNATURE_TYPE = PoseMetricSignature
29+
30+
def __init__(
31+
self,
32+
name: str = "PoseMetric",
33+
higher_is_better: bool = False,
34+
pose_preprocessors: Union[None, List[PoseProcessor]] = None,
35+
):
36+
37+
super().__init__(name, higher_is_better)
38+
if pose_preprocessors is None:
39+
self.pose_preprocessers = []
40+
else:
41+
self.pose_preprocessers = pose_preprocessors
42+
43+
def score(self, hypothesis: Pose, reference: Pose) -> float:
44+
hypothesis, reference = self.process_poses([hypothesis, reference])
45+
return self.score(hypothesis, reference)
46+
47+
def process_poses(self, poses: Iterable[Pose]) -> List[Pose]:
48+
poses = list(poses)
49+
for preprocessor in self.pose_preprocessers:
50+
preprocessor = cast(PoseProcessor, preprocessor)
51+
poses = preprocessor.process_poses(poses)
52+
return poses
53+
54+
def add_preprocessor(self, processor: PoseProcessor):
55+
self.pose_preprocessers.append(processor)

pose_evaluation/utils/pose_utils.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,6 @@ def pose_remove_legs(pose: Pose) -> Pose:
7171
return pose
7272

7373

74-
# TODO: remove, once https://github.com/sign-language-processing/pose/pull/148 is added to pip version
75-
def copy_pose(pose: Pose) -> Pose:
76-
return pose.get_components([component.name for component in pose.header.components])
77-
78-
7974
def get_face_and_hands_from_pose(pose: Pose) -> Pose:
8075
# based on MediaPipe Holistic format.
8176
components_to_keep = [
@@ -96,7 +91,7 @@ def load_pose_file(pose_path: Path) -> Pose:
9691
def reduce_pose_components_and_points_to_intersection(
9792
poses: Iterable[Pose],
9893
) -> List[Pose]:
99-
poses = [copy_pose(pose) for pose in poses]
94+
poses = [pose.copy() for pose in poses]
10095
component_names_for_each_pose = []
10196
point_dict_for_each_pose = []
10297
for pose in poses:
@@ -130,7 +125,7 @@ def reduce_pose_components_and_points_to_intersection(
130125

131126

132127
def zero_pad_shorter_poses(poses: Iterable[Pose]) -> List[Pose]:
133-
poses = [copy_pose(pose) for pose in poses]
128+
poses = [pose.copy() for pose in poses]
134129
# arrays = [pose.body.data for pose in poses]
135130

136131
# first dimension is frames. Then People, joint-points, XYZ or XY
@@ -150,7 +145,7 @@ def zero_pad_shorter_poses(poses: Iterable[Pose]) -> List[Pose]:
150145

151146

152147
def set_masked_to_origin_position(pose: Pose) -> Pose:
153-
pose = copy_pose(pose)
148+
pose = pose.copy()
154149
# frames, person, keypoint, xyz
155150
data_copy = ma.copy(pose.body.data)
156151
data_copy[data_copy.mask]=0

pose_evaluation/utils/test_pose_utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_remove_specific_landmarks_mediapipe(
7474

7575
def test_pose_copy(mediapipe_poses_test_data: List[Pose]):
7676
for pose in mediapipe_poses_test_data:
77-
copy = copy_pose(pose)
77+
copy = pose.copy()
7878

7979
assert copy != pose # Not the same object
8080
assert (
@@ -117,7 +117,7 @@ def test_reduce_pose_components_to_intersection(
117117
standard_mediapipe_components_dict: Dict[str, List[str]],
118118
):
119119

120-
test_poses_with_one_reduced = [copy_pose(pose) for pose in mediapipe_poses_test_data]
120+
test_poses_with_one_reduced = [pose.copy() for pose in mediapipe_poses_test_data]
121121

122122
pose_with_only_face_and_hands_and_no_wrist = get_face_and_hands_from_pose(
123123
test_poses_with_one_reduced.pop()
@@ -226,7 +226,7 @@ def test_detect_format(
226226

227227
def test_set_masked_to_origin_pos(mediapipe_poses_test_data: List[Pose]):
228228
# Create a copy of the original poses for comparison
229-
originals = [copy_pose(pose) for pose in mediapipe_poses_test_data]
229+
originals = [pose.copy() for pose in mediapipe_poses_test_data]
230230

231231
# Apply the transformation
232232
poses = [set_masked_to_origin_position(pose) for pose in mediapipe_poses_test_data]
@@ -252,15 +252,15 @@ def test_set_masked_to_origin_pos(mediapipe_poses_test_data: List[Pose]):
252252

253253

254254
def test_hide_low_conf(mediapipe_poses_test_data: List[Pose]):
255-
copies = [copy_pose(pose) for pose in mediapipe_poses_test_data]
255+
copies = [pose.copy() for pose in mediapipe_poses_test_data]
256256
for pose, copy in zip(mediapipe_poses_test_data, copies):
257257
pose_hide_low_conf(pose, 1.0)
258258

259259
assert np.array_equal(pose.body.confidence, copy.body.confidence) is False
260260

261261

262262
def test_zero_pad_shorter_poses(mediapipe_poses_test_data: List[Pose]):
263-
copies = [copy_pose(pose) for pose in mediapipe_poses_test_data]
263+
copies = [pose.copy() for pose in mediapipe_poses_test_data]
264264

265265
max_len = max(len(pose.body.data) for pose in mediapipe_poses_test_data)
266266
padded_poses = zero_pad_shorter_poses(mediapipe_poses_test_data)

0 commit comments

Comments
 (0)