Skip to content

Commit 5c5fcc9

Browse files
authored
neo4j NER+RE support (#464)
1 parent 22a1931 commit 5c5fcc9

File tree

7 files changed

+484
-295
lines changed

7 files changed

+484
-295
lines changed

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## 1.0.8
2+
3+
### Enhancements
4+
5+
* **Update Neo4J Entity Support** to support NER + RE(Relationship extraction)
6+
17
## 1.0.7
28

39
### Fixes

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ test = [
143143
"vertexai",
144144
"pyiceberg",
145145
"pyarrow",
146+
"networkx"
146147
]
147148
# Add constraints needed for CI
148149
ci = [

test/unit/connectors/test_neo4j.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import networkx as nx
2+
import pytest
3+
4+
from test.integration.connectors.utils.constants import DESTINATION_TAG, GRAPH_DB_TAG
5+
from unstructured_ingest.processes.connectors import neo4j
6+
from unstructured_ingest.processes.connectors.neo4j import (
7+
CONNECTOR_TYPE,
8+
Label,
9+
Neo4jUploadStager,
10+
Relationship,
11+
)
12+
13+
14+
@pytest.mark.tags(DESTINATION_TAG, CONNECTOR_TYPE, GRAPH_DB_TAG)
15+
def test_neo4j_stager_with_entities_no_re():
16+
stager = Neo4jUploadStager()
17+
18+
graph = nx.MultiDiGraph()
19+
doc_node = neo4j._Node(id_="root", properties={"id": "root"}, labels=[Label.DOCUMENT])
20+
graph.add_node(doc_node)
21+
22+
element_node = neo4j._Node(
23+
id_="element_id",
24+
properties={"id": "element_id", "text": "This is a test"},
25+
labels=[Label.UNSTRUCTURED_ELEMENT],
26+
)
27+
graph.add_edge(element_node, doc_node, relationship=Relationship.PART_OF_DOCUMENT)
28+
29+
stager._add_entity_data(
30+
{"metadata": {"entities": [{"type": "PERSON", "entity": "Steve Jobs"}]}},
31+
graph,
32+
element_node,
33+
)
34+
assert len(graph.nodes) == 4
35+
assert len(graph.edges) == 3
36+
37+
38+
@pytest.mark.tags(DESTINATION_TAG, CONNECTOR_TYPE, GRAPH_DB_TAG)
39+
def test_neo4j_stager_with_entities():
40+
stager = Neo4jUploadStager()
41+
42+
graph = nx.MultiDiGraph()
43+
doc_node = neo4j._Node(id_="root", properties={"id": "root"}, labels=[Label.DOCUMENT])
44+
graph.add_node(doc_node)
45+
46+
element_node = neo4j._Node(
47+
id_="element_id",
48+
properties={"id": "element_id", "text": "This is a test"},
49+
labels=[Label.UNSTRUCTURED_ELEMENT],
50+
)
51+
graph.add_edge(element_node, doc_node, relationship=Relationship.PART_OF_DOCUMENT)
52+
53+
stager._add_entity_data(
54+
{
55+
"metadata": {
56+
"entities": {
57+
"items": [
58+
{"type": "PERSON", "entity": "Steve Jobs"},
59+
{"type": "COMPANY", "entity": "Apple"},
60+
],
61+
"relationships": [
62+
{"from": "Steve Jobs", "to": "Apple", "relationship": "founded"},
63+
{"from": "Steve Jobs", "to": "Apple", "relationship": "worked_for"},
64+
],
65+
}
66+
}
67+
},
68+
graph,
69+
element_node,
70+
)
71+
assert len(graph.nodes) == 6
72+
assert len(graph.edges) == 7

unstructured_ingest/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.7" # pragma: no cover
1+
__version__ = "1.0.8" # pragma: no cover
+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from pydantic import BaseModel, Field
2+
3+
4+
class Entity(BaseModel):
5+
type: str
6+
entity: str
7+
8+
9+
class EntityRelationship(BaseModel):
10+
to: str
11+
from_: str = Field(..., alias="from")
12+
relationship: str
13+
14+
15+
class EntitiesData(BaseModel):
16+
items: list[Entity] = Field(default_factory=list)
17+
relationships: list[EntityRelationship] = Field(default_factory=list)

unstructured_ingest/processes/connectors/neo4j.py

+53-22
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from pathlib import Path
1010
from typing import TYPE_CHECKING, Any, AsyncGenerator, Literal, Optional
1111

12-
from pydantic import BaseModel, ConfigDict, Field, Secret, field_validator
12+
from pydantic import BaseModel, ConfigDict, Field, Secret, ValidationError, field_validator
1313

14+
from unstructured_ingest.data_types.entities import EntitiesData, Entity, EntityRelationship
1415
from unstructured_ingest.data_types.file_data import FileData
1516
from unstructured_ingest.error import DestinationConnectionError
1617
from unstructured_ingest.interfaces import (
@@ -97,7 +98,6 @@ def run( # type: ignore
9798
**kwargs: Any,
9899
) -> Path:
99100
elements = get_json_data(elements_filepath)
100-
101101
nx_graph = self._create_lexical_graph(
102102
elements, self._create_document_node(file_data=file_data)
103103
)
@@ -109,28 +109,54 @@ def run( # type: ignore
109109

110110
return output_filepath
111111

112-
def _add_entities(self, element: dict, graph: "Graph", element_node: _Node) -> None:
113-
entities = element.get("metadata", {}).get("entities", [])
114-
if not entities:
115-
return None
116-
if not isinstance(entities, list):
117-
return None
118-
112+
def _add_entities(self, entities: list[Entity], graph: "Graph", element_node: _Node) -> None:
119113
for entity in entities:
120-
if not isinstance(entity, dict):
121-
continue
122-
if "entity" not in entity or "type" not in entity:
123-
continue
124114
entity_node = _Node(
125-
labels=[Label.ENTITY], properties={"id": entity["entity"]}, id_=entity["entity"]
115+
labels=[Label.ENTITY], properties={"id": entity.entity}, id_=entity.entity
126116
)
127117
graph.add_edge(
128118
entity_node,
129-
_Node(labels=[Label.ENTITY], properties={"id": entity["type"]}, id_=entity["type"]),
119+
_Node(labels=[Label.ENTITY], properties={"id": entity.type}, id_=entity.type),
130120
relationship=Relationship.ENTITY_TYPE,
131121
)
132122
graph.add_edge(element_node, entity_node, relationship=Relationship.HAS_ENTITY)
133123

124+
def _add_entity_relationships(
125+
self, relationships: list[EntityRelationship], graph: "Graph"
126+
) -> None:
127+
for relationship in relationships:
128+
from_node = _Node(
129+
labels=[Label.ENTITY],
130+
properties={"id": relationship.from_},
131+
id_=relationship.from_,
132+
)
133+
to_node = _Node(
134+
labels=[Label.ENTITY], properties={"id": relationship.to}, id_=relationship.to
135+
)
136+
graph.add_edge(from_node, to_node, relationship=relationship.relationship)
137+
138+
def _add_entity_data(self, element: dict, graph: "Graph", element_node: _Node) -> None:
139+
entities = element.get("metadata", {}).get("entities", {})
140+
if not entities:
141+
return None
142+
try:
143+
if isinstance(entities, list):
144+
self._add_entities(
145+
[Entity.model_validate(e) for e in entities if isinstance(e, dict)],
146+
graph,
147+
element_node,
148+
)
149+
elif isinstance(entities, dict):
150+
entity_data = EntitiesData.model_validate(entities)
151+
self._add_entities(entity_data.items, graph, element_node)
152+
self._add_entity_relationships(entity_data.relationships, graph)
153+
except ValidationError:
154+
logger.warning(
155+
"Failed to add entities to the graph. "
156+
"Please check the format of the entities in the input data."
157+
)
158+
return None
159+
134160
def _create_lexical_graph(self, elements: list[dict], document_node: _Node) -> "Graph":
135161
import networkx as nx
136162

@@ -149,7 +175,7 @@ def _create_lexical_graph(self, elements: list[dict], document_node: _Node) -> "
149175
previous_node = element_node
150176
graph.add_edge(element_node, document_node, relationship=Relationship.PART_OF_DOCUMENT)
151177

152-
self._add_entities(element, graph, element_node)
178+
self._add_entity_data(element, graph, element_node)
153179

154180
if self._is_chunk(element):
155181
for origin_element in format_and_truncate_orig_elements(element, include_text=True):
@@ -165,7 +191,7 @@ def _create_lexical_graph(self, elements: list[dict], document_node: _Node) -> "
165191
document_node,
166192
relationship=Relationship.PART_OF_DOCUMENT,
167193
)
168-
self._add_entities(origin_element, graph, origin_element_node)
194+
self._add_entity_data(origin_element, graph, origin_element_node)
169195

170196
return graph
171197

@@ -208,7 +234,9 @@ def from_nx(cls, nx_graph: "MultiDiGraph") -> _GraphData:
208234
_Edge(
209235
source=u,
210236
destination=v,
211-
relationship=Relationship(data_dict["relationship"]),
237+
relationship=Relationship(data_dict["relationship"])
238+
if data_dict["relationship"] in Relationship
239+
else data_dict["relationship"],
212240
)
213241
for u, v, data_dict in nx_graph.edges(data=True)
214242
]
@@ -242,7 +270,7 @@ class _Edge(BaseModel):
242270

243271
source: _Node
244272
destination: _Node
245-
relationship: Relationship
273+
relationship: Relationship | str
246274

247275

248276
class Label(Enum):
@@ -380,7 +408,7 @@ async def _merge_graph(self, graph_data: _GraphData, client: AsyncDriver) -> Non
380408
)
381409
logger.info(f"Finished merging {len(graph_data.nodes)} graph nodes.")
382410

383-
edges_by_relationship: defaultdict[tuple[Relationship, Label, Label], list[_Edge]] = (
411+
edges_by_relationship: defaultdict[tuple[Relationship | str, Label, Label], list[_Edge]] = (
384412
defaultdict(list)
385413
)
386414
for edge in graph_data.edges:
@@ -463,16 +491,19 @@ def _create_nodes_query(nodes: list[_Node], label: Label) -> tuple[str, dict]:
463491
@staticmethod
464492
def _create_edges_query(
465493
edges: list[_Edge],
466-
relationship: Relationship,
494+
relationship: Relationship | str,
467495
source_label: Label,
468496
destination_label: Label,
469497
) -> tuple[str, dict]:
470498
logger.info(f"Preparing MERGE query for {len(edges)} {relationship} relationships.")
499+
relationship = (
500+
relationship.value if isinstance(relationship, Relationship) else relationship
501+
)
471502
query_string = f"""
472503
UNWIND $edges AS edge
473504
MATCH (u: `{source_label.value}` {{id: edge.source}})
474505
MATCH (v: `{destination_label.value}` {{id: edge.destination}})
475-
MERGE (u)-[:`{relationship.value}`]->(v)
506+
MERGE (u)-[:`{relationship}`]->(v)
476507
"""
477508
parameters = {
478509
"edges": [

0 commit comments

Comments
 (0)