diff --git a/src/semra/api.py b/src/semra/api.py index 15f17fe0..97fbe26a 100644 --- a/src/semra/api.py +++ b/src/semra/api.py @@ -17,7 +17,7 @@ from ssslm import LiteralMapping from tqdm.auto import tqdm -from semra.io.graph import _from_digraph_edge, to_digraph +from semra.io.graph import to_digraph, to_simple_graph from semra.rules import EXACT_MATCH, FLIP, INVERSION_MAPPING, SubsetConfiguration from semra.struct import ( Evidence, @@ -559,19 +559,31 @@ def prioritize( >>> prioritize(mappings, ["mesh", "doid", "umls"]) """ original_mappings = len(mappings) - mappings = [m for m in mappings if m.predicate == EXACT_MATCH] + mappings_by_subj_obj: dict[tuple[str, str], Mapping] = { + (mapping.subject.curie, mapping.object.curie): mapping + for mapping in mappings + if mapping.predicate == EXACT_MATCH + } + # Gather all the references by CURIE + curie_to_reference: dict[str, Reference] = { + reference.curie: reference + for mapping in mappings + for reference in (mapping.subject, mapping.object) + } + exact_mappings = len(mappings) priority = _clean_priority_prefixes(priority) - graph = to_digraph(mappings).to_undirected() + graph = to_simple_graph(mappings) rv: list[Mapping] = [] for component in tqdm( nx.connected_components(graph), unit="component", unit_scale=True, disable=not progress ): - o = get_priority_reference(component, priority) + component_references = [curie_to_reference[curie] for curie in component] + o = get_priority_reference(component_references, priority) if o is None: continue - for s in component: + for s in component_references: if s == o: # don't add self-edges continue if not graph.has_edge(s, o): @@ -583,7 +595,7 @@ def prioritize( "that in a given component, it is a full clique (i.e., there are edges " "in both directions between all nodes)" ) - rv.extend(_from_digraph_edge(graph, s, o)) + rv.append(mappings_by_subj_obj[s.curie, o.curie]) # sort such that the mappings are ordered by object by priority order # then identifier of object, then subject prefix in alphabetical order diff --git a/src/semra/io/graph.py b/src/semra/io/graph.py index 0362448a..53d8c13f 100644 --- a/src/semra/io/graph.py +++ b/src/semra/io/graph.py @@ -17,6 +17,7 @@ "from_multidigraph", "to_digraph", "to_multidigraph", + "to_simple_graph", ] #: The key inside the data dictionary for a SeMRA mapping graph @@ -56,6 +57,21 @@ def to_digraph(mappings: t.Iterable[Mapping]) -> nx.DiGraph: return graph +def to_simple_graph(mappings: t.Iterable[Mapping]) -> nx.Graph: + """Return an undirected graph capturing only the structure of mappings. + + :param mappings: An iterable of mappings + + :returns: An undirected graph in which the nodes are simple string CURIEs + corresponding to References. The edges are undirected and represent + the relationships between subject and object CURIEs in mappings. + """ + graph = nx.Graph() + edges = {(mapping.subject.curie, mapping.object.curie) for mapping in mappings} + graph.add_edges_from(edges) + return graph + + def from_digraph(graph: nx.DiGraph) -> list[Mapping]: """Extract mappings from a simple directed graph data model.""" return [mapping for s, o in graph.edges() for mapping in _from_digraph_edge(graph, s, o)]