Skip to content

Commit a5e99c3

Browse files
authored
Trim preprocessor (#25)
* add "needs-trim" test file for testing trim_pose * add spoken-to-signed to reqs for trim_pose * add trim_pose PoseProcessor * test trim_pose PoseProcessor * pylint changes * fixing pytests to deal with new test file * run black * Turns out detect_known_format crashes when given a pose with no components
1 parent 1e6d9f7 commit a5e99c3

12 files changed

+71
-5
lines changed

pose_evaluation/metrics/conftest.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ def _check_range(
5757

5858
@pytest.fixture
5959
def real_pose_files() -> List[Pose]:
60-
test_files_folder = Path("pose_evaluation") / "utils" / "test" / "test_data"
60+
# pose_evaluation/utils/test/test_data/standard_landmarks
61+
test_files_folder = Path("pose_evaluation") / "utils" / "test" / "test_data" / "mediapipe" / "standard_landmarks"
6162
real_pose_files_list = [Pose.read(test_file.read_bytes()) for test_file in test_files_folder.glob("*.pose")]
6263
return real_pose_files_list
64+
65+
66+
@pytest.fixture
67+
def real_refined_landmark_pose_file_paths() -> List[Path]:
68+
test_files_folder = Path("pose_evaluation") / "utils" / "test" / "test_data" / "mediapipe" / "refined_landmarks"
69+
return list(test_files_folder.glob("*.pose"))

pose_evaluation/metrics/pose_processors.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
from tqdm import tqdm
44

55
from pose_format import Pose
6-
from pose_format.utils.generic import pose_hide_legs, reduce_holistic
6+
from pose_format.utils.generic import pose_hide_legs, reduce_holistic, detect_known_pose_format
7+
from spoken_to_signed.gloss_to_pose.concatenate import trim_pose
8+
79
from pose_evaluation.metrics.base import Signature
810
from pose_evaluation.utils.pose_utils import (
911
zero_pad_shorter_poses,
1012
reduce_poses_to_intersection,
1113
)
1214

15+
1316
PosesTransformerFunctionType = Callable[[Iterable[Pose]], List[Pose]]
1417

1518

@@ -117,7 +120,22 @@ def process_pose(self, pose: Pose) -> Pose:
117120
return pose
118121

119122

123+
class TrimMeaninglessFramesPoseProcessor(PoseProcessor):
124+
def __init__(self, start=True, end=True) -> None:
125+
super().__init__(name="trim_pose")
126+
self.start = start
127+
self.end = end
128+
129+
def process_pose(self, pose):
130+
if detect_known_pose_format(pose) == "holistic":
131+
132+
return trim_pose(pose.copy(), start=self.start, end=self.end)
133+
# not supported
134+
return pose
135+
136+
120137
def get_standard_pose_processors( # pylint: disable=too-many-arguments,too-many-positional-arguments
138+
trim_meaningless_frames: bool = True,
121139
normalize_poses: bool = True,
122140
reduce_poses_to_common_components: bool = True,
123141
remove_world_landmarks=True,
@@ -128,6 +146,12 @@ def get_standard_pose_processors( # pylint: disable=too-many-arguments,too-many
128146
) -> List[PoseProcessor]:
129147
pose_processors = []
130148

149+
# remove leading/trailing frames with no hands in frame.
150+
if trim_meaningless_frames:
151+
pose_processors.append(TrimMeaninglessFramesPoseProcessor())
152+
153+
# Note: by default this uses the shoulder joints,
154+
# so it should be BEFORE anything that might remove those, such as reduce poses to common components
131155
if normalize_poses:
132156
pose_processors.append(NormalizePosesProcessor())
133157

pose_evaluation/metrics/test_distance_metric.py

+2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def setUp(self):
7070
distance_measure=AggregatedPowerDistance(order=1, default_distance=0),
7171
# preprocessors that won't crash
7272
pose_preprocessors=get_standard_pose_processors(
73+
trim_meaningless_frames=False, # fake poses have no components, this crashes.
7374
normalize_poses=False,
7475
remove_world_landmarks=False,
7576
remove_legs=False,
@@ -99,6 +100,7 @@ def setUp(self):
99100
distance_measure=AggregatedPowerDistance(order=2, default_distance=self.default_distance),
100101
# preprocessors that won't crash
101102
pose_preprocessors=get_standard_pose_processors(
103+
trim_meaningless_frames=False, # fake poses don't have the necessary components
102104
normalize_poses=False,
103105
remove_world_landmarks=False,
104106
remove_legs=False,

pose_evaluation/metrics/test_dtw_metric.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def setUp(self):
1212
name="DTWPowerDistance",
1313
distance_measure=distance_measure,
1414
pose_preprocessors=get_standard_pose_processors(
15+
trim_meaningless_frames=False, # fake poses have no components, this crashes.
1516
normalize_poses=False, # no shoulders, will crash
1617
remove_world_landmarks=False, # there are none, will crash
1718
reduce_poses_to_common_components=False, # removes all components, there are none in common
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import List
2+
from pathlib import Path
3+
from pose_format import Pose
4+
5+
from pose_evaluation.metrics.pose_processors import TrimMeaninglessFramesPoseProcessor
6+
7+
8+
def test_trim_pose(real_refined_landmark_pose_file_paths: List[Path]):
9+
10+
for pose_path in real_refined_landmark_pose_file_paths:
11+
if "needs-trim" in pose_path.name:
12+
pose = Pose.read(pose_path.read_bytes())
13+
14+
original_frame_count = pose.body.data.shape[0]
15+
16+
processor = TrimMeaninglessFramesPoseProcessor(start=True, end=True)
17+
18+
processed_pose = processor.process_pose(pose)
19+
20+
# not expecting it to edit the original
21+
assert (
22+
pose.body.data.shape[0] == original_frame_count
23+
), f"Original data changed! Frames before: {original_frame_count}. Now: {pose.body.data.shape[0]}"
24+
25+
# should have fewer frames
26+
assert (
27+
processed_pose.body.data.shape[0] < pose.body.data.shape[0]
28+
), f"{pose_path}, {pose.body}, {processed_pose.body}"

pose_evaluation/utils/conftest.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
from pose_evaluation.utils.pose_utils import load_pose_file
1414

1515

16-
utils_test_data_dir = Path(__file__).parent / "test" / "test_data"
16+
utils_standard_mediapipe_landmarks_test_data_dir = (
17+
Path(__file__).parent / "test" / "test_data" / "mediapipe" / "standard_landmarks"
18+
)
1719

1820

1921
@pytest.fixture(scope="function")
2022
def mediapipe_poses_test_data_paths() -> List[Path]:
21-
pose_file_paths = list(utils_test_data_dir.glob("*.pose"))
23+
pose_file_paths = list(utils_standard_mediapipe_landmarks_test_data_dir.glob("*.pose"))
2224
return pose_file_paths
2325

2426

@@ -33,7 +35,7 @@ def mediapipe_poses_test_data(mediapipe_poses_test_data_paths) -> List[Pose]: #
3335

3436
@pytest.fixture
3537
def standard_mediapipe_components_dict() -> Dict[str, List[str]]:
36-
format_json = utils_test_data_dir / "mediapipe_components_and_points.json"
38+
format_json = utils_standard_mediapipe_landmarks_test_data_dir / "mediapipe_components_and_points.json"
3739
with open(format_json, "r", encoding="utf-8") as f:
3840
return json.load(f)
3941

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ dependencies = [
2222
"fastdtw",
2323
# alternative to fastdtw
2424
"dtaidistance",
25+
# so that we can have the "trim_pose" preprocessor
26+
"spoken-to-signed @ git+https://github.com/ZurichNLP/spoken-to-signed-translation.git",
2527
]
2628

2729
[project.optional-dependencies]

0 commit comments

Comments
 (0)