diff --git a/pori_python/ipr/ipr.py b/pori_python/ipr/ipr.py index 499cb28..cd1f784 100644 --- a/pori_python/ipr/ipr.py +++ b/pori_python/ipr/ipr.py @@ -4,12 +4,17 @@ """ from __future__ import annotations -from itertools import product + +from requests.exceptions import HTTPError + +import uuid from copy import copy +from itertools import product from typing import Dict, Iterable, List, Sequence, Set, Tuple, cast -import uuid + from pori_python.graphkb import GraphKBConnection from pori_python.graphkb import statement as gkb_statement +from pori_python.graphkb import util as gkb_util from pori_python.graphkb import vocab as gkb_vocab from pori_python.types import ( Hashabledict, @@ -18,12 +23,12 @@ IprGene, IprVariant, KbMatch, - Statement, - Variant, - KbVariantMatch, KbMatchedStatement, KbMatchedStatementConditionSet, KbMatchSections, + KbVariantMatch, + Statement, + Variant, ) from .constants import GERMLINE_BASE_TERMS, VARIANT_CLASSES @@ -612,7 +617,7 @@ def get_kb_statement_matched_conditions( def get_kb_matches_sections( gkb_matches: List[KbMatch] | List[Hashabledict], - allow_partial_matches=False, + allow_partial_matches: bool = False, ) -> KbMatchSections: kb_variants = get_kb_variants(gkb_matches) kb_matched_statements = get_kb_matched_statements(gkb_matches) @@ -627,25 +632,72 @@ def get_kb_matches_sections( def get_kb_disease_matches( - graphkb_conn: GraphKBConnection, kb_disease_match: str = None, verbose: bool = True + graphkb_conn: GraphKBConnection, + kb_disease_match: str = None, + verbose: bool = True, + similarToExtended: bool = True, ) -> list[str]: + disease_matches = [] + if not kb_disease_match: kb_disease_match = 'cancer' if verbose: logger.warning(f"No disease provided; will use '{kb_disease_match}'") - if verbose: - logger.info(f"Matching disease ({kb_disease_match}) to graphkb") + if similarToExtended: + if verbose: + logger.info( + f"Matching disease ({kb_disease_match}) to graphkb using 'similarToExtended' queryType." + ) - disease_matches = { - r["@rid"] - for r in gkb_vocab.get_term_tree( - graphkb_conn, - kb_disease_match, - ontology_class="Disease", + try: + # KBDEV-1306 + # Matching disease(s) from name, then tree traversal for ancestors & descendants. + # Leverage new 'similarToExtended' queryType + base_records = gkb_util.convert_to_rid_list( + graphkb_conn.query( + gkb_vocab.query_by_name( + 'Disease', + kb_disease_match, + ) + ) + ) + if base_records: + disease_matches = list( + { + r["@rid"] + for r in graphkb_conn.query( + { + "target": base_records, + "queryType": "similarToExtended", + "matchType": "Disease", + "edges": ["AliasOf", "CrossReferenceOf", "DeprecatedBy"], + "treeEdges": ["subClassOf"], + "returnProperties": ["@rid"], + } + ) + } + ) + except HTTPError: + if verbose: + logger.info("Failed at using 'similarToExtended' queryType.") + similarToExtended = False + + if not similarToExtended: + if verbose: + logger.info(f"Matching disease ({kb_disease_match}) to graphkb using get_term_tree()") + # Previous solution w/ get_term_tree() -> 'similarTo' queryType + disease_matches = list( + { + r["@rid"] + for r in gkb_vocab.get_term_tree( + graphkb_conn, + kb_disease_match, + ontology_class="Disease", + ) + } ) - } if not disease_matches: msg = f"failed to match disease ({kb_disease_match}) to graphkb" @@ -653,4 +705,4 @@ def get_kb_disease_matches( logger.error(msg) raise ValueError(msg) - return list(disease_matches) + return disease_matches diff --git a/tests/test_ipr/test_ipr.py b/tests/test_ipr/test_ipr.py index 5c55d9d..03628e0 100644 --- a/tests/test_ipr/test_ipr.py +++ b/tests/test_ipr/test_ipr.py @@ -6,6 +6,7 @@ from pori_python.ipr.ipr import ( convert_statements_to_alterations, germline_kb_matches, + get_kb_disease_matches, get_kb_matched_statements, get_kb_statement_matched_conditions, get_kb_variants, @@ -172,27 +173,34 @@ @pytest.fixture def graphkb_conn(): - class QueryMock: - return_values = [ - # get approved evidence levels - [{"@rid": v} for v in APPROVED_EVIDENCE_RIDS] - ] - index = -1 + # Mock for the 'query' method + query_mock = Mock() + query_return_values = [[{"@rid": v} for v in APPROVED_EVIDENCE_RIDS]] + query_index = {"value": -1} # Mutable index for closure + + def query_side_effect(*args, **kwargs): + if args: + # for TestGetKbDiseaseMatches + return [{'@rid': '#123:45'}] + query_index["value"] += 1 + idx = query_index["value"] + return query_return_values[idx] if idx < len(query_return_values) else [] - def __call__(self, *args, **kwargs): - self.index += 1 - ret_val = self.return_values[self.index] if self.index < len(self.return_values) else [] - return ret_val + query_mock.side_effect = query_side_effect - class PostMock: - def __call__(self, *args, **kwargs): - # custom return tailored for multi_variant_filtering() testing - return {"result": KB_MATCHES_STATEMENTS} + # Mock for the 'post' method + post_mock = Mock(return_value={"result": KB_MATCHES_STATEMENTS}) + # 'get_source' remains a plain function def mock_get_source(source): return {"@rid": 0} - conn = Mock(query=QueryMock(), cache={}, get_source=mock_get_source, post=PostMock()) + # Create the connection mock with attributes + conn = Mock() + conn.query = query_mock + conn.post = post_mock + conn.cache = {} + conn.get_source = mock_get_source return conn @@ -233,10 +241,9 @@ def base_graphkb_statement(disease_id: str = "disease", relevance_rid: str = "ot @pytest.fixture(autouse=True) def mock_get_term_tree(monkeypatch): - def mock_func(*pos, **kwargs): - return [{"@rid": d} for d in DISEASE_RIDS] - + mock_func = Mock(return_value=[{"@rid": d} for d in DISEASE_RIDS]) monkeypatch.setattr(gkb_vocab, "get_term_tree", mock_func) + return mock_func @pytest.fixture(autouse=True) @@ -255,6 +262,24 @@ def mock_func(_, relevance_id): monkeypatch.setattr(gkb_statement, "categorize_relevance", mock_func) +class TestGetKbDiseaseMatches: + def test_get_kb_disease_matches_similarToExtended(self, graphkb_conn) -> None: + get_kb_disease_matches(graphkb_conn, 'Breast Cancer') + assert graphkb_conn.query.called + assert not gkb_vocab.get_term_tree.called + + def test_get_kb_disease_matches_get_term_tree(self, graphkb_conn) -> None: + get_kb_disease_matches(graphkb_conn, 'Breast Cancer', similarToExtended=False) + assert gkb_vocab.get_term_tree.called + assert not graphkb_conn.query.called + + def test_get_kb_disease_matches_default(self, graphkb_conn) -> None: + get_kb_disease_matches(graphkb_conn) + assert graphkb_conn.query.call_args_list[0].args == ( + {'target': 'Disease', 'filters': {'name': 'cancer'}}, + ) + + class TestConvertStatementsToAlterations: def test_disease_match(self, graphkb_conn) -> None: statement = base_graphkb_statement(DISEASE_RIDS[0])