Skip to content

Commit 87f486d

Browse files
authored
Consistent code style (#20)
* fix(style): use webstorm + manual overview * feat(style): add black * feat(style): add black
1 parent 362466e commit 87f486d

16 files changed

+801
-775
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ and correlate to human judgments.
5656

5757
**TODO** list evaluation metrics here.
5858

59+
### Contributing
60+
61+
Please make sure to run `black pose_evaluation` before submitting a pull request.
62+
5963
## Cite
6064

6165
If you use our toolkit in your research or projects, please consider citing the work.

pose_evaluation/evaluation/evaluate_signclip.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import argparse
2-
from pathlib import Path
3-
import time
42
import json
53
import random
6-
import pandas as pd
4+
import time
5+
from pathlib import Path
6+
77
import numpy as np
8+
import pandas as pd
89
import torch
910
from tqdm import tqdm
11+
1012
from pose_evaluation.metrics.embedding_distance_metric import EmbeddingDistanceMetric
1113

14+
1215
def load_embedding(file_path: Path) -> np.ndarray:
1316
"""
1417
Load a SignCLIP embedding from a .npy file, ensuring it has the correct shape.
@@ -61,7 +64,10 @@ def get_embedding(video_file):
6164

6265

6366
def calculate_mean_distances(
64-
distance_matrix: torch.Tensor, indices_a: torch.Tensor, indices_b: torch.Tensor, exclude_self: bool = False
67+
distance_matrix: torch.Tensor,
68+
indices_a: torch.Tensor,
69+
indices_b: torch.Tensor,
70+
exclude_self: bool = False,
6571
) -> float:
6672
"""
6773
Calculate the mean of distances between two sets of indices in a 2D distance matrix.
@@ -92,7 +98,6 @@ def calculate_mean_distances(
9298

9399

94100
def generate_synthetic_data(num_items, num_classes, num_items_per_class=4):
95-
96101
torch.manual_seed(42)
97102
random.seed(42)
98103
# distance_matrix = torch.rand((num_items, num_items)) * 100
@@ -238,7 +243,7 @@ def evaluate_signclip(emb_dir: Path, split_file: Path, out_path: Path, kind: str
238243

239244
find_class_distances_end = time.perf_counter()
240245

241-
print(f"Finding within and without took {find_class_distances_end-find_class_distances_start}")
246+
print(f"Finding within and without took {find_class_distances_end - find_class_distances_start}")
242247

243248
analysis_end = time.perf_counter()
244249
analysis_duration = analysis_end - analysis_start
@@ -288,8 +293,17 @@ def evaluate_signclip(emb_dir: Path, split_file: Path, out_path: Path, kind: str
288293

289294
def main():
290295
parser = argparse.ArgumentParser(description="Evaluate SignCLIP embeddings with score_all.")
291-
parser.add_argument("emb_dir", type=Path, help="Path to the directory containing SignCLIP .npy files")
292-
parser.add_argument("--split_file", type=Path, required=True, help="Path to the split CSV file (e.g., test.csv)")
296+
parser.add_argument(
297+
"emb_dir",
298+
type=Path,
299+
help="Path to the directory containing SignCLIP .npy files",
300+
)
301+
parser.add_argument(
302+
"--split_file",
303+
type=Path,
304+
required=True,
305+
help="Path to the split CSV file (e.g., test.csv)",
306+
)
293307
parser.add_argument(
294308
"--kind",
295309
type=str,
@@ -298,7 +312,11 @@ def main():
298312
help="Type of distance metric to use (default: cosine)",
299313
)
300314

301-
parser.add_argument("--out_path", type=Path, help="Where to save output distance npz matrix+file list")
315+
parser.add_argument(
316+
"--out_path",
317+
type=Path,
318+
help="Where to save output distance npz matrix+file list",
319+
)
302320

303321
args = parser.parse_args()
304322

@@ -311,7 +329,12 @@ def main():
311329

312330
print(f"Scores will be saved to {output_file}")
313331

314-
evaluate_signclip(emb_dir=args.emb_dir, split_file=args.split_file, out_path=output_file, kind=args.kind)
332+
evaluate_signclip(
333+
emb_dir=args.emb_dir,
334+
split_file=args.split_file,
335+
out_path=output_file,
336+
kind=args.kind,
337+
)
315338

316339

317340
if __name__ == "__main__":

pose_evaluation/examples/example_metric_construction.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
from pathlib import Path
2+
23
from pose_format import Pose
3-
from pose_evaluation.metrics.distance_metric import DistanceMetric
4-
from pose_evaluation.metrics.distance_measure import AggregatedPowerDistance
4+
55
from pose_evaluation.metrics.base import BaseMetric
6+
from pose_evaluation.metrics.distance_measure import AggregatedPowerDistance
7+
from pose_evaluation.metrics.distance_metric import DistanceMetric
68
from pose_evaluation.metrics.test_distance_metric import get_poses
79
from pose_evaluation.utils.pose_utils import zero_pad_shorter_poses
810

911
if __name__ == "__main__":
1012
# Define file paths for test pose data
11-
reference_file = (
12-
Path("pose_evaluation") / "utils" / "test" / "test_data" / "colin-1-HOUSE.pose"
13-
)
14-
hypothesis_file = (
15-
Path("pose_evaluation") / "utils" / "test" / "test_data" / "colin-2-HOUSE.pose"
16-
)
13+
test_data_path = Path("pose_evaluation") / "utils" / "test" / "test_data"
14+
reference_file = test_data_path / "colin-1-HOUSE.pose"
15+
hypothesis_file = test_data_path / "colin-2-HOUSE.pose"
1716

1817
# Choose whether to load real files or generate test poses
1918
# They have different lengths, and so some metrics will crash!
@@ -33,25 +32,19 @@
3332
poses = [hypothesis, reference]
3433

3534
# Define distance metrics
36-
mean_l1_metric = DistanceMetric(
37-
"mean_l1_metric", distance_measure=AggregatedPowerDistance(1, 17)
38-
)
35+
mean_l1_metric = DistanceMetric("mean_l1_metric", distance_measure=AggregatedPowerDistance(1, 17))
3936
metrics = [
4037
BaseMetric("base"),
4138
DistanceMetric("PowerDistanceMetric", AggregatedPowerDistance(2, 1)),
4239
DistanceMetric("AnotherPowerDistanceMetric", AggregatedPowerDistance(1, 10)),
4340
mean_l1_metric,
4441
DistanceMetric(
4542
"max_l1_metric",
46-
AggregatedPowerDistance(
47-
order=1, aggregation_strategy="max", default_distance=0
48-
),
43+
AggregatedPowerDistance(order=1, aggregation_strategy="max", default_distance=0),
4944
),
5045
DistanceMetric(
5146
"MeanL2Score",
52-
AggregatedPowerDistance(
53-
order=2, aggregation_strategy="mean", default_distance=0
54-
),
47+
AggregatedPowerDistance(order=2, aggregation_strategy="mean", default_distance=0),
5548
),
5649
]
5750

pose_evaluation/metrics/base.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# pylint: disable=undefined-variable
22
from typing import Any, Callable, Sequence
3+
34
from tqdm import tqdm
45

56

@@ -83,9 +84,7 @@ def __call__(self, hypothesis: T, reference: T) -> float:
8384
def score(self, hypothesis: T, reference: T) -> float:
8485
raise NotImplementedError
8586

86-
def score_with_signature(
87-
self, hypothesis: T, reference: T, short: bool = False
88-
) -> Score:
87+
def score_with_signature(self, hypothesis: T, reference: T, short: bool = False) -> Score:
8988
return Score(
9089
name=self.name,
9190
score=self.score(hypothesis, reference),
@@ -96,29 +95,19 @@ def score_max(self, hypothesis: T, references: Sequence[T]) -> float:
9695
all_scores = self.score_all([hypothesis], references)
9796
return max(max(scores) for scores in all_scores)
9897

99-
def validate_corpus_score_input(
100-
self, hypotheses: Sequence[T], references: Sequence[Sequence[T]]
101-
):
98+
def validate_corpus_score_input(self, hypotheses: Sequence[T], references: Sequence[Sequence[T]]):
10299
# This method is designed to avoid mistakes in the use of the corpus_score method
103100
for reference in references:
104-
assert len(hypotheses) == len(
105-
reference
106-
), "Hypothesis and reference must have the same number of instances"
101+
assert len(hypotheses) == len(reference), "Hypothesis and reference must have the same number of instances"
107102

108-
def corpus_score(
109-
self, hypotheses: Sequence[T], references: Sequence[list[T]]
110-
) -> float:
103+
def corpus_score(self, hypotheses: Sequence[T], references: Sequence[list[T]]) -> float:
111104
"""Default implementation: average over sentence scores."""
112105
self.validate_corpus_score_input(hypotheses, references)
113106
transpose_references = list(zip(*references))
114-
scores = [
115-
self.score_max(h, r) for h, r in zip(hypotheses, transpose_references)
116-
]
107+
scores = [self.score_max(h, r) for h, r in zip(hypotheses, transpose_references)]
117108
return sum(scores) / len(hypotheses)
118109

119-
def score_all(
120-
self, hypotheses: Sequence[T], references: Sequence[T], progress_bar=True
121-
) -> list[list[float]]:
110+
def score_all(self, hypotheses: Sequence[T], references: Sequence[T], progress_bar=True) -> list[list[float]]:
122111
"""Call the score function for each hypothesis-reference pair."""
123112
return [
124113
[self.score(h, r) for r in references]

pose_evaluation/metrics/base_embedding_metric.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from typing import TypeVar
2+
23
import torch
3-
from pose_evaluation.metrics.base import BaseMetric
44

5+
from pose_evaluation.metrics.base import BaseMetric
56

67
# Define a type alias for embeddings (e.g., torch.Tensor)
78
Embedding = TypeVar("Embedding", bound=torch.Tensor)

pose_evaluation/metrics/conftest.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import shutil
22
from pathlib import Path
33
from typing import Callable, Union
4-
import torch
4+
55
import numpy as np
66
import pytest
7+
import torch
78

89

910
@pytest.fixture(scope="session", autouse=True)
@@ -20,11 +21,12 @@ def clean_test_artifacts():
2021
@pytest.fixture(name="distance_matrix_shape_checker")
2122
def fixture_distance_matrix_shape_checker() -> Callable[[torch.Tensor, torch.Tensor], None]:
2223
def _check_shape(hyp_count: int, ref_count: int, distance_matrix: torch.Tensor):
23-
2424
expected_shape = torch.Size([hyp_count, ref_count])
25-
assert (
26-
distance_matrix.shape == expected_shape
27-
), f"For M={hyp_count} hypotheses, N={ref_count} references, Distance Matrix should be MxN={expected_shape}. Instead, received {distance_matrix.shape}"
25+
assert distance_matrix.shape == expected_shape, (
26+
f"For M={hyp_count} hypotheses, N={ref_count} references, "
27+
f"Distance Matrix should be MxN={expected_shape}. "
28+
f"Instead, received {distance_matrix.shape}"
29+
)
2830

2931
return _check_shape
3032

pose_evaluation/metrics/distance_measure.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from typing import Literal, Dict, Any
2+
23
import numpy.ma as ma # pylint: disable=consider-using-from-import
4+
35
from pose_evaluation.metrics.base import Signature
46

57
AggregationStrategy = Literal["max", "min", "mean", "sum"]
68

9+
710
class DistanceMeasureSignature(Signature):
811
"""Signature for distance measure metrics."""
12+
913
def __init__(self, name: str, args: Dict[str, Any]) -> None:
1014
super().__init__(name=name, args=args)
1115
self.update_abbr("distance", "dist")
@@ -14,6 +18,7 @@ def __init__(self, name: str, args: Dict[str, Any]) -> None:
1418

1519
class DistanceMeasure:
1620
"""Abstract base class for distance measures."""
21+
1722
_SIGNATURE_TYPE = DistanceMeasureSignature
1823

1924
def __init__(self, name: str) -> None:
@@ -22,7 +27,7 @@ def __init__(self, name: str) -> None:
2227
def get_distance(self, hyp_data: ma.MaskedArray, ref_data: ma.MaskedArray) -> float:
2328
"""
2429
Compute the distance between hypothesis and reference data.
25-
30+
2631
This method should be implemented by subclasses.
2732
"""
2833
raise NotImplementedError
@@ -37,6 +42,7 @@ def get_signature(self) -> Signature:
3742

3843
class PowerDistanceSignature(DistanceMeasureSignature):
3944
"""Signature for power distance measures."""
45+
4046
def __init__(self, name: str, args: Dict[str, Any]) -> None:
4147
super().__init__(name=name, args=args)
4248
self.update_signature_and_abbr("order", "ord", args)
@@ -46,6 +52,7 @@ def __init__(self, name: str, args: Dict[str, Any]) -> None:
4652

4753
class AggregatedPowerDistance(DistanceMeasure):
4854
"""Aggregated power distance metric using a specified aggregation strategy."""
55+
4956
_SIGNATURE_TYPE = PowerDistanceSignature
5057

5158
def __init__(
@@ -56,7 +63,7 @@ def __init__(
5663
) -> None:
5764
"""
5865
Initialize the aggregated power distance metric.
59-
66+
6067
:param order: The exponent to which differences are raised.
6168
:param default_distance: The value to fill in for masked entries.
6269
:param aggregation_strategy: Strategy to aggregate computed distances.
@@ -69,7 +76,7 @@ def __init__(
6976
def _aggregate(self, distances: ma.MaskedArray) -> float:
7077
"""
7178
Aggregate computed distances using the specified strategy.
72-
79+
7380
:param distances: A masked array of computed distances.
7481
:return: A single aggregated distance value.
7582
"""
@@ -82,23 +89,19 @@ def _aggregate(self, distances: ma.MaskedArray) -> float:
8289
if self.aggregation_strategy in aggregation_funcs:
8390
return aggregation_funcs[self.aggregation_strategy]()
8491

85-
raise NotImplementedError(
86-
f"Aggregation Strategy {self.aggregation_strategy} not implemented"
87-
)
92+
raise NotImplementedError(f"Aggregation Strategy {self.aggregation_strategy} not implemented")
8893

89-
def _calculate_distances(
90-
self, hyp_data: ma.MaskedArray, ref_data: ma.MaskedArray
91-
) -> ma.MaskedArray:
94+
def _calculate_distances(self, hyp_data: ma.MaskedArray, ref_data: ma.MaskedArray) -> ma.MaskedArray:
9295
"""
9396
Compute element-wise distances between hypothesis and reference data.
94-
97+
9598
Steps:
9699
1. Compute the absolute differences.
97100
2. Raise the differences to the specified power.
98101
3. Sum the powered differences along the last axis.
99102
4. Extract the root corresponding to the power.
100103
5. Fill masked values with the default distance.
101-
104+
102105
:param hyp_data: Hypothesis data as a masked array.
103106
:param ref_data: Reference data as a masked array.
104107
:return: A masked array of computed distances.

pose_evaluation/metrics/distance_metric.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pose_format import Pose
2+
23
from pose_evaluation.metrics.base_pose_metric import PoseMetric
34
from pose_evaluation.metrics.distance_measure import DistanceMeasure
45

pose_evaluation/metrics/embedding_distance_metric.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
from typing import Literal, List, Union
21
import logging
2+
from typing import Literal, List, Union
33

4+
import numpy as np
45
import torch
6+
from sentence_transformers import util as st_util
57
from torch import Tensor
68
from torch.types import Number
7-
import numpy as np
8-
from sentence_transformers import util as st_util
99

1010
from pose_evaluation.metrics.base_embedding_metric import EmbeddingMetric
1111

12-
1312
# Useful reference: https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/util.py#L31
1413
# * Helper functions such as batch_to_device, _convert_to_tensor, _convert_to_batch, _convert_to_batch_tensor
1514
# * a whole semantic search function, with chunking and top_k
@@ -86,11 +85,7 @@ def _to_batch_tensor_on_device(self, data: TensorConvertableType) -> Tensor:
8685

8786
return st_util._convert_to_batch_tensor(data).to(device=self.device, dtype=self.dtype)
8887

89-
def score(
90-
self,
91-
hypothesis: TensorConvertableType,
92-
reference: TensorConvertableType,
93-
) -> Number:
88+
def score(self, hypothesis: TensorConvertableType, reference: TensorConvertableType) -> Number:
9489
"""
9590
Compute the distance between two embeddings.
9691

0 commit comments

Comments
 (0)