Skip to content

Commit 51291d5

Browse files
authored
🎨 Make search consistent with the lamindb implementation (#95)
* rework search * rm rank col * refactor * fix * test * fix * fix * check rank values * fix * explain how to get ranks with lamindb for tests * correct start rule for match * edit group
1 parent c8f48ea commit 51291d5

File tree

3 files changed

+130
-133
lines changed

3 files changed

+130
-133
lines changed

‎lamin_utils/_search.py

+81-102
Original file line numberDiff line numberDiff line change
@@ -1,129 +1,108 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Literal
4-
5-
from lamin_utils import logger
3+
from typing import TYPE_CHECKING
64

75
if TYPE_CHECKING:
8-
import pandas as pd
6+
from pandas import DataFrame, Series
7+
8+
9+
def _contains(col: Series, string: str, case_sensitive: bool, fields_convert: dict):
10+
if col.name not in fields_convert:
11+
return [False] * len(col)
12+
if fields_convert[col.name]:
13+
col = col.astype(str)
14+
return col.str.contains(string, case=case_sensitive)
15+
16+
17+
def _ranks(col: Series, string: str, case_sensitive: bool, fields_convert: dict):
18+
if col.name not in fields_convert:
19+
return [0] * len(col)
20+
if fields_convert[col.name]:
21+
col = col.astype(str)
22+
exact_rank = col.str.fullmatch(string, case=case_sensitive) * 200
23+
synonym_rank = (
24+
col.str.match(rf"(?:^|.*\|){string}(?:\|.*|$)", case=case_sensitive) * 200
25+
)
26+
sub_rank = (
27+
col.str.match(
28+
rf"(?:^|.*[ \|\.,;:]){string}(?:[ \|\.,;:].*|$)", case=case_sensitive
29+
)
30+
* 10
31+
)
32+
startswith_rank = (
33+
col.str.match(rf"(?:^|.*\|){string}[^ ]*(?:\|.*|$)", case=case_sensitive) * 8
34+
)
35+
right_rank = col.str.match(rf"(?:^|.*[ \|]){string}.*", case=case_sensitive) * 2
36+
left_rank = col.str.match(rf".*{string}(?:$|[ \|\.,;:].*)", case=case_sensitive) * 2
37+
contains_rank = col.str.contains(string, case=case_sensitive).astype("int32")
38+
return (
39+
exact_rank
40+
+ synonym_rank
41+
+ sub_rank
42+
+ startswith_rank
43+
+ right_rank
44+
+ left_rank
45+
+ contains_rank
46+
)
947

1048

1149
def search(
12-
df: pd.DataFrame,
50+
df: DataFrame,
1351
string: str,
14-
field: str = "name",
52+
*,
53+
field: str | list[str] | None = None,
1554
limit: int | None = 20,
16-
synonyms_field: str | None = "synonyms",
1755
case_sensitive: bool = False,
18-
synonyms_sep: str = "|",
19-
keep: Literal["first", "last", False] = "first",
20-
) -> pd.DataFrame:
56+
_show_rank: bool = False,
57+
) -> DataFrame:
2158
"""Search a given string against a field.
2259
2360
Args:
2461
df: The DataFrame to search in.
2562
string: The input string to match against the field values.
26-
field: The name of the field to search against.
27-
limit: The maximum number of top results to return. If None, returns all results.
28-
synonyms_field: The name of the field containing synonyms.
29-
If None, no synonym matching is performed.
30-
case_sensitive: Whether the match should be case sensitive. Defaults to False.
31-
synonyms_sep: The separator used in the synonyms field.
32-
keep: Determines which duplicates to keep when grouping results.
33-
Options are "first", "last", or False (keep all).
63+
field: The field or fields to search. Search all fields containing strings by default.
64+
limit: Maximum amount of top results to return.
65+
case_sensitive: Whether the match is case sensitive.
3466
3567
Returns:
3668
A DataFrame of ranked search results.
3769
This DataFrame contains the matched rows from the input DataFrame,
38-
sorted by the match ratio in descending order.
39-
It includes all columns from the input DataFrame plus an additional '__ratio__' column indicating the match score.
70+
sorted by the match rank in descending order.
4071
4172
Raises:
42-
KeyError: If the specified field or synonyms_field is not found in the DataFrame.
43-
ValueError: If an invalid value is provided for the 'keep' parameter.
73+
KeyError: If the specified field is not found in the DataFrame.
4474
"""
4575
import pandas as pd
76+
from pandas.api.types import is_object_dtype, is_string_dtype
4677

