Skip to content

Commit 96c9b23

Browse files
authored
Add keep_predicates filter (#222)
1 parent 03add59 commit 96c9b23

3 files changed

Lines changed: 56 additions & 0 deletions

File tree

src/curies/triples/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
exclude_subject_prefixes,
140140
exclude_triples,
141141
keep_object_prefixes,
142+
keep_predicates,
142143
keep_prefixes_both,
143144
keep_prefixes_either,
144145
keep_references_both,
@@ -164,6 +165,7 @@
164165
"exclude_triples",
165166
"hash_triple",
166167
"keep_object_prefixes",
168+
"keep_predicates",
167169
"keep_prefixes_both",
168170
"keep_prefixes_either",
169171
"keep_references_both",

src/curies/triples/filters.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"exclude_subject_prefixes",
1717
"exclude_triples",
1818
"keep_object_prefixes",
19+
"keep_predicates",
1920
"keep_prefixes_both",
2021
"keep_prefixes_either",
2122
"keep_references_both",
@@ -546,3 +547,47 @@ def _func(triple: TripleType) -> bool:
546547
return triple.subject not in references and triple.object not in references
547548

548549
return _func
550+
551+
552+
def keep_predicates(
553+
triples: Iterable[TripleType],
554+
predicates: Reference | Collection[Reference],
555+
*,
556+
progress: bool = False,
557+
) -> Iterable[TripleType]:
558+
"""Keep triples whose predicate appear in the given references.
559+
560+
:param triples: An iterable of triples
561+
:param predicates: A collection of references
562+
:param progress: Should a progress bar be shown?
563+
564+
:returns: A sub-iterable of triples whose predicate appear in the given
565+
references.
566+
567+
>>> from curies import Reference, Triple
568+
>>> from curies.vocabulary import exact_match, subclass_of
569+
>>> c1, c2, c3 = "DOID:0050577", "mesh:C562966", "DOID:225"
570+
>>> r1, r2, r3 = (Reference.from_curie(c) for c in (c1, c2, c3))
571+
>>> m1 = Triple.from_curies(c1, exact_match.curie, c2)
572+
>>> m2 = Triple.from_curies(c2, exact_match.curie, c3)
573+
>>> m3 = Triple.from_curies(c1, subclass_of.curie, c3)
574+
>>> assert list(keep_predicates([m1, m2, m3], exact_match)) == [m1, m2]
575+
"""
576+
return _filter(_include_predicates(predicates), triples, progress=progress)
577+
578+
579+
def _include_predicates(
580+
predicates: Reference | Collection[Reference],
581+
) -> TriplePredicate[TripleType]:
582+
if isinstance(predicates, Reference):
583+
584+
def _func(triple: TripleType) -> bool:
585+
return triple.predicate == predicates
586+
587+
else:
588+
predicates = set(predicates)
589+
590+
def _func(triple: TripleType) -> bool:
591+
return triple.predicate in predicates
592+
593+
return _func

tests/test_triples/test_filter.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
exclude_subject_prefixes,
1313
exclude_triples,
1414
keep_object_prefixes,
15+
keep_predicates,
1516
keep_prefixes_both,
1617
keep_prefixes_either,
1718
keep_references_both,
@@ -149,3 +150,11 @@ def test_exclude_references(self) -> None:
149150
self.assert_triple_lists([m2], exclude_references_both(M123, [r1]))
150151
self.assert_triple_lists([m3], exclude_references_both(M123, [r2]))
151152
self.assert_triple_lists([m1], exclude_references_both(M123, [r3]))
153+
154+
def test_keep_predicate(self) -> None:
155+
"""Test keep predicate."""
156+
self.assert_triple_lists([m1, m2, m3], keep_predicates([m1, m2, m3, m4], exact_match))
157+
self.assert_triple_lists([m1, m2, m3], keep_predicates([m1, m2, m3, m4], [exact_match]))
158+
self.assert_triple_lists(
159+
[m1, m2, m3, m4], keep_predicates([m1, m2, m3, m4], [exact_match, subclass_of])
160+
)

0 commit comments

Comments
 (0)