Skip to content

Commit 7d6dc48

Browse files
committed
feat: add has_nodes_batch with UNWIND implementation for Neo4j
1 parent ef7620d commit 7d6dc48

File tree

4 files changed

+100
-0
lines changed

4 files changed

+100
-0
lines changed

lightrag/base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,25 @@ async def batch_upsert_edges(
595595
for src, tgt, edge_data in edges:
596596
await self.upsert_edge(src, tgt, edge_data)
597597

598+
async def has_nodes_batch(self, node_ids: list[str]) -> set[str]:
599+
"""Check which nodes exist. Returns set of existing node_ids.
600+
601+
Default implementation calls has_node for each ID.
602+
Override this method for better performance in storage backends
603+
that support batch operations.
604+
605+
Args:
606+
node_ids: List of node IDs to check
607+
608+
Returns:
609+
Set of node_ids that exist in the graph
610+
"""
611+
result = set()
612+
for node_id in node_ids:
613+
if await self.has_node(node_id):
614+
result.add(node_id)
615+
return result
616+
598617
@abstractmethod
599618
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
600619
"""Insert a new node or update an existing node in the graph.

lightrag/kg/neo4j_impl.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,33 @@ async def _execute(tx, params=batch_params):
12131213
logger.error(f"[{self.workspace}] Error during batch edge upsert: {str(e)}")
12141214
raise
12151215

1216+
async def has_nodes_batch(self, node_ids: list[str]) -> set[str]:
1217+
"""Batch check node existence using UNWIND."""
1218+
if not node_ids:
1219+
return set()
1220+
1221+
workspace_label = self._get_workspace_label()
1222+
result = set()
1223+
1224+
try:
1225+
async with self._driver.session(
1226+
database=self._DATABASE, default_access_mode="READ"
1227+
) as session:
1228+
query = (
1229+
"UNWIND $ids AS eid "
1230+
f"MATCH (n:`{workspace_label}` {{entity_id: eid}}) "
1231+
"RETURN n.entity_id AS entity_id"
1232+
)
1233+
records = await session.run(query, ids=node_ids)
1234+
async for record in records:
1235+
result.add(record["entity_id"])
1236+
await records.consume()
1237+
except Exception as e:
1238+
logger.error(f"[{self.workspace}] Error during batch has_nodes: {str(e)}")
1239+
raise
1240+
1241+
return result
1242+
12161243
async def get_knowledge_graph(
12171244
self,
12181245
node_label: str,

tests/test_batch_neo4j.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,49 @@ async def test_batch_upsert_nodes_large_batch_chunked(neo4j_storage):
101101

102102
# 1200 / 500 = 3 sub-batches
103103
assert mock_tx.run.call_count == 3
104+
105+
106+
@pytest.mark.asyncio
107+
async def test_has_nodes_batch(neo4j_storage):
108+
storage, mock_tx = neo4j_storage
109+
110+
# Mock the session for read access
111+
mock_records = [{"entity_id": "alice"}, {"entity_id": "charlie"}]
112+
113+
class MockResult:
114+
def __aiter__(self):
115+
return self
116+
117+
def __init__(self, records):
118+
self._records = iter(records)
119+
120+
async def __anext__(self):
121+
try:
122+
return next(self._records)
123+
except StopIteration:
124+
raise StopAsyncIteration
125+
126+
async def consume(self):
127+
pass
128+
129+
mock_result = MockResult(mock_records)
130+
131+
mock_session = AsyncMock()
132+
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
133+
mock_session.__aexit__ = AsyncMock(return_value=False)
134+
mock_session.run = AsyncMock(return_value=mock_result)
135+
136+
storage._driver.session = MagicMock(return_value=mock_session)
137+
138+
result = await storage.has_nodes_batch(["alice", "bob", "charlie"])
139+
assert result == {"alice", "charlie"}
140+
# Verify UNWIND query was used
141+
query = mock_session.run.call_args[0][0]
142+
assert "UNWIND" in query
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_has_nodes_batch_empty(neo4j_storage):
147+
storage, mock_tx = neo4j_storage
148+
result = await storage.has_nodes_batch([])
149+
assert result == set()

tests/test_batch_upsert_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,11 @@ async def test_batch_upsert_nodes_empty_list(storage):
7171
async def test_batch_upsert_edges_empty_list(storage):
7272
await storage.batch_upsert_edges([])
7373
storage.upsert_edge.assert_not_called()
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_has_nodes_batch_default(storage):
78+
storage.has_node = AsyncMock(side_effect=[True, False, True])
79+
result = await storage.has_nodes_batch(["a", "b", "c"])
80+
assert result == {"a", "c"}
81+
assert storage.has_node.call_count == 3

0 commit comments

Comments
 (0)