diff --git a/src/semra/api.py b/src/semra/api.py index 15f17fe0..53313f18 100644 --- a/src/semra/api.py +++ b/src/semra/api.py @@ -17,7 +17,7 @@ from ssslm import LiteralMapping from tqdm.auto import tqdm -from semra.io.graph import _from_digraph_edge, to_digraph +from semra.io.graph import to_digraph from semra.rules import EXACT_MATCH, FLIP, INVERSION_MAPPING, SubsetConfiguration from semra.struct import ( Evidence, @@ -55,7 +55,6 @@ "get_index", "get_many_to_many", "get_observed_terms", - "get_priority_reference", "get_symmetric_counter", "get_terms", "get_test_evidence", @@ -509,7 +508,11 @@ def assert_projection(mappings: list[Mapping]) -> None: def prioritize( - mappings: list[Mapping], priority: list[str], *, progress: bool = True + mappings: list[Mapping], + priority: list[str], + *, + progress: bool = True, + sort: bool = True, ) -> list[Mapping]: """Get a priority star graph. @@ -523,6 +526,7 @@ def prioritize( if there exists a mapping ``A, exact, B``, there must be a ``B, exact, A``. :param priority: A priority list of prefixes, where earlier in the list means the priority is higher. + :param sort: Sort by object then subject? :return: A list of mappings representing a "prioritization", meaning that each element only appears as subject once. This condition means that the prioritization mapping can be applied @@ -531,7 +535,13 @@ def prioritize( This algorithm works in the following way 1. Get the subset of exact matches from the input mapping list - 2. Convert the exact matches to an undirected mapping graph + 2. Convert the exact matches to an index that's like undirected mapping graph. + + .. warning:: + + This assumes that all evidences have been aggregated into a single mapping! + Make sure you run :func:`assemble_evidences` first + 3. Extract connected components. .. note:: @@ -558,45 +568,37 @@ def prioritize( >>> mappings = infer_chains(mappings) >>> prioritize(mappings, ["mesh", "doid", "umls"]) """ - original_mappings = len(mappings) - mappings = [m for m in mappings if m.predicate == EXACT_MATCH] - exact_mappings = len(mappings) - priority = _clean_priority_prefixes(priority) - - graph = to_digraph(mappings).to_undirected() rv: list[Mapping] = [] - for component in tqdm( - nx.connected_components(graph), unit="component", unit_scale=True, disable=not progress - ): - o = get_priority_reference(component, priority) - if o is None: + + aggregator = Aggregator(priority) + subject_object_mapping, original_mappings, exact_mappings = _get_index(mappings) + + for s, object_mapping in subject_object_mapping.items(): + o_key, o = aggregator.get_priority_reference(object_mapping) + if o == s: continue - for s in component: - if s == o: # don't add self-edges - continue - if not graph.has_edge(s, o): - # TODO should this work even if s-o edge not exists? - # can also do "inference" here, but also might be - # because of negative edge filtering - raise NotImplementedError( - "prioritize() should only be called on fully inferred graphs, meaning " - "that in a given component, it is a full clique (i.e., there are edges " - "in both directions between all nodes)" - ) - rv.extend(_from_digraph_edge(graph, s, o)) - - # sort such that the mappings are ordered by object by priority order - # then identifier of object, then subject prefix in alphabetical order - pos = {prefix: i for i, prefix in enumerate(priority)} - rv = sorted( - rv, - key=lambda m: ( - pos[m.object.prefix], - m.object.identifier, - m.subject.prefix, - m.subject.identifier, - ), - ) + + s_key = aggregator.get_reference_key(s) + + # when the object key is smaller than the subject key, + # we prioritized in the right direction + if s_key > o_key: + rv.append(subject_object_mapping[s][o]) + elif o in subject_object_mapping and s in subject_object_mapping[o]: + raise NotImplementedError + else: + flipped_mapping = flip(subject_object_mapping[s][o], strict=True) + rv.append(flipped_mapping) + + if sort: + rv = sorted( + rv, + key=lambda m: ( + aggregator.get_reference_key(m.object), + m.subject.prefix, + m.subject.identifier, + ), + ) end_mappings = len(rv) logger.info( @@ -605,46 +607,64 @@ def prioritize( return rv -def _clean_priority_prefixes(priority: list[str]) -> list[str]: - return [bioregistry.normalize_prefix(prefix, strict=True) for prefix in priority] - - -def get_priority_reference( - component: t.Iterable[Reference], priority: list[str] -) -> Reference | None: - """Get the priority reference from a component. - - :param component: A set of references with the pre-condition that they're all "equivalent" - :param priority: A priority list of prefixes, where earlier in the list means the priority is higher - :returns: - Returns the reference with the prefix that has the highest priority. - If multiple references have the highest priority prefix, returns the first one encountered. - If none have a priority prefix, return None. - - >>> from semra import Reference - >>> curies = ["DOID:0050577", "mesh:C562966", "umls:C4551571"] - >>> references = [Reference.from_curie(curie) for curie in curies] - >>> get_priority_reference(references, ["mesh", "umls"]).curie - 'mesh:C562966' - >>> get_priority_reference(references, ["DOID", "mesh", "umls"]).curie - 'doid:0050577' - >>> get_priority_reference(references, ["hpo", "ordo", "symp"]) +def _get_index( + mappings: Iterable[Mapping], +) -> tuple[dict[Reference, dict[Reference, Mapping]], int, int]: + original_mappings = 0 + exact_mappings = 0 - """ - prefix_to_references: defaultdict[str, list[Reference]] = defaultdict(list) - for reference in component: - prefix_to_references[reference.prefix].append(reference) - for prefix in _clean_priority_prefixes(priority): - references = prefix_to_references.get(prefix, []) - if not references: - continue - if len(references) == 1: - return references[0] - # TODO multiple - I guess let's just return the first - logger.debug("multiple references for %s", prefix) - return references[0] - # nothing found in priority, don't return at all. - return None + subject_object_mapping: defaultdict[Reference, dict[Reference, Mapping]] = defaultdict(dict) + for mapping in mappings: + original_mappings += 1 + if mapping.predicate == EXACT_MATCH: + exact_mappings += 1 + subject_object_mapping[mapping.subject][mapping.object] = mapping + + # need to rasterize, otherwise dictionary size could + # change during iteration in case we try and access + # an element that doesn't exist + return dict(subject_object_mapping), original_mappings, exact_mappings + + +ReferenceKey: TypeAlias = tuple[int, str, str] + + +class Aggregator: + """A class for aggregating nodes based on a priority list.""" + + def __init__(self, priority: Iterable[str]) -> None: + """Initialize an aggregator.""" + priority = [bioregistry.normalize_prefix(prefix, strict=True) for prefix in priority] + # sort such that the mappings are ordered by object by priority order + # then identifier of object, then subject prefix in alphabetical order + self.pos = {prefix: i for i, prefix in enumerate(priority)} + self.n = len(self.pos) + 1 + + def get_reference_key(self, node: Reference) -> ReferenceKey: + """Get a sort key for a node based on priority, prefix, then identifier.""" + # sort by both prefix priority, then also prefix to tiebrake + # when none are prioritized, then identifier within vocabulary + return self.pos.get(node.prefix, self.n), node.prefix, node.identifier + + def get_priority_reference(self, nodes: Iterable[Reference]) -> tuple[ReferenceKey, Reference]: + """Get a unique priority reference from a set of references. + + :param nodes: The collection of references to get the priority reference from + :returns: A pair of the "reference key" and the priority reference + + Example: + >>> from semra import Reference + >>> from semra.api import Aggregator + >>> curies = ["DOID:0050577", "mesh:C562966", "umls:C4551571"] + >>> references = [Reference.from_curie(curie) for curie in curies] + >>> Aggregator(["mesh", "umls"]).get_priority_reference(references)[1].curie + 'mesh:C562966' + >>> Aggregator(["DOID", "mesh", "umls"]).get_priority_reference(references)[1].curie + 'doid:0050577' + >>> Aggregator(["hpo", "ordo", "symp"]).get_priority_reference(references)[1].curie + 'doid:0050577' + """ + return min((self.get_reference_key(n), n) for n in nodes) def unindex(index: Index, *, progress: bool = True) -> list[Mapping]: diff --git a/tests/constants.py b/tests/constants.py index ff475127..8027819f 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,6 +1,11 @@ """Reusable assets for testing.""" -from semra import Reference +from __future__ import annotations + +import unittest + +from semra import Mapping, Reference +from semra.api import Index, get_index a1_curie = "CHEBI:10084" # Xylopinine a2_curie = "CHEBI:10100" # zafirlukast @@ -16,3 +21,33 @@ ) TEST_CURIES = {a1, a2, b1, b2} + + +class BaseTestCase(unittest.TestCase): + """A test case with functionality for testing mapping equivalence.""" + + def assert_same_triples( + self, + expected_mappings: Index | list[Mapping], + actual_mappings: Index | list[Mapping], + msg: str | None = None, + ) -> None: + """Assert that two sets of mappings are the same.""" + if not isinstance(expected_mappings, dict): + expected_mappings = get_index(expected_mappings, progress=False) + if not isinstance(actual_mappings, dict): + actual_mappings = get_index(actual_mappings, progress=False) + + self.assertEqual( + self._clean_index(expected_mappings), + self._clean_index(actual_mappings), + msg=msg, + ) + + @staticmethod + def _clean_index(index: Index) -> list[str]: + triples = sorted(set(index)) + return [ + f"<{triple.subject.curie}, {triple.predicate.curie}, {triple.object.curie}>" + for triple in triples + ] diff --git a/tests/test_api.py b/tests/test_api.py index eee62b23..e1b0c082 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -40,6 +40,7 @@ SimpleEvidence, line, ) +from tests import constants PREFIX_A = "go" PREFIX_B = "mondo" @@ -75,7 +76,7 @@ def _exact(s: Reference, o: Reference, evidence: list[SimpleEvidence] | None = N MS = MappingSet(name="test", confidence=0.95) -class TestOperations(unittest.TestCase): +class TestOperations(constants.BaseTestCase): """Test mapping operations.""" def test_path(self) -> None: @@ -166,14 +167,6 @@ def assert_same_triples( msg=msg, ) - @staticmethod - def _clean_index(index: Index) -> list[str]: - triples = sorted(set(index)) - return [ - f"<{triple.subject.curie}, {triple.predicate.curie}, {triple.object.curie}>" - for triple in triples - ] - def test_infer_exact_match(self) -> None: """Test inference through the transitivity of SKOS exact matches.""" r1, r2, r3, r4 = _get_references(4, different_prefixes=True) @@ -486,68 +479,91 @@ def test_prioritize_df(self) -> None: list(df["curie_prioritized"]), ) - def test_prioritize(self) -> None: - """Test prioritize.""" + +class TestPrioritize(constants.BaseTestCase): + """Test prioritization.""" + + def setUp(self) -> None: + """Set up the prioritization test case.""" a1 = Reference(prefix=PREFIX_A, identifier="0000001") b1 = Reference(prefix=PREFIX_B, identifier="0000002") c1 = Reference(prefix=PREFIX_C, identifier="0000003") ev = SimpleEvidence(confidence=0.95, mapping_set=MS) - m1 = Mapping(subject=a1, predicate=EXACT_MATCH, object=b1, evidence=[ev]) - m1_rev = Mapping(subject=b1, predicate=EXACT_MATCH, object=a1, evidence=[ev]) - m2 = Mapping(subject=b1, predicate=EXACT_MATCH, object=c1, evidence=[ev]) - m2_rev = Mapping(subject=c1, predicate=EXACT_MATCH, object=b1, evidence=[ev]) - m3 = Mapping(subject=a1, predicate=EXACT_MATCH, object=c1, evidence=[ev]) - m3_rev = Mapping(subject=c1, predicate=EXACT_MATCH, object=a1, evidence=[ev]) - - # can't address priority + self.m1 = Mapping(subject=a1, predicate=EXACT_MATCH, object=b1, evidence=[ev]) + self.m1_rev = Mapping(subject=b1, predicate=EXACT_MATCH, object=a1, evidence=[ev]) + self.m2 = Mapping(subject=b1, predicate=EXACT_MATCH, object=c1, evidence=[ev]) + self.m2_rev = Mapping(subject=c1, predicate=EXACT_MATCH, object=b1, evidence=[ev]) + self.m3 = Mapping(subject=a1, predicate=EXACT_MATCH, object=c1, evidence=[ev]) + self.m3_rev = Mapping(subject=c1, predicate=EXACT_MATCH, object=a1, evidence=[ev]) + self.full = [self.m1, self.m1_rev, self.m2, self.m2_rev, self.m3, self.m3_rev] + + def test_minimal_connected_component(self) -> None: + """Test a minimal (complete) component with 2 nodes and 2 edges.""" self.assert_same_triples( - [], - prioritize([m1, m1_rev, m2, m2_rev, m3, m3_rev], [PREFIX_D], progress=False), + [self.m1_rev], + prioritize([self.m1, self.m1_rev], [PREFIX_A], progress=False), ) - - # has unusable priority first, but then defaults self.assert_same_triples( - [m1_rev, m3_rev], - prioritize([m1, m1_rev, m2, m2_rev, m3, m3_rev], [PREFIX_D, PREFIX_A], progress=False), + [self.m1], + prioritize([self.m1, self.m1_rev], [PREFIX_B], progress=False), ) + self.assert_same_triples( + [self.m1_rev], + prioritize([self.m1, self.m1_rev], [PREFIX_C], progress=False), + ) + + def test_minimal_incomplete(self) -> None: + """Test an incomplete component with 2 nodes and only a single edge.""" + # test when the priority is on the subject + self.assert_same_triples([self.m1_rev], prioritize([self.m1], [PREFIX_A], progress=False)) + + # test when the priority is on the object + self.assert_same_triples([self.m1], prioritize([self.m1], [PREFIX_B], progress=False)) + # test irrelevant prioritization prefix, this relies on sort order of the prefix + p = self.m1 if self.m1.subject.prefix > self.m1.object.prefix else self.m1_rev + self.assert_same_triples([p], prioritize([self.m1], [PREFIX_C], progress=False)) + + def test_connected_component_size(self) -> None: + """Test a complete component with 3 nodes and 3 edges.""" self.assert_same_triples( - [m1_rev, m3_rev], - prioritize([m1, m1_rev, m2, m2_rev, m3, m3_rev], [PREFIX_A], progress=False), + [self.m1_rev, self.m3_rev], + prioritize(self.full, [PREFIX_A], progress=False), ) self.assert_same_triples( - [m1, m2_rev], - prioritize([m1, m1_rev, m2, m2_rev, m3, m3_rev], [PREFIX_B], progress=False), + [self.m1, self.m2_rev], + prioritize(self.full, [PREFIX_B], progress=False), ) self.assert_same_triples( - [m2, m3], - prioritize([m1, m1_rev, m2, m2_rev, m3, m3_rev], [PREFIX_C], progress=False), + [self.m2, self.m3], + prioritize(self.full, [PREFIX_C], progress=False), ) - # test on component with only 1 + # can't address priority self.assert_same_triples( - [m1_rev], - prioritize([m1, m1_rev], [PREFIX_A], progress=False), + [self.m1_rev, self.m3_rev], + prioritize(self.full, [PREFIX_D], progress=False), ) + + # has unusable priority first, but then defaults self.assert_same_triples( - [m1], - prioritize([m1, m1_rev], [PREFIX_B], progress=False), + [self.m1_rev, self.m3_rev], + prioritize(self.full, [PREFIX_D, PREFIX_A], progress=False), ) + + def test_incomplete(self) -> None: + """Test prioritize.""" + # FIXME note that in the following three cases, some mappings get thrown away. need to check if there's + # an inverse in these cases self.assert_same_triples( - [], - prioritize([m1, m1_rev], [PREFIX_C], progress=False), + [self.m2], prioritize([self.m1, self.m2], [PREFIX_A], progress=False) + ) + self.assert_same_triples( + [self.m1], prioritize([self.m1, self.m2], [PREFIX_B], progress=False) + ) + self.assert_same_triples( + [self.m2], prioritize([self.m1, self.m2], [PREFIX_C], progress=False) ) - - # the following three tests reflect that the prioritize() function - # is not implemented in cases when inference hasn't been fully done - with self.assertRaises(NotImplementedError): - prioritize([m1, m2], [PREFIX_A], progress=False) - with self.assertRaises(NotImplementedError): - prioritize([m1, m2], [PREFIX_C], progress=False) - - # this one is able to complete, by chance, but it's not part of - # the contract, so just left here for later - # self.assertEqual([m1, m2_rev], prioritize([m1, m2], [PREFIX_B], progress=False)) class TestUpgrades(unittest.TestCase):