Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions lightrag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,28 @@ async def get_nodes_edges_batch(
result[node_id] = edges if edges is not None else []
return result

async def batch_upsert_nodes(
self, nodes: list[tuple[str, dict[str, str]]]
) -> None:
"""Batch insert/update nodes. Default loops over upsert_node."""
for node_id, node_data in nodes:
await self.upsert_node(node_id, node_data)

async def batch_upsert_edges(
self, edges: list[tuple[str, str, dict[str, str]]]
) -> None:
"""Batch insert/update edges. Default loops over upsert_edge."""
for src, tgt, edge_data in edges:
await self.upsert_edge(src, tgt, edge_data)

async def has_nodes_batch(self, node_ids: list[str]) -> set[str]:
"""Return the subset of node_ids that exist. Default loops over has_node."""
result = set()
for node_id in node_ids:
if await self.has_node(node_id):
result.add(node_id)
return result

@abstractmethod
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""Insert a new node or update an existing node in the graph.
Expand Down
103 changes: 103 additions & 0 deletions lightrag/kg/neo4j_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,109 @@ async def execute_upsert(tx: AsyncManagedTransaction):
logger.error(f"[{self.workspace}] Error during edge upsert: {str(e)}")
raise

_BATCH_CHUNK_SIZE = 500

async def batch_upsert_nodes(
self, nodes: list[tuple[str, dict[str, str]]]
) -> None:
"""UNWIND-based batch node upsert, grouped by entity_type."""
if not nodes:
return

# Group by entity_type for separate SET n:`{type}` labels
by_type: dict[str, list[tuple[str, dict[str, str]]]] = {}
for node_id, node_data in nodes:
entity_type = node_data.get("entity_type", "unknown")
by_type.setdefault(entity_type, []).append((node_id, node_data))

workspace_label = self._get_workspace_label()

try:
async with self._driver.session(database=self._DATABASE) as session:
for entity_type, typed_nodes in by_type.items():
for i in range(0, len(typed_nodes), self._BATCH_CHUNK_SIZE):
chunk = typed_nodes[i : i + self._BATCH_CHUNK_SIZE]
batch_params = [
{"entity_id": nid, "properties": ndata}
for nid, ndata in chunk
]

async def _execute(tx, params=batch_params, et=entity_type):
query = (
"UNWIND $batch AS item "
f"MERGE (n:`{workspace_label}` {{entity_id: item.entity_id}}) "
"SET n += item.properties "
f"SET n:`{et}`"
)
await tx.run(query, batch=params)

await session.execute_write(_execute)
except Exception as e:
logger.error(f"[{self.workspace}] Error during batch node upsert: {str(e)}")
raise

async def batch_upsert_edges(
self, edges: list[tuple[str, str, dict[str, str]]]
) -> None:
"""UNWIND-based batch edge upsert."""
if not edges:
return

workspace_label = self._get_workspace_label()

try:
async with self._driver.session(database=self._DATABASE) as session:
for i in range(0, len(edges), self._BATCH_CHUNK_SIZE):
chunk = edges[i : i + self._BATCH_CHUNK_SIZE]
batch_params = [
{"source_id": src, "target_id": tgt, "properties": data}
for src, tgt, data in chunk
]

async def _execute(tx, params=batch_params):
query = (
"UNWIND $batch AS item "
f"MATCH (source:`{workspace_label}` {{entity_id: item.source_id}}) "
f"WITH source, item "
f"MATCH (target:`{workspace_label}` {{entity_id: item.target_id}}) "
"MERGE (source)-[r:DIRECTED]-(target) "
"SET r += item.properties"
)
result = await tx.run(query, batch=params)
await result.consume()

await session.execute_write(_execute)
except Exception as e:
logger.error(f"[{self.workspace}] Error during batch edge upsert: {str(e)}")
raise

async def has_nodes_batch(self, node_ids: list[str]) -> set[str]:
"""UNWIND-based batch existence check."""
if not node_ids:
return set()

workspace_label = self._get_workspace_label()
result = set()

try:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = (
"UNWIND $ids AS eid "
f"MATCH (n:`{workspace_label}` {{entity_id: eid}}) "
"RETURN n.entity_id AS entity_id"
)
records = await session.run(query, ids=node_ids)
async for record in records:
result.add(record["entity_id"])
await records.consume()
except Exception as e:
logger.error(f"[{self.workspace}] Error during batch has_nodes: {str(e)}")
raise

return result

async def get_knowledge_graph(
self,
node_label: str,
Expand Down
42 changes: 24 additions & 18 deletions lightrag/operate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,6 +1726,7 @@ async def _merge_nodes_then_upsert(
f"Skipped `{entity_name}`: KEEP old chunks {already_source_ids}/{len(full_source_ids)}"
)
existing_node_data = dict(already_node)
existing_node_data["_skip_graph_upsert"] = True
return existing_node_data
else:
logger.error(f"Internal Error: already_node missing for `{entity_name}`")
Expand Down Expand Up @@ -1888,10 +1889,6 @@ async def _merge_nodes_then_upsert(
created_at=int(time.time()),
truncate=truncation_info,
)
await knowledge_graph_inst.upsert_node(
entity_name,
node_data=node_data,
)
node_data["entity_name"] = entity_name
if entity_vdb is not None:
entity_vdb_id = compute_mdhash_id(str(entity_name), prefix="ent-")
Expand Down Expand Up @@ -2045,6 +2042,7 @@ async def _merge_edges_then_upsert(
f"Skipped `{src_id}`~`{tgt_id}`: KEEP old chunks {already_source_ids}/{len(full_source_ids)}"
)
existing_edge_data = dict(already_edge)
existing_edge_data["_skip_graph_upsert"] = True
return existing_edge_data
else:
logger.error(
Expand Down Expand Up @@ -2377,20 +2375,6 @@ async def _merge_edges_then_upsert(
pipeline_status["history_messages"].append(status_message)

edge_created_at = int(time.time())
await knowledge_graph_inst.upsert_edge(
src_id,
tgt_id,
edge_data=dict(
weight=weight,
description=description,
keywords=keywords,
source_id=source_id,
file_path=file_path,
created_at=edge_created_at,
truncate=truncation_info,
),
)

edge_data = dict(
src_id=src_id,
tgt_id=tgt_id,
Expand Down Expand Up @@ -2618,6 +2602,15 @@ async def _locked_process_entity_name(entity_name, entities):
if first_exception is not None:
raise first_exception

# Batch-write all entity nodes to the graph in one call
nodes_to_upsert = [
(e["entity_name"], {k: v for k, v in e.items() if k != "_skip_graph_upsert"})
for e in processed_entities
if e is not None and not e.get("_skip_graph_upsert")
]
if nodes_to_upsert:
await knowledge_graph_inst.batch_upsert_nodes(nodes_to_upsert)

# ===== Phase 2: Process all relationships concurrently =====
log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})"
logger.info(log_message)
Expand Down Expand Up @@ -2738,6 +2731,19 @@ async def _locked_process_edges(edge_key, edges):
if first_exception is not None:
raise first_exception

# Batch-write all edges to the graph in one call
edges_to_upsert = [
(
e["src_id"],
e["tgt_id"],
{k: v for k, v in e.items() if k not in ("src_id", "tgt_id", "_skip_graph_upsert")},
)
for e in processed_edges
if e is not None and not e.get("_skip_graph_upsert")
]
if edges_to_upsert:
await knowledge_graph_inst.batch_upsert_edges(edges_to_upsert)

# ===== Phase 3: Update full_entities and full_relations storage =====
if full_entities_storage and full_relations_storage and doc_id:
try:
Expand Down
168 changes: 168 additions & 0 deletions tests/test_batch_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""Tests that merge phase uses batch upsert instead of individual calls."""

import pytest
from unittest.mock import AsyncMock, MagicMock


@pytest.fixture
def global_config():
return {
"source_ids_limit_method": "FIFO",
"max_source_ids_per_entity": 100,
"max_source_ids_per_relation": 100,
"max_file_paths": 10,
"file_path_more_placeholder": "more",
"use_llm_func": AsyncMock(),
"entity_summary_to_max_tokens": 500,
"summary_language": "English",
}


@pytest.fixture
def mock_kg():
kg = AsyncMock()
kg.get_node = AsyncMock(return_value=None)
kg.upsert_node = AsyncMock()
kg.has_edge = AsyncMock(return_value=False)
kg.get_edge = AsyncMock(return_value=None)
kg.upsert_edge = AsyncMock()
return kg


@pytest.fixture
def mock_entity_chunks():
chunks = AsyncMock()
chunks.get_by_id = AsyncMock(return_value=None)
return chunks


@pytest.mark.asyncio
async def test_merge_nodes_returns_data_without_graph_call(mock_kg, mock_entity_chunks, global_config):
from lightrag.operate import _merge_nodes_then_upsert

nodes_data = [
{
"entity_type": "PERSON",
"description": "A test entity",
"source_id": "chunk-1",
"file_path": "test.txt",
}
]

result = await _merge_nodes_then_upsert(
entity_name="TEST_ENTITY",
nodes_data=nodes_data,
knowledge_graph_inst=mock_kg,
entity_vdb=None,
global_config=global_config,
pipeline_status=None,
pipeline_status_lock=None,
llm_response_cache=None,
entity_chunks_storage=mock_entity_chunks,
)

assert result is not None
assert result["entity_name"] == "TEST_ENTITY"
assert result["entity_type"] == "PERSON"
mock_kg.upsert_node.assert_not_called()


@pytest.mark.asyncio
async def test_merge_edges_returns_data_without_graph_call(mock_kg, global_config):
from lightrag.operate import _merge_edges_then_upsert

mock_kg.get_node = AsyncMock(
return_value={
"entity_id": "EXISTS",
"entity_type": "THING",
"source_id": "chunk-1",
"description": "existing",
"file_path": "test.txt",
}
)

edges_data = [
{
"description": "A relates to B",
"keywords": "test",
"weight": 1.0,
"source_id": "chunk-1",
"file_path": "test.txt",
}
]

result = await _merge_edges_then_upsert(
src_id="ENTITY_A",
tgt_id="ENTITY_B",
edges_data=edges_data,
knowledge_graph_inst=mock_kg,
relationships_vdb=None,
entity_vdb=None,
global_config=global_config,
pipeline_status=None,
pipeline_status_lock=None,
llm_response_cache=None,
added_entities=None,
relation_chunks_storage=None,
entity_chunks_storage=None,
)

assert result is not None
assert result["src_id"] == "ENTITY_A"
assert result["tgt_id"] == "ENTITY_B"
mock_kg.upsert_edge.assert_not_called()


@pytest.mark.asyncio
async def test_skip_graph_upsert_flag_on_early_return():
from lightrag.operate import _merge_nodes_then_upsert

mock_kg = AsyncMock()
mock_kg.get_node = AsyncMock(
return_value={
"entity_id": "TEST",
"entity_type": "PERSON",
"description": "existing",
"source_id": "c1<SEP>c2<SEP>c3",
"file_path": "test.txt",
}
)

mock_entity_chunks = AsyncMock()
mock_entity_chunks.get_by_id = AsyncMock(
return_value={"chunk_ids": ["c1", "c2", "c3"], "count": 3}
)

global_config = {
"source_ids_limit_method": "KEEP",
"max_source_ids_per_entity": 3,
"max_file_paths": 10,
"file_path_more_placeholder": "more",
"use_llm_func": AsyncMock(),
"entity_summary_to_max_tokens": 500,
"summary_language": "English",
}

nodes_data = [
{
"entity_type": "PERSON",
"description": "new desc",
"source_id": "c-new",
"file_path": "new.txt",
}
]

result = await _merge_nodes_then_upsert(
entity_name="TEST",
nodes_data=nodes_data,
knowledge_graph_inst=mock_kg,
entity_vdb=None,
global_config=global_config,
pipeline_status=None,
pipeline_status_lock=None,
llm_response_cache=None,
entity_chunks_storage=mock_entity_chunks,
)

assert result is not None
assert result.get("_skip_graph_upsert") is True
Loading
Loading