Skip to content

Commit 5dc3a65

Browse files
committed
Improve typing and testing
1 parent 1360c51 commit 5dc3a65

2 files changed

Lines changed: 88 additions & 65 deletions

File tree

src/curies/triples/filters.py

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,6 @@
2626
logger = logging.getLogger(__name__)
2727

2828

29-
def _cleanup_prefixes(prefixes: str | Iterable[str]) -> set[str]:
30-
if isinstance(prefixes, str):
31-
return {prefixes}
32-
return set(prefixes)
33-
34-
3529
def _filter(
3630
func: TriplePredicate[TripleType], triples: Iterable[TripleType], progress: bool = False
3731
) -> Iterable[TripleType]:
@@ -46,7 +40,7 @@ def _filter(
4640

4741

4842
def keep_prefixes(
49-
triples: Iterable[TripleType], prefixes: str | Iterable[str], *, progress: bool = False
43+
triples: Iterable[TripleType], prefixes: Iterable[str], *, progress: bool = False
5044
) -> Iterable[TripleType]:
5145
"""Keep triples whose subjects' and objects' prefixes are in the given prefixes.
5246
@@ -68,8 +62,10 @@ def keep_prefixes(
6862
return _filter(_keep_prefixes_filter(prefixes), triples, progress=progress)
6963

7064

71-
def _keep_prefixes_filter(prefixes: str | Iterable[str]) -> TriplePredicate[TripleType]:
72-
prefixes = _cleanup_prefixes(prefixes)
65+
def _keep_prefixes_filter(prefixes: Iterable[str]) -> TriplePredicate[TripleType]:
66+
prefixes = set(prefixes)
67+
if len(prefixes) < 2:
68+
raise ValueError
7369

7470
def _func(triple: TripleType) -> bool:
7571
return triple.subject.prefix in prefixes and triple.object.prefix in prefixes
@@ -101,10 +97,16 @@ def keep_subject_prefixes(
10197

10298

10399
def _keep_subject_prefixes_filter(prefixes: str | Iterable[str]) -> TriplePredicate[TripleType]:
104-
prefixes = _cleanup_prefixes(prefixes)
100+
if isinstance(prefixes, str):
105101

106-
def _func(triple: TripleType) -> bool:
107-
return triple.subject.prefix in prefixes
102+
def _func(triple: TripleType) -> bool:
103+
return triple.subject.prefix == prefixes
104+
105+
else:
106+
prefixes = set(prefixes)
107+
108+
def _func(triple: TripleType) -> bool:
109+
return triple.subject.prefix in prefixes
108110

109111
return _func
110112

@@ -134,10 +136,15 @@ def keep_object_prefixes(
134136

135137

136138
def _keep_object_prefixes_filter(prefixes: str | Iterable[str]) -> TriplePredicate[TripleType]:
137-
prefixes = _cleanup_prefixes(prefixes)
139+
if isinstance(prefixes, str):
138140

139-
def _func(triple: TripleType) -> bool:
140-
return triple.object.prefix in prefixes
141+
def _func(triple: TripleType) -> bool:
142+
return triple.object.prefix == prefixes
143+
else:
144+
prefixes = set(prefixes)
145+
146+
def _func(triple: TripleType) -> bool:
147+
return triple.object.prefix in prefixes
141148

142149
return _func
143150

@@ -168,10 +175,16 @@ def exclude_prefixes(
168175

169176

170177
def _exclude_prefixes_filter(prefixes: str | Iterable[str]) -> TriplePredicate[TripleType]:
171-
prefixes = _cleanup_prefixes(prefixes)
178+
if isinstance(prefixes, str):
172179

173-
def _func(triple: TripleType) -> bool:
174-
return triple.subject.prefix not in prefixes and triple.object.prefix not in prefixes
180+
def _func(triple: TripleType) -> bool:
181+
return triple.subject.prefix != prefixes and triple.object.prefix != prefixes
182+
183+
else:
184+
prefixes = set(prefixes)
185+
186+
def _func(triple: TripleType) -> bool:
187+
return triple.subject.prefix not in prefixes and triple.object.prefix not in prefixes
175188

176189
return _func
177190

@@ -202,10 +215,16 @@ def exclude_subject_prefixes(
202215

203216

204217
def _exclude_subject_prefixes_filter(prefixes: str | Iterable[str]) -> TriplePredicate[TripleType]:
205-
prefixes = _cleanup_prefixes(prefixes)
218+
if isinstance(prefixes, str):
206219

207-
def _func(triple: TripleType) -> bool:
208-
return triple.subject.prefix not in prefixes
220+
def _func(triple: TripleType) -> bool:
221+
return triple.subject.prefix != prefixes
222+
223+
else:
224+
prefixes = set(prefixes)
225+
226+
def _func(triple: TripleType) -> bool:
227+
return triple.subject.prefix not in prefixes
209228

210229
return _func
211230

@@ -236,10 +255,16 @@ def exclude_object_prefixes(
236255

237256

238257
def _exclude_object_prefixes_filter(prefixes: str | Iterable[str]) -> TriplePredicate[TripleType]:
239-
prefixes = _cleanup_prefixes(prefixes)
258+
if isinstance(prefixes, str):
240259

241-
def _func(triple: TripleType) -> bool:
242-
return triple.object.prefix not in prefixes
260+
def _func(triple: TripleType) -> bool:
261+
return triple.object.prefix != prefixes
262+
263+
else:
264+
prefixes = set(prefixes)
265+
266+
def _func(triple: TripleType) -> bool:
267+
return triple.object.prefix not in prefixes
243268

244269
return _func
245270

tests/test_triples/test_filter.py

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
m2 = Triple.from_curies(c2, exact_match.curie, c3)
2727
m3 = Triple.from_curies(c1, exact_match.curie, c3)
2828
m4 = Triple.from_curies(c1, subclass_of.curie, c4)
29+
M123 = [m1, m2, m3]
2930
converter = Converter.from_prefix_map(
3031
{
3132
"DOID": "http://purl.obolibrary.org/obo/DOID_",
@@ -46,84 +47,81 @@ class TestFilters(unittest.TestCase):
4647
def assert_triple_lists(self, expected: list[Triple], actual: Iterable[Triple]) -> None:
4748
"""Test two triple lists are the same."""
4849
actual = list(actual)
49-
self.assertEqual(
50-
expected, list(actual), msg=f"\nExpected: {_x(expected)}\nActual: {_x(actual)}"
51-
)
50+
self.assertEqual(expected, actual, msg=f"\nExpected: {_x(expected)}\nActual: {_x(actual)}")
5251

5352
def test_exclude_object_prefixes(self) -> None:
5453
"""Test excluding object prefixes."""
55-
self.assertEqual(
56-
[m1],
57-
list(exclude_object_prefixes([m1, m2, m3], {"umls"})),
58-
)
59-
self.assertEqual([m2, m3], list(exclude_object_prefixes([m1, m2, m3], {"mesh"})))
60-
self.assertEqual([m1, m2, m3], list(exclude_object_prefixes([m1, m2, m3], {"DOID"})))
54+
self.assert_triple_lists([m1], exclude_object_prefixes(M123, {"umls"}))
55+
self.assert_triple_lists([m2, m3], exclude_object_prefixes(M123, {"mesh"}))
56+
self.assert_triple_lists(M123, exclude_object_prefixes(M123, {"DOID"}))
6157

6258
def test_exclude_prefixes(self) -> None:
6359
"""Test excluding prefixes."""
64-
self.assertEqual([m1], list(exclude_prefixes([m1, m2, m3], {"umls"})))
65-
self.assertEqual([m2], list(exclude_prefixes([m1, m2, m3], {"DOID"})))
66-
self.assertEqual([m3], list(exclude_prefixes([m1, m2, m3], {"mesh"})))
60+
self.assert_triple_lists([m1], exclude_prefixes(M123, {"umls"}))
61+
self.assert_triple_lists([m2], exclude_prefixes(M123, {"DOID"}))
62+
self.assert_triple_lists([m3], exclude_prefixes(M123, {"mesh"}))
6763

6864
def test_exclude_subject_prefixes(self) -> None:
6965
"""Test excluding subject prefixes."""
70-
self.assertEqual([m2], list(exclude_subject_prefixes([m1, m2, m3], {"DOID"})))
71-
self.assertEqual([m1, m2, m3], list(exclude_subject_prefixes([m1, m2, m3], {"umls"})))
72-
self.assertEqual([m1, m3], list(exclude_subject_prefixes([m1, m2, m3], {"mesh"})))
66+
self.assert_triple_lists([m2], exclude_subject_prefixes(M123, {"DOID"}))
67+
self.assert_triple_lists(M123, exclude_subject_prefixes(M123, {"umls"}))
68+
self.assert_triple_lists([m1, m3], exclude_subject_prefixes(M123, {"mesh"}))
7369

7470
def test_exclude_same_prefixes(self) -> None:
7571
"""Test excluding same prefixes."""
76-
self.assertEqual([m1, m2, m3], list(exclude_same_prefixes([m1, m2, m3, m4])))
72+
self.assert_triple_lists(M123, exclude_same_prefixes([m1, m2, m3, m4]))
7773

7874
def test_exclude_triples(self) -> None:
7975
"""Test excluding triples."""
80-
self.assertEqual([m1, m2], list(exclude_triples([m1, m2, m3], m3)))
81-
self.assertEqual([m1, m2], list(exclude_triples([m1, m2, m3], [m3])))
76+
self.assert_triple_lists([m1, m2], exclude_triples(M123, m3))
77+
self.assert_triple_lists([m1, m2], exclude_triples(M123, [m3]))
8278

8379
def test_keep_object_prefixes(self) -> None:
8480
"""Test keeping object prefixes."""
85-
self.assertEqual([m2, m3], list(keep_object_prefixes([m1, m2, m3], {"umls"})))
81+
self.assert_triple_lists([m2, m3], keep_object_prefixes(M123, {"umls"}))
8682

8783
def test_keep_prefixes(self) -> None:
8884
"""Test keeping prefixes."""
89-
self.assertEqual([m1], list(keep_prefixes([m1, m2, m3], {"DOID", "mesh"})))
85+
self.assert_triple_lists([], keep_prefixes(M123, {"NOPE", "also nope"}))
86+
self.assert_triple_lists([m1], keep_prefixes(M123, {"DOID", "mesh"}))
87+
self.assert_triple_lists([m1], keep_prefixes(M123, {"DOID", "umls"}))
88+
self.assert_triple_lists([m1], keep_prefixes(M123, {"mesh", "umls"}))
89+
self.assert_triple_lists(M123, keep_prefixes(M123, {"DOID", "umls", "mesh"}))
9090

9191
def test_keep_subject_prefixes(self) -> None:
9292
"""Test keeping subject prefixes."""
93-
self.assertEqual([m1, m3], list(keep_subject_prefixes([m1, m2, m3], {"DOID"})))
93+
self.assert_triple_lists([m1, m3], keep_subject_prefixes(M123, {"DOID"}))
9494

9595
def test_keep_triple_by_hash(self) -> None:
9696
"""Test keeping triples by hash."""
97-
self.assertEqual(
98-
[m1], list(keep_triples_by_hash([m1, m2, m3], converter, converter.hash_triple(m1)))
97+
self.assert_triple_lists(
98+
[m1], keep_triples_by_hash(M123, converter, converter.hash_triple(m1))
9999
)
100-
self.assertEqual(
100+
self.assert_triple_lists(
101101
[m1, m2],
102-
list(
103-
keep_triples_by_hash(
104-
[m1, m2, m3], converter, [converter.hash_triple(m2), converter.hash_triple(m1)]
105-
)
102+
keep_triples_by_hash(
103+
M123, converter, [converter.hash_triple(m2), converter.hash_triple(m1)]
106104
),
107105
)
108106

109107
def test_keep_references_either(self) -> None:
110108
"""Test keeping references."""
111-
self.assert_triple_lists([m1, m3], keep_references_either([m1, m2, m3], [r1]))
112-
self.assert_triple_lists([m1, m2], keep_references_either([m1, m2, m3], [r2]))
113-
self.assert_triple_lists([m2, m3], keep_references_either([m1, m2, m3], [r3]))
114-
self.assert_triple_lists([m1, m2, m3], keep_references_either([m1, m2, m3], [r1, r2]))
115-
self.assert_triple_lists([m1, m2, m3], keep_references_either([m1, m2, m3], [r2, r3]))
116-
self.assert_triple_lists([m1, m2, m3], keep_references_either([m1, m2, m3], [r1, r2, r3]))
109+
self.assert_triple_lists([m1, m3], keep_references_either(M123, [r1]))
110+
self.assert_triple_lists([m1, m2], keep_references_either(M123, [r2]))
111+
self.assert_triple_lists([m2, m3], keep_references_either(M123, [r3]))
112+
self.assert_triple_lists(M123, keep_references_either(M123, [r1, r2]))
113+
self.assert_triple_lists(M123, keep_references_either(M123, [r2, r3]))
114+
self.assert_triple_lists(M123, keep_references_either(M123, [r1, r2, r3]))
117115

118116
def test_keep_references_both(self) -> None:
119117
"""Test keeping references."""
120-
self.assert_triple_lists([m1], keep_references_both([m1, m2, m3], [r1, r2]))
121-
self.assert_triple_lists([m2], keep_references_both([m1, m2, m3], [r2, r3]))
122-
self.assert_triple_lists([m3], keep_references_both([m1, m2, m3], [r1, r3]))
123-
self.assert_triple_lists([m1, m2, m3], keep_references_both([m1, m2, m3], [r1, r2, r3]))
118+
self.assert_triple_lists([m1], keep_references_both(M123, [r1, r2]))
119+
self.assert_triple_lists([m2], keep_references_both(M123, [r2, r3]))
120+
self.assert_triple_lists([m3], keep_references_both(M123, [r1, r3]))
121+
self.assert_triple_lists(M123, keep_references_both(M123, [r1, r2, r3]))
124122

125123
def test_exclude_references(self) -> None:
126124
"""Test exclude references."""
127-
self.assertEqual([m2], list(exclude_references([m1, m2, m3], [r1])))
128-
self.assertEqual([m3], list(exclude_references([m1, m2, m3], [r2])))
129-
self.assertEqual([m1], list(exclude_references([m1, m2, m3], [r3])))
125+
self.assert_triple_lists([m2], exclude_references(M123, [r1]))
126+
self.assert_triple_lists([m3], exclude_references(M123, [r2]))
127+
self.assert_triple_lists([m1], exclude_references(M123, [r3]))

0 commit comments

Comments
 (0)