diff --git a/src/curies/triples/__init__.py b/src/curies/triples/__init__.py index e440777..8bbac3c 100644 --- a/src/curies/triples/__init__.py +++ b/src/curies/triples/__init__.py @@ -139,6 +139,7 @@ exclude_subject_prefixes, exclude_triples, keep_object_prefixes, + keep_predicates, keep_prefixes_both, keep_prefixes_either, keep_references_both, @@ -164,6 +165,7 @@ "exclude_triples", "hash_triple", "keep_object_prefixes", + "keep_predicates", "keep_prefixes_both", "keep_prefixes_either", "keep_references_both", diff --git a/src/curies/triples/filters.py b/src/curies/triples/filters.py index 0e2e6f3..f65b301 100644 --- a/src/curies/triples/filters.py +++ b/src/curies/triples/filters.py @@ -16,6 +16,7 @@ "exclude_subject_prefixes", "exclude_triples", "keep_object_prefixes", + "keep_predicates", "keep_prefixes_both", "keep_prefixes_either", "keep_references_both", @@ -546,3 +547,47 @@ def _func(triple: TripleType) -> bool: return triple.subject not in references and triple.object not in references return _func + + +def keep_predicates( + triples: Iterable[TripleType], + predicates: Reference | Collection[Reference], + *, + progress: bool = False, +) -> Iterable[TripleType]: + """Keep triples whose predicate appear in the given references. + + :param triples: An iterable of triples + :param predicates: A collection of references + :param progress: Should a progress bar be shown? + + :returns: A sub-iterable of triples whose predicate appear in the given + references. + + >>> from curies import Reference, Triple + >>> from curies.vocabulary import exact_match, subclass_of + >>> c1, c2, c3 = "DOID:0050577", "mesh:C562966", "DOID:225" + >>> r1, r2, r3 = (Reference.from_curie(c) for c in (c1, c2, c3)) + >>> m1 = Triple.from_curies(c1, exact_match.curie, c2) + >>> m2 = Triple.from_curies(c2, exact_match.curie, c3) + >>> m3 = Triple.from_curies(c1, subclass_of.curie, c3) + >>> assert list(keep_predicates([m1, m2, m3], exact_match)) == [m1, m2] + """ + return _filter(_include_predicates(predicates), triples, progress=progress) + + +def _include_predicates( + predicates: Reference | Collection[Reference], +) -> TriplePredicate[TripleType]: + if isinstance(predicates, Reference): + + def _func(triple: TripleType) -> bool: + return triple.predicate == predicates + + else: + predicates = set(predicates) + + def _func(triple: TripleType) -> bool: + return triple.predicate in predicates + + return _func diff --git a/tests/test_triples/test_filter.py b/tests/test_triples/test_filter.py index 900db6a..4b440b5 100644 --- a/tests/test_triples/test_filter.py +++ b/tests/test_triples/test_filter.py @@ -12,6 +12,7 @@ exclude_subject_prefixes, exclude_triples, keep_object_prefixes, + keep_predicates, keep_prefixes_both, keep_prefixes_either, keep_references_both, @@ -149,3 +150,11 @@ def test_exclude_references(self) -> None: self.assert_triple_lists([m2], exclude_references_both(M123, [r1])) self.assert_triple_lists([m3], exclude_references_both(M123, [r2])) self.assert_triple_lists([m1], exclude_references_both(M123, [r3])) + + def test_keep_predicate(self) -> None: + """Test keep predicate.""" + self.assert_triple_lists([m1, m2, m3], keep_predicates([m1, m2, m3, m4], exact_match)) + self.assert_triple_lists([m1, m2, m3], keep_predicates([m1, m2, m3, m4], [exact_match])) + self.assert_triple_lists( + [m1, m2, m3, m4], keep_predicates([m1, m2, m3, m4], [exact_match, subclass_of]) + )