|
16 | 16 | "exclude_subject_prefixes", |
17 | 17 | "exclude_triples", |
18 | 18 | "keep_object_prefixes", |
| 19 | + "keep_predicates", |
19 | 20 | "keep_prefixes_both", |
20 | 21 | "keep_prefixes_either", |
21 | 22 | "keep_references_both", |
@@ -546,3 +547,47 @@ def _func(triple: TripleType) -> bool: |
546 | 547 | return triple.subject not in references and triple.object not in references |
547 | 548 |
|
548 | 549 | return _func |
| 550 | + |
| 551 | + |
| 552 | +def keep_predicates( |
| 553 | + triples: Iterable[TripleType], |
| 554 | + predicates: Reference | Collection[Reference], |
| 555 | + *, |
| 556 | + progress: bool = False, |
| 557 | +) -> Iterable[TripleType]: |
| 558 | + """Keep triples whose predicate appear in the given references. |
| 559 | +
|
| 560 | + :param triples: An iterable of triples |
| 561 | + :param predicates: A collection of references |
| 562 | + :param progress: Should a progress bar be shown? |
| 563 | +
|
| 564 | + :returns: A sub-iterable of triples whose predicate appear in the given |
| 565 | + references. |
| 566 | +
|
| 567 | + >>> from curies import Reference, Triple |
| 568 | + >>> from curies.vocabulary import exact_match, subclass_of |
| 569 | + >>> c1, c2, c3 = "DOID:0050577", "mesh:C562966", "DOID:225" |
| 570 | + >>> r1, r2, r3 = (Reference.from_curie(c) for c in (c1, c2, c3)) |
| 571 | + >>> m1 = Triple.from_curies(c1, exact_match.curie, c2) |
| 572 | + >>> m2 = Triple.from_curies(c2, exact_match.curie, c3) |
| 573 | + >>> m3 = Triple.from_curies(c1, subclass_of.curie, c3) |
| 574 | + >>> assert list(keep_predicates([m1, m2, m3], exact_match)) == [m1, m2] |
| 575 | + """ |
| 576 | + return _filter(_include_predicates(predicates), triples, progress=progress) |
| 577 | + |
| 578 | + |
| 579 | +def _include_predicates( |
| 580 | + predicates: Reference | Collection[Reference], |
| 581 | +) -> TriplePredicate[TripleType]: |
| 582 | + if isinstance(predicates, Reference): |
| 583 | + |
| 584 | + def _func(triple: TripleType) -> bool: |
| 585 | + return triple.predicate == predicates |
| 586 | + |
| 587 | + else: |
| 588 | + predicates = set(predicates) |
| 589 | + |
| 590 | + def _func(triple: TripleType) -> bool: |
| 591 | + return triple.predicate in predicates |
| 592 | + |
| 593 | + return _func |
0 commit comments