Skip to content

Commit 1d9ba7a

Browse files
authored
Merge pull request #16 from cleong110/metric_signatures
Adding basic MetricSignature functionality
2 parents 3babc9e + f51799f commit 1d9ba7a

File tree

1 file changed

+63
-1
lines changed

1 file changed

+63
-1
lines changed

pose_evaluation/metrics/base.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,70 @@
11
# pylint: disable=undefined-variable
2+
from typing import Any, Callable
23
from tqdm import tqdm
34

45

6+
class Signature:
7+
"""Represents reproducibility signatures for metrics. Inspired by sacreBLEU
8+
"""
9+
def __init__(self, name:str, args: dict):
10+
11+
self._abbreviated = {
12+
"name":"n",
13+
"higher_is_better":"hb"
14+
}
15+
16+
self.signature_info = {"name": name, **args}
17+
18+
def update(self, key: str, value: Any):
19+
self.signature_info[key] = value
20+
21+
def update_signature_and_abbr(self, key:str, abbr:str, args:dict):
22+
self._abbreviated.update({
23+
key: abbr
24+
})
25+
26+
self.signature_info.update({
27+
key: args.get(key, None)
28+
})
29+
30+
def format(self, short: bool = False) -> str:
31+
pairs = []
32+
keys = list(self.signature_info.keys())
33+
for name in keys:
34+
value = self.signature_info[name]
35+
if value is not None:
36+
# Check for nested signature objects
37+
if hasattr(value, "get_signature"):
38+
# Wrap nested signatures in brackets
39+
nested_signature = value.get_signature()
40+
if isinstance(nested_signature, Signature):
41+
nested_signature = nested_signature.format(short=short)
42+
value = f"{{{nested_signature}}}"
43+
if isinstance(value, bool):
44+
# Replace True/False with yes/no
45+
value = "yes" if value else "no"
46+
if isinstance(value, Callable):
47+
value = value.__name__
48+
49+
# if the abbreviation is not defined, use the full name as a fallback.
50+
abbreviated_name = self._abbreviated.get(name, name)
51+
final_name = abbreviated_name if short else name
52+
pairs.append(f"{final_name}:{value}")
53+
54+
return "|".join(pairs)
55+
56+
def __str__(self):
57+
return self.format()
58+
59+
def __repr__(self):
60+
return self.format()
61+
562
class BaseMetric[T]:
663
"""Base class for all metrics."""
64+
# Each metric should define its Signature class' name here
65+
_SIGNATURE_TYPE = Signature
766

8-
def __init__(self, name: str, higher_is_better: bool = True):
67+
def __init__(self, name: str, higher_is_better: bool = False):
968
self.name = name
1069
self.higher_is_better = higher_is_better
1170

@@ -38,3 +97,6 @@ def score_all(self, hypotheses: list[T], references: list[T], progress_bar=True)
3897

3998
def __str__(self):
4099
return self.name
100+
101+
def get_signature(self) -> Signature:
102+
return self._SIGNATURE_TYPE(self.name, self.__dict__)

0 commit comments

Comments
 (0)