diff --git a/lightrag/base.py b/lightrag/base.py index bfbeca2133..0e3d4fa85c 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -565,6 +565,28 @@ async def get_nodes_edges_batch( result[node_id] = edges if edges is not None else [] return result + async def batch_upsert_nodes( + self, nodes: list[tuple[str, dict[str, str]]] + ) -> None: + """Batch insert/update nodes. Default loops over upsert_node.""" + for node_id, node_data in nodes: + await self.upsert_node(node_id, node_data) + + async def batch_upsert_edges( + self, edges: list[tuple[str, str, dict[str, str]]] + ) -> None: + """Batch insert/update edges. Default loops over upsert_edge.""" + for src, tgt, edge_data in edges: + await self.upsert_edge(src, tgt, edge_data) + + async def has_nodes_batch(self, node_ids: list[str]) -> set[str]: + """Return the subset of node_ids that exist. Default loops over has_node.""" + result = set() + for node_id in node_ids: + if await self.has_node(node_id): + result.add(node_id) + return result + @abstractmethod async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """Insert a new node or update an existing node in the graph. diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 172e1a3694..0e100f7e83 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -1133,6 +1133,109 @@ async def execute_upsert(tx: AsyncManagedTransaction): logger.error(f"[{self.workspace}] Error during edge upsert: {str(e)}") raise + _BATCH_CHUNK_SIZE = 500 + + async def batch_upsert_nodes( + self, nodes: list[tuple[str, dict[str, str]]] + ) -> None: + """UNWIND-based batch node upsert, grouped by entity_type.""" + if not nodes: + return + + # Group by entity_type for separate SET n:`{type}` labels + by_type: dict[str, list[tuple[str, dict[str, str]]]] = {} + for node_id, node_data in nodes: + entity_type = node_data.get("entity_type", "unknown") + by_type.setdefault(entity_type, []).append((node_id, node_data)) + + workspace_label = self._get_workspace_label() + + try: + async with self._driver.session(database=self._DATABASE) as session: + for entity_type, typed_nodes in by_type.items(): + for i in range(0, len(typed_nodes), self._BATCH_CHUNK_SIZE): + chunk = typed_nodes[i : i + self._BATCH_CHUNK_SIZE] + batch_params = [ + {"entity_id": nid, "properties": ndata} + for nid, ndata in chunk + ] + + async def _execute(tx, params=batch_params, et=entity_type): + query = ( + "UNWIND $batch AS item " + f"MERGE (n:`{workspace_label}` {{entity_id: item.entity_id}}) " + "SET n += item.properties " + f"SET n:`{et}`" + ) + await tx.run(query, batch=params) + + await session.execute_write(_execute) + except Exception as e: + logger.error(f"[{self.workspace}] Error during batch node upsert: {str(e)}") + raise + + async def batch_upsert_edges( + self, edges: list[tuple[str, str, dict[str, str]]] + ) -> None: + """UNWIND-based batch edge upsert.""" + if not edges: + return + + workspace_label = self._get_workspace_label() + + try: + async with self._driver.session(database=self._DATABASE) as session: + for i in range(0, len(edges), self._BATCH_CHUNK_SIZE): + chunk = edges[i : i + self._BATCH_CHUNK_SIZE] + batch_params = [ + {"source_id": src, "target_id": tgt, "properties": data} + for src, tgt, data in chunk + ] + + async def _execute(tx, params=batch_params): + query = ( + "UNWIND $batch AS item " + f"MATCH (source:`{workspace_label}` {{entity_id: item.source_id}}) " + f"WITH source, item " + f"MATCH (target:`{workspace_label}` {{entity_id: item.target_id}}) " + "MERGE (source)-[r:DIRECTED]-(target) " + "SET r += item.properties" + ) + result = await tx.run(query, batch=params) + await result.consume() + + await session.execute_write(_execute) + except Exception as e: + logger.error(f"[{self.workspace}] Error during batch edge upsert: {str(e)}") + raise + + async def has_nodes_batch(self, node_ids: list[str]) -> set[str]: + """UNWIND-based batch existence check.""" + if not node_ids: + return set() + + workspace_label = self._get_workspace_label() + result = set() + + try: + async with self._driver.session( + database=self._DATABASE, default_access_mode="READ" + ) as session: + query = ( + "UNWIND $ids AS eid " + f"MATCH (n:`{workspace_label}` {{entity_id: eid}}) " + "RETURN n.entity_id AS entity_id" + ) + records = await session.run(query, ids=node_ids) + async for record in records: + result.add(record["entity_id"]) + await records.consume() + except Exception as e: + logger.error(f"[{self.workspace}] Error during batch has_nodes: {str(e)}") + raise + + return result + async def get_knowledge_graph( self, node_label: str, diff --git a/lightrag/operate.py b/lightrag/operate.py index 20af067908..a432fe36a2 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -1726,6 +1726,7 @@ async def _merge_nodes_then_upsert( f"Skipped `{entity_name}`: KEEP old chunks {already_source_ids}/{len(full_source_ids)}" ) existing_node_data = dict(already_node) + existing_node_data["_skip_graph_upsert"] = True return existing_node_data else: logger.error(f"Internal Error: already_node missing for `{entity_name}`") @@ -1888,10 +1889,6 @@ async def _merge_nodes_then_upsert( created_at=int(time.time()), truncate=truncation_info, ) - await knowledge_graph_inst.upsert_node( - entity_name, - node_data=node_data, - ) node_data["entity_name"] = entity_name if entity_vdb is not None: entity_vdb_id = compute_mdhash_id(str(entity_name), prefix="ent-") @@ -2045,6 +2042,7 @@ async def _merge_edges_then_upsert( f"Skipped `{src_id}`~`{tgt_id}`: KEEP old chunks {already_source_ids}/{len(full_source_ids)}" ) existing_edge_data = dict(already_edge) + existing_edge_data["_skip_graph_upsert"] = True return existing_edge_data else: logger.error( @@ -2377,20 +2375,6 @@ async def _merge_edges_then_upsert( pipeline_status["history_messages"].append(status_message) edge_created_at = int(time.time()) - await knowledge_graph_inst.upsert_edge( - src_id, - tgt_id, - edge_data=dict( - weight=weight, - description=description, - keywords=keywords, - source_id=source_id, - file_path=file_path, - created_at=edge_created_at, - truncate=truncation_info, - ), - ) - edge_data = dict( src_id=src_id, tgt_id=tgt_id, @@ -2618,6 +2602,15 @@ async def _locked_process_entity_name(entity_name, entities): if first_exception is not None: raise first_exception + # Batch-write all entity nodes to the graph in one call + nodes_to_upsert = [ + (e["entity_name"], {k: v for k, v in e.items() if k != "_skip_graph_upsert"}) + for e in processed_entities + if e is not None and not e.get("_skip_graph_upsert") + ] + if nodes_to_upsert: + await knowledge_graph_inst.batch_upsert_nodes(nodes_to_upsert) + # ===== Phase 2: Process all relationships concurrently ===== log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})" logger.info(log_message) @@ -2738,6 +2731,19 @@ async def _locked_process_edges(edge_key, edges): if first_exception is not None: raise first_exception + # Batch-write all edges to the graph in one call + edges_to_upsert = [ + ( + e["src_id"], + e["tgt_id"], + {k: v for k, v in e.items() if k not in ("src_id", "tgt_id", "_skip_graph_upsert")}, + ) + for e in processed_edges + if e is not None and not e.get("_skip_graph_upsert") + ] + if edges_to_upsert: + await knowledge_graph_inst.batch_upsert_edges(edges_to_upsert) + # ===== Phase 3: Update full_entities and full_relations storage ===== if full_entities_storage and full_relations_storage and doc_id: try: diff --git a/tests/test_batch_merge.py b/tests/test_batch_merge.py new file mode 100644 index 0000000000..6a7e33aa85 --- /dev/null +++ b/tests/test_batch_merge.py @@ -0,0 +1,168 @@ +"""Tests that merge phase uses batch upsert instead of individual calls.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock + + +@pytest.fixture +def global_config(): + return { + "source_ids_limit_method": "FIFO", + "max_source_ids_per_entity": 100, + "max_source_ids_per_relation": 100, + "max_file_paths": 10, + "file_path_more_placeholder": "more", + "use_llm_func": AsyncMock(), + "entity_summary_to_max_tokens": 500, + "summary_language": "English", + } + + +@pytest.fixture +def mock_kg(): + kg = AsyncMock() + kg.get_node = AsyncMock(return_value=None) + kg.upsert_node = AsyncMock() + kg.has_edge = AsyncMock(return_value=False) + kg.get_edge = AsyncMock(return_value=None) + kg.upsert_edge = AsyncMock() + return kg + + +@pytest.fixture +def mock_entity_chunks(): + chunks = AsyncMock() + chunks.get_by_id = AsyncMock(return_value=None) + return chunks + + +@pytest.mark.asyncio +async def test_merge_nodes_returns_data_without_graph_call(mock_kg, mock_entity_chunks, global_config): + from lightrag.operate import _merge_nodes_then_upsert + + nodes_data = [ + { + "entity_type": "PERSON", + "description": "A test entity", + "source_id": "chunk-1", + "file_path": "test.txt", + } + ] + + result = await _merge_nodes_then_upsert( + entity_name="TEST_ENTITY", + nodes_data=nodes_data, + knowledge_graph_inst=mock_kg, + entity_vdb=None, + global_config=global_config, + pipeline_status=None, + pipeline_status_lock=None, + llm_response_cache=None, + entity_chunks_storage=mock_entity_chunks, + ) + + assert result is not None + assert result["entity_name"] == "TEST_ENTITY" + assert result["entity_type"] == "PERSON" + mock_kg.upsert_node.assert_not_called() + + +@pytest.mark.asyncio +async def test_merge_edges_returns_data_without_graph_call(mock_kg, global_config): + from lightrag.operate import _merge_edges_then_upsert + + mock_kg.get_node = AsyncMock( + return_value={ + "entity_id": "EXISTS", + "entity_type": "THING", + "source_id": "chunk-1", + "description": "existing", + "file_path": "test.txt", + } + ) + + edges_data = [ + { + "description": "A relates to B", + "keywords": "test", + "weight": 1.0, + "source_id": "chunk-1", + "file_path": "test.txt", + } + ] + + result = await _merge_edges_then_upsert( + src_id="ENTITY_A", + tgt_id="ENTITY_B", + edges_data=edges_data, + knowledge_graph_inst=mock_kg, + relationships_vdb=None, + entity_vdb=None, + global_config=global_config, + pipeline_status=None, + pipeline_status_lock=None, + llm_response_cache=None, + added_entities=None, + relation_chunks_storage=None, + entity_chunks_storage=None, + ) + + assert result is not None + assert result["src_id"] == "ENTITY_A" + assert result["tgt_id"] == "ENTITY_B" + mock_kg.upsert_edge.assert_not_called() + + +@pytest.mark.asyncio +async def test_skip_graph_upsert_flag_on_early_return(): + from lightrag.operate import _merge_nodes_then_upsert + + mock_kg = AsyncMock() + mock_kg.get_node = AsyncMock( + return_value={ + "entity_id": "TEST", + "entity_type": "PERSON", + "description": "existing", + "source_id": "c1c2c3", + "file_path": "test.txt", + } + ) + + mock_entity_chunks = AsyncMock() + mock_entity_chunks.get_by_id = AsyncMock( + return_value={"chunk_ids": ["c1", "c2", "c3"], "count": 3} + ) + + global_config = { + "source_ids_limit_method": "KEEP", + "max_source_ids_per_entity": 3, + "max_file_paths": 10, + "file_path_more_placeholder": "more", + "use_llm_func": AsyncMock(), + "entity_summary_to_max_tokens": 500, + "summary_language": "English", + } + + nodes_data = [ + { + "entity_type": "PERSON", + "description": "new desc", + "source_id": "c-new", + "file_path": "new.txt", + } + ] + + result = await _merge_nodes_then_upsert( + entity_name="TEST", + nodes_data=nodes_data, + knowledge_graph_inst=mock_kg, + entity_vdb=None, + global_config=global_config, + pipeline_status=None, + pipeline_status_lock=None, + llm_response_cache=None, + entity_chunks_storage=mock_entity_chunks, + ) + + assert result is not None + assert result.get("_skip_graph_upsert") is True diff --git a/tests/test_batch_neo4j.py b/tests/test_batch_neo4j.py new file mode 100644 index 0000000000..b8c76a3b54 --- /dev/null +++ b/tests/test_batch_neo4j.py @@ -0,0 +1,148 @@ +"""Tests for Neo4JStorage batch upsert methods.""" +import pytest +from unittest.mock import AsyncMock, MagicMock + + +@pytest.fixture +def neo4j_storage(): + from lightrag.kg.neo4j_impl import Neo4JStorage + + storage = Neo4JStorage.__new__(Neo4JStorage) + storage._DATABASE = "neo4j" + storage.workspace = "test" + + # Mock the driver and session + mock_tx = AsyncMock() + mock_session = AsyncMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + async def _exec_write(fn): + return await fn(mock_tx) + + mock_session.execute_write = AsyncMock(side_effect=_exec_write) + + mock_driver = MagicMock() + mock_driver.session = MagicMock(return_value=mock_session) + + storage._driver = mock_driver + storage._get_workspace_label = MagicMock(return_value="test_workspace") + + return storage, mock_tx + + +@pytest.mark.asyncio +async def test_batch_upsert_nodes_unwind(neo4j_storage): + storage, mock_tx = neo4j_storage + mock_tx.run = AsyncMock() + + nodes = [ + ("alice", {"entity_id": "alice", "entity_type": "PERSON", "description": "Alice"}), + ("bob", {"entity_id": "bob", "entity_type": "PERSON", "description": "Bob"}), + ("acme", {"entity_id": "acme", "entity_type": "ORG", "description": "Acme Corp"}), + ] + await storage.batch_upsert_nodes(nodes) + + # Should group by entity_type: 2 PERSON nodes in one call, 1 ORG in another + assert mock_tx.run.call_count == 2 + for c in mock_tx.run.call_args_list: + query = c[0][0] + assert "UNWIND" in query + assert "MERGE" in query + + +@pytest.mark.asyncio +async def test_batch_upsert_nodes_empty(neo4j_storage): + storage, mock_tx = neo4j_storage + mock_tx.run = AsyncMock() + await storage.batch_upsert_nodes([]) + mock_tx.run.assert_not_called() + + +@pytest.mark.asyncio +async def test_batch_upsert_edges_unwind(neo4j_storage): + storage, mock_tx = neo4j_storage + mock_result = AsyncMock() + mock_result.consume = AsyncMock() + mock_tx.run = AsyncMock(return_value=mock_result) + + edges = [ + ("alice", "bob", {"weight": "1.0", "description": "knows"}), + ("alice", "acme", {"weight": "0.5", "description": "works_at"}), + ] + await storage.batch_upsert_edges(edges) + + assert mock_tx.run.call_count == 1 + query = mock_tx.run.call_args[0][0] + assert "UNWIND" in query + assert "MERGE" in query + assert "DIRECTED" in query + + +@pytest.mark.asyncio +async def test_batch_upsert_edges_empty(neo4j_storage): + storage, mock_tx = neo4j_storage + mock_tx.run = AsyncMock() + await storage.batch_upsert_edges([]) + mock_tx.run.assert_not_called() + + +@pytest.mark.asyncio +async def test_batch_upsert_nodes_large_batch_chunked(neo4j_storage): + """Batches larger than _BATCH_CHUNK_SIZE should be split into sub-batches.""" + storage, mock_tx = neo4j_storage + mock_tx.run = AsyncMock() + + nodes = [ + (f"node_{i}", {"entity_id": f"node_{i}", "entity_type": "ITEM", "description": f"Item {i}"}) + for i in range(1200) + ] + await storage.batch_upsert_nodes(nodes) + + # 1200 / 500 = 3 sub-batches + assert mock_tx.run.call_count == 3 + + +@pytest.mark.asyncio +async def test_has_nodes_batch(neo4j_storage): + storage, mock_tx = neo4j_storage + + # Mock the session for read access + mock_records = [{"entity_id": "alice"}, {"entity_id": "charlie"}] + + class MockResult: + def __aiter__(self): + return self + + def __init__(self, records): + self._records = iter(records) + + async def __anext__(self): + try: + return next(self._records) + except StopIteration: + raise StopAsyncIteration + + async def consume(self): + pass + + mock_result = MockResult(mock_records) + + mock_session = AsyncMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + mock_session.run = AsyncMock(return_value=mock_result) + + storage._driver.session = MagicMock(return_value=mock_session) + + result = await storage.has_nodes_batch(["alice", "bob", "charlie"]) + assert result == {"alice", "charlie"} + # Verify UNWIND query was used + query = mock_session.run.call_args[0][0] + assert "UNWIND" in query + + +@pytest.mark.asyncio +async def test_has_nodes_batch_empty(neo4j_storage): + storage, mock_tx = neo4j_storage + result = await storage.has_nodes_batch([]) + assert result == set() diff --git a/tests/test_batch_upsert_base.py b/tests/test_batch_upsert_base.py new file mode 100644 index 0000000000..4fd910d779 --- /dev/null +++ b/tests/test_batch_upsert_base.py @@ -0,0 +1,81 @@ +"""Tests for batch upsert base class default implementations.""" +import pytest +from unittest.mock import AsyncMock +from lightrag.base import BaseGraphStorage + + +class ConcreteGraphStorage(BaseGraphStorage): + """Minimal concrete subclass for testing.""" + + async def delete_node(self, node_id): ... + async def drop(self): ... + async def edge_degree(self, src, tgt): ... + async def get_all_edges(self): ... + async def get_all_labels(self): ... + async def get_all_nodes(self): ... + async def get_edge(self, src, tgt): ... + async def get_knowledge_graph(self, node_label, max_depth): ... + async def get_node(self, node_id): ... + async def get_node_edges(self, node_id): ... + async def get_popular_labels(self, num): ... + async def has_edge(self, src, tgt): ... + async def has_node(self, node_id): ... + async def index_done_callback(self): ... + async def node_degree(self, node_id): ... + async def remove_edges(self, edges): ... + async def remove_nodes(self, nodes): ... + async def search_labels(self, query): ... + async def upsert_edge(self, src, tgt, data): ... + async def upsert_node(self, node_id, data): ... + + +@pytest.fixture +def storage(): + s = object.__new__(ConcreteGraphStorage) + s.upsert_node = AsyncMock() + s.upsert_edge = AsyncMock() + return s + + +@pytest.mark.asyncio +async def test_batch_upsert_nodes_calls_upsert_node(storage): + nodes = [ + ("entity1", {"entity_id": "entity1", "entity_type": "PERSON", "description": "A person"}), + ("entity2", {"entity_id": "entity2", "entity_type": "ORG", "description": "An org"}), + ] + await storage.batch_upsert_nodes(nodes) + assert storage.upsert_node.call_count == 2 + storage.upsert_node.assert_any_call("entity1", nodes[0][1]) + storage.upsert_node.assert_any_call("entity2", nodes[1][1]) + + +@pytest.mark.asyncio +async def test_batch_upsert_edges_calls_upsert_edge(storage): + edges = [ + ("src1", "tgt1", {"weight": "1.0", "description": "related"}), + ("src2", "tgt2", {"weight": "0.5", "description": "similar"}), + ] + await storage.batch_upsert_edges(edges) + assert storage.upsert_edge.call_count == 2 + storage.upsert_edge.assert_any_call("src1", "tgt1", edges[0][2]) + storage.upsert_edge.assert_any_call("src2", "tgt2", edges[1][2]) + + +@pytest.mark.asyncio +async def test_batch_upsert_nodes_empty_list(storage): + await storage.batch_upsert_nodes([]) + storage.upsert_node.assert_not_called() + + +@pytest.mark.asyncio +async def test_batch_upsert_edges_empty_list(storage): + await storage.batch_upsert_edges([]) + storage.upsert_edge.assert_not_called() + + +@pytest.mark.asyncio +async def test_has_nodes_batch_default(storage): + storage.has_node = AsyncMock(side_effect=[True, False, True]) + result = await storage.has_nodes_batch(["a", "b", "c"]) + assert result == {"a", "c"} + assert storage.has_node.call_count == 3 diff --git a/tests/test_description_api_validation.py b/tests/test_description_api_validation.py index ef9c1a72fd..d8bc56a449 100644 --- a/tests/test_description_api_validation.py +++ b/tests/test_description_api_validation.py @@ -63,7 +63,8 @@ async def test_merge_nodes_then_upsert_handles_missing_legacy_description(): ) assert result["description"] == "Entity LegacyEntity" - assert graph.upserted_nodes[-1][1]["description"] == "Entity LegacyEntity" + # upsert_node is now deferred to batch — verify no individual call was made + assert len(graph.upserted_nodes) == 0 @pytest.mark.asyncio