47-
from ._map_synonyms import explode_aggregated_column_to_map
48-
49-
def _fuzz(
50-
string: str,
51-
iterable: pd.Series,
52-
case_sensitive: bool = True,
53-
limit: int | None = None,
54-
):
55-
from rapidfuzz import fuzz, process, utils
56-
57-
# use WRatio to account for typos
58-
if " " in string:
59-
scorer = fuzz.QRatio
60-
else:
61-
scorer = fuzz.WRatio
62-
63-
processor = None if case_sensitive else utils.default_process
64-
results = process.extract(
65-
string,
66-
iterable,
67-
scorer=scorer,
68-
limit=limit,
69-
processor=processor,
70-
)
71-
try:
72-
return pd.DataFrame(results).set_index(2)[1]
73-
except KeyError:
74-
# no search results
75-
return None
76-
77-
# empty DataFrame
78-
if df.shape[0] == 0:
78+
if len(df) == 0:
7979
return df
8080

81-
# search against each of the synonyms
82-
if (synonyms_field in df.columns) and (synonyms_field != field):
83-
# creates field_value:synonym
84-
mapper = explode_aggregated_column_to_map(
85-
df,
86-
agg_col=synonyms_field, # type:ignore
87-
target_col=field,
88-
keep=keep,
89-
sep=synonyms_sep,
90-
)
91-
if keep is False:
92-
mapper = mapper.explode()
93-
# adds field_value:field_value to field_value:synonym
94-
df_field = pd.Series(df[field].values, index=df[field], name=field)
95-
df_field.index.name = synonyms_field
96-
df_field = df_field[df_field.index.difference(mapper.index)]
97-
mapper = pd.concat([mapper, df_field])
98-
df_exp = mapper.reset_index()
99-
target_column = synonyms_field
81+
fields_convert = {}
82+
if field is None:
83+
fields = df.columns.to_list()
84+
for f in fields:
85+
df_f = df[f]
86+
if is_object_dtype(df_f):
87+
fields_convert[f] = True
88+
elif is_string_dtype(df_f):
89+
fields_convert[f] = False
10090
else:
101-
if synonyms_field == field:
102-
logger.warning(
103-
"Input field is the same as synonyms field, skipping synonyms matching"
104-
)
105-
df_exp = df[[field]].copy()
106-
target_column = field
107-
108-
# add matching scores as a __ratio__ column
109-
ratios = _fuzz(
110-
string=string,
111-
iterable=df_exp[target_column],
112-
case_sensitive=case_sensitive,
113-
limit=limit,
114-
)
115-
if ratios is None:
116-
return pd.DataFrame(columns=df.columns)
117-
df_exp["__ratio__"] = ratios
118-
119-
if limit is not None:
120-
df_exp = df_exp[~df_exp["__ratio__"].isna()]
121-
# only keep the max score between field and synonyms for each entry
122-
# here groupby is also used to remove duplicates of field values
123-
df_exp_grouped = df_exp.groupby(field).max("__ratio__")
124-
# subset to original field values (as synonyms were mixed in before)
125-
df_exp_grouped = df_exp_grouped[df_exp_grouped.index.isin(df[field])]
126-
df_scored = df.set_index(field).loc[df_exp_grouped.index]
127-
df_scored["__ratio__"] = df_exp_grouped["__ratio__"]
128-
129-
return df_scored.sort_values("__ratio__", ascending=False)
91+
fields = [field] if isinstance(field, str) else field
92+
for f in fields:
93+
fields_convert[f] = not is_string_dtype(df[f])
94+
95+
contains = lambda col: _contains(col, string, case_sensitive, fields_convert)
96+
df_contains = df.loc[df.apply(contains).any(axis=1)]
97+
if len(df_contains) == 0:
98+
return df_contains
99+
100+
ranks = lambda col: _ranks(col, string, case_sensitive, fields_convert)
101+
rank = df_contains.apply(ranks).sum(axis=1)
102+
103+
if _show_rank:
104+
df_contains = df_contains.copy()
105+
df_contains.loc[:, "rank"] = rank
106+
107+
df_result = df_contains.loc[rank.sort_values(ascending=False).index]
108+
return df_result if limit is None else df_result.head(limit)

‎pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ Home = "https://github.com/laminlabs/lamin-utils"
2323
[project.optional-dependencies]
2424
dev = [
2525
"pandas", # lookup
26-
"rapidfuzz", # search
2726
"pre-commit",
2827
"nox",
2928
"pytest>=6.0",

‎tests/test_search.py

+49-30
Original file line numberDiff line numberDiff line change
@@ -10,73 +10,92 @@ def df():
1010
"ontology_id": "CL:0000084",
1111
"name": "T cell",
1212
"synonyms": "T-cell|T lymphocyte|T-lymphocyte",
13+
"description": "A Type Of Lymphocyte Whose Defining Characteristic Is The Expression Of A T Cell Receptor Complex.",
1314
"children": ["CL:0000798", "CL:0002420", "CL:0002419", "CL:0000789"],
1415
},
1516
{
1617
"ontology_id": "CL:0000236",
1718
"name": "B cell",
1819
"synonyms": "B lymphocyte|B-lymphocyte|B-cell",
20+
"description": "A Lymphocyte Of B Lineage That Is Capable Of B Cell Mediated Immunity.",
1921
"children": ["CL:0009114", "CL:0001201"],
2022
},
2123
{
2224
"ontology_id": "CL:0000696",
2325
"name": "PP cell",
2426
"synonyms": "type F enteroendocrine cell",
27+
"description": "A Cell That Stores And Secretes Pancreatic Polypeptide Hormone.",
2528
"children": ["CL:0002680"],
2629
},
2730
{
2831
"ontology_id": "CL:0002072",
2932
"name": "nodal myocyte",
3033
"synonyms": "cardiac pacemaker cell|myocytus nodalis|P cell",
34+
"description": "A Specialized Cardiac Myocyte In The Sinoatrial And Atrioventricular Nodes. The Cell Is Slender And Fusiform Confined To The Nodal Center, Circumferentially Arranged Around The Nodal Artery.",
3135
"children": ["CL:1000409", "CL:1000410"],
3236
},
3337
]
3438
return pd.DataFrame.from_records(records)
3539

