diff --git a/.gitignore b/.gitignore index 96ee987..4edea12 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,10 @@ .idea/ build/ pose_evaluation.egg-info/ -**/__pycache__/ \ No newline at end of file +**/__pycache__/ +.coverage +.vscode/ +coverage.lcov +**/test_data/ +*.npz +*.code-workspace \ No newline at end of file diff --git a/pose_evaluation/metrics/.gitignore b/pose_evaluation/metrics/.gitignore index cd78447..3598c30 100644 --- a/pose_evaluation/metrics/.gitignore +++ b/pose_evaluation/metrics/.gitignore @@ -1 +1 @@ -temp/ \ No newline at end of file +tests \ No newline at end of file diff --git a/pose_evaluation/metrics/aggregate_distances.py b/pose_evaluation/metrics/aggregate_distances.py new file mode 100644 index 0000000..c14807e --- /dev/null +++ b/pose_evaluation/metrics/aggregate_distances.py @@ -0,0 +1,46 @@ +from typing import Callable, Iterable, Literal +from numpy.ma import MaskedArray +from pose_evaluation.metrics.base import Signature, SignatureMixin + + +AggregationStrategy = Literal["sum", "mean", "max"] +DistanceAggregatorFunction = Callable[[Iterable[float]], float] + + +class DistanceAggregatorSignature(Signature): + def __init__(self, args: dict): + super().__init__(args) + self.update_signature_and_abbr("aggregation_strategy", "s", args) + + +class DistanceAggregator(SignatureMixin): + _SIGNATURE_TYPE = DistanceAggregatorSignature + + def __init__(self, aggregation_strategy: AggregationStrategy) -> None: + self.aggregator_function = get_aggregator_function( + strategy=aggregation_strategy + ) + self.aggregation_strategy = aggregation_strategy + + def __call__(self, distances: Iterable[float]) -> float: + return self.aggregator_function(distances) + + +def create_maskedarrray_and_cast_result_to_float( + callable_to_wrap: Callable, +) -> DistanceAggregatorFunction: + return lambda a: float(callable_to_wrap(MaskedArray(a))) + + +def get_aggregator_function( + strategy: AggregationStrategy, +) -> DistanceAggregatorFunction: + + if strategy == "max": + return create_maskedarrray_and_cast_result_to_float(MaskedArray.max) + + if strategy == "mean": + return create_maskedarrray_and_cast_result_to_float(MaskedArray.mean) + + if strategy == "sum": + return create_maskedarrray_and_cast_result_to_float(MaskedArray.sum) diff --git a/pose_evaluation/metrics/base.py b/pose_evaluation/metrics/base.py index 75f8803..3238e58 100644 --- a/pose_evaluation/metrics/base.py +++ b/pose_evaluation/metrics/base.py @@ -1,11 +1,76 @@ # pylint: disable=undefined-variable from tqdm import tqdm +from typing import Any, Callable +class Signature: + """Represents reproducibility signatures for metrics. Inspired by sacreBLEU + """ + def __init__(self, args: dict): + + self._abbreviated = { + "name":"n", + "higher_is_better":"hb" + } + + self.signature_info = { + "name": args.get("name", None), + "higher_is_better": args.get("higher_is_better", None) + } + + def update(self, key: str, value: Any): + self.signature_info[key] = value + + def update_signature_and_abbr(self, key:str, abbr:str, args:dict): + self._abbreviated.update({ + key: abbr + }) + + self.signature_info.update({ + key: args.get(key, None) + }) + + def format(self, short: bool = False) -> str: + pairs = [] + keys = list(self.signature_info.keys()) + for name in keys: + value = self.signature_info[name] + if value is not None: + # Check for nested signature objects + if hasattr(value, "get_signature"): + + # Wrap nested signatures in brackets + nested_signature = value.get_signature() + if isinstance(nested_signature, Signature): + nested_signature = nested_signature.format(short=short) + value = f"{{{nested_signature}}}" + if isinstance(value, bool): + # Replace True/False with yes/no + value = "yes" if value else "no" + if isinstance(value, Callable): + value = value.__name__ + final_name = self._abbreviated[name] if short else name + pairs.append(f"{final_name}:{value}") + + return "|".join(pairs) + + def __str__(self): + return self.format() + + def __repr__(self): + return self.format() + + +class SignatureMixin: + _SIGNATURE_TYPE = Signature + def get_signature(self) -> Signature: + return self._SIGNATURE_TYPE(self.__dict__) class BaseMetric[T]: """Base class for all metrics.""" + # Each metric should define its Signature class' name here + _SIGNATURE_TYPE = Signature - def __init__(self, name: str, higher_is_better: bool = True): + def __init__(self, name: str, higher_is_better: bool = False): self.name = name self.higher_is_better = higher_is_better @@ -38,3 +103,8 @@ def score_all(self, hypotheses: list[T], references: list[T], progress_bar=True) def __str__(self): return self.name + + def get_signature(self) -> Signature: + return self._SIGNATURE_TYPE(self.__dict__) + + diff --git a/pose_evaluation/metrics/base_pose_metric.py b/pose_evaluation/metrics/base_pose_metric.py index 903bf01..07aceb9 100644 --- a/pose_evaluation/metrics/base_pose_metric.py +++ b/pose_evaluation/metrics/base_pose_metric.py @@ -1,5 +1,55 @@ +from typing import Iterable, List, cast, Union from pose_format import Pose +from pose_evaluation.metrics.pose_processors import PoseProcessor -from pose_evaluation.metrics.base import BaseMetric +from pose_evaluation.metrics.base import BaseMetric, Signature -PoseMetric = BaseMetric[Pose] + +class PoseMetricSignature(Signature): + + def __init__(self, args: dict): + super().__init__(args) + + self._abbreviated.update({"pose_preprocessers": "pre"}) + + pose_preprocessors = args.get("pose_preprocessers", None) + prep_string = "" + if pose_preprocessors is not None: + prep_string = ( + "{" + "|".join([f"{prep}" for prep in pose_preprocessors]) + "}" + ) + + self.signature_info.update( + {"pose_preprocessers": prep_string if pose_preprocessors else None} + ) + + +class PoseMetric(BaseMetric[Pose]): + _SIGNATURE_TYPE = PoseMetricSignature + + def __init__( + self, + name: str = "PoseMetric", + higher_is_better: bool = False, + pose_preprocessors: Union[None, List[PoseProcessor]] = None, + ): + + super().__init__(name, higher_is_better) + if pose_preprocessors is None: + self.pose_preprocessers = [] + else: + self.pose_preprocessers = pose_preprocessors + + def score(self, hypothesis: Pose, reference: Pose) -> float: + hypothesis, reference = self.process_poses([hypothesis, reference]) + return self.score(hypothesis, reference) + + def process_poses(self, poses: Iterable[Pose]) -> List[Pose]: + poses = list(poses) + for preprocessor in self.pose_preprocessers: + preprocessor = cast(PoseProcessor, preprocessor) + poses = preprocessor.process_poses(poses) + return poses + + def add_preprocessor(self, processor: PoseProcessor): + self.pose_preprocessers.append(processor) diff --git a/pose_evaluation/metrics/build_ham2pose_metrics.py b/pose_evaluation/metrics/build_ham2pose_metrics.py new file mode 100644 index 0000000..c42945a --- /dev/null +++ b/pose_evaluation/metrics/build_ham2pose_metrics.py @@ -0,0 +1,94 @@ +from pose_evaluation.metrics.distance_measure import ( + PowerDistance, +) +from pose_evaluation.metrics.distance_metric import DistanceMetric +from pose_evaluation.metrics.ham2pose_distances import ( + Ham2PoseMSEDistance, + Ham2PoseMaskedEuclideanDistance, + Ham2PoseAPEDistance, +) +from pose_evaluation.metrics.mje_metric import MeanJointErrorMetric +from pose_evaluation.metrics.dynamic_time_warping_metric import DTWMetric +from pose_evaluation.metrics.pose_processors import ( + get_standard_pose_processors, +) + +if __name__ == "__main__": + + metrics = [] + MJEMetric = ( + MeanJointErrorMetric() + ) # automatically sets distance measure, zero-padding. + metrics.append(MJEMetric) + + Ham2Pose_DTW_MJE_Metric = DTWMetric( + name="DTW_MJE", + distance_measure=PowerDistance(2, 0), + pose_preprocessors=get_standard_pose_processors(zero_pad_shorter=False), + ) + + Ham2Pose_nDTW_MJE_Metric = DTWMetric( + name="nDTW_MJE", + distance_measure=Ham2PoseMaskedEuclideanDistance(), + pose_preprocessors=get_standard_pose_processors(zero_pad_shorter=False), + ) + + metrics.append(Ham2Pose_DTW_MJE_Metric) + metrics.append(Ham2Pose_nDTW_MJE_Metric) + + # Ham2Pose APE is a PowerDistance. But with a few preprocessors. + # 1. standard preprocessors + # 2. then these: basically this is "zero_pad_shorter", and also setting masked values to zero. + # if len(trajectory1) < len(trajectory2): + # diff = len(trajectory2) - len(trajectory1) + # trajectory1 = np.concatenate((trajectory1, np.zeros((diff, 3)))) + # elif len(trajectory2) < len(trajectory1): + # trajectory2 = np.concatenate((trajectory2, np.zeros((len(trajectory1) - len(trajectory2), 3)))) + # pose1_mask = np.ma.getmask(trajectory1) + # pose2_mask = np.ma.getmask(trajectory2) + # trajectory1[pose1_mask] = 0 + # trajectory1[pose2_mask] = 0 + # trajectory2[pose1_mask] = 0 + # trajectory2[pose2_mask] = 0 + # 3. pointwise aggregate by SUM + # sq_error = np.power(trajectory1 - trajectory2, 2).sum(-1) + # 4. trajectorywise aggregate by MEAN + # np.sqrt(sq_error).mean() + Ham2PoseAPEMetric = DistanceMetric( + name="Ham2PoseAPEMetric", + distance_measure=Ham2PoseAPEDistance(), + pose_preprocessors=get_standard_pose_processors( + zero_pad_shorter=True, set_masked_values_to_origin=True + ), + ) + metrics.append(Ham2PoseAPEMetric) + + # MSE from Ham2Pose is zero-padding, plus set to origin, and then squared error. + # if len(trajectory1) < len(trajectory2): + # diff = len(trajectory2) - len(trajectory1) + # trajectory1 = np.concatenate((trajectory1, np.zeros((diff, 3)))) + # elif len(trajectory2) < len(trajectory1): + # trajectory2 = np.concatenate((trajectory2, np.zeros((len(trajectory1) - len(trajectory2), 3)))) + # pose1_mask = np.ma.getmask(trajectory1) + # pose2_mask = np.ma.getmask(trajectory2) + # trajectory1[pose1_mask] = 0 + # trajectory1[pose2_mask] = 0 + # trajectory2[pose1_mask] = 0 + # trajectory2[pose2_mask] = 0 + # sq_error = np.power(trajectory1 - trajectory2, 2).sum(-1) + # return sq_error.mean() + Ham2PoseMSEMetric = DistanceMetric( + name="Ham2PoseMSEMetric", + distance_measure=Ham2PoseMSEDistance(), + pose_preprocessors=get_standard_pose_processors( + zero_pad_shorter=True, set_masked_values_to_origin=True + ), + ) + metrics.append(Ham2PoseMSEMetric) + + for metric in metrics: + print("*" * 30) + print(f"METRIC: {metric}") + print(metric.get_signature()) + print(metric.get_signature().format(short=True)) + print() diff --git a/pose_evaluation/metrics/conftest.py b/pose_evaluation/metrics/conftest.py index c04f587..2c7ad1f 100644 --- a/pose_evaluation/metrics/conftest.py +++ b/pose_evaluation/metrics/conftest.py @@ -1,9 +1,24 @@ import shutil from pathlib import Path -from typing import Callable, Union +from typing import Callable, Union, Tuple, List import torch import numpy as np +import numpy.ma as ma import pytest +import os +import copy + +from pose_evaluation.metrics.distance_measure import PowerDistance +from pose_evaluation.metrics.distance_metric import DistanceMetric +from pose_evaluation.metrics.dynamic_time_warping_metric import DTWMetric +from pose_evaluation.metrics.ham2pose_distances import Ham2PoseAPEDistance, Ham2PoseMSEDistance, Ham2PoseMaskedEuclideanDistance +from pose_evaluation.metrics.mje_metric import MeanJointErrorMetric +from pose_evaluation.metrics.pose_processors import get_standard_pose_processors +from pose_format import Pose +from pose_format.utils.generic import fake_pose +from pose_format.numpy import NumPyPoseBody + +from pose_evaluation.utils.pose_utils import load_pose_file, copy_pose @pytest.fixture(scope="session", autouse=True) @@ -18,7 +33,7 @@ def clean_test_artifacts(): @pytest.fixture(name="distance_matrix_shape_checker") -def fixture_distance_matrix_shape_checker() -> Callable[[torch.Tensor, torch.Tensor], None]: +def fixture_distance_matrix_shape_checker() -> Callable[[int, int, torch.Tensor], None]: def _check_shape(hyp_count: int, ref_count: int, distance_matrix: torch.Tensor): expected_shape = torch.Size([hyp_count, ref_count]) @@ -48,3 +63,138 @@ def _check_range( ), f"Maximum distance ({max_distance}) is outside the expected range [{min_val}, {max_val}]" return _check_range + + + +utils_test_data_dir = Path(os.path.dirname(os.path.realpath(__file__))).parent/'utils' / 'test'/'test_data' + +@pytest.fixture(scope="function") +def test_mediapipe_poses_paths()->List[Path]: + pose_file_paths = list(utils_test_data_dir.glob("*.pose")) + return pose_file_paths + +@pytest.fixture(scope="function") +def test_mediapipe_poses(test_mediapipe_poses_paths)->List[Pose]: + original_poses = [load_pose_file(pose_path) for pose_path in test_mediapipe_poses_paths] + # I ran into issues where if one test would modify a Pose, it would affect other tests. + # specifically, pose.header.components[0].name = unsupported_component_name in test_detect_format + # this ensures we get a fresh object each time. + return copy.deepcopy(original_poses) + + + + +@pytest.fixture(scope="function") +def test_mediapipe_poses_zeros_and_ones_different_length(test_mediapipe_poses)->List[Pose]: + hypothesis = copy_pose(test_mediapipe_poses[0]) + + reference = copy_pose(test_mediapipe_poses[1]) + + + zeros_data = ma.array(np.zeros_like(hypothesis.body.data), mask=hypothesis.body.data.mask) + hypothesis_body = NumPyPoseBody(fps=hypothesis.body.fps, data=zeros_data, confidence= hypothesis.body.confidence) + + ones_data = ma.array(np.ones_like(reference.body.data), mask=reference.body.data.mask) + reference_body = NumPyPoseBody(fps= reference.body.fps, data=ones_data, confidence=reference.body.confidence) + + + hypothesis = Pose(hypothesis.header, hypothesis_body) + + reference = Pose(reference.header, reference_body) + + + return copy.deepcopy([hypothesis, reference]) + + +@pytest.fixture(scope="function") +def test_mediapipe_poses_zeros_and_ones_same_length(test_mediapipe_poses)->List[Pose]: + hypothesis = copy_pose(test_mediapipe_poses[0]) + reference = copy_pose(test_mediapipe_poses[0]) + + zeros_data = ma.array(np.zeros_like(hypothesis.body.data), mask=hypothesis.body.data.mask) + hypothesis_body = NumPyPoseBody(fps=hypothesis.body.fps, data=zeros_data, confidence= hypothesis.body.confidence) + + ones_data = ma.array(np.ones_like(reference.body.data), mask=reference.body.data.mask) + reference_body = NumPyPoseBody(fps= reference.body.fps, data=ones_data, confidence=reference.body.confidence) + + + hypothesis = Pose(hypothesis.header, hypothesis_body) + reference = Pose(reference.header, reference_body) + + return copy.deepcopy([hypothesis, reference]) + + + +@pytest.fixture +def ham2pose_metrics_for_testing()->List[DistanceMetric]: + metrics =[] + MJEMetric = ( + MeanJointErrorMetric() + ) # automatically sets distance measure, zero-padding. + metrics.append(MJEMetric) + + Ham2Pose_DTW_MJE_Metric = DTWMetric( + name="DTW_MJE", + distance_measure=PowerDistance(2, 0), + pose_preprocessors=get_standard_pose_processors(zero_pad_shorter=False), + ) + + Ham2Pose_nDTW_MJE_Metric = DTWMetric( + name="nDTW_MJE", + distance_measure=Ham2PoseMaskedEuclideanDistance(), + pose_preprocessors=get_standard_pose_processors(zero_pad_shorter=False), + ) + + metrics.append(Ham2Pose_DTW_MJE_Metric) + metrics.append(Ham2Pose_nDTW_MJE_Metric) + + # Ham2Pose APE is a PowerDistance. But with a few preprocessors. + # 1. standard preprocessors + # 2. then these: basically this is "zero_pad_shorter", and also setting masked values to zero. + # if len(trajectory1) < len(trajectory2): + # diff = len(trajectory2) - len(trajectory1) + # trajectory1 = np.concatenate((trajectory1, np.zeros((diff, 3)))) + # elif len(trajectory2) < len(trajectory1): + # trajectory2 = np.concatenate((trajectory2, np.zeros((len(trajectory1) - len(trajectory2), 3)))) + # pose1_mask = np.ma.getmask(trajectory1) + # pose2_mask = np.ma.getmask(trajectory2) + # trajectory1[pose1_mask] = 0 + # trajectory1[pose2_mask] = 0 + # trajectory2[pose1_mask] = 0 + # trajectory2[pose2_mask] = 0 + # 3. pointwise aggregate by SUM + # sq_error = np.power(trajectory1 - trajectory2, 2).sum(-1) + # 4. trajectorywise aggregate by MEAN + # np.sqrt(sq_error).mean() + Ham2PoseAPEMetric = DistanceMetric( + name="Ham2PoseAPEMetric", + distance_measure=Ham2PoseAPEDistance(), + pose_preprocessors=get_standard_pose_processors( + zero_pad_shorter=True, set_masked_values_to_origin=True + ), + ) + metrics.append(Ham2PoseAPEMetric) + + # MSE from Ham2Pose is zero-padding, plus set to origin, and then squared error. + # if len(trajectory1) < len(trajectory2): + # diff = len(trajectory2) - len(trajectory1) + # trajectory1 = np.concatenate((trajectory1, np.zeros((diff, 3)))) + # elif len(trajectory2) < len(trajectory1): + # trajectory2 = np.concatenate((trajectory2, np.zeros((len(trajectory1) - len(trajectory2), 3)))) + # pose1_mask = np.ma.getmask(trajectory1) + # pose2_mask = np.ma.getmask(trajectory2) + # trajectory1[pose1_mask] = 0 + # trajectory1[pose2_mask] = 0 + # trajectory2[pose1_mask] = 0 + # trajectory2[pose2_mask] = 0 + # sq_error = np.power(trajectory1 - trajectory2, 2).sum(-1) + # return sq_error.mean() + Ham2PoseMSEMetric = DistanceMetric( + name="Ham2PoseMSEMetric", + distance_measure=Ham2PoseMSEDistance(), + pose_preprocessors=get_standard_pose_processors( + zero_pad_shorter=True, set_masked_values_to_origin=True + ), + ) + metrics.append(Ham2PoseMSEMetric) + return metrics \ No newline at end of file diff --git a/pose_evaluation/metrics/distance_measure.py b/pose_evaluation/metrics/distance_measure.py new file mode 100644 index 0000000..9d2f3e3 --- /dev/null +++ b/pose_evaluation/metrics/distance_measure.py @@ -0,0 +1,128 @@ +from typing import Iterable, Callable, Optional +import numpy as np +from numpy.ma import MaskedArray +from pose_evaluation.metrics.aggregate_distances import DistanceAggregator +from pose_evaluation.metrics.base import Signature, SignatureMixin + + +PointwiseDistanceFunction = Callable[[MaskedArray, MaskedArray], float] + + +class DistanceMeasureSignature(Signature): + def __init__(self, args: dict): + super().__init__(args) + self.update_signature_and_abbr("name", "n", args) + + +class DistanceMeasure(SignatureMixin): + __SIGNATURE_TYPE = DistanceMeasureSignature + + def get_distance( + self, hyp_data: np.ma.MaskedArray, ref_data: np.ma.MaskedArray + ) -> float: + raise NotImplementedError + + def __call__( + self, hyp_data: np.ma.MaskedArray, ref_data: np.ma.MaskedArray + ) -> float: + return self.get_distance(hyp_data, ref_data) + + +class PowerDistanceSignature(Signature): + def __init__(self, args: dict): + super().__init__(args) + self.update_signature_and_abbr("power", "pow", args) + self.update_signature_and_abbr("default_distance", "def_d", args) + + +class PowerDistance(DistanceMeasure): + _SIGNATURE_TYPE = PowerDistanceSignature + + def __init__(self, power: int = 2, default_distance=0): + self.name = "power_distance" + self.power = power + self.default_distance = default_distance + + def get_distance( + self, hyp_data: np.ma.MaskedArray, ref_data: np.ma.MaskedArray + ) -> float: + return ( + (hyp_data - ref_data) + .pow(self.power) + .abs() + .filled(self.default_distance) + .mean() + ) + + +class AggregatePointWiseThenAggregateTrajectorywiseDistanceSignature( + DistanceMeasureSignature +): + + def __init__(self, args: dict): + super().__init__(args) + self.update_signature_and_abbr("pointwise_distance_function", "pwd", args) + self.update_signature_and_abbr("pointwise_aggregator", "pt_agg", args) + self.update_signature_and_abbr("trajectorywise_aggregator", "tw_agg", args) + + +class AggregatePointWiseThenAggregateTrajectorywiseDistance(DistanceMeasure): + _SIGNATURE_TYPE = AggregatePointWiseThenAggregateTrajectorywiseDistanceSignature + + def __init__( + self, + pointwise_distance_function: Optional[PointwiseDistanceFunction], + pointwise_aggregator: Optional[DistanceAggregator], + trajectorywise_aggregator: DistanceAggregator, + ) -> None: + super().__init__() + self.pointwise_distance_function = pointwise_distance_function + self.pointwise_aggregator = pointwise_aggregator + self.trajectorywise_aggregator = trajectorywise_aggregator + + def pointwise_distance( + self, hyp_point: MaskedArray, ref_point: MaskedArray + ) -> float: + if self.pointwise_distance_function is not None: + return self.pointwise_distance_function(hyp_point, ref_point) + raise NotImplementedError(f"Undefined pointwise distance function for {self}") + + def get_pointwise_distances( + self, hyp_traj: MaskedArray, ref_traj: MaskedArray + ) -> Iterable[float]: + pointwise_distances = [] + for hyp_point, ref_point in zip(hyp_traj, ref_traj): + pointwise_distances.append(self.pointwise_distance(hyp_point, ref_point)) + return pointwise_distances + + def get_trajectory_pair_distance( + self, hyp_traj: MaskedArray, ref_traj: MaskedArray + ) -> float: + pointwise_distances = self.get_pointwise_distances(hyp_traj, ref_traj) + return self.aggregate_pointwise_distances(pointwise_distances) + + def get_trajectory_pair_distances( + self, hyp_trajectories: MaskedArray, ref_trajectories: MaskedArray + ) -> Iterable[float]: + return [ + self.get_trajectory_pair_distance(hyp_traj, ref_traj) + for hyp_traj, ref_traj in zip(hyp_trajectories, ref_trajectories) + ] + + def aggregate_pointwise_distances( + self, pointwise_distances: Iterable[float] + ) -> float: + if self.pointwise_aggregator is not None: + return self.pointwise_aggregator(pointwise_distances) + raise NotImplementedError(f"No pointwise aggregator for {self}") + + def aggregate_trajectory_distances( + self, trajectory_distances: Iterable[float] + ) -> float: + return self.trajectorywise_aggregator(trajectory_distances) + + def get_distance(self, hyp_data: MaskedArray, ref_data: MaskedArray) -> float: + keypoint_trajectory_distances = self.get_trajectory_pair_distances( + hyp_data, ref_data + ) + return self.aggregate_trajectory_distances(keypoint_trajectory_distances) diff --git a/pose_evaluation/metrics/distance_metric.py b/pose_evaluation/metrics/distance_metric.py index c3cbc16..dd79edf 100644 --- a/pose_evaluation/metrics/distance_metric.py +++ b/pose_evaluation/metrics/distance_metric.py @@ -1,38 +1,64 @@ -from typing import Literal - -from numpy import ma +from typing import Literal, Tuple, List, Union, Optional from pose_format import Pose +from pose_evaluation.metrics.base_pose_metric import PoseMetric, PoseMetricSignature +from pose_evaluation.metrics.distance_measure import DistanceMeasure, PowerDistance +from pose_evaluation.metrics.pose_processors import PoseProcessor + +BuildTrajectoryStrategy = Literal["keypoint", "frame"] +TrajectoryAlignmentStrategy = Literal[ + "zero_pad_shorter", "truncate_longer", "by_reference" +] +MaskedKeypointPositionStrategy = Literal[ + "skip_masked", + "return_zero", + "masked_to_origin", + "ref_return_zero_hyp_to_origin", + "undefined", +] + +KeypointPositionType = Union[ + Tuple[float, float, float], Tuple[float, float] +] # XYZ or XY +ValidPointDistanceKinds = Literal["euclidean", "manhattan"] -from pose_evaluation.metrics.base_pose_metric import PoseMetric + +class DistanceMetricSignature(PoseMetricSignature): + def __init__(self, args: dict): + super().__init__(args) + self.update_signature_and_abbr("distance_measure", "dist", args) + self.update_signature_and_abbr("trajectory", "trj", args) class DistanceMetric(PoseMetric): - def __init__(self, kind: Literal["l1", "l2"] = "l2"): - super().__init__(f"DistanceMetric {kind}", higher_is_better=False) - self.kind = kind + """Metrics that compute some sort of distance""" + + _SIGNATURE_TYPE = DistanceMetricSignature + + def __init__( + self, + name="DistanceMetric", + distance_measure: Optional[DistanceMeasure] = None, + pose_preprocessors: None | List[PoseProcessor] = None, + trajectory: BuildTrajectoryStrategy = "keypoint", + ): + super().__init__( + name=name, higher_is_better=False, pose_preprocessors=pose_preprocessors + ) + + if distance_measure is None: + self.distance_measure = PowerDistance() - def score(self, hypothesis: Pose, reference: Pose) -> float: - arrays = [hypothesis.body.data, reference.body.data] - max_length = max(len(array) for array in arrays) - # Pad the shorter array with zeros - for i, array in enumerate(arrays): - if len(array) < max_length: - shape = list(array.shape) - shape[0] = max_length - len(array) - padding_tensor = ma.zeros(shape) - arrays[i] = ma.concatenate([array, padding_tensor], axis=0) - - # Calculate the error - error = arrays[0] - arrays[1] - - # for l2, we need to calculate the error for each point - if self.kind == "l2": - # the last dimension is the 3D coordinates - error = ma.power(error, 2) - error = error.sum(axis=-1) - error = ma.sqrt(error) else: - error = ma.abs(error) + self.distance_measure = distance_measure + + self.trajectory = trajectory + + def score(self, hypothesis: Pose, reference: Pose) -> float: + hypothesis, reference = self.process_poses([hypothesis, reference]) + if self.trajectory == "keypoint": + return self.distance_measure( + hypothesis.body.points_perspective(), + reference.body.points_perspective(), + ) - error = error.filled(0) - return error.sum() + return self.distance_measure(hypothesis.body.data, reference.body.data) diff --git a/pose_evaluation/metrics/dynamic_time_warping_metric.py b/pose_evaluation/metrics/dynamic_time_warping_metric.py new file mode 100644 index 0000000..e860af4 --- /dev/null +++ b/pose_evaluation/metrics/dynamic_time_warping_metric.py @@ -0,0 +1,75 @@ +from typing import Literal, List, Optional + +import numpy as np +from fastdtw import fastdtw + +from pose_format import Pose + +from pose_evaluation.metrics.distance_measure import DistanceMeasure, PowerDistance +from pose_evaluation.metrics.base_pose_metric import PoseMetric, PoseMetricSignature +from pose_evaluation.metrics.pose_processors import PoseProcessor + + +class DTWSignature(PoseMetricSignature): + def __init__(self, args: dict): + super().__init__(args) + + self._abbreviated.update( + { + "radius": "rad", + "distance_measure": "dist", + "trajectory": "traj", + } + ) + + self.signature_info.update( + { + "radius": args.get("radius", None), + "distance_measure": args.get("distance_measure", None), + "trajectory": args.get("trajectory", None), + } + ) + + +class DTWMetric(PoseMetric): + _SIGNATURE_TYPE = DTWSignature + + def __init__( + self, + name: str = "DTWMetric", + radius: int = 1, + distance_measure: Optional[DistanceMeasure] = None, + trajectory: Literal["keypoints", "frames"] = "keypoints", + higher_is_better: bool = False, + pose_preprocessors: None | List[PoseProcessor] = None, + ): + super().__init__(name, higher_is_better, pose_preprocessors) + + self.radius = radius + + if distance_measure is None: + self.distance_measure = PowerDistance() + else: + self.distance_measure = distance_measure + self.trajectory = trajectory + + def score(self, hypothesis: Pose, reference: Pose): + hypothesis, reference = self.process_poses([hypothesis, reference]) + if self.trajectory == "keypoints": + keypoint_trajectory_distances = [] + tensor1 = hypothesis.body.points_perspective() + tensor2 = reference.body.points_perspective() + for keypoint_trajectory1, keypoint_trajectory2 in zip(tensor1, tensor2): + keypoint_trajectory_distances.append( + fastdtw( + keypoint_trajectory1, + keypoint_trajectory2, + radius=self.radius, + dist=self.distance_measure, + ) + ) + return float(np.mean(keypoint_trajectory_distances)) + + tensor1 = hypothesis.body.data + tensor2 = reference.body.data + return fastdtw(tensor1, tensor2, radius=self.radius, dist=self.distance_measure) diff --git a/pose_evaluation/metrics/ham2pose_distances.py b/pose_evaluation/metrics/ham2pose_distances.py new file mode 100644 index 0000000..9c0c533 --- /dev/null +++ b/pose_evaluation/metrics/ham2pose_distances.py @@ -0,0 +1,76 @@ +import numpy as np +from numpy.ma.core import MaskedArray +from scipy.spatial.distance import euclidean +from pose_evaluation.metrics.distance_measure import ( + AggregatePointWiseThenAggregateTrajectorywiseDistance, + DistanceAggregator, +) + + +class Ham2PoseAPEDistance(AggregatePointWiseThenAggregateTrajectorywiseDistance): + def __init__(self) -> None: + super().__init__( + None, + pointwise_aggregator=None, + trajectorywise_aggregator=DistanceAggregator("sum"), + ) + self.name = "ham2pose_ape" + + def get_trajectory_pair_distance( + self, hyp_traj: MaskedArray, ref_traj: MaskedArray + ) -> float: + return ham2pose_ape(hyp_traj, ref_traj) + + +class Ham2PoseMSEDistance(AggregatePointWiseThenAggregateTrajectorywiseDistance): + def __init__(self) -> None: + super().__init__( + None, + pointwise_aggregator=None, + trajectorywise_aggregator=DistanceAggregator("sum"), + ) + self.name = "ham2pose_mse" + + def get_trajectory_pair_distance( + self, hyp_traj: MaskedArray, ref_traj: MaskedArray + ) -> float: + return ham2pose_mse(hyp_traj, ref_traj) + + +class Ham2PoseMaskedEuclideanDistance( + AggregatePointWiseThenAggregateTrajectorywiseDistance +): + + def __init__(self) -> None: + super().__init__( + pointwise_distance_function=ham2pose_masked_euclidean, + pointwise_aggregator=DistanceAggregator("mean"), + trajectorywise_aggregator=DistanceAggregator("mean"), + ) + self.name = "ham2pose_masked_euclidean" + self.mask_strategy = "ref_return_0,hyp_set_to_origin" + self.pointwise_dist = "euclidean" + + # def get_signature(self) -> str: + # return "name:ham2pose_masked_euclidean|mask_strategy:ref_return_0,hyp_set_to_origin|dist:euclidean" + + +def ham2pose_masked_euclidean(hyp_point: MaskedArray, ref_point: MaskedArray) -> float: + if np.ma.is_masked(ref_point): # reference label keypoint is missing + return 0 + elif np.ma.is_masked( + hyp_point + ): # reference label keypoint is not missing, other label keypoint is missing + return euclidean((0, 0, 0), ref_point) / 2 + d = euclidean(hyp_point, ref_point) + return d + + +def ham2pose_ape(hyp_traj, ref_traj): + sq_error = np.power(hyp_traj - ref_traj, 2).sum(-1) + return np.sqrt(sq_error).mean() + + +def ham2pose_mse(hyp_traj, ref_traj): + sq_error = np.power(hyp_traj - ref_traj, 2).sum(-1) + return sq_error.mean() diff --git a/pose_evaluation/metrics/mje_metric.py b/pose_evaluation/metrics/mje_metric.py new file mode 100644 index 0000000..f03b9b4 --- /dev/null +++ b/pose_evaluation/metrics/mje_metric.py @@ -0,0 +1,14 @@ +from pose_evaluation.metrics.distance_measure import PowerDistance +from pose_evaluation.metrics.distance_metric import DistanceMetric +from pose_evaluation.metrics.pose_processors import get_standard_pose_processors + + +class MeanJointErrorMetric(DistanceMetric): + def __init__(self): + pose_preprocessors = get_standard_pose_processors() + super().__init__( + distance_measure=PowerDistance(2), + pose_preprocessors=pose_preprocessors, + trajectory="keypoint", + ) + self.name = "MJE" diff --git a/pose_evaluation/metrics/pose_processors.py b/pose_evaluation/metrics/pose_processors.py new file mode 100644 index 0000000..6b2155d --- /dev/null +++ b/pose_evaluation/metrics/pose_processors.py @@ -0,0 +1,160 @@ +from typing import Any, List, Union, Iterable, Callable +from pose_format import Pose + +from pose_evaluation.metrics.base import SignatureMixin +from pose_evaluation.utils.pose_utils import ( + remove_components, + pose_remove_legs, + get_face_and_hands_from_pose, + reduce_pose_components_and_points_to_intersection, + zero_pad_shorter_poses, + copy_pose, + pose_hide_low_conf, + set_masked_to_origin_position, +) + +PosesTransformerFunctionType = Callable[[Iterable[Pose]], List[Pose]] + + +class PoseProcessor(SignatureMixin): + def __init__(self, name="PoseProcessor") -> None: + self.name = name + + def __call__(self, pose_or_poses: Union[Iterable[Pose], Pose]) -> Any: + if isinstance(pose_or_poses, Iterable): + return self.process_poses(pose_or_poses) + else: + return self.process_pose(pose_or_poses) + + def __repr__(self) -> str: + return self.name + + def __str__(self) -> str: + return self.name + + def process_pose(self, pose: Pose) -> Pose: + return pose + + def process_poses(self, poses: Iterable[Pose]) -> List[Pose]: + return [self.process_pose(pose) for pose in poses] + + +class RemoveComponentsProcessor(PoseProcessor): + def __init__(self, landmarks: List[str]) -> None: + super().__init__(f"remove_landmarks[landmarks{landmarks}]") + self.landmarks = landmarks + + def process_pose(self, pose: Pose) -> Pose: + return remove_components(pose, self.landmarks) + + +class RemoveWorldLandmarksProcessor(RemoveComponentsProcessor): + def __init__(self) -> None: + landmarks = ["POSE_WORLD_LANDMARKS"] + super().__init__(landmarks) + + +class RemoveLegsPosesProcessor(PoseProcessor): + def __init__(self, name="remove_legs") -> None: + super().__init__(name) + + def process_pose(self, pose: Pose) -> Pose: + return pose_remove_legs(pose) + + +class GetFaceAndHandsProcessor(PoseProcessor): + def __init__(self, name="face_and_hands") -> None: + super().__init__(name) + + def process_pose(self, pose: Pose) -> Pose: + return get_face_and_hands_from_pose(pose) + + +class ReducePosesToCommonComponentsProcessor(PoseProcessor): + def __init__(self, name="reduce_pose_components") -> None: + super().__init__(name) + + def process_pose(self, pose: Pose) -> Pose: + return self.process_poses([pose])[0] + + def process_poses(self, poses: Iterable[Pose]) -> List[Pose]: + return reduce_pose_components_and_points_to_intersection(poses) + + +class ZeroPadShorterPosesProcessor(PoseProcessor): + def __init__(self) -> None: + super().__init__(name="zero_pad_shorter_sequence") + + def process_poses(self, poses: Iterable[Pose]) -> List[Pose]: + return zero_pad_shorter_poses(poses) + + +class PadOrTruncateByReferencePosesProcessor(PoseProcessor): + def __init__(self) -> None: + super().__init__(name="by_reference") + + def process_poses(self, poses: Iterable[Pose]) -> List[Pose]: + raise NotImplementedError # TODO + + +class NormalizePosesProcessor(PoseProcessor): + def __init__(self, info=None, scale_factor=1) -> None: + super().__init__(f"normalize_poses[info:{info},scale_factor:{scale_factor}]") + self.info = info + self.scale_factor = scale_factor + + def process_pose(self, pose: Pose) -> Pose: + return pose.normalize(self.info, self.scale_factor) + + +class HideLowConfProcessor(PoseProcessor): + def __init__(self, conf_threshold: float = 0.2) -> None: + + super().__init__(f"hide_low_conf[{conf_threshold}]") + self.conf_threshold = conf_threshold + + def process_pose(self, pose: Pose) -> Pose: + pose = copy_pose(pose) + pose_hide_low_conf(pose, self.conf_threshold) + return pose + + +class SetMaskedValuesToOriginPositionProcessor(PoseProcessor): + def __init__( + self, + ) -> None: + super().__init__(name="set_masked_to_origin") + + def process_pose(self, pose: Pose) -> Pose: + return set_masked_to_origin_position(pose) + + +def get_standard_pose_processors( + normalize_poses: bool = True, + reduce_poses_to_common_components: bool = True, + remove_world_landmarks=True, + remove_legs=True, + zero_pad_shorter=True, + set_masked_values_to_origin=False, +) -> List[PoseProcessor]: + pose_processors = [] + + if normalize_poses: + pose_processors.append(NormalizePosesProcessor()) + + if reduce_poses_to_common_components: + pose_processors.append(ReducePosesToCommonComponentsProcessor()) + + if remove_world_landmarks: + pose_processors.append(RemoveWorldLandmarksProcessor()) + + if remove_legs: + pose_processors.append(RemoveLegsPosesProcessor()) + + if zero_pad_shorter: + pose_processors.append(ZeroPadShorterPosesProcessor()) + + if set_masked_values_to_origin: + pose_processors.append(SetMaskedValuesToOriginPositionProcessor()) + + return pose_processors diff --git a/pose_evaluation/metrics/test_distance_metric.py b/pose_evaluation/metrics/test_distance_metric.py index e1d7d39..0bab116 100644 --- a/pose_evaluation/metrics/test_distance_metric.py +++ b/pose_evaluation/metrics/test_distance_metric.py @@ -1,72 +1,224 @@ import math import unittest +from typing import Tuple, List, get_args import numpy as np +import pytest +import random + from pose_format import Pose +from pose_format.utils.generic import fake_pose, detect_known_pose_format from pose_format.numpy import NumPyPoseBody -from pose_evaluation.metrics.distance_metric import DistanceMetric - - -def get_poses(length1: int, length2: int): - data_tensor = np.full((length1, 3, 4, 3), fill_value=2) - zeros_tensor = np.zeros((length2, 3, 4, 3)) - data_confidence = np.ones(data_tensor.shape[:-1]) - zeros_confidence = np.ones(zeros_tensor.shape[:-1]) - - hypothesis = Pose(header=None, body=NumPyPoseBody(fps=1, data=data_tensor, confidence=data_confidence)) - reference = Pose(header=None, body=NumPyPoseBody(fps=1, data=zeros_tensor, confidence=zeros_confidence)) - return hypothesis, reference - -class TestDistanceMetricGeneric(unittest.TestCase): - def setUp(self): - self.metric = DistanceMetric("l2") - - def test_scores_are_symmetric(self): - hypothesis, reference = get_poses(2, 2) - - score1 = self.metric.score(hypothesis, reference) - # pylint: disable=arguments-out-of-order - score2 = self.metric.score(reference, hypothesis) - self.assertAlmostEqual(score1, score2) - - def test_score_different_length(self): - hypothesis, reference = get_poses(3, 2) +from pose_evaluation.metrics.distance_metric import DistanceMetric, ValidPointDistanceKinds +from pose_evaluation.metrics import dynamic_time_warping_metric +# from pose_evaluation.metrics import ape_metric, mse_metric, ndtw_mje_metric + +# pytest --cov --cov-report lcov . + +DISTANCE_KINDS_TO_CHECK = get_args(ValidPointDistanceKinds) + + +# def get_test_poses(frame_count1: int, frame_count2: int, people_count:int=3, point_count:int=5, fill_value=2, fps=25)-> Tuple[Pose, Pose]: +# # point_coordinate_count = 3 # x, y, z + +# hypothesis = fake_pose(num_frames=frame_count1, fps=fps) +# reference = fake_pose(num_frames=frame_count2, fps=fps) + +# data_tensor = np.full_like(hypothesis.body.data, fill_value=fill_value) +# zeros_tensor = np.zeros_like(reference.body.data) +# data_confidence = np.ones(data_tensor.shape[:-1]) +# zeros_confidence = np.ones(zeros_tensor.shape[:-1]) + +# hypothesis.body.data = data_tensor +# reference.body.data = zeros_tensor + +# # fake_pose_header = fake_pose(1).header + + +# # hypothesis = Pose(header=fake_pose_1.header, body=NumPyPoseBody(fps=fps, data=data_tensor, confidence=data_confidence)) +# # reference = Pose(header=fake_pose_2.header, body=NumPyPoseBody(fps=fps, data=zeros_tensor, confidence=zeros_confidence)) +# return hypothesis, reference + + +# def test_get_test_poses(): +# # hyp, ref = test_pose_pair +# frame_count1 = 20 +# frame_count2 = 30 +# people_count = 1 +# point_count = 137 +# coordinate_count_per_point = 3 +# fill_value = 2 +# hyp, ref = get_test_poses(frame_count1=frame_count1,frame_count2=frame_count2, point_count=point_count, fill_value=fill_value) + +# assert hyp.body.data.shape == (frame_count1, people_count, point_count, coordinate_count_per_point) +# assert ref.body.data.shape == (frame_count2, people_count, point_count, coordinate_count_per_point) +# # assert np.sum(hyp.body.data) == frame_count1*people_count*point_count*coordinate_count_per_point*fill_value +# assert np.sum(ref.body.data) == 0 + +def test_data_validity(test_mediapipe_poses:List[Pose], test_mediapipe_poses_zeros_and_ones_different_length, test_mediapipe_poses_zeros_and_ones_same_length): + assert len(test_mediapipe_poses) == 3 + assert len(test_mediapipe_poses_zeros_and_ones_different_length) == 2 + assert len(test_mediapipe_poses_zeros_and_ones_same_length) == 2 + test_mediapipe_poses.extend(test_mediapipe_poses_zeros_and_ones_same_length) + test_mediapipe_poses.extend(test_mediapipe_poses_zeros_and_ones_different_length) + + for pose in test_mediapipe_poses: + assert np.count_nonzero(np.isnan(pose.body.data)) == 0 + assert pose.header.num_dims() == pose.body.data.shape[-1] + assert pose.body.confidence.shape == pose.body.data.shape[:-1] + + assert detect_known_pose_format(pose) == "holistic" + +def test_data_zeros_and_ones(test_mediapipe_poses_zeros_and_ones_different_length): + hyp, ref = test_mediapipe_poses_zeros_and_ones_different_length[0], test_mediapipe_poses_zeros_and_ones_different_length[1] + assert np.ma.sum(hyp.body.data) == 0 + + for frame_data in ref.body.data: + for person_data in frame_data: + for keypoint_data in person_data: + if np.ma.count_masked(keypoint_data) < len(keypoint_data): + + assert np.ma.sum(keypoint_data) == len(keypoint_data) - np.ma.count_masked(keypoint_data) + + +@pytest.mark.parametrize("metric_name", DISTANCE_KINDS_TO_CHECK) +def test_preprocessing(metric_name: ValidPointDistanceKinds, test_mediapipe_poses:List[Pose]): + metric = DistanceMetric(point_distance_calculation_kind=metric_name) + for pose in test_mediapipe_poses: + assert np.count_nonzero(np.isnan(pose.body.data)) == 0 + poses = metric.process_poses(test_mediapipe_poses) + + + for pose in poses: + data = pose.body.data + assert np.count_nonzero(np.isnan(data)) ==0 + + assert isinstance(data, np.ma.MaskedArray) + assert len(data) == len(poses[0].body.data) + +@pytest.mark.parametrize("metric_name", DISTANCE_KINDS_TO_CHECK) +def test_preprocessing_with_zeros_and_ones_different_length(metric_name: ValidPointDistanceKinds, test_mediapipe_poses_zeros_and_ones_different_length:List[Pose]): + metric = DistanceMetric(point_distance_calculation_kind=metric_name, normalize_poses=False) # normalizing when they're all zeros gives an error + + for pose in test_mediapipe_poses_zeros_and_ones_different_length: + assert np.count_nonzero(np.isnan(pose.body.data)) == 0 + poses = metric.process_poses(test_mediapipe_poses_zeros_and_ones_different_length) + + + for pose in poses: + data = pose.body.data + assert np.count_nonzero(np.isnan(data)) ==0 + + assert isinstance(data, np.ma.MaskedArray) + assert len(data) == len(poses[0].body.data) + + + + + + + +@pytest.mark.parametrize("metric_name", DISTANCE_KINDS_TO_CHECK) +def test_base_distance_metric_scores_equal_length(metric_name:ValidPointDistanceKinds, test_mediapipe_poses_zeros_and_ones_same_length): + + # hypothesis, reference = get_test_poses(2, 3) + # hypothesis, reference = test_mediapipe_poses[0], test_mediapipe_poses[1] + metric = DistanceMetric(point_distance_calculation_kind=metric_name, normalize_poses=False) # gives me nans in this case + fill_value =1 + + hyp, ref = test_mediapipe_poses_zeros_and_ones_same_length[0], test_mediapipe_poses_zeros_and_ones_same_length[1] + coordinates_per_point = hyp.body.data.shape[-1] + point_count = np.prod(hyp.body.confidence.shape) + + assert np.count_nonzero(np.isnan(hyp.body.data)) == 0 + assert np.count_nonzero(np.isnan(ref.body.data)) == 0 + + if metric_name == 'euclidean': + # euclidean difference per point: + # between (2,2,2) and (0,0,0) is 3.4641, + # aka sqrt((2-0)^2 +(2-0)^2 +(2-0^2) + expected_difference_per_point = np.sqrt(fill_value*fill_value*coordinates_per_point) + expected_distance = expected_difference_per_point # it's a mean value, they should all be the same + score = metric.score(hyp, ref) + + + + + elif metric_name == 'manhattan': + + expected_difference_per_point = fill_value * coordinates_per_point + expected_distance = expected_difference_per_point + + score = metric.score(hyp, ref) + # point_count = np.prod(hyp.body.confidence.shape) + assert score == expected_difference_per_point # mean error for every pair of spatial points is the same + assert np.isclose(score, expected_distance) + assert isinstance(score, float) # Check if the score is a float + + +def get_all_subclasses(base_class): + """Recursively discover all subclasses of a given base class.""" + subclasses = set(base_class.__subclasses__()) + for subclass in base_class.__subclasses__(): + subclasses.update(get_all_subclasses(subclass)) + return subclasses + +# @pytest.mark.parametrize('distance_metric_to_test',get_all_subclasses(DistanceMetric)) +# def test_all_distance_metrics_and_kinds(distance_metric_to_test): +# for kind in DISTANCE_KINDS_TO_CHECK: +# metric = distance_metric_to_test(kind) +# assert isinstance(metric, DistanceMetric) + + + +def test_get_subclasses_for_distance_metric(): + distance_metrics= get_all_subclasses(DistanceMetric) + assert len(distance_metrics) > 0 - difference = 6 * np.prod(hypothesis.body.confidence.shape) +def generate_test_cases(base_class, kinds): + """Generate tuples of (metric_class, kind) for parameterization.""" + subclasses = get_all_subclasses(base_class) + return [(subclass, kind) for subclass in subclasses for kind in kinds] - score = self.metric.score(hypothesis, reference) - self.assertIsInstance(score, float) - self.assertAlmostEqual(score, difference) +# Parameterize with (metric_class, kind) +@pytest.mark.parametrize("metric_class,kind", generate_test_cases(DistanceMetric, DISTANCE_KINDS_TO_CHECK)) +def test_distance_metric_calculations(metric_class, kind): + """Test all DistanceMetric subclasses with various 'kinds'.""" + metric = metric_class(kind) + + # if kind == "default": + # result = metric.calculate(3, 7) + # elif kind == "weighted" and hasattr(metric, "calculate"): # Check for additional arguments + # result = metric.calculate(3, 7) # Modify if weighted args are supported + # else: + # pytest.skip(f"{metric_class} does not support kind '{kind}'") + # assert result is not None # Example assertion -class TestDistanceMetricL1(unittest.TestCase): - def setUp(self): - self.metric = DistanceMetric("l1") - def test_score_equal_length(self): - hypothesis, reference = get_poses(2, 2) +def test_distance_metric_invalid_kind(test_mediapipe_poses_zeros_and_ones_same_length): + normalize_poses = False # our test data makes normalize divide by zero + - # calculate what the difference should be - difference = 6 * np.prod(hypothesis.body.confidence.shape) + with pytest.raises(NotImplementedError, match="invalid distance function"): + metric = DistanceMetric(point_distance_calculation_kind="invalid", normalize_poses=normalize_poses) # type: ignore + metric.score(test_mediapipe_poses_zeros_and_ones_same_length[0], test_mediapipe_poses_zeros_and_ones_same_length[1]) - score = self.metric.score(hypothesis, reference) - self.assertIsInstance(score, float) # Check if the score is a float - self.assertAlmostEqual(score, difference) -class TestDistanceMetricL2(unittest.TestCase): - def setUp(self): - self.metric = DistanceMetric("l2") + # what if we do one that is in scipy? + metric = DistanceMetric(point_distance_calculation_kind="chebyshev", normalize_poses=normalize_poses) # type: ignore + metric.score(test_mediapipe_poses_zeros_and_ones_same_length[0], test_mediapipe_poses_zeros_and_ones_same_length[1]) - def test_score_equal_length(self): - hypothesis, reference = get_poses(2, 2) - # calculate what the difference should be - difference = math.sqrt(12) * np.prod(hypothesis.body.confidence.shape) +@pytest.mark.parametrize("metric_class,kind", generate_test_cases(DistanceMetric, DISTANCE_KINDS_TO_CHECK)) +def test_scores_are_symmetric(metric_class, kind, test_mediapipe_poses:List[Pose]): + metric = metric_class(spatial_distance_function_kind=kind) - score = self.metric.score(hypothesis, reference) - self.assertIsInstance(score, float) # Check if the score is a float - self.assertAlmostEqual(score, difference) + # hypothesis, reference = get_test_poses(2, 3) + hypothesis, reference = test_mediapipe_poses[0], test_mediapipe_poses[1] -if __name__ == '__main__': - unittest.main() + score1 = metric.score(hypothesis, reference) + # pylint: disable=arguments-out-of-order + score2 = metric.score(reference, hypothesis) + assert np.isclose(score1, score2) \ No newline at end of file diff --git a/pose_evaluation/metrics/test_embedding_distance_metric.py b/pose_evaluation/metrics/test_embedding_distance_metric.py index ab275c6..fc34ae6 100644 --- a/pose_evaluation/metrics/test_embedding_distance_metric.py +++ b/pose_evaluation/metrics/test_embedding_distance_metric.py @@ -67,7 +67,7 @@ def call_with_both_input_orders_and_do_standard_checks( scoring_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], distance_range_checker, distance_matrix_shape_checker, - expected_shape: Tuple = None, + expected_shape: Tuple|None = None, ): scores, scores2 = call_and_call_with_inputs_swapped(hyps, refs, scoring_function) if expected_shape is not None: @@ -87,7 +87,7 @@ def save_and_plot_distances(distances, matrix_name, num_points, dim): """Helper function to save distance matrix and plot distances.""" distances = distances.cpu() - test_artifacts_dir = Path(__file__).parent / "temp" + test_artifacts_dir = Path(__file__).parent / "tests" # TODO: use a proper temp dir output_path = test_artifacts_dir / f"distance_matrix_{matrix_name}_{num_points}_{dim}D.csv" np.savetxt(output_path, distances.numpy(), delimiter=",", fmt="%.4f") print(f"Distance matrix saved to {output_path}") diff --git a/pose_evaluation/metrics/test_signatures.py b/pose_evaluation/metrics/test_signatures.py new file mode 100644 index 0000000..963ffaf --- /dev/null +++ b/pose_evaluation/metrics/test_signatures.py @@ -0,0 +1,31 @@ +from typing import List + +import pytest + +from pose_evaluation.metrics.base_pose_metric import PoseMetric +from pose_evaluation.metrics.distance_measure import PowerDistance +from pose_evaluation.metrics.distance_metric import DistanceMetric +from pose_evaluation.metrics.dynamic_time_warping_metric import DTWMetric +from pose_evaluation.metrics.ham2pose_distances import Ham2PoseAPEDistance, Ham2PoseMSEDistance, Ham2PoseMaskedEuclideanDistance +from pose_evaluation.metrics.mje_metric import MeanJointErrorMetric +from pose_evaluation.metrics.pose_processors import ( + NormalizePosesProcessor, + get_standard_pose_processors, +) + + + +def test_pose_metric_signature_has_preprocessor_information(): + metric = PoseMetric("PoseMetric", pose_preprocessors=[NormalizePosesProcessor()]) + + assert "pose_preprocessers" in metric.get_signature().format() + assert "pre" in metric.get_signature().format(short=True) + + metric = PoseMetric("PoseMetric") + assert "pose_preprocessers" not in metric.get_signature().format() + assert "pre" not in metric.get_signature().format(short=True) + +def test_pose_metric_signature_has_distance_measure_information(ham2pose_metrics_for_testing:List[DistanceMetric]): + for metric in ham2pose_metrics_for_testing: + assert "distance_measure:{" in metric.get_signature().format(short=False) + assert "dist:{" in metric.get_signature().format(short=True) \ No newline at end of file diff --git a/pose_evaluation/utils/conftest.py b/pose_evaluation/utils/conftest.py new file mode 100644 index 0000000..eaa5708 --- /dev/null +++ b/pose_evaluation/utils/conftest.py @@ -0,0 +1,49 @@ +import os +import json +import copy +import itertools +from pathlib import Path +from typing import List, Dict, Tuple +from pose_format import Pose +import pytest +from pose_evaluation.utils.pose_utils import load_pose_file +from pose_format.utils.generic import fake_pose +from pose_format.utils.openpose_135 import OpenPose_Components as openpose_135_components + + +utils_test_data_dir = Path(os.path.dirname(os.path.realpath(__file__))) / 'test'/'test_data' + +@pytest.fixture(scope="function") +def test_mediapipe_poses_paths()->List[Path]: + pose_file_paths = list(utils_test_data_dir.glob("*.pose")) + return pose_file_paths + +@pytest.fixture(scope="function") +def test_mediapipe_poses(test_mediapipe_poses_paths)->List[Pose]: + original_poses = [load_pose_file(pose_path) for pose_path in test_mediapipe_poses_paths] + # I ran into issues where if one test would modify a Pose, it would affect other tests. + # specifically, pose.header.components[0].name = unsupported_component_name in test_detect_format + # this ensures we get a fresh object each time. + return copy.deepcopy(original_poses) + +# @pytest.fixture +# def pairs_of_identical_test_mediapipe_poses(test_mediapipe_poses)->List[Tuple[Pose, Pose]]: +# poses =[] +# for pose in test_mediapipe_poses: +# poses.append(pose, pose) + + +@pytest.fixture +def standard_mediapipe_components_dict()->Dict[str, List[str]]: + format_json = utils_test_data_dir/"mediapipe_components_and_points.json" + with open(format_json, "r") as f: + return json.load(f) + +@pytest.fixture +def fake_openpose_poses(count:int=3)->List[Pose]: + return [fake_pose(30) for _ in range(count)] + +@pytest.fixture +def fake_openpose_135_poses(count:int=3)->List[Pose]: + return [fake_pose(30, components=openpose_135_components) for _ in range(count)] + diff --git a/pose_evaluation/utils/pose_utils.py b/pose_evaluation/utils/pose_utils.py new file mode 100644 index 0000000..b251fe0 --- /dev/null +++ b/pose_evaluation/utils/pose_utils.py @@ -0,0 +1,282 @@ +from pathlib import Path +from typing import List, Tuple, Dict, Union, Iterable +import numpy as np +from pose_format import Pose +from pose_format.utils.openpose import OpenPose_Components +from pose_format.utils.openpose_135 import OpenPose_Components as OpenPose135_Components +# from pose_format.utils.holistic import holistic_components # creates an error: ImportError: Please install mediapipe with: pip install mediapipe +from collections import defaultdict +from pose_format.utils.generic import pose_normalization_info, pose_hide_legs, fake_pose + + +def pose_remove_world_landmarks(pose: Pose)->Pose: + return remove_components(pose, ["POSE_WORLD_LANDMARKS"]) + +# TODO: remove, and use the one implemented in the latest pose_format +def detect_format(pose: Pose) -> str: + component_names = [c.name for c in pose.header.components] + mediapipe_components = [ + "POSE_LANDMARKS", + "FACE_LANDMARKS", + "LEFT_HAND_LANDMARKS", + "RIGHT_HAND_LANDMARKS", + "POSE_WORLD_LANDMARKS", + ] + + openpose_components = [c.name for c in OpenPose_Components] + openpose_135_components = [c.name for c in OpenPose135_Components] + for component_name in component_names: + if component_name in mediapipe_components: + return "mediapipe" + if component_name in openpose_components: + return "openpose" + if component_name in openpose_135_components: + return "openpose_135" + + raise ValueError( + f"Unknown pose header schema with component names: {component_names}" + ) + +def get_component_names_and_points_dict(pose:Pose)->Tuple[List[str], Dict[str, List[str]]]: + component_names = [] + points_dict = defaultdict(list) + for component in pose.header.components: + component_names.append(component.name) + + for point in component.points: + points_dict[component.name].append(point) + + return component_names, points_dict + +def remove_components( + pose: Pose, components_to_remove: List[str]|str, points_to_remove: List[str]|str|None=None +): + if points_to_remove is None: + points_to_remove = [] + if isinstance(components_to_remove, str): + components_to_remove = [components_to_remove] + if isinstance(points_to_remove, str): + points_to_remove = [points_to_remove] + components_to_keep = [] + points_dict = {} + + for component in pose.header.components: + if component.name not in components_to_remove: + components_to_keep.append(component.name) + points_dict[component.name] = [] + for point in component.points: + if point not in points_to_remove: + points_dict[component.name].append(point) + + return pose.get_components(components_to_keep, points_dict) + + + + +def pose_remove_legs(pose: Pose) -> Pose: + detected_format = detect_format(pose) + if detected_format == "mediapipe": + mediapipe_point_names = ["KNEE", "ANKLE", "HEEL", "FOOT_INDEX"] + mediapipe_sides = ["LEFT", "RIGHT"] + point_names_to_remove = [ + side + "_" + name + for name in mediapipe_point_names + for side in mediapipe_sides + ] + else: + raise NotImplementedError( + f"Remove legs not implemented yet for pose header schema {detected_format}" + ) + + pose = remove_components(pose, [], point_names_to_remove) + return pose + + + +def copy_pose(pose: Pose) -> Pose: + return pose.get_components([component.name for component in pose.header.components]) + + + +def get_face_and_hands_from_pose(pose: Pose) -> Pose: + # based on MediaPipe Holistic format. + components_to_keep = [ + "FACE_LANDMARKS", + "LEFT_HAND_LANDMARKS", + "RIGHT_HAND_LANDMARKS", + ] + return pose.get_components(components_to_keep) + +def load_pose_file(pose_path: Path) -> Pose: + pose_path = Path(pose_path).resolve() + with pose_path.open("rb") as f: + pose = Pose.read(f.read()) + return pose + + +def reduce_pose_components_and_points_to_intersection(poses: Iterable[Pose]) -> List[Pose]: + poses = [copy_pose(pose) for pose in poses] + component_names_for_each_pose = [] + point_dict_for_each_pose = [] + for pose in poses: + names, points_dict = get_component_names_and_points_dict(pose) + component_names_for_each_pose.append(set(names)) + point_dict_for_each_pose.append(points_dict) + + set_of_common_components = list(set.intersection(*component_names_for_each_pose)) + + common_points = {} + for component_name in set_of_common_components: + max_length = 0 + min_length = np.inf + points_for_each_pose = [] + for point_dict in point_dict_for_each_pose: + points_list = point_dict.get(component_name) + if points_list is None: + min_length =0 + max_length = max(max_length, len(points_list)) + min_length = min(min_length, len(points_list)) + points_for_each_pose.append(set(points_list)) + set_of_common_points = list(set.intersection(*points_for_each_pose)) + + if min_length < max_length and min_length>0: + common_points[component_name] = set_of_common_points + + + + poses = [pose.get_components(set_of_common_components, common_points) for pose in poses] + return poses + +def zero_pad_shorter_poses(poses:Iterable[Pose]) -> List[Pose]: + poses = [copy_pose(pose) for pose in poses] + # arrays = [pose.body.data for pose in poses] + + + # first dimension is frames. Then People, joint-points, XYZ or XY + max_frame_count = max(len(pose.body.data) for pose in poses) + # Pad the shorter array with zeros + for pose in poses: + if len(pose.body.data) < max_frame_count: + desired_shape = list(pose.body.data.shape) + desired_shape[0] = max_frame_count - len(pose.body.data) + padding_tensor = np.ma.zeros(desired_shape) + padding_tensor_conf = np.ones(desired_shape[:-1]) + pose.body.data = np.ma.concatenate([pose.body.data, padding_tensor], axis=0) + pose.body.confidence = np.concatenate([pose.body.confidence, padding_tensor_conf]) + return poses + + + +# def preprocess_poses( +# poses: List[Pose], +# normalize_poses: bool = True, +# reduce_poses_to_common_points: bool = False, +# remove_legs: bool = True, +# remove_world_landmarks: bool = False, +# conf_threshold_to_drop_points: None | float = None, +# zero_pad_shorter_pose = True, +# ) -> List[Pose]: +# for pose in poses: +# assert np.count_nonzero(np.isnan(pose.body.data)) == 0 +# # NOTE: this is a lot of arguments. Perhaps a list may be better? +# if reduce_poses_to_common_points: + +# poses = reduce_pose_components_and_points_to_intersection(poses) + +# poses = [ +# preprocess_pose( +# pose, +# normalize_poses=normalize_poses, +# remove_legs=remove_legs, +# remove_world_landmarks=remove_world_landmarks, +# conf_threshold_to_drop_points=conf_threshold_to_drop_points, +# ) +# for pose in poses +# ] +# for pose in poses: +# assert np.count_nonzero(np.isnan(pose.body.data)) == 0 + +# if zero_pad_shorter_pose: +# poses = zero_pad_shorter_poses(poses) +# return poses + + + +def set_masked_to_origin_position(pose:Pose)->Pose: + pose = copy_pose(pose) + # frames, person, keypoint, xyz + + pose.body.data = np.ma.array(pose.body.data.filled(0), mask=False) + + return pose + + +# def pre_align_with_dtw(hyp: Pose, ref:Pose): + +# x = hyp_trajectory +# y = ref_trajectory +# _, path = fastdtw(x.data, y.data) # Use the raw data for DTW computation + +# # Initialize lists for aligned data and masks +# aligned_x_data = [] +# aligned_y_data = [] + +# aligned_x_mask = [] +# aligned_y_mask = [] + +# # Loop through the DTW path +# for xi, yi in path: +# # Append aligned data +# aligned_x_data.append(x.data[xi]) +# aligned_y_data.append(y.data[yi]) + +# # Append aligned masks (directly use .mask) +# aligned_x_mask.append(x.mask[xi]) +# aligned_y_mask.append(y.mask[yi]) + +# # Create aligned masked arrays +# aligned_x = np.ma.array(aligned_x_data, mask=aligned_x_mask) +# aligned_y = np.ma.array(aligned_y_data, mask=aligned_y_mask) +# return aligned_x, aligned_y + +# def preprocess_pose( +# pose: Pose, +# normalize_poses: bool = True, +# remove_legs: bool = True, +# remove_world_landmarks: bool = False, +# conf_threshold_to_drop_points: None | float = None, +# ) -> Pose: +# assert np.count_nonzero(np.isnan(pose.body.data)) == 0 +# pose = copy_pose(pose) +# if normalize_poses: +# # note: latest version (not yet released) does it automatically +# pose = pose.normalize(pose_normalization_info(pose.header)) +# # TODO: https://github.com/sign-language-processing/pose/issues/146 + +# # Drop legs +# if remove_legs: +# try: +# pose = pose_remove_legs(pose) +# except NotImplementedError as e: +# print(f"Could not remove legs: {e}") +# # raise Warning(f"Could not remove legs: {e}") + +# # not used, typically. +# if remove_world_landmarks: +# pose = pose_remove_world_landmarks(pose) +# assert np.count_nonzero(np.isnan(pose.body.data)) == 0 + +# # hide low conf +# if conf_threshold_to_drop_points is not None: +# pose_hide_low_conf(pose, confidence_threshold=conf_threshold_to_drop_points) +# assert np.count_nonzero(np.isnan(pose.body.data)) == 0 + +# return pose + + +def pose_hide_low_conf(pose: Pose, confidence_threshold: float = 0.2) -> None: + mask = pose.body.confidence <= confidence_threshold + pose.body.confidence[mask] = 0 + stacked_confidence = np.stack([mask, mask, mask], axis=3) + masked_data = np.ma.masked_array(pose.body.data, mask=stacked_confidence) + pose.body.data = masked_data diff --git a/pose_evaluation/utils/test/test_data/colin_20240904_12_20_56-SAD.pose b/pose_evaluation/utils/test/test_data/colin_20240904_12_20_56-SAD.pose new file mode 100644 index 0000000..f00fc7f Binary files /dev/null and b/pose_evaluation/utils/test/test_data/colin_20240904_12_20_56-SAD.pose differ diff --git a/pose_evaluation/utils/test/test_data/colin_20240904_12_24_12-HOUSE.pose b/pose_evaluation/utils/test/test_data/colin_20240904_12_24_12-HOUSE.pose new file mode 100644 index 0000000..8f74698 Binary files /dev/null and b/pose_evaluation/utils/test/test_data/colin_20240904_12_24_12-HOUSE.pose differ diff --git a/pose_evaluation/utils/test/test_data/colin_20240904_12_56_42-HOUSE.pose b/pose_evaluation/utils/test/test_data/colin_20240904_12_56_42-HOUSE.pose new file mode 100644 index 0000000..67d12e3 Binary files /dev/null and b/pose_evaluation/utils/test/test_data/colin_20240904_12_56_42-HOUSE.pose differ diff --git a/pose_evaluation/utils/test/test_data/mediapipe_components_and_points.json b/pose_evaluation/utils/test/test_data/mediapipe_components_and_points.json new file mode 100644 index 0000000..d42dddc --- /dev/null +++ b/pose_evaluation/utils/test/test_data/mediapipe_components_and_points.json @@ -0,0 +1,588 @@ +{ + "POSE_LANDMARKS": [ + "NOSE", + "LEFT_EYE_INNER", + "LEFT_EYE", + "LEFT_EYE_OUTER", + "RIGHT_EYE_INNER", + "RIGHT_EYE", + "RIGHT_EYE_OUTER", + "LEFT_EAR", + "RIGHT_EAR", + "MOUTH_LEFT", + "MOUTH_RIGHT", + "LEFT_SHOULDER", + "RIGHT_SHOULDER", + "LEFT_ELBOW", + "RIGHT_ELBOW", + "LEFT_WRIST", + "RIGHT_WRIST", + "LEFT_PINKY", + "RIGHT_PINKY", + "LEFT_INDEX", + "RIGHT_INDEX", + "LEFT_THUMB", + "RIGHT_THUMB", + "LEFT_HIP", + "RIGHT_HIP", + "LEFT_KNEE", + "RIGHT_KNEE", + "LEFT_ANKLE", + "RIGHT_ANKLE", + "LEFT_HEEL", + "RIGHT_HEEL", + "LEFT_FOOT_INDEX", + "RIGHT_FOOT_INDEX" + ], + "FACE_LANDMARKS": [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "10", + "11", + "12", + "13", + "14", + "15", + "16", + "17", + "18", + "19", + "20", + "21", + "22", + "23", + "24", + "25", + "26", + "27", + "28", + "29", + "30", + "31", + "32", + "33", + "34", + "35", + "36", + "37", + "38", + "39", + "40", + "41", + "42", + "43", + "44", + "45", + "46", + "47", + "48", + "49", + "50", + "51", + "52", + "53", + "54", + "55", + "56", + "57", + "58", + "59", + "60", + "61", + "62", + "63", + "64", + "65", + "66", + "67", + "68", + "69", + "70", + "71", + "72", + "73", + "74", + "75", + "76", + "77", + "78", + "79", + "80", + "81", + "82", + "83", + "84", + "85", + "86", + "87", + "88", + "89", + "90", + "91", + "92", + "93", + "94", + "95", + "96", + "97", + "98", + "99", + "100", + "101", + "102", + "103", + "104", + "105", + "106", + "107", + "108", + "109", + "110", + "111", + "112", + "113", + "114", + "115", + "116", + "117", + "118", + "119", + "120", + "121", + "122", + "123", + "124", + "125", + "126", + "127", + "128", + "129", + "130", + "131", + "132", + "133", + "134", + "135", + "136", + "137", + "138", + "139", + "140", + "141", + "142", + "143", + "144", + "145", + "146", + "147", + "148", + "149", + "150", + "151", + "152", + "153", + "154", + "155", + "156", + "157", + "158", + "159", + "160", + "161", + "162", + "163", + "164", + "165", + "166", + "167", + "168", + "169", + "170", + "171", + "172", + "173", + "174", + "175", + "176", + "177", + "178", + "179", + "180", + "181", + "182", + "183", + "184", + "185", + "186", + "187", + "188", + "189", + "190", + "191", + "192", + "193", + "194", + "195", + "196", + "197", + "198", + "199", + "200", + "201", + "202", + "203", + "204", + "205", + "206", + "207", + "208", + "209", + "210", + "211", + "212", + "213", + "214", + "215", + "216", + "217", + "218", + "219", + "220", + "221", + "222", + "223", + "224", + "225", + "226", + "227", + "228", + "229", + "230", + "231", + "232", + "233", + "234", + "235", + "236", + "237", + "238", + "239", + "240", + "241", + "242", + "243", + "244", + "245", + "246", + "247", + "248", + "249", + "250", + "251", + "252", + "253", + "254", + "255", + "256", + "257", + "258", + "259", + "260", + "261", + "262", + "263", + "264", + "265", + "266", + "267", + "268", + "269", + "270", + "271", + "272", + "273", + "274", + "275", + "276", + "277", + "278", + "279", + "280", + "281", + "282", + "283", + "284", + "285", + "286", + "287", + "288", + "289", + "290", + "291", + "292", + "293", + "294", + "295", + "296", + "297", + "298", + "299", + "300", + "301", + "302", + "303", + "304", + "305", + "306", + "307", + "308", + "309", + "310", + "311", + "312", + "313", + "314", + "315", + "316", + "317", + "318", + "319", + "320", + "321", + "322", + "323", + "324", + "325", + "326", + "327", + "328", + "329", + "330", + "331", + "332", + "333", + "334", + "335", + "336", + "337", + "338", + "339", + "340", + "341", + "342", + "343", + "344", + "345", + "346", + "347", + "348", + "349", + "350", + "351", + "352", + "353", + "354", + "355", + "356", + "357", + "358", + "359", + "360", + "361", + "362", + "363", + "364", + "365", + "366", + "367", + "368", + "369", + "370", + "371", + "372", + "373", + "374", + "375", + "376", + "377", + "378", + "379", + "380", + "381", + "382", + "383", + "384", + "385", + "386", + "387", + "388", + "389", + "390", + "391", + "392", + "393", + "394", + "395", + "396", + "397", + "398", + "399", + "400", + "401", + "402", + "403", + "404", + "405", + "406", + "407", + "408", + "409", + "410", + "411", + "412", + "413", + "414", + "415", + "416", + "417", + "418", + "419", + "420", + "421", + "422", + "423", + "424", + "425", + "426", + "427", + "428", + "429", + "430", + "431", + "432", + "433", + "434", + "435", + "436", + "437", + "438", + "439", + "440", + "441", + "442", + "443", + "444", + "445", + "446", + "447", + "448", + "449", + "450", + "451", + "452", + "453", + "454", + "455", + "456", + "457", + "458", + "459", + "460", + "461", + "462", + "463", + "464", + "465", + "466", + "467" + ], + "LEFT_HAND_LANDMARKS": [ + "WRIST", + "THUMB_CMC", + "THUMB_MCP", + "THUMB_IP", + "THUMB_TIP", + "INDEX_FINGER_MCP", + "INDEX_FINGER_PIP", + "INDEX_FINGER_DIP", + "INDEX_FINGER_TIP", + "MIDDLE_FINGER_MCP", + "MIDDLE_FINGER_PIP", + "MIDDLE_FINGER_DIP", + "MIDDLE_FINGER_TIP", + "RING_FINGER_MCP", + "RING_FINGER_PIP", + "RING_FINGER_DIP", + "RING_FINGER_TIP", + "PINKY_MCP", + "PINKY_PIP", + "PINKY_DIP", + "PINKY_TIP" + ], + "RIGHT_HAND_LANDMARKS": [ + "WRIST", + "THUMB_CMC", + "THUMB_MCP", + "THUMB_IP", + "THUMB_TIP", + "INDEX_FINGER_MCP", + "INDEX_FINGER_PIP", + "INDEX_FINGER_DIP", + "INDEX_FINGER_TIP", + "MIDDLE_FINGER_MCP", + "MIDDLE_FINGER_PIP", + "MIDDLE_FINGER_DIP", + "MIDDLE_FINGER_TIP", + "RING_FINGER_MCP", + "RING_FINGER_PIP", + "RING_FINGER_DIP", + "RING_FINGER_TIP", + "PINKY_MCP", + "PINKY_PIP", + "PINKY_DIP", + "PINKY_TIP" + ], + "POSE_WORLD_LANDMARKS": [ + "NOSE", + "LEFT_EYE_INNER", + "LEFT_EYE", + "LEFT_EYE_OUTER", + "RIGHT_EYE_INNER", + "RIGHT_EYE", + "RIGHT_EYE_OUTER", + "LEFT_EAR", + "RIGHT_EAR", + "MOUTH_LEFT", + "MOUTH_RIGHT", + "LEFT_SHOULDER", + "RIGHT_SHOULDER", + "LEFT_ELBOW", + "RIGHT_ELBOW", + "LEFT_WRIST", + "RIGHT_WRIST", + "LEFT_PINKY", + "RIGHT_PINKY", + "LEFT_INDEX", + "RIGHT_INDEX", + "LEFT_THUMB", + "RIGHT_THUMB", + "LEFT_HIP", + "RIGHT_HIP", + "LEFT_KNEE", + "RIGHT_KNEE", + "LEFT_ANKLE", + "RIGHT_ANKLE", + "LEFT_HEEL", + "RIGHT_HEEL", + "LEFT_FOOT_INDEX", + "RIGHT_FOOT_INDEX" + ] +} \ No newline at end of file diff --git a/pose_evaluation/utils/test_pose_utils.py b/pose_evaluation/utils/test_pose_utils.py new file mode 100644 index 0000000..03d4806 --- /dev/null +++ b/pose_evaluation/utils/test_pose_utils.py @@ -0,0 +1,324 @@ +import copy +import numpy as np +from typing import List, Dict, Tuple +import pytest +from pathlib import Path +from pose_format import Pose +from pose_format.utils.generic import pose_hide_legs +from pose_evaluation.utils.pose_utils import ( + load_pose_file, + pose_remove_world_landmarks, + remove_components, + pose_remove_legs, + pose_hide_low_conf, + copy_pose, + get_face_and_hands_from_pose, + reduce_pose_components_and_points_to_intersection, + # preprocess_pose, + get_component_names_and_points_dict, + # preprocess_poses, + detect_format, + zero_pad_shorter_poses, + set_masked_to_origin_position +) + + +def test_load_poses_mediapipe( + test_mediapipe_poses_paths: List[Path], + standard_mediapipe_components_dict: Dict[str, List[str]], +): + + poses = [load_pose_file(pose_path) for pose_path in test_mediapipe_poses_paths] + + assert len(poses) == 3 + + for pose in poses: + # do they all have headers? + assert pose.header is not None + + # check if the expected components are there. + for component in pose.header.components: + # should have specific expected components + assert component.name in standard_mediapipe_components_dict + + # should have specific expected points + assert sorted(component.points) == sorted( + standard_mediapipe_components_dict[component.name] + ) + + # checking the data: + # Frames, People, Points, Dims + assert pose.body.data.ndim == 4 + + # all frames have the standard shape? + assert all(frame.shape == (1, 576, 3) for frame in pose.body.data) + + +def test_remove_specific_landmarks_mediapipe( + test_mediapipe_poses: List[Pose], + standard_mediapipe_components_dict: Dict[str, List[str]], +): + for pose in test_mediapipe_poses: + component_count = len(pose.header.components) + assert component_count == len(standard_mediapipe_components_dict.keys()) + for component_name in standard_mediapipe_components_dict.keys(): + pose_with_component_removed = remove_components(pose, [str(component_name)]) + assert component_name not in pose_with_component_removed.header.components + assert ( + len(pose_with_component_removed.header.components) + == component_count - 1 + ) + + +def test_pose_copy(test_mediapipe_poses: List[Pose]): + for pose in test_mediapipe_poses: + copy = copy_pose(pose) + + assert copy != pose # Not the same object + assert ( + pose.header.components != copy.header.components + ) # header is also not the same object + assert pose.body != copy.body # also not the same + assert np.array_equal( + copy.body.data, pose.body.data + ) # the data should have the same values + + assert sorted([c.name for c in pose.header.components]) == sorted( + [c.name for c in copy.header.components] + ) # same components + assert ( + copy.header.total_points() == pose.header.total_points() + ) # same number of points + + +def test_pose_remove_legs(test_mediapipe_poses: List[Pose]): + points_that_should_be_hidden = ["KNEE", "HEEL", "FOOT", "TOE"] + for pose in test_mediapipe_poses: + # pose_hide_legs(pose) + pose = pose_remove_legs(pose) + + for component in pose.header.components: + point_names = [point.upper() for point in component.points] + for point_name in point_names: + for point_that_should_be_hidden in points_that_should_be_hidden: + assert point_that_should_be_hidden not in point_name + + +def test_pose_remove_legs_openpose(fake_openpose_poses): + for pose in fake_openpose_poses: + with pytest.raises(NotImplementedError): + pose_remove_legs(pose) + + +def test_reduce_pose_components_to_intersection( + test_mediapipe_poses: List[Pose], + standard_mediapipe_components_dict: Dict[str, List[str]], +): + + test_poses_with_one_reduced = [copy_pose(pose) for pose in test_mediapipe_poses] + + + pose_with_only_face_and_hands_and_no_wrist = get_face_and_hands_from_pose( + test_poses_with_one_reduced.pop() + ) + + c_names, p_dict = get_component_names_and_points_dict(pose_with_only_face_and_hands_and_no_wrist) + + new_p_dict = {} + for c_name, p_list in p_dict.items(): + new_p_dict[c_name] = [point_name for point_name in p_list if "WRIST" not in point_name] + + + pose_with_only_face_and_hands_and_no_wrist = pose_with_only_face_and_hands_and_no_wrist.get_components(c_names, new_p_dict) + + test_poses_with_one_reduced.append(pose_with_only_face_and_hands_and_no_wrist) + assert len(test_mediapipe_poses) == len(test_poses_with_one_reduced) + + original_component_count = len( + standard_mediapipe_components_dict.keys() + ) # 5, at time of writing + + target_component_count = 3 # face, left hand, right hand + assert ( + len(pose_with_only_face_and_hands_and_no_wrist.header.components) == target_component_count + ) + + target_point_count = pose_with_only_face_and_hands_and_no_wrist.header.total_points() + + + + reduced_poses = reduce_pose_components_and_points_to_intersection(test_poses_with_one_reduced) + for reduced_pose in reduced_poses: + assert len(reduced_pose.header.components) == target_component_count + assert reduced_pose.header.total_points() == target_point_count + + # check if the originals are unaffected + assert all( + [ + len(pose.header.components) == original_component_count + for pose in test_mediapipe_poses + ] + ) + + + +def test_remove_world_landmarks(test_mediapipe_poses: List[Pose]): + for pose in test_mediapipe_poses: + component_names = [c.name for c in pose.header.components] + starting_component_count = len(pose.header.components) + assert "POSE_WORLD_LANDMARKS" in component_names + + pose = pose_remove_world_landmarks(pose) + component_names = [c.name for c in pose.header.components] + assert "POSE_WORLD_LANDMARKS" not in component_names + ending_component_count = len(pose.header.components) + + assert ending_component_count == starting_component_count - 1 + + +def test_remove_one_point_and_one_component(test_mediapipe_poses: List[Pose]): + component_to_drop = "POSE_WORLD_LANDMARKS" + point_to_drop = "LEFT_KNEE" + for pose in test_mediapipe_poses: + original_component_names, original_points_dict = ( + get_component_names_and_points_dict(pose) + ) + + assert component_to_drop in original_component_names + assert point_to_drop in original_points_dict["POSE_LANDMARKS"] + reduced_pose = remove_components(pose, component_to_drop, point_to_drop) + new_component_names, new_points_dict = get_component_names_and_points_dict( + reduced_pose + ) + assert component_to_drop not in new_component_names + assert point_to_drop not in new_points_dict["POSE_LANDMARKS"] + + +def test_detect_format( + fake_openpose_poses, fake_openpose_135_poses, test_mediapipe_poses +): + for pose in fake_openpose_poses: + assert detect_format(pose) == "openpose" + + for pose in fake_openpose_135_poses: + assert detect_format(pose) == "openpose_135" + + for pose in test_mediapipe_poses: + assert detect_format(pose) == "mediapipe" + + for pose in test_mediapipe_poses: + unsupported_component_name = "UNSUPPORTED" + pose.header.components[0].name = unsupported_component_name + pose = pose.get_components(["UNSUPPORTED"]) + component_names, _ = get_component_names_and_points_dict(pose) + assert len(pose.header.components) == 1 + + # pose.header.components[0] + # assert pose.header.components[0] == changing_component + with pytest.raises( + ValueError, match="Unknown pose header schema with component names" + ): + detect_format(pose) + + +# def test_preprocess_pose(test_mediapipe_poses_paths: List[Path]): +# poses = [load_pose_file(pose_path) for pose_path in test_mediapipe_poses_paths] +# preprocessed_poses = [] + +# data_arrays = [pose.body.data for pose in poses] + +# for pose in poses: +# processed_pose = preprocess_pose(pose, +# normalize_poses=True, +# remove_legs=True, +# remove_world_landmarks=True, +# conf_threshold_to_drop_points=0.2) +# #TODO: check expected result + + +# def test_preprocess_poses(test_mediapipe_poses: List[Pose]): + +# nan_counts = [np.count_nonzero(np.isnan(pose.body.data)) for pose in test_mediapipe_poses] + +# preprocessed_poses = preprocess_poses( +# test_mediapipe_poses, +# normalize_poses=True, +# reduce_poses_to_common_points=True, +# remove_world_landmarks=True, +# remove_legs=True, +# zero_pad_shorter_pose=True, +# conf_threshold_to_drop_points=0.2, +# ) + +# for i, pose in enumerate(preprocessed_poses): +# component_names, points_dict = get_component_names_and_points_dict(pose) +# assert "LEFT_KNEE" not in points_dict["POSE_LANDMARKS"] +# assert "POSE_WORLD_LANDMARKS" not in component_names + +# # zero-padded properly? Should all be the same number of frames +# assert pose.body.data.shape[0] == preprocessed_poses[0].body.data.shape[0] + +# # do we have nan values? +# nan_count = np.count_nonzero(np.isnan(pose.body.data)) +# assert np.count_nonzero(np.isnan(pose.body.data)) == nan_counts[i] +# assert nan_count == 0 + + +def test_set_masked_to_origin_pos(test_mediapipe_poses: List[Pose]): + # Create a copy of the original poses for comparison + originals = [copy_pose(pose) for pose in test_mediapipe_poses] + + # Apply the transformation + poses = [set_masked_to_origin_position(pose) for pose in test_mediapipe_poses] + + for original, transformed in zip(originals, poses): + # 1. Ensure the transformed data is still a MaskedArray + assert isinstance(transformed.body.data, np.ma.MaskedArray) + + # 2. Ensure the mask is now all False + assert np.all(transformed.body.data.mask == False) + + # 3. Check the shape matches the original + assert transformed.body.data.shape == original.body.data.shape + + # 4. Validate masked positions in the original are now zeros + assert np.all( + transformed.body.data.data[original.body.data.mask] == 0 + ) + + # 5. Validate unmasked positions in the original remain unchanged + assert np.all( + transformed.body.data.data[~original.body.data.mask] + == original.body.data.data[~original.body.data.mask] + ) + + + +def test_hide_low_conf(test_mediapipe_poses: List[Pose]): + copies = [copy_pose(pose) for pose in test_mediapipe_poses] + for pose, copy in zip(test_mediapipe_poses, copies): + pose_hide_low_conf(pose, 1.0) + + assert np.array_equal(pose.body.confidence, copy.body.confidence) == False + + +def test_zero_pad_shorter_poses(test_mediapipe_poses: List[Pose]): + copies = [copy_pose(pose) for pose in test_mediapipe_poses] + + max_len = max([len(pose.body.data) for pose in test_mediapipe_poses]) + padded_poses = zero_pad_shorter_poses(test_mediapipe_poses) + + for i, padded_pose in enumerate(padded_poses): + assert ( + test_mediapipe_poses[i] != padded_poses[i] + ) # shouldn't be the same object + old_length = len(copies[i].body.data) + new_length = len(padded_pose.body.data) + assert new_length == max_len + if old_length == new_length: + assert old_length == max_len + + # does the confidence match? + assert padded_pose.body.confidence.shape == padded_pose.body.data.shape[:-1] + + diff --git a/pyproject.toml b/pyproject.toml index 893fa3d..a74a052 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,8 @@ dependencies = [ "scipy", "torch", "numpy", # possibly could replace all with torch + "cython", # used to accelerate fastdtw + "fastdtw", # for various vector/tensor similarities and distances in torch "sentence-transformers", # For reading .csv files, etc