9
9
from pathlib import Path
10
10
from typing import TYPE_CHECKING , Any , AsyncGenerator , Literal , Optional
11
11
12
- from pydantic import BaseModel , ConfigDict , Field , Secret , field_validator
12
+ from pydantic import BaseModel , ConfigDict , Field , Secret , ValidationError , field_validator
13
13
14
+ from unstructured_ingest .data_types .entities import EntitiesData , Entity , EntityRelationship
14
15
from unstructured_ingest .data_types .file_data import FileData
15
16
from unstructured_ingest .error import DestinationConnectionError
16
17
from unstructured_ingest .interfaces import (
@@ -97,7 +98,6 @@ def run( # type: ignore
97
98
** kwargs : Any ,
98
99
) -> Path :
99
100
elements = get_json_data (elements_filepath )
100
-
101
101
nx_graph = self ._create_lexical_graph (
102
102
elements , self ._create_document_node (file_data = file_data )
103
103
)
@@ -109,28 +109,54 @@ def run( # type: ignore
109
109
110
110
return output_filepath
111
111
112
- def _add_entities (self , element : dict , graph : "Graph" , element_node : _Node ) -> None :
113
- entities = element .get ("metadata" , {}).get ("entities" , [])
114
- if not entities :
115
- return None
116
- if not isinstance (entities , list ):
117
- return None
118
-
112
+ def _add_entities (self , entities : list [Entity ], graph : "Graph" , element_node : _Node ) -> None :
119
113
for entity in entities :
120
- if not isinstance (entity , dict ):
121
- continue
122
- if "entity" not in entity or "type" not in entity :
123
- continue
124
114
entity_node = _Node (
125
- labels = [Label .ENTITY ], properties = {"id" : entity [ " entity" ] }, id_ = entity [ " entity" ]
115
+ labels = [Label .ENTITY ], properties = {"id" : entity . entity }, id_ = entity . entity
126
116
)
127
117
graph .add_edge (
128
118
entity_node ,
129
- _Node (labels = [Label .ENTITY ], properties = {"id" : entity [ " type" ] }, id_ = entity [ " type" ] ),
119
+ _Node (labels = [Label .ENTITY ], properties = {"id" : entity . type }, id_ = entity . type ),
130
120
relationship = Relationship .ENTITY_TYPE ,
131
121
)
132
122
graph .add_edge (element_node , entity_node , relationship = Relationship .HAS_ENTITY )
133
123
124
+ def _add_entity_relationships (
125
+ self , relationships : list [EntityRelationship ], graph : "Graph"
126
+ ) -> None :
127
+ for relationship in relationships :
128
+ from_node = _Node (
129
+ labels = [Label .ENTITY ],
130
+ properties = {"id" : relationship .from_ },
131
+ id_ = relationship .from_ ,
132
+ )
133
+ to_node = _Node (
134
+ labels = [Label .ENTITY ], properties = {"id" : relationship .to }, id_ = relationship .to
135
+ )
136
+ graph .add_edge (from_node , to_node , relationship = relationship .relationship )
137
+
138
+ def _add_entity_data (self , element : dict , graph : "Graph" , element_node : _Node ) -> None :
139
+ entities = element .get ("metadata" , {}).get ("entities" , {})
140
+ if not entities :
141
+ return None
142
+ try :
143
+ if isinstance (entities , list ):
144
+ self ._add_entities (
145
+ [Entity .model_validate (e ) for e in entities if isinstance (e , dict )],
146
+ graph ,
147
+ element_node ,
148
+ )
149
+ elif isinstance (entities , dict ):
150
+ entity_data = EntitiesData .model_validate (entities )
151
+ self ._add_entities (entity_data .items , graph , element_node )
152
+ self ._add_entity_relationships (entity_data .relationships , graph )
153
+ except ValidationError :
154
+ logger .warning (
155
+ "Failed to add entities to the graph. "
156
+ "Please check the format of the entities in the input data."
157
+ )
158
+ return None
159
+
134
160
def _create_lexical_graph (self , elements : list [dict ], document_node : _Node ) -> "Graph" :
135
161
import networkx as nx
136
162
@@ -149,7 +175,7 @@ def _create_lexical_graph(self, elements: list[dict], document_node: _Node) -> "
149
175
previous_node = element_node
150
176
graph .add_edge (element_node , document_node , relationship = Relationship .PART_OF_DOCUMENT )
151
177
152
- self ._add_entities (element , graph , element_node )
178
+ self ._add_entity_data (element , graph , element_node )
153
179
154
180
if self ._is_chunk (element ):
155
181
for origin_element in format_and_truncate_orig_elements (element , include_text = True ):
@@ -165,7 +191,7 @@ def _create_lexical_graph(self, elements: list[dict], document_node: _Node) -> "
165
191
document_node ,
166
192
relationship = Relationship .PART_OF_DOCUMENT ,
167
193
)
168
- self ._add_entities (origin_element , graph , origin_element_node )
194
+ self ._add_entity_data (origin_element , graph , origin_element_node )
169
195
170
196
return graph
171
197
@@ -208,7 +234,9 @@ def from_nx(cls, nx_graph: "MultiDiGraph") -> _GraphData:
208
234
_Edge (
209
235
source = u ,
210
236
destination = v ,
211
- relationship = Relationship (data_dict ["relationship" ]),
237
+ relationship = Relationship (data_dict ["relationship" ])
238
+ if data_dict ["relationship" ] in Relationship
239
+ else data_dict ["relationship" ],
212
240
)
213
241
for u , v , data_dict in nx_graph .edges (data = True )
214
242
]
@@ -242,7 +270,7 @@ class _Edge(BaseModel):
242
270
243
271
source : _Node
244
272
destination : _Node
245
- relationship : Relationship
273
+ relationship : Relationship | str
246
274
247
275
248
276
class Label (Enum ):
@@ -380,7 +408,7 @@ async def _merge_graph(self, graph_data: _GraphData, client: AsyncDriver) -> Non
380
408
)
381
409
logger .info (f"Finished merging { len (graph_data .nodes )} graph nodes." )
382
410
383
- edges_by_relationship : defaultdict [tuple [Relationship , Label , Label ], list [_Edge ]] = (
411
+ edges_by_relationship : defaultdict [tuple [Relationship | str , Label , Label ], list [_Edge ]] = (
384
412
defaultdict (list )
385
413
)
386
414
for edge in graph_data .edges :
@@ -463,16 +491,19 @@ def _create_nodes_query(nodes: list[_Node], label: Label) -> tuple[str, dict]:
463
491
@staticmethod
464
492
def _create_edges_query (
465
493
edges : list [_Edge ],
466
- relationship : Relationship ,
494
+ relationship : Relationship | str ,
467
495
source_label : Label ,
468
496
destination_label : Label ,
469
497
) -> tuple [str , dict ]:
470
498
logger .info (f"Preparing MERGE query for { len (edges )} { relationship } relationships." )
499
+ relationship = (
500
+ relationship .value if isinstance (relationship , Relationship ) else relationship
501
+ )
471
502
query_string = f"""
472
503
UNWIND $edges AS edge
473
504
MATCH (u: `{ source_label .value } ` {{id: edge.source}})
474
505
MATCH (v: `{ destination_label .value } ` {{id: edge.destination}})
475
- MERGE (u)-[:`{ relationship . value } `]->(v)
506
+ MERGE (u)-[:`{ relationship } `]->(v)
476
507
"""
477
508
parameters = {
478
509
"edges" : [
0 commit comments