-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path2_map_terms.py
More file actions
476 lines (398 loc) · 20.2 KB
/
2_map_terms.py
File metadata and controls
476 lines (398 loc) · 20.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
#!/usr/bin/env python3
"""Map named entities to ontology terms using OLS4 llm_search + OpenAI + ZOOMA.
Reads CSV with columns study_id,named_entity from stdin.
Writes CSV with columns study_id,named_entity,result_type,term_id,term_label,mapping_source to stdout.
Mapping sources (in priority order):
EXACT_LEXICAL_MATCH – exact case-insensitive match on OLS label or synonym
ZOOMA – high-confidence ZOOMA annotation
LLM – OpenAI LLM classification
Uses a local SQLite cache (2_map_terms_cache.db) keyed on named_entity
to avoid redundant lookups.
"""
import argparse
import csv
import os
import sqlite3
import sys
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import requests
from openai import OpenAI
OLS4_EMBED_URL = "https://wwwdev.ebi.ac.uk/ols4/api/v2/classes/llm_search"
OLS4_LEXICAL_URL = "https://wwwdev.ebi.ac.uk/ols4/api/v2/entities"
OLS4_TERMS_URL = "https://www.ebi.ac.uk/ols4/api/terms"
ZOOMA_URL = "https://www.ebi.ac.uk/spot/zooma/v2/api/services/annotate"
EMBED_MODEL = "llama-embed-nemotron-8b_pca512"
CACHE_DB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "2_map_terms_cache.db")
def _ontology_ids_key(ontology_ids: list[str]) -> str:
"""Build a stable cache key fragment from the ordered ontology IDs."""
return ",".join(o.lower() for o in ontology_ids)
def init_cache(db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(db_path, check_same_thread=False)
conn.execute(
"""CREATE TABLE IF NOT EXISTS cache (
named_entity TEXT NOT NULL,
ontology_ids TEXT NOT NULL,
result_type TEXT NOT NULL,
term_id TEXT NOT NULL,
term_label TEXT NOT NULL,
mapping_source TEXT NOT NULL DEFAULT '',
PRIMARY KEY (named_entity, ontology_ids)
)"""
)
# Migrate: add mapping_source column if missing (old schema)
try:
conn.execute("ALTER TABLE cache ADD COLUMN mapping_source TEXT NOT NULL DEFAULT ''")
conn.commit()
except sqlite3.OperationalError:
pass # column already exists
conn.commit()
return conn
def lookup_cache(
conn: sqlite3.Connection, named_entity: str, ontology_ids: list[str]
) -> tuple[str, str, str, str] | None:
row = conn.execute(
"SELECT result_type, term_id, term_label, mapping_source FROM cache WHERE named_entity = ? AND ontology_ids = ?",
(named_entity, _ontology_ids_key(ontology_ids)),
).fetchone()
return row if row else None
def store_cache(
conn: sqlite3.Connection,
named_entity: str,
ontology_ids: list[str],
result_type: str,
term_id: str,
term_label: str,
mapping_source: str,
):
conn.execute(
"INSERT OR REPLACE INTO cache (named_entity, ontology_ids, result_type, term_id, term_label, mapping_source) VALUES (?, ?, ?, ?, ?, ?)",
(named_entity, _ontology_ids_key(ontology_ids), result_type, term_id, term_label, mapping_source),
)
conn.commit()
EXACT_SYNONYM_KEY = "http://www.geneontology.org/formats/oboInOwl#hasExactSynonym"
def _parse_ols_elements(data: dict, allowed: set[str]) -> list[dict]:
"""Extract {curie, label, synonyms, ontology} dicts from an OLS4 v2 response, filtered by allowed ontologies."""
results = []
for elem in data.get("elements", []):
label_list = elem.get("label", [])
label = label_list[0] if label_list else ""
curie = elem.get("curie", "")
ontology = elem.get("ontologyId", "").lower()
synonyms: list[str] = []
for syn in elem.get(EXACT_SYNONYM_KEY, []):
val = syn.get("value", "") if isinstance(syn, dict) else ""
if val and val.lower() != label.lower():
synonyms.append(val)
if label and curie and ontology in allowed:
results.append({"curie": curie, "label": label, "synonyms": synonyms, "ontology": ontology})
return results
def ols_search(
query: str, num_results: int, ontology_ids: list[str]
) -> list[dict]:
"""Call OLS4 embedding search and lexical search, combine and deduplicate.
Returns list of {curie, label, ontology} dicts filtered to the allowed ontologies.
"""
allowed = {o.lower() for o in ontology_ids}
# Build comma-separated ontologyId for server-side filtering
ontology_id_str = ",".join(o.lower() for o in ontology_ids)
# Embedding search (no server-side ontology filter; filtered client-side)
embed_resp = requests.get(
OLS4_EMBED_URL,
params=[("q", query), ("page", 0), ("size", num_results), ("model", EMBED_MODEL)],
timeout=30,
)
embed_resp.raise_for_status()
embed_results = _parse_ols_elements(embed_resp.json(), allowed)
# Lexical search
lexical_resp = requests.get(
OLS4_LEXICAL_URL,
params={"search": query, "page": 0, "size": num_results, "isDefiningOntology": "true", "ontologyId": ontology_id_str},
timeout=30,
)
lexical_resp.raise_for_status()
lexical_results = _parse_ols_elements(lexical_resp.json(), allowed)
# Combine and deduplicate by curie (embedding results first)
seen_curies: set[str] = set()
combined: list[dict] = []
for item in embed_results + lexical_results:
if item["curie"] not in seen_curies:
seen_curies.add(item["curie"])
combined.append(item)
return combined
def _iri_to_curie(iri: str) -> str:
"""Convert an OBO IRI like http://purl.obolibrary.org/obo/UBERON_0002107 to UBERON:0002107."""
obo_prefix = "http://purl.obolibrary.org/obo/"
if iri.startswith(obo_prefix):
local = iri[len(obo_prefix):] # e.g. UBERON_0002107
parts = local.split("_", 1)
if len(parts) == 2:
return f"{parts[0]}:{parts[1]}"
return iri
def _curie_ontology(curie: str) -> str:
"""Extract the ontology prefix from a CURIE, lowercased."""
return curie.split(":", 1)[0].lower() if ":" in curie else ""
def _resolve_iri_label(iri: str) -> str:
"""Look up a term label from OLS4 by IRI."""
try:
resp = requests.get(OLS4_TERMS_URL, params={"iri": iri}, timeout=15)
resp.raise_for_status()
embedded = resp.json().get("_embedded", {}).get("terms", [])
if embedded:
return embedded[0].get("label", "")
except Exception:
pass
return ""
def zooma_search(
query: str, ontology_ids: list[str]
) -> tuple[str, str, str] | None:
"""Query ZOOMA for a high-confidence annotation, with client-side ontology filtering.
Returns (result_type, curie, label) on a GOOD or HIGH confidence hit,
or None if nothing suitable is found.
"""
allowed = {o.lower() for o in ontology_ids}
try:
resp = requests.get(ZOOMA_URL, params={"propertyValue": query}, timeout=30)
resp.raise_for_status()
results = resp.json()
except Exception as e:
print(f" [{query}] ZOOMA error: {e}", file=sys.stderr)
return None
# Filter to HIGH / GOOD confidence, check ontology client-side
for r in results:
confidence = r.get("confidence", "")
if confidence not in ("HIGH", "GOOD"):
continue
for tag_iri in r.get("semanticTags", []):
curie = _iri_to_curie(tag_iri)
ontology = _curie_ontology(curie)
if ontology in allowed:
label = _resolve_iri_label(tag_iri)
if label:
print(f" [{query}] ZOOMA hit (confidence={confidence}): {curie} ({label})", file=sys.stderr)
return "EXACT_MATCH_FOUND", curie, label
return None
def pick_best_match(client: OpenAI, named_entity: str, candidates: list[dict]) -> tuple[str, str, dict]:
"""Use OpenAI to pick the best matching label from the candidate list.
Each candidate has 'label' and 'synonyms'. The LLM sees labels with
their exact synonyms so it can make a better-informed choice.
Returns (result_type, label, usage) where result_type is one of:
EXACT_MATCH_FOUND – a candidate means the same thing as the entity
BROADER_MATCH_FOUND – the label is a broader/parent concept
NARROWER_MATCH_FOUND – the label is a narrower/child concept
NO_MATCH_FOUND – none of the candidates are a reasonable match
"""
lines = []
for c in candidates:
entry = f"- {c['label']}"
if c.get("synonyms"):
entry += f" (synonyms: {', '.join(c['synonyms'])})"
lines.append(entry)
labels_text = "\n".join(lines)
prompt = (
f"I need to map the term \"{named_entity}\" to the most appropriate "
f"ontology label from the list below.\n\n"
f"{labels_text}\n\n"
f"Think step by step, then give your answer.\n\n"
f"Step 1: What kind of concept is \"{named_entity}\"? (e.g. organ, chemical, disease, assay, process, etc.)\n"
f"Step 2: For each candidate label, what kind of concept is it?\n"
f"Step 3: Is there a candidate that is the SAME concept as \"{named_entity}\"? "
f"(same meaning, or a synonym, or an assay that specifically measures it, "
f"or in the case of an exposure, an exposure to it) → EXACT_MATCH_FOUND\n"
f"Step 4: If no exact match, is there a candidate where the sentence "
f"\"\\\"{named_entity}\\\" is a type of [candidate]\" is literally true? "
f"The candidate must be a MORE GENERAL concept of the SAME kind. "
f"Example: \"glucose is a type of monosaccharide\" ✓ (both are chemicals). "
f"Counter-example: \"liver is a type of liver disease\" ✗ (organ vs disease, different kinds). "
f"Counter-example: \"caecum metabolomics is a type of caecum morphology phenotype\" ✗ (different disciplines). "
f"→ BROADER_MATCH_FOUND\n"
f"Step 5: If no exact or broader match, is there a candidate where "
f"\"[candidate] is a type of \\\"{named_entity}\\\"\" is literally true? → NARROWER_MATCH_FOUND\n"
f"Step 6: If none of the above, → NO_MATCH_FOUND\n\n"
f"After your reasoning, write a line containing only \"---\" and then exactly TWO lines:\n"
f"Line 1: one of EXACT_MATCH_FOUND, BROADER_MATCH_FOUND, NARROWER_MATCH_FOUND, or NO_MATCH_FOUND\n"
f"Line 2: ONLY the label text (the part before any parentheses), exactly as it appears in the list. "
f"Do NOT include synonyms or anything in parentheses. "
f"(omit this line if NO_MATCH_FOUND)"
)
response = client.chat.completions.create(
model=os.environ.get("OPENAI_MODEL", "gpt-5.2"),
messages=[{"role": "user", "content": prompt}],
temperature=0,
)
text = response.choices[0].message.content.strip()
usage = {
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
}
# Extract answer after "---" separator (chain-of-thought reasoning before it)
if "---" in text:
text = text.split("---", 1)[1].strip()
lines = [l.strip() for l in text.splitlines() if l.strip()]
result_type = lines[0] if lines else "NO_MATCH_FOUND"
valid_types = ("EXACT_MATCH_FOUND", "BROADER_MATCH_FOUND", "NARROWER_MATCH_FOUND", "NO_MATCH_FOUND")
if result_type not in valid_types:
# Fallback: treat unexpected output as exact match with the full text as label
return "EXACT_MATCH_FOUND", text, usage
chosen_label = lines[1] if len(lines) > 1 else ""
return result_type, chosen_label, usage
def _resolve_label(candidates: list[dict], chosen_label: str, result_type: str, named_entity: str) -> tuple[str, str, str]:
"""Match a chosen label back to a candidate, with fallback to first candidate."""
chosen_lower = chosen_label.lower().strip()
for c in candidates:
if c["label"].lower().strip() == chosen_lower:
return result_type, c["curie"], c["label"]
# Also check synonyms in case the LLM returned one
for syn in c.get("synonyms", []):
if syn.lower().strip() == chosen_lower:
return result_type, c["curie"], c["label"]
# If label match fails, return the first candidate as fallback
print(
f" [{named_entity}] Warning: OpenAI returned '{chosen_label}' which didn't match any candidate; using first result",
file=sys.stderr,
)
return result_type, candidates[0]["curie"], candidates[0]["label"]
def map_entity(
client: OpenAI,
named_entity: str,
num_results: int,
ontology_ids: list[str],
) -> tuple[str, str, str, str, dict]:
"""Map a single named entity to an ontology term using a priority cascade.
Pipeline order:
1. Exact lexical match on OLS label/synonym → EXACT_LEXICAL_MATCH
2. ZOOMA high-confidence annotation → ZOOMA
3. LLM classification per ontology → LLM
For the LLM stage, ontologies in *ontology_ids* are tried in order:
- EXACT_MATCH_FOUND → return immediately.
- BROADER_MATCH_FOUND → remember as fallback, keep trying lower-priority
ontologies for an exact match.
- NARROWER_MATCH_FOUND → treated as no match (discarded).
- NO_MATCH_FOUND → continue to next ontology.
Returns (result_type, term_id, term_label, mapping_source, usage).
term_id and term_label are empty strings when NO_MATCH_FOUND.
"""
all_candidates = ols_search(named_entity, num_results, ontology_ids)
total_usage = {"prompt_tokens": 0, "completion_tokens": 0}
if not all_candidates:
# Still try ZOOMA even with no OLS candidates
zooma_result = zooma_search(named_entity, ontology_ids)
if zooma_result is not None:
rt, tid, tl = zooma_result
return rt, tid, tl, "ZOOMA", total_usage
return "NO_MATCH_FOUND", "", "", "", total_usage
# Stage 1: Exact lexical match on label or synonym (ontology priority order)
entity_lower = named_entity.lower().strip()
by_ontology_tmp: dict[str, list[dict]] = {}
for c in all_candidates:
by_ontology_tmp.setdefault(c["ontology"], []).append(c)
for ont_id in ontology_ids:
for c in by_ontology_tmp.get(ont_id.lower(), []):
if c["label"].lower().strip() == entity_lower:
print(f" [{named_entity}] Exact lexical match on label: {c['label']} ({c['curie']})", file=sys.stderr)
return "EXACT_MATCH_FOUND", c["curie"], c["label"], "EXACT_LEXICAL_MATCH", total_usage
for syn in c.get("synonyms", []):
if syn.lower().strip() == entity_lower:
print(f" [{named_entity}] Exact lexical match on synonym '{syn}': {c['label']} ({c['curie']})", file=sys.stderr)
return "EXACT_MATCH_FOUND", c["curie"], c["label"], "EXACT_LEXICAL_MATCH", total_usage
# Stage 2: ZOOMA high-confidence annotation (client-side ontology filtering)
zooma_result = zooma_search(named_entity, ontology_ids)
if zooma_result is not None:
rt, tid, tl = zooma_result
return rt, tid, tl, "ZOOMA", total_usage
# Stage 3: LLM classification
# Partition candidates by ontology, preserving --ontology-ids priority
by_ontology: dict[str, list[dict]] = {}
for c in all_candidates:
by_ontology.setdefault(c["ontology"], []).append(c)
broader_fallback: tuple[str, str, str, str, dict] | None = None
for ont_id in ontology_ids:
group = by_ontology.get(ont_id.lower(), [])
if not group:
continue
labels = [c["label"] for c in group]
print(f" [{named_entity}] Trying ontology '{ont_id}' ({len(group)} candidates): {', '.join(labels)}", file=sys.stderr)
result_type, chosen_label, usage = pick_best_match(client, named_entity, group)
total_usage["prompt_tokens"] += usage["prompt_tokens"]
total_usage["completion_tokens"] += usage["completion_tokens"]
if result_type == "EXACT_MATCH_FOUND":
rt, tid, tl = _resolve_label(group, chosen_label, result_type, named_entity)
return rt, tid, tl, "LLM", total_usage
if result_type == "BROADER_MATCH_FOUND" and broader_fallback is None:
rt, tid, tl = _resolve_label(group, chosen_label, result_type, named_entity)
broader_fallback = (rt, tid, tl, "LLM", total_usage)
# NARROWER_MATCH_FOUND and NO_MATCH_FOUND → continue
if broader_fallback is not None:
return broader_fallback
return "NO_MATCH_FOUND", "", "", "", total_usage
def main():
parser = argparse.ArgumentParser(description="Map named entities to ontology terms via OLS4")
parser.add_argument("--ontology-ids", type=lambda s: s.split(","), required=True, help="Comma-separated ontology IDs in priority order (e.g. --ontology-ids efo,chebi,txpo)")
parser.add_argument("--num-results", type=int, default=20, help="Number of OLS results to retrieve (default: 20)")
parser.add_argument("--workers", type=int, default=8, help="Number of parallel workers (default: 8)")
args = parser.parse_args()
client = OpenAI()
conn = init_cache(CACHE_DB)
cache_lock = threading.Lock()
total_prompt_tokens = 0
total_completion_tokens = 0
num_requests = 0
token_lock = threading.Lock()
reader = csv.DictReader(sys.stdin)
if reader.fieldnames is None or "study_id" not in reader.fieldnames or "named_entity" not in reader.fieldnames:
print("Error: input CSV must have columns 'study_id' and 'named_entity'", file=sys.stderr)
sys.exit(1)
writer = csv.writer(sys.stdout)
writer.writerow(["study_id", "named_entity", "result_type", "term_id", "term_label", "mapping_source"])
# Collect all rows and resolve cache hits first
rows_to_process: list[tuple[int, str, str]] = [] # (index, study_id, named_entity)
results: dict[int, list] = {} # index -> output row
idx = 0
for row in reader:
study_id = row["study_id"]
named_entity = row["named_entity"]
if not named_entity or not named_entity.strip():
continue
cached = lookup_cache(conn, named_entity, args.ontology_ids)
if cached:
result_type, term_id, term_label, mapping_source = cached
print(f" Cache hit: {named_entity} -> {result_type} {term_id} [{mapping_source}]", file=sys.stderr)
results[idx] = [study_id, named_entity, result_type, term_id, term_label, mapping_source]
else:
rows_to_process.append((idx, study_id, named_entity))
idx += 1
def _process(item: tuple[int, str, str]) -> tuple[int, list]:
nonlocal total_prompt_tokens, total_completion_tokens, num_requests
i, study_id, named_entity = item
print(f" [{named_entity}] Mapping...", file=sys.stderr)
result_type, term_id, term_label, mapping_source, usage = map_entity(
client, named_entity, args.num_results, args.ontology_ids
)
print(f" [{named_entity}] Result: {result_type} {term_id} ({term_label}) [{mapping_source}]", file=sys.stderr)
with token_lock:
total_prompt_tokens += usage["prompt_tokens"]
total_completion_tokens += usage["completion_tokens"]
if usage["prompt_tokens"] > 0:
num_requests += 1
with cache_lock:
store_cache(conn, named_entity, args.ontology_ids, result_type, term_id, term_label, mapping_source)
return i, [study_id, named_entity, result_type, term_id, term_label, mapping_source]
# Process uncached entities in parallel
with ThreadPoolExecutor(max_workers=args.workers) as executor:
futures = {executor.submit(_process, item): item for item in rows_to_process}
for future in as_completed(futures):
i, result_row = future.result()
results[i] = result_row
# Write all results in original order
for i in sorted(results):
writer.writerow(results[i])
conn.close()
model = os.environ.get("OPENAI_MODEL", "gpt-5.2")
print(f"\n--- Token Usage Summary ---", file=sys.stderr)
print(f"Model: {model}", file=sys.stderr)
print(f"LLM requests: {num_requests}", file=sys.stderr)
print(f"Prompt tokens: {total_prompt_tokens}", file=sys.stderr)
print(f"Completion tokens: {total_completion_tokens}", file=sys.stderr)
print(f"Total tokens: {total_prompt_tokens + total_completion_tokens}", file=sys.stderr)
if __name__ == "__main__":
main()