Skip to content

Commit 6963fe5

Browse files
feat: Add Parquet writer (#486)
1 parent 25f3d2d commit 6963fe5

File tree

12 files changed

+1658
-22
lines changed

12 files changed

+1658
-22
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## Next
44

5+
### Added
6+
7+
- Parquet export (experimental): `ParquetWriter` (extends `KGWriter`), `Neo4jGraphParquetFormatter`, and `FilenameCollisionHandler` for writing knowledge graphs to Parquet (one file per node label and per relationship type).
8+
59
### Changed
610

711
- Updated examples, default values, and documentation to use `gpt-4.1` / `gpt-4.1-mini` instead of deprecated GPT-4* models (e.g. `gpt-4o`, `gpt-4`).

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ experimental = [
5858
"langchain-text-splitters>=0.3.0,<0.4.0",
5959
"neo4j-viz>=0.4.2,<0.5.0",
6060
"llama-index>=0.13.0,<0.14.0",
61+
"pyarrow>=20.0.0", # ParquetWriter, Neo4jGraphParquetFormatter
6162
]
6263
examples = [
6364
"langchain-openai>=0.2.2,<0.3.0",
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Filename collision handler for Parquet file writing."""
16+
17+
from __future__ import annotations
18+
19+
from pathlib import Path
20+
from typing import Union
21+
22+
23+
class FilenameCollisionHandler:
24+
"""Handles filename collisions by adding numeric suffixes.
25+
26+
Tracks filename collisions per output path and generates unique filenames
27+
by appending _n suffixes when the same base filename is requested more
28+
than once for the same output path.
29+
30+
Example:
31+
32+
.. code-block:: python
33+
34+
handler = FilenameCollisionHandler()
35+
filename1 = handler.get_unique_filename("Person.parquet", Path("./out"))
36+
# Returns: "Person.parquet"
37+
filename2 = handler.get_unique_filename("Person.parquet", Path("./out"))
38+
# Returns: "Person_1.parquet"
39+
filename3 = handler.get_unique_filename("Person.parquet", Path("./out"))
40+
# Returns: "Person_2.parquet"
41+
"""
42+
43+
# Class-level dictionary to track filename collisions across all instances
44+
_filename_counts: dict[str, int] = {}
45+
46+
def get_unique_filename(
47+
self,
48+
base_filename: str,
49+
output_path: Union[str, Path],
50+
) -> str:
51+
"""Return a unique filename by adding a _n suffix if a collision is detected.
52+
53+
Args:
54+
base_filename: The original filename (e.g. "Person.parquet").
55+
output_path: The output directory path; collisions are tracked per path.
56+
57+
Returns:
58+
A unique filename (e.g. "Person.parquet" or "Person_1.parquet").
59+
"""
60+
path_str = str(Path(output_path).resolve())
61+
key = f"{path_str}{base_filename}"
62+
63+
if key not in self._filename_counts:
64+
self._filename_counts[key] = 0
65+
return base_filename
66+
67+
self._filename_counts[key] += 1
68+
count = self._filename_counts[key]
69+
if base_filename.endswith(".parquet"):
70+
name_without_ext = base_filename[: -len(".parquet")]
71+
return f"{name_without_ext}_{count}.parquet"
72+
return f"{base_filename}_{count}"
73+
74+
@classmethod
75+
def reset(cls) -> None:
76+
"""Clear the collision-tracking state.
77+
78+
Intended for tests so each run starts with a clean state.
79+
"""
80+
cls._filename_counts.clear()

src/neo4j_graphrag/experimental/components/kg_writer.py

Lines changed: 240 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,23 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
import os
1718
import logging
1819
from abc import abstractmethod
1920
from typing import Any, Generator, Literal, Optional
2021

2122
import neo4j
2223
from pydantic import validate_call
2324

25+
from neo4j_graphrag.experimental.components.filename_collision_handler import (
26+
FilenameCollisionHandler,
27+
)
28+
from neo4j_graphrag.experimental.components.parquet_formatter import (
29+
Neo4jGraphParquetFormatter,
30+
)
31+
from neo4j_graphrag.experimental.components.parquet_output import (
32+
ParquetOutputDestination,
33+
)
2434
from neo4j_graphrag.experimental.components.types import (
2535
LexicalGraphConfig,
2636
Neo4jGraph,
@@ -43,6 +53,24 @@
4353
logger = logging.getLogger(__name__)
4454

4555

56+
def _build_columns_from_schema(
57+
schema: Any, primary_key_names: list[str]
58+
) -> list[dict[str, Any]]:
59+
"""Build a list of column dicts (name, type, is_primary_key) from a PyArrow schema."""
60+
columns: list[dict[str, Any]] = []
61+
for i in range(len(schema)):
62+
field = schema.field(i)
63+
type_info = Neo4jGraphParquetFormatter.pyarrow_type_to_type_info(field.type)
64+
columns.append(
65+
{
66+
"name": field.name,
67+
"type": type_info.source_type,
68+
"is_primary_key": field.name in primary_key_names,
69+
}
70+
)
71+
return columns
72+
73+
4674
def batched(rows: list[Any], batch_size: int) -> Generator[list[Any], None, None]:
4775
index = 0
4876
for i in range(0, len(rows), batch_size):
@@ -53,11 +81,46 @@ def batched(rows: list[Any], batch_size: int) -> Generator[list[Any], None, None
5381
index += 1
5482

5583

84+
def _graph_stats(
85+
graph: Neo4jGraph,
86+
nodes_per_label: Optional[dict[str, int]] = None,
87+
rel_per_type: Optional[dict[str, int]] = None,
88+
input_files_count: int = 0,
89+
input_files_total_size_bytes: int = 0,
90+
) -> dict[str, Any]:
91+
"""Build the statistics dict for writer metadata.
92+
93+
Schema:
94+
node_count, relationship_count, nodes_per_label, rel_per_type,
95+
input_files_count, input_files_total_size_bytes.
96+
"""
97+
if nodes_per_label is None:
98+
nodes_per_label = {}
99+
for node in graph.nodes:
100+
nodes_per_label[node.label] = nodes_per_label.get(node.label, 0) + 1
101+
if rel_per_type is None:
102+
rel_per_type = {}
103+
for rel in graph.relationships:
104+
rel_per_type[rel.type] = rel_per_type.get(rel.type, 0) + 1
105+
return {
106+
"node_count": len(graph.nodes),
107+
"relationship_count": len(graph.relationships),
108+
"nodes_per_label": nodes_per_label,
109+
"rel_per_type": rel_per_type,
110+
"input_files_count": input_files_count,
111+
"input_files_total_size_bytes": input_files_total_size_bytes,
112+
}
113+
114+
56115
class KGWriterModel(DataModel):
57116
"""Data model for the output of the Knowledge Graph writer.
58117
59118
Attributes:
60-
status (Literal["SUCCESS", "FAILURE"]): Whether the write operation was successful.
119+
status: Whether the write operation was successful ("SUCCESS" or "FAILURE").
120+
metadata: Optional dict. When status is SUCCESS, contains at least:
121+
- "statistics": dict with node_count, relationship_count, nodes_per_label,
122+
rel_per_type, input_files_count, input_files_total_size_bytes.
123+
- "files": list of file descriptors with file_path, etc. (ParquetWriter).
61124
"""
62125

63126
status: Literal["SUCCESS", "FAILURE"]
@@ -223,10 +286,184 @@ async def run(
223286
return KGWriterModel(
224287
status="SUCCESS",
225288
metadata={
226-
"node_count": len(graph.nodes),
227-
"relationship_count": len(graph.relationships),
289+
"statistics": _graph_stats(graph),
290+
"files": [],
228291
},
229292
)
230293
except neo4j.exceptions.ClientError as e:
231294
logger.exception(e)
232295
return KGWriterModel(status="FAILURE", metadata={"error": str(e)})
296+
297+
298+
class ParquetWriter(KGWriter):
299+
"""Writes a knowledge graph to Parquet files using Neo4jGraphParquetFormatter.
300+
301+
Writes one Parquet file per node label and one per (head_label, relationship_type, tail_label)
302+
to the given destinations, e.g. ``Person.parquet``, ``Person_KNOWS_Person.parquet``.
303+
304+
Args:
305+
nodes_dest (ParquetOutputDestination): Destination for node Parquet files.
306+
relationships_dest (ParquetOutputDestination): Destination for relationship Parquet files.
307+
collision_handler (FilenameCollisionHandler): Handler for resolving filename collisions.
308+
prefix (str): Optional filename prefix for all written files. Defaults to "".
309+
310+
Example:
311+
312+
.. code-block:: python
313+
314+
from neo4j_graphrag.experimental.components.filename_collision_handler import FilenameCollisionHandler
315+
from neo4j_graphrag.experimental.components.kg_writer import ParquetWriter
316+
from neo4j_graphrag.experimental.components.parquet_output import ParquetOutputDestination
317+
from neo4j_graphrag.experimental.pipeline import Pipeline
318+
319+
# Provide your own implementation of ParquetOutputDestination (local, GCS, S3, etc.)
320+
nodes_dest: ParquetOutputDestination = ...
321+
relationships_dest: ParquetOutputDestination = ...
322+
323+
writer = ParquetWriter(
324+
nodes_dest=nodes_dest,
325+
relationships_dest=relationships_dest,
326+
collision_handler=FilenameCollisionHandler(),
327+
)
328+
pipeline = Pipeline()
329+
pipeline.add_component(writer, "writer")
330+
"""
331+
332+
def __init__(
333+
self,
334+
nodes_dest: ParquetOutputDestination,
335+
relationships_dest: ParquetOutputDestination,
336+
collision_handler: FilenameCollisionHandler,
337+
prefix: str = "",
338+
) -> None:
339+
self.nodes_dest = nodes_dest
340+
self.relationships_dest = relationships_dest
341+
self.collision_handler = collision_handler
342+
self.prefix = prefix
343+
344+
@validate_call
345+
async def run(
346+
self,
347+
graph: Neo4jGraph,
348+
lexical_graph_config: LexicalGraphConfig = LexicalGraphConfig(),
349+
schema: Optional[dict[str, Any]] = None,
350+
) -> KGWriterModel:
351+
"""Write the knowledge graph to Parquet files via Neo4jGraphParquetFormatter.
352+
353+
Args:
354+
graph (Neo4jGraph): The knowledge graph to write.
355+
lexical_graph_config (LexicalGraphConfig): Used by the formatter for
356+
lexical graph labels (e.g. __Entity__) and key properties.
357+
schema (Optional[dict[str, Any]]): Optional GraphSchema as a dictionary for
358+
uniqueness constraints and key properties. If not provided, ``__id__`` is used.
359+
"""
360+
try:
361+
formatter = Neo4jGraphParquetFormatter(schema=schema)
362+
data, file_metadata, stats = formatter.format_graph(
363+
graph, lexical_graph_config, prefix=self.prefix
364+
)
365+
366+
meta_by_filename: dict[str, Any] = {m.filename: m for m in file_metadata}
367+
files: list[dict[str, Any]] = []
368+
node_label_to_source_name: dict[str, str] = {}
369+
370+
base_nodes = self.nodes_dest.output_path.rstrip("/")
371+
for filename, content in data["nodes"].items():
372+
meta = meta_by_filename[filename]
373+
unique_filename = self.collision_handler.get_unique_filename(
374+
filename, self.nodes_dest.output_path
375+
)
376+
await self.nodes_dest.write(content, unique_filename)
377+
file_path = os.path.join(base_nodes, unique_filename)
378+
379+
resolved_stem = (
380+
unique_filename[:-8]
381+
if unique_filename.endswith(".parquet")
382+
else unique_filename
383+
)
384+
if meta.node_label is not None:
385+
node_label_to_source_name[meta.node_label] = resolved_stem
386+
387+
columns = _build_columns_from_schema(
388+
meta.schema,
389+
meta.key_properties or [],
390+
)
391+
name = meta.node_label or (
392+
meta.labels[0] if meta.labels else resolved_stem
393+
)
394+
files.append(
395+
{
396+
"name": name,
397+
"file_path": file_path,
398+
"columns": columns,
399+
"is_node": True,
400+
"labels": meta.labels or [],
401+
}
402+
)
403+
404+
base_rel = self.relationships_dest.output_path.rstrip("/")
405+
for filename, content in data["relationships"].items():
406+
meta = meta_by_filename[filename]
407+
unique_filename = self.collision_handler.get_unique_filename(
408+
filename, self.relationships_dest.output_path
409+
)
410+
await self.relationships_dest.write(content, unique_filename)
411+
file_path = os.path.join(base_rel, unique_filename)
412+
413+
start_node_source = node_label_to_source_name.get(
414+
meta.relationship_head or "", meta.relationship_head or ""
415+
)
416+
end_node_source = node_label_to_source_name.get(
417+
meta.relationship_tail or "", meta.relationship_tail or ""
418+
)
419+
columns = _build_columns_from_schema(
420+
meta.schema,
421+
["from", "to"],
422+
)
423+
rel_name = (
424+
f"{meta.relationship_head}_{meta.relationship_type}_{meta.relationship_tail}"
425+
if meta.relationship_head
426+
and meta.relationship_type
427+
and meta.relationship_tail
428+
else unique_filename[:-8]
429+
if unique_filename.endswith(".parquet")
430+
else unique_filename
431+
)
432+
files.append(
433+
{
434+
"name": rel_name,
435+
"file_path": file_path,
436+
"columns": columns,
437+
"is_node": False,
438+
"relationship_type": meta.relationship_type,
439+
"start_node_source": start_node_source,
440+
"start_node_primary_keys": meta.head_node_key_properties
441+
or ["__id__"],
442+
"end_node_source": end_node_source,
443+
"end_node_primary_keys": meta.tail_node_key_properties
444+
or ["__id__"],
445+
}
446+
)
447+
448+
logger.info(
449+
"Wrote %d node files and %d relationship files",
450+
len(data["nodes"]),
451+
len(data["relationships"]),
452+
)
453+
statistics = _graph_stats(
454+
graph,
455+
nodes_per_label=stats["nodes_per_label"],
456+
rel_per_type=stats["rel_per_type"],
457+
input_files_count=0,
458+
input_files_total_size_bytes=0,
459+
)
460+
return KGWriterModel(
461+
status="SUCCESS",
462+
metadata={
463+
"statistics": statistics,
464+
"files": files,
465+
},
466+
)
467+
except Exception as e:
468+
logger.exception(e)
469+
return KGWriterModel(status="FAILURE", metadata={"error": str(e)})

0 commit comments

Comments
 (0)