-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgraph_models.py
More file actions
175 lines (130 loc) · 4.97 KB
/
graph_models.py
File metadata and controls
175 lines (130 loc) · 4.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import re
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field
class Alias(BaseModel):
name: str
source_files: List[str] = Field(default_factory=list)
class Entity(BaseModel):
id: str
canonical_key: str
label: Optional[str] = None
aliases: List[Alias] = Field(default_factory=list)
properties: Dict[str, Any] = Field(default_factory=dict)
type: Optional[str] = None
description: Optional[str] = None
model_config = ConfigDict(extra="allow")
def model_post_init(self, __context: Any) -> None:
if self.label:
new_alias_name = " ".join(
word.lower() for word in re.split(r"(?=[A-Z])", self.label) if word
)
existing_alias = next((a for a in self.aliases if a.name == new_alias_name), None)
if not existing_alias:
source_urls_val = self.properties.get("sourceUrls", "")
if isinstance(source_urls_val, list):
new_urls = [
url.strip()
for url in source_urls_val
if isinstance(url, str) and url.strip()
]
elif isinstance(source_urls_val, str):
new_urls = (
[url.strip() for url in source_urls_val.split(",") if url.strip()]
if source_urls_val
else []
)
else:
new_urls = []
new_alias = Alias(name=new_alias_name, source_files=new_urls)
self.aliases.append(new_alias)
class Relationship(BaseModel):
type: str
from_: str
to: str
model_config = ConfigDict(extra="allow")
class GraphInput(BaseModel):
entities: List[Entity]
relationships: List[Relationship] = Field(default_factory=list)
model_config = ConfigDict(extra="allow")
class Occurrence(BaseModel):
link: str
context: str
class NodeData(BaseModel):
id: str
label: str
type: Literal["entity", "alias"]
occurrences: Optional[List[Occurrence]] = None
class Node(BaseModel):
data: NodeData
class EdgeData(BaseModel):
source: str
target: str
label: str
edge_type: Optional[Literal["alias", "relationship"]] = None
class Edge(BaseModel):
data: EdgeData
class SimilarAlias(BaseModel):
id: str
label: str
similarity: int
class OutlierAlias(BaseModel):
id: str
label: str
occurrence_count: int = 0
similar_aliases: List[SimilarAlias] = Field(default_factory=list)
class AliasImbalanceStats(BaseModel):
alias_id: str
alias_label: str
occurrence_count: int = 0
z_score: Optional[float] = None
class EntityOutlier(BaseModel):
entity_id: str
entity_label: str
occurrence_std_dev: float = 0.0
aliases: List[OutlierAlias] = Field(default_factory=list)
alias_imbalance: List[AliasImbalanceStats] = Field(default_factory=list)
def model_post_init(self, __context: Any) -> None:
"""Compute z-score of alias_imbalance for each alias
and add it to the alias_imbalance in an entity
"""
import statistics
self.alias_imbalance = [stat for stat in self.alias_imbalance if stat.occurrence_count > 0]
self.aliases = [alias for alias in self.aliases if alias.occurrence_count > 0]
counts = [stat.occurrence_count for stat in self.alias_imbalance]
if len(counts) > 1:
mean = statistics.mean(counts)
std_dev = statistics.pstdev(counts)
self.occurrence_std_dev = round(std_dev, 2)
for stat in self.alias_imbalance:
if std_dev > 0:
stat.z_score = round((stat.occurrence_count - mean) / std_dev, 2)
else:
stat.z_score = 0.0
elif len(counts) == 1:
self.occurrence_std_dev = 0.0
self.alias_imbalance[0].z_score = 0.0
class GraphOutput(BaseModel):
nodes: List[Node]
edges: List[Edge]
relationships: List[Relationship] = Field(default_factory=list)
outliers: List[EntityOutlier] = Field(default_factory=list)
def model_post_init(self, __context: Any) -> None:
if not self.outliers:
return
valid_entity_ids = set()
valid_node_ids = set()
for outlier in self.outliers:
if len(outlier.alias_imbalance) > 0:
valid_entity_ids.add(outlier.entity_id)
valid_node_ids.add(outlier.entity_id)
for alias in outlier.aliases:
valid_node_ids.add(alias.id)
self.outliers = [
outlier for outlier in self.outliers if outlier.entity_id in valid_entity_ids
]
self.nodes = [node for node in self.nodes if node.data.id in valid_node_ids]
self.edges = [
edge
for edge in self.edges
if edge.data.source in valid_node_ids and edge.data.target in valid_node_ids
]