3640

37-
def test_search_synonyms(df):
38-
res = search(df=df, string="P cells")
39-
assert res.index[0] == "nodal myocyte"
40-
41-
# without synonyms search
42-
res = search(df=df, synonyms_field=None, string="P cells")
43-
assert res.index[0] == "PP cell"
41+
# these tests also check ranks of the searches values (res["rank"] below)
42+
# this is needed to perform cross-check with lamindb search
43+
# to recompute the ranks via lamindb
44+
# change .alias to .annotate in lamindb/_record.py def _search(...)
45+
# then run the code below in an empty instance with bionty schema
46+
# import lamindb as ln
47+
# import bionty as bt
48+
# cts = ["CL:0000084", "CL:0000236", "CL:0000696", "CL:0002072"]
49+
# ln.save([bt.CellType.from_source(ontology_id=oid) for oid in cts])
50+
# results = bt.CellType.search("P cell")
51+
# print([(result.name, result.rank) for result in results.list()])
52+
# results = bt.CellType.search("b cell")
53+
# print([(result.name, result.rank) for result in results.list()])
54+
# results = bt.CellType.search("type F enteroendocrine", field="synonyms")
55+
# print([(result.name, result.rank) for result in results.list()])
56+
57+
58+
def test_search_general(df):
59+
res = search(df=df, string="P cell", _show_rank=True)
60+
assert res.iloc[0]["name"] == "nodal myocyte"
61+
assert res.iloc[0]["rank"] == 223
62+
assert len(res) == 2
63+
assert res.iloc[1]["rank"] == 3
64+
65+
# search in name, without synonyms search
66+
res = search(df=df, string="P cell", field="name", _show_rank=True)
67+
assert res.iloc[0]["name"] == "PP cell"
68+
assert res.iloc[0]["rank"] == 3
4469

4570

4671
def test_search_limit(df):
47-
res = search(df=df, string="P cells", limit=1)
72+
res = search(df=df, string="P cell", limit=1)
4873
assert res.shape[0] == 1
4974

5075

51-
def test_search_keep(df):
52-
# TODO: better test here
53-
res = search(df=df, string="enteroendocrine", keep=False)
54-
assert res.index[0] == "PP cell"
55-
56-
5776
def test_search_return_df(df):
58-
res = search(df=df, string="P cells")
59-
assert res.shape == (4, 4)
60-
assert res.iloc[0].name == "nodal myocyte"
77+
res = search(df=df, string="P cell")
78+
assert res.shape == (2, 5)
79+
assert res.iloc[0]["name"] == "nodal myocyte"
6180

6281

63-
def test_search_return_tie_results(df):
64-
res = search(df=df, string="A cell", synonyms_field=None)
65-
assert res.iloc[0].__ratio__ == res.iloc[1].__ratio__
66-
67-
68-
def test_search_non_default_field(df):
69-
res = search(df=df, string="type F enteroendocrine", field="synonyms")
70-
assert res.index[0] == "type F enteroendocrine cell"
82+
def test_search_pass_fields(df):
83+
res = search(
84+
df=df,
85+
string="type F enteroendocrine",
86+
field=["synonyms", "children"],
87+
_show_rank=True,
88+
)
89+
assert res.iloc[0]["synonyms"] == "type F enteroendocrine cell"
90+
assert res.iloc[0]["rank"] == 15
7191

7292

7393
def test_search_case_sensitive(df):
7494
res = search(df=df, string="b cell", case_sensitive=True)
75-
assert res.iloc[0].__ratio__ < 100
76-
77-
res = search(df=df, string="b cell", case_sensitive=False)
78-
assert res.index[0] == "B cell"
79-
assert res.iloc[0].__ratio__ == 100
95+
assert len(res) == 0
96+
res = search(df=df, string="b cell", case_sensitive=False, _show_rank=True)
97+
assert res.iloc[0]["name"] == "B cell"
98+
assert res.iloc[0]["rank"] == 438
8099

81100

82101
def test_search_empty_df():

0 commit comments

Comments
 (0)