generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 80
Expand file tree
/
Copy pathentity_linker.py
More file actions
111 lines (93 loc) · 4.14 KB
/
entity_linker.py
File metadata and controls
111 lines (93 loc) · 4.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from abc import ABC, abstractmethod
from typing import List
class Linker(ABC):
"""
Abstract base class for Linker.
This class defines the interface for query to entity linking.
"""
def __init__(self, *args, **kwargs):
"""
Initialize the Linker instance.
"""
pass
@abstractmethod
def link(self, queries: List[str], return_dict=True, **kwargs):
"""
Process to link the given queries to graph (nodes/edges).
Args:
queries: List of input query texts to perform graph linking on
return_dict: Whether to return a dictionary of linking results or linked entities only
**kwargs: Additional keyword arguments for graph linking configuration
Returns:
If return_dict is True:
List[Dict]: A list of dictionaries containing linking results for each query.
Each dictionary has the following structure:
{
'hits': [
{
'document_id': List[str], # List of matched entity IDs
'document': List[str], # List of matched entity documents
'match_score': List[float] # List of matching scores
}
]
}
If return_dict is False:
List[str]: A list of matched nodes, i.e., documents or entities
"""
if return_dict:
return [{'hits': [{'document_id': [],
'document': [],
'match_score': []}
]
} for _ in queries]
else:
return [[] for _ in queries]
class EntityLinker(Linker):
"""
The EntityLinker instance which performs two step linking.
If entity_extractor is passed then step 1 is to use the entity extractors to extract entities.
Step 2 is to use retriever i.e entity matcher to retrieve most similar entities from the index
"""
def __init__(self, retriever=None, topk=3, **kwargs):
"""
Initialize the EntityLinker instance.
Args:
retriever: An indexing.EntityMatcher object
topk: How many items to return per extracted entity per query
**kwargs: Additional keyword arguments for graph linking configuration
"""
self.retriever = retriever
self.topk = topk
def link(self, query_extracted_entities, retriever=None, topk=None, id_selector=None, return_dict=True):
"""
Process to link the given or extracted query entities to graph entities.
Args:
query_extracted_entities: List of entity lists to perform graph linking on
retriever: A retriever object to use for entity lookup.
If None, the default retriever configured for this instance will be used.
topk: The number of items to return per extracted entity
id_selector: A list of ids to retrieve the topk from (allowlist)
return_dict: Whether to return a dictionary of linking results or linked entities only
Returns:
If return_dict is True:
List[Dict]: A list of dictionaries containing linking results for each query
If return_dict is False:
List[str]: A list of matched entities
Note:
topk is applied per entity
"""
if retriever is None and self.retriever is None:
raise ValueError("Error: Either 'retriever' or 'self.retriever' must be provided")
if retriever is None:
retriever = self.retriever
if topk is None:
topk = self.topk
if return_dict:
return retriever.retrieve(queries=query_extracted_entities, topk=topk)
else:
results = retriever.retrieve(queries=query_extracted_entities, topk=topk)
results = results["hits"]
parsed_results = []
for res in results:
parsed_results.append(res['document_id'])
return parsed_results