-
Notifications
You must be signed in to change notification settings - Fork 1
Pose distance metrics from Ham2Pose #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
cleong110
wants to merge
31
commits into
sign-language-processing:main
from
cleong110:ham2pose_metrics
Closed
Changes from 8 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
0c7bb56
CDL: copying code by @j22melody, as requested.
cleong110 db9056e
Merge branch 'main' into ham2pose_metrics
cleong110 e5f703f
CDL: add new req for fastdtw
cleong110 a7ca062
Start ndtw_mje, add common functions for pose preprocessing, etc
cleong110 2dee6ce
Edit the name of a test function to avoid potential collisions
cleong110 04472e1
Stubbed test file
cleong110 1f5767d
A bit of pylint cleanup
cleong110 07225cb
Preprocessing for poses, and some type annotations, and a bit of refa…
cleong110 e931966
adding tests for local pose_utils
cleong110 7bc5371
some gitignore updates
cleong110 4369938
Fixing a few type issues
cleong110 bdf5d73
adding test data
cleong110 99e27f9
remove instead of hide legs in pose_utils
cleong110 ec09e3c
Take out temp test code
cleong110 66361ca
fix forgetting to assign in preprocess_pose
cleong110 389abe2
some minor fixes in tests
cleong110 8344d54
euclidean, not l2
cleong110 e9e8cc1
Transitioning to pytest from unittest
cleong110 4334fb8
Trying to figure out pytest
cleong110 92ec8e0
Caught another L2
cleong110 251ad26
basic scoring script
cleong110 0449eba
implement ape_metric
cleong110 daedc1e
Very WIP, pushing code for the day
cleong110 72cec6d
Starting the move to PoseProcessors
cleong110 74e016e
adding in the set_masked_values_to_zero
cleong110 d2d1759
Pushing all changes as-is
cleong110 d8dd461
Cleaning u and implementing separate DTW and Distance Metrics
cleong110 3517a76
I can build the basic Ham2Pose metrics!
cleong110 688fcee
remove unused file
cleong110 57f6932
remove unused alignment_strategy
cleong110 a636406
Various pylint changes
cleong110 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,10 +4,10 @@ | |
from pose_format import Pose | ||
|
||
from pose_evaluation.metrics.base_pose_metric import PoseMetric | ||
|
||
ValidDistanceKinds = Literal["euclidean", "manhattan"] | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. class PowerDistance(PointwiseDistance):
def __init__(self, power: int = 2, default_distance=0):
self.power = power
self.default_distance = default_distance
def __call__(self, p1: MaskedArray, p2: MaskedArray):
return (p1 - p2).pow(self.power).abs().filled(self.default_distance).mean()
L2Distance = PowerDistance(power=2)
L1Distance = PowerDistance(power=1)
class APEDistance(PointwiseDistance):
def __call__(self, p1: MaskedArray, p2: MaskedArray):
return (p1 - p2).pow(2).sum(-1).sqrt().filled(default_distance).mean()
# Example usage
DistanceMetric(distance=L2Distance, alignment_strategy="pad"|"truncate"|"by-reference") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. class DTWMetric(BaseMetric):
def __init__(self, distance: Distance, trajectory:"keypoints"|"frames", processors=List[Processor])
self.__super__...
def score(self, pose1: Pose, pose2: Pose):
tensor1 = self.process(pose1)
tensor2 = self.process(pose2)
if self. trajectory == "keypoints":
tensor1 = points_perspective(tensor1)
tensor2 = points_perspective(tensor2)
distance = 0
for trajectory1, trajectory2 in zip(tensor1, tensor2):
distance += fastdtw(metric=self.distance, seq1=trajectory1, seq2=trajectory2)
return distance
return fastdtw(metric=self.distance, seq1=tensor1, seq2=tensor2)
def PoseSetMaskedToOrigin(pose):
return pose.filled(0)
dtw_metric = DTWMetric(
distance=L2Distance(),
trajectory="keypoints",
processors=[
PoseNormalize(),
HideLegs(),
PoseSetMaskedToOrigin(),
ReduceToCommonKeypoints(),
],
)
# Usage signature: dtw(distance=str(distance), processors="|".join(processors)) |
||
class DistanceMetric(PoseMetric): | ||
def __init__(self, kind: Literal["l1", "l2"] = "l2"): | ||
def __init__(self, kind: ValidDistanceKinds = "euclidean"): | ||
super().__init__(f"DistanceMetric {kind}", higher_is_better=False) | ||
self.kind = kind | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
# Originally based on https://github.com/rotem-shalev/Ham2Pose/blob/main/metrics.py, | ||
# then adapted for MediaPipe Holistic format by @J22Melody, in another (private) repo | ||
# and then code was copied to this repo by @cleong110 | ||
|
||
import os | ||
from pathlib import Path | ||
import numpy as np | ||
from scipy.spatial.distance import euclidean | ||
from fastdtw import fastdtw | ||
|
||
from pose_evaluation.utils.pose_utils import preprocess_pose, load_pose_file, pose_hide_low_conf | ||
|
||
# rootdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) | ||
# sys.path.insert(0, rootdir) | ||
|
||
# from pose_format.pose_header import PoseHeader | ||
# from pose_format.utils.reader import BufferReader | ||
# from pose_utils import pose_normalization_info, pose_hide_legs, pose_hide_low_conf | ||
# from data.tfds_dataset import flip_pose | ||
# from data.hamnosys.hamnosys import get_pose | ||
# from predict import predict_pose | ||
|
||
# PJM_FRAME_WIDTH = 1280 | ||
# with open("data/pjm_left_videos.json", 'r') as f: | ||
# PJM_LEFT_VIDEOS_LST = json.load(f) | ||
|
||
# dataset_module = importlib.import_module(f"data.hamnosys.hamnosys") | ||
|
||
# with open(dataset_module._POSE_HEADERS["openpose"], "rb") as buffer: | ||
# pose_header = PoseHeader.read(BufferReader(buffer.read())) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
def masked_euclidean(point1, point2): | ||
if np.ma.is_masked(point2): # reference label keypoint is missing | ||
return 0 | ||
elif np.ma.is_masked(point1): # reference label keypoint is not missing, other label keypoint is missing | ||
return euclidean((0, 0, 0), point2)/2 | ||
d = euclidean(point1, point2) | ||
return d | ||
|
||
|
||
def masked_mse(trajectory1, trajectory2, confidence): | ||
if len(trajectory1) < len(trajectory2): | ||
diff = len(trajectory2) - len(trajectory1) | ||
trajectory1 = np.concatenate((trajectory1, np.zeros((diff, 3)))) | ||
confidence = np.concatenate((confidence, np.zeros((diff)))) | ||
elif len(trajectory2) < len(trajectory1): | ||
trajectory2 = np.concatenate((trajectory2, np.zeros((len(trajectory1) - len(trajectory2), 3)))) | ||
sq_error = np.power(trajectory1 - trajectory2, 2).sum(-1) | ||
return (sq_error * confidence).mean() | ||
|
||
|
||
def mse(trajectory1, trajectory2): | ||
if len(trajectory1) < len(trajectory2): | ||
diff = len(trajectory2) - len(trajectory1) | ||
trajectory1 = np.concatenate((trajectory1, np.zeros((diff, 3)))) | ||
elif len(trajectory2) < len(trajectory1): | ||
trajectory2 = np.concatenate((trajectory2, np.zeros((len(trajectory1) - len(trajectory2), 3)))) | ||
pose1_mask = np.ma.getmask(trajectory1) | ||
pose2_mask = np.ma.getmask(trajectory2) | ||
trajectory1[pose1_mask] = 0 | ||
trajectory1[pose2_mask] = 0 | ||
trajectory2[pose1_mask] = 0 | ||
trajectory2[pose2_mask] = 0 | ||
sq_error = np.power(trajectory1 - trajectory2, 2).sum(-1) | ||
return sq_error.mean() | ||
|
||
|
||
def masked_APE(trajectory1, trajectory2, confidence): | ||
if len(trajectory1) < len(trajectory2): | ||
diff = len(trajectory2) - len(trajectory1) | ||
trajectory1 = np.concatenate((trajectory1, np.zeros((diff, 3)))) | ||
confidence = np.concatenate((confidence, np.zeros((diff)))) | ||
elif len(trajectory2) < len(trajectory1): | ||
trajectory2 = np.concatenate((trajectory2, np.zeros((len(trajectory1) - len(trajectory2), 3)))) | ||
sq_error = np.power(trajectory1 - trajectory2, 2).sum(-1) | ||
return np.sqrt(sq_error * confidence).mean() | ||
|
||
|
||
def APE(trajectory1, trajectory2): | ||
if len(trajectory1) < len(trajectory2): | ||
diff = len(trajectory2) - len(trajectory1) | ||
trajectory1 = np.concatenate((trajectory1, np.zeros((diff, 3)))) | ||
elif len(trajectory2) < len(trajectory1): | ||
trajectory2 = np.concatenate((trajectory2, np.zeros((len(trajectory1) - len(trajectory2), 3)))) | ||
pose1_mask = np.ma.getmask(trajectory1) | ||
pose2_mask = np.ma.getmask(trajectory2) | ||
trajectory1[pose1_mask] = 0 | ||
trajectory1[pose2_mask] = 0 | ||
trajectory2[pose1_mask] = 0 | ||
trajectory2[pose2_mask] = 0 | ||
sq_error = np.power(trajectory1 - trajectory2, 2).sum(-1) | ||
return np.sqrt(sq_error).mean() | ||
|
||
|
||
def compare_pose_videos(pose1_id, pose2_id, keypoints_path, distance_function=fastdtw): | ||
pose1 = load_pose_file(Path(keypoints_path/ pose1_id), pose1_id) | ||
pose1 = preprocess_pose(os.path.join(keypoints_path, pose1_id), pose1_id) | ||
pose2 = preprocess_pose(os.path.join(keypoints_path, pose2_id), pose2_id) | ||
pose2 = load_pose_file(Path(keypoints_path/ pose2_id), pose2_id) | ||
return compare_poses(pose1, pose2, distance_function=distance_function) | ||
|
||
|
||
def get_idx2weight(max_idx): | ||
idx2weight = {i: 1 for i in range(9)} | ||
idx2weight.update({i: 1 for i in range(95, max_idx)}) | ||
return idx2weight | ||
|
||
|
||
def get_idx2weight_mediapipe(pose): | ||
# TODO: weights | ||
idx2weight = {i: 1 for i in range(pose.body.data.shape[2])} | ||
return idx2weight | ||
|
||
|
||
def compare_poses(pose1, pose2, distance_function='nfastdtw'): | ||
# reduce pose2 the set of keypoints of pose1 (hypothesis) | ||
pose_components = [c.name for c in pose1.header.components] | ||
pose_points = {c.name: c.points for c in pose1.header.components} | ||
pose2 = pose2.get_components(pose_components, pose_points) | ||
|
||
pose_hide_low_conf(pose1) | ||
pose_hide_low_conf(pose2) | ||
|
||
# poses_data = get_pose_data([pose1, pose2]) | ||
poses_data = [pose1.body.data, pose2.body.data] | ||
|
||
total_distance = 0 | ||
idx2weight = get_idx2weight_mediapipe(pose1) | ||
|
||
for keypoint_idx, weight in idx2weight.items(): | ||
pose1_keypoint_trajectory = poses_data[0][:, :, keypoint_idx, :].squeeze(1) | ||
pose2_keypoint_trajectory = poses_data[1][:, :, keypoint_idx, :].squeeze(1) | ||
|
||
if distance_function in [mse, APE]: | ||
dist = distance_function(pose1_keypoint_trajectory, pose2_keypoint_trajectory) | ||
elif distance_function in [masked_mse, masked_APE]: | ||
dist = distance_function(pose1_keypoint_trajectory, pose2_keypoint_trajectory, pose1.body.confidence[:, | ||
:, keypoint_idx].squeeze(1)) | ||
elif distance_function == fastdtw: | ||
dist = distance_function(pose1_keypoint_trajectory, pose2_keypoint_trajectory, dist=euclidean)[0] | ||
elif distance_function == 'nfastdtw': | ||
dist = fastdtw(pose1_keypoint_trajectory, pose2_keypoint_trajectory, dist=masked_euclidean)[0] | ||
total_distance += dist*weight | ||
return total_distance/len(idx2weight) | ||
|
||
|
||
def get_pose_data(poses): | ||
# return relevant pose data for trajectory distance computations- only upper body and hands | ||
poses_data = [] | ||
for pose in poses: | ||
# Note: pose format shape is typically frames, persons (almost always 1), and then XYZ components/points. | ||
# The following is therefore selecting specific points via index. | ||
poses_data.append(np.ma.concatenate([pose.body.data[:, :, :95], | ||
pose.body.data[:, :, 95:116], | ||
pose.body.data[:, :, 116:]], axis=2)) | ||
return poses_data | ||
|
||
|
||
# def __compare_pred_to_video(pred, keypoints_path, pose_id, distance_function=fastdtw): | ||
# label_pose = get_pose(os.path.join(keypoints_path, pose_id), pose_id) | ||
# return compare_poses(pred, label_pose, distance_function=distance_function) | ||
|
||
|
||
# def check_ranks(distances, index): | ||
# rank_1 = (index == distances[0]) | ||
# rank_5 = (index in distances[:5]) | ||
# rank_10 = (index in distances) | ||
# return rank_1, rank_5, rank_10 | ||
|
||
|
||
# def get_poses_ranks(pred, pred_id, keypoints_path, data_ids, distance_function=fastdtw, num_samples=20, | ||
# model=None, pose_header=None, ds=None): | ||
# pred2label_distance = __compare_pred_to_video(pred, keypoints_path, pred_id, distance_function=distance_function) | ||
|
||
# distances_to_label = [pred2label_distance] | ||
# distances_to_pred = [pred2label_distance] | ||
# pred2label_index = 0 | ||
|
||
# if model is not None: | ||
# indices = random.sample(range(len(ds)), num_samples) | ||
# for idx in indices: | ||
# if ds[idx]["id"] == pred_id: | ||
# continue | ||
# cur_pred = predict_pose(model, ds[idx], pose_header) | ||
# distances_to_label.append(__compare_pred_to_video(cur_pred, keypoints_path, pred_id, | ||
# distance_function=distance_function)) | ||
# distances_to_pred.append(compare_poses(pred, cur_pred, distance_function=distance_function)) | ||
|
||
# pose_ids = random.sample(data_ids, num_samples) | ||
# for pose_id in pose_ids: | ||
# distances_to_label.append(compare_pose_videos(pose_id, pred_id, keypoints_path, | ||
# distance_function=distance_function)) | ||
# distances_to_pred.append(__compare_pred_to_video(pred, keypoints_path, pose_id, | ||
# distance_function=distance_function)) | ||
|
||
# best_pred = np.argsort(distances_to_pred)[:10] | ||
# rank_1_pred, rank_5_pred, rank_10_pred = check_ranks(best_pred, pred2label_index) | ||
# best_label = np.argsort(distances_to_label)[:10] | ||
# rank_1_label, rank_5_label, rank_10_label = check_ranks(best_label, pred2label_index) | ||
|
||
# return pred2label_distance, rank_1_pred, rank_5_pred, rank_10_pred, rank_1_label, rank_5_label, rank_10_label |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import Literal, List | ||
|
||
from pose_format import Pose | ||
|
||
from pose_evaluation.metrics.distance_metric import DistanceMetric, ValidDistanceKinds | ||
from pose_evaluation.utils.pose_utils import pose_hide_low_conf, preprocess_pose | ||
|
||
class DynamicTimeWarpingMeanJointError(DistanceMetric): | ||
def __init__(self, kind: ValidDistanceKinds = "euclidean", | ||
normalize_poses:bool=True, | ||
reduce_poses:bool=False, | ||
remove_legs:bool=True, | ||
remove_world_landmarks:bool=False, | ||
conf_threshold_to_drop_points:None|int=None, | ||
): | ||
super().__init__(kind) | ||
|
||
self.normalize_poses = normalize_poses | ||
self.reduce_reference_poses = reduce_poses | ||
self.remove_legs = remove_legs | ||
self.remove_world_landmarks = remove_world_landmarks | ||
self.conf_threshold_to_drop_points = conf_threshold_to_drop_points | ||
|
||
def score_all(self, hypotheses:List[Pose], references:List[Pose], progress_bar=True): | ||
# TODO: | ||
return super().score_all(hypotheses, references, progress_bar) | ||
|
||
|
||
def score(self, hypothesis:Pose, reference:Pose): | ||
# TODO | ||
return super().score(hypothesis, reference) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from pose_evaluation.metrics.test_distance_metric import TestDistanceMetricGeneric | ||
from pose_evaluation.utils.pose_utils import get_preprocessed_pose | ||
|
||
# TODO |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.