Skip to content

Commit 28f20e4

Browse files
committed
Update pipeline.py
1 parent d1973ba commit 28f20e4

2 files changed

Lines changed: 73 additions & 13 deletions

File tree

src/semra/api.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,19 @@
1414
import networkx as nx
1515
import pandas as pd
1616
import ssslm
17+
from pydantic import BaseModel, Field
1718
from ssslm import LiteralMapping
1819
from tqdm.auto import tqdm
1920

2021
from semra.io.graph import _from_digraph_edge, to_digraph
21-
from semra.rules import EXACT_MATCH, FLIP, INVERSION_MAPPING, SubsetConfiguration
22+
from semra.rules import (
23+
DB_XREF,
24+
EXACT_MATCH,
25+
FLIP,
26+
INVERSION_MAPPING,
27+
KNOWLEDGE_MAPPING,
28+
SubsetConfiguration,
29+
)
2230
from semra.struct import (
2331
Evidence,
2432
Mapping,
@@ -35,6 +43,7 @@
3543
"IdentifierIndex",
3644
"Index",
3745
"M2MIndex",
46+
"Mutation",
3847
"PrefixIdentifierDict",
3948
"PrefixIdentifierDict",
4049
"PrefixPairCounter",
@@ -60,6 +69,7 @@
6069
"get_terms",
6170
"get_test_evidence",
6271
"get_test_reference",
72+
"handle_mutations",
6373
"hydrate_subsets",
6474
"keep_object_prefixes",
6575
"keep_prefixes",
@@ -1190,3 +1200,57 @@ def get_asymmetric_counter(
11901200
for (left_prefix, right_prefix), identifiers in index.items()
11911201
}
11921202
)
1203+
1204+
1205+
class Mutation(BaseModel):
1206+
"""Represents a mutation operation on a mapping set."""
1207+
1208+
source: str = Field(..., description="The source type")
1209+
target: str | list[str] | None = Field(None, description="limit mutation to these")
1210+
confidence: float = 1.0
1211+
old: Reference = Field(default=DB_XREF)
1212+
new: Reference = Field(default=EXACT_MATCH)
1213+
1214+
def should_apply_to(self, mapping: Mapping) -> bool:
1215+
"""Check if the mutation should be applied."""
1216+
if mapping.subject.prefix != self.source:
1217+
return False
1218+
if mapping.predicate != self.old:
1219+
return False
1220+
if self.target is None:
1221+
return True
1222+
elif isinstance(self.target, str):
1223+
return self.target == mapping.object.prefix
1224+
elif isinstance(self.target, list):
1225+
return any(t == mapping.object.prefix for t in self.target)
1226+
raise NotImplementedError
1227+
1228+
1229+
def handle_mutations(
1230+
mappings: Iterable[Mapping], mutations: Iterable[Mutation], *, progress: bool = True
1231+
) -> Iterable[Mapping]:
1232+
"""Apply mutations."""
1233+
mutation_index = {}
1234+
for mutation__ in mutations:
1235+
if mutation__.source in mutation_index:
1236+
raise KeyError(f"got multiple configured mutations for source: {mutation__.source}")
1237+
mutation_index[mutation__.source] = mutation__
1238+
for mapping in tqdm(mappings, disable=not progress):
1239+
mutation = mutation_index.get(mapping.subject.prefix)
1240+
if not mutation:
1241+
yield mapping
1242+
elif not mutation.should_apply_to(mapping):
1243+
yield mapping
1244+
else:
1245+
yield Mapping(
1246+
subject=mapping.subject,
1247+
predicate=mutation.new,
1248+
object=mapping.object,
1249+
evidence=[
1250+
ReasonedEvidence(
1251+
justification=KNOWLEDGE_MAPPING,
1252+
mappings=[mapping],
1253+
confidence_factor=mutation.confidence,
1254+
)
1255+
],
1256+
)

src/semra/pipeline.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
from typing_extensions import Self
1919

2020
from semra.api import (
21+
Mutation,
2122
assemble_evidences,
2223
filter_mappings,
2324
filter_prefixes,
2425
filter_self_matches,
2526
filter_subsets,
27+
handle_mutations,
2628
hydrate_subsets,
2729
keep_prefixes,
2830
prioritize,
@@ -40,7 +42,7 @@
4042
write_neo4j,
4143
write_sssom,
4244
)
43-
from semra.rules import DB_XREF, EXACT_MATCH, IMPRECISE, SubsetConfiguration
45+
from semra.rules import IMPRECISE, SubsetConfiguration
4446
from semra.sources import SOURCE_RESOLVER
4547
from semra.sources.biopragmatics import (
4648
from_biomappings_negative,
@@ -100,16 +102,6 @@ class Input(BaseModel):
100102
extras: dict[str, Any] = Field(default_factory=dict)
101103

102104

103-
class Mutation(BaseModel):
104-
"""Represents a mutation operation on a mapping set."""
105-
106-
source: str = Field(..., description="The source type")
107-
target: str | list[str] | None = Field(None, description="limit mutation to these")
108-
confidence: float = 1.0
109-
old: Reference = Field(default=DB_XREF)
110-
new: Reference = Field(default=EXACT_MATCH)
111-
112-
113105
class Configuration(BaseModel):
114106
"""Represents the steps taken during mapping assembly."""
115107

@@ -974,7 +966,11 @@ def process(
974966
# _log_diff(before, mappings, verb="Filtered source internal", elapsed=time.time() - start)
975967

976968
if mutations:
977-
raise NotImplementedError
969+
logger.info("Applying mutations")
970+
before = len(mappings)
971+
start = time.time()
972+
mappings = list(handle_mutations(mappings, mutations, progress=progress))
973+
_log_diff(before, mappings, verb="Applied mutations", elapsed=time.time() - start)
978974

979975
if upgrade_prefixes and len(upgrade_prefixes) > 1:
980976
logger.info("Inferring mapping upgrades")

0 commit comments

Comments
 (0)