Skip to content

Commit 37bdb51

Browse files
committed
chore(map): add/correct type annotations
1 parent d8d672b commit 37bdb51

File tree

3 files changed

+11
-13
lines changed

3 files changed

+11
-13
lines changed

src/copairs/map/average_precision.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Functions to compute average precision."""
22

3-
import itertools
43
import logging
54
from typing import List
65

src/copairs/map/map.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
from pathlib import Path
5-
from typing import Optional, Union
5+
from typing import Optional, Union, List
66

77
import numpy as np
88
import pandas as pd
@@ -16,7 +16,7 @@
1616

1717
def mean_average_precision(
1818
ap_scores: pd.DataFrame,
19-
sameby,
19+
sameby: List[str],
2020
null_size: int,
2121
threshold: float,
2222
seed: int,

src/copairs/map/multilabel.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Functions to compute mAP with multilabel support."""
22

3-
import itertools
43
import logging
4+
from typing import List
55

66
import numpy as np
77
import pandas as pd
@@ -68,12 +68,12 @@ def _build_rank_lists_multi(pos_pairs, pos_sims, pos_counts, negs_for):
6868

6969

7070
def average_precision(
71-
meta,
72-
feats,
73-
pos_sameby,
74-
pos_diffby,
75-
neg_sameby,
76-
neg_diffby,
71+
meta: pd.DataFrame,
72+
feats: pd.DataFrame,
73+
pos_sameby: List[str],
74+
pos_diffby: List[str],
75+
neg_sameby: List[str],
76+
neg_diffby: List[str],
7777
multilabel_col,
7878
batch_size=20000,
7979
distance="cosine",
@@ -98,7 +98,6 @@ def average_precision(
9898
pos_pairs, keys, pos_counts = find_pairs_multilabel(
9999
meta, sameby=pos_sameby, diffby=pos_diffby, multilabel_col=multilabel_col
100100
)
101-
total_counts = sum(pos_counts)
102101
if len(pos_pairs) == 0:
103102
raise UnpairedException("Unable to find positive pairs.")
104103

@@ -142,7 +141,7 @@ def average_precision(
142141
results = pd.concat(results).reset_index(drop=True)
143142
meta = meta.drop(multilabel_col, axis=1)
144143
results = meta.merge(results, right_on="ix", left_index=True).drop("ix", axis=1)
145-
results["n_pos_pairs"] = results["n_pos_pairs"].fillna(0).astype(np.int32)
146-
results["n_total_pairs"] = results["n_total_pairs"].fillna(0).astype(np.int32)
144+
results["n_pos_pairs"] = results["n_pos_pairs"].fillna(0).astype(np.uint32)
145+
results["n_total_pairs"] = results["n_total_pairs"].fillna(0).astype(np.uint32)
147146
logger.info("Finished.")
148147
return results

0 commit comments

Comments
 (0)