Skip to content

Commit 4c02ce8

Browse files
authored
Merge pull request #43 from alphagov/ACW-70/similarity-distance
Calculating textual/edit distance between aliases
2 parents 809f441 + d2aef21 commit 4c02ce8

5 files changed

Lines changed: 203 additions & 4 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dependencies = [
1515
"requests-aws4auth>=1.3.1",
1616
"flask[async]>=3.1.3",
1717
"uvicorn>=0.30.0",
18+
"levenshtein>=0.27.3",
1819
]
1920

2021
[dependency-groups]

src/models/graph_models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,26 @@ class Edge(BaseModel):
6262
data: EdgeData
6363

6464

65+
class SimilarAlias(BaseModel):
66+
id: str
67+
label: str
68+
similarity: int
69+
70+
71+
class OutlierAlias(BaseModel):
72+
id: str
73+
label: str
74+
similar_aliases: List[SimilarAlias] = Field(default_factory=list)
75+
76+
77+
class EntityOutlier(BaseModel):
78+
entity_id: str
79+
entity_label: str
80+
aliases: List[OutlierAlias] = Field(default_factory=list)
81+
82+
6583
class GraphOutput(BaseModel):
6684
nodes: List[Node]
6785
edges: List[Edge]
6886
relationships: List[Relationship] = Field(default_factory=list)
87+
outliers: List[EntityOutlier] = Field(default_factory=list)

src/utils/alias_distance.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from Levenshtein import distance
2+
3+
4+
def calculate_alias_distance(alias1: str, alias2: str) -> int:
5+
"""Calculate the Levenshtein distance between two aliases."""
6+
return distance(alias1, alias2)

src/visualiser_graph_generator.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
Edge,
1616
EdgeData,
1717
Entity,
18+
EntityOutlier,
1819
GraphInput,
1920
GraphOutput,
2021
Node,
2122
NodeData,
2223
Occurrence,
24+
OutlierAlias,
2325
Relationship,
26+
SimilarAlias,
2427
)
28+
from src.utils.alias_distance import calculate_alias_distance
2529

2630

2731
logger = logging.getLogger(__name__)
@@ -181,6 +185,7 @@ def build_node_structure(
181185
) -> GraphOutput:
182186
"""Constructs the final list of nodes and edges."""
183187
nodes, edges = [], []
188+
outliers = []
184189
id_to_canonical = {ent.id: ent.canonical_key for ent in entities}
185190

186191
for ent in entities:
@@ -205,7 +210,9 @@ def build_node_structure(
205210
occ.extend(occurrences)
206211

207212
# Add the deduplicated alias nodes and their edges
208-
for alias_id, node_data in alias_map.items():
213+
unique_aliases = list(alias_map.values())
214+
for node_data in unique_aliases:
215+
alias_id = node_data.id
209216
if not node_data.occurrences:
210217
node_data.occurrences = None
211218

@@ -223,6 +230,36 @@ def build_node_structure(
223230
)
224231
)
225232

233+
# Build outlier aliases structure for this entity
234+
outlier_aliases = []
235+
for curr_alias in unique_aliases:
236+
similar_aliases = []
237+
for other_alias in unique_aliases:
238+
if other_alias.id != curr_alias.id:
239+
dist = calculate_alias_distance(curr_alias.label, other_alias.label)
240+
similar_aliases.append(
241+
SimilarAlias(
242+
id=other_alias.id,
243+
label=other_alias.label,
244+
similarity=dist,
245+
)
246+
)
247+
outlier_aliases.append(
248+
OutlierAlias(
249+
id=curr_alias.id,
250+
label=curr_alias.label,
251+
similar_aliases=similar_aliases,
252+
)
253+
)
254+
255+
outliers.append(
256+
EntityOutlier(
257+
entity_id=ent_id,
258+
entity_label=human_label,
259+
aliases=outlier_aliases,
260+
)
261+
)
262+
226263
for rel in relationships or []:
227264
source = id_to_canonical.get(rel.from_, rel.from_)
228265
target = id_to_canonical.get(rel.to, rel.to)
@@ -237,7 +274,12 @@ def build_node_structure(
237274
)
238275
)
239276

240-
return GraphOutput(nodes=nodes, edges=edges, relationships=relationships or [])
277+
return GraphOutput(
278+
nodes=nodes,
279+
edges=edges,
280+
relationships=relationships or [],
281+
outliers=outliers,
282+
)
241283

242284

243285
async def generate_graph(

0 commit comments

Comments
 (0)