1414# limitations under the License.
1515from __future__ import annotations
1616
17+ import os
1718import logging
1819from abc import abstractmethod
1920from typing import Any , Generator , Literal , Optional
2021
2122import neo4j
2223from pydantic import validate_call
2324
25+ from neo4j_graphrag .experimental .components .filename_collision_handler import (
26+ FilenameCollisionHandler ,
27+ )
28+ from neo4j_graphrag .experimental .components .parquet_formatter import (
29+ Neo4jGraphParquetFormatter ,
30+ )
31+ from neo4j_graphrag .experimental .components .parquet_output import (
32+ ParquetOutputDestination ,
33+ )
2434from neo4j_graphrag .experimental .components .types import (
2535 LexicalGraphConfig ,
2636 Neo4jGraph ,
4353logger = logging .getLogger (__name__ )
4454
4555
56+ def _build_columns_from_schema (
57+ schema : Any , primary_key_names : list [str ]
58+ ) -> list [dict [str , Any ]]:
59+ """Build a list of column dicts (name, type, is_primary_key) from a PyArrow schema."""
60+ columns : list [dict [str , Any ]] = []
61+ for i in range (len (schema )):
62+ field = schema .field (i )
63+ type_info = Neo4jGraphParquetFormatter .pyarrow_type_to_type_info (field .type )
64+ columns .append (
65+ {
66+ "name" : field .name ,
67+ "type" : type_info .source_type ,
68+ "is_primary_key" : field .name in primary_key_names ,
69+ }
70+ )
71+ return columns
72+
73+
4674def batched (rows : list [Any ], batch_size : int ) -> Generator [list [Any ], None , None ]:
4775 index = 0
4876 for i in range (0 , len (rows ), batch_size ):
@@ -53,11 +81,46 @@ def batched(rows: list[Any], batch_size: int) -> Generator[list[Any], None, None
5381 index += 1
5482
5583
84+ def _graph_stats (
85+ graph : Neo4jGraph ,
86+ nodes_per_label : Optional [dict [str , int ]] = None ,
87+ rel_per_type : Optional [dict [str , int ]] = None ,
88+ input_files_count : int = 0 ,
89+ input_files_total_size_bytes : int = 0 ,
90+ ) -> dict [str , Any ]:
91+ """Build the statistics dict for writer metadata.
92+
93+ Schema:
94+ node_count, relationship_count, nodes_per_label, rel_per_type,
95+ input_files_count, input_files_total_size_bytes.
96+ """
97+ if nodes_per_label is None :
98+ nodes_per_label = {}
99+ for node in graph .nodes :
100+ nodes_per_label [node .label ] = nodes_per_label .get (node .label , 0 ) + 1
101+ if rel_per_type is None :
102+ rel_per_type = {}
103+ for rel in graph .relationships :
104+ rel_per_type [rel .type ] = rel_per_type .get (rel .type , 0 ) + 1
105+ return {
106+ "node_count" : len (graph .nodes ),
107+ "relationship_count" : len (graph .relationships ),
108+ "nodes_per_label" : nodes_per_label ,
109+ "rel_per_type" : rel_per_type ,
110+ "input_files_count" : input_files_count ,
111+ "input_files_total_size_bytes" : input_files_total_size_bytes ,
112+ }
113+
114+
56115class KGWriterModel (DataModel ):
57116 """Data model for the output of the Knowledge Graph writer.
58117
59118 Attributes:
60- status (Literal["SUCCESS", "FAILURE"]): Whether the write operation was successful.
119+ status: Whether the write operation was successful ("SUCCESS" or "FAILURE").
120+ metadata: Optional dict. When status is SUCCESS, contains at least:
121+ - "statistics": dict with node_count, relationship_count, nodes_per_label,
122+ rel_per_type, input_files_count, input_files_total_size_bytes.
123+ - "files": list of file descriptors with file_path, etc. (ParquetWriter).
61124 """
62125
63126 status : Literal ["SUCCESS" , "FAILURE" ]
@@ -223,10 +286,184 @@ async def run(
223286 return KGWriterModel (
224287 status = "SUCCESS" ,
225288 metadata = {
226- "node_count " : len (graph . nodes ),
227- "relationship_count " : len ( graph . relationships ) ,
289+ "statistics " : _graph_stats (graph ),
290+ "files " : [] ,
228291 },
229292 )
230293 except neo4j .exceptions .ClientError as e :
231294 logger .exception (e )
232295 return KGWriterModel (status = "FAILURE" , metadata = {"error" : str (e )})
296+
297+
298+ class ParquetWriter (KGWriter ):
299+ """Writes a knowledge graph to Parquet files using Neo4jGraphParquetFormatter.
300+
301+ Writes one Parquet file per node label and one per (head_label, relationship_type, tail_label)
302+ to the given destinations, e.g. ``Person.parquet``, ``Person_KNOWS_Person.parquet``.
303+
304+ Args:
305+ nodes_dest (ParquetOutputDestination): Destination for node Parquet files.
306+ relationships_dest (ParquetOutputDestination): Destination for relationship Parquet files.
307+ collision_handler (FilenameCollisionHandler): Handler for resolving filename collisions.
308+ prefix (str): Optional filename prefix for all written files. Defaults to "".
309+
310+ Example:
311+
312+ .. code-block:: python
313+
314+ from neo4j_graphrag.experimental.components.filename_collision_handler import FilenameCollisionHandler
315+ from neo4j_graphrag.experimental.components.kg_writer import ParquetWriter
316+ from neo4j_graphrag.experimental.components.parquet_output import ParquetOutputDestination
317+ from neo4j_graphrag.experimental.pipeline import Pipeline
318+
319+ # Provide your own implementation of ParquetOutputDestination (local, GCS, S3, etc.)
320+ nodes_dest: ParquetOutputDestination = ...
321+ relationships_dest: ParquetOutputDestination = ...
322+
323+ writer = ParquetWriter(
324+ nodes_dest=nodes_dest,
325+ relationships_dest=relationships_dest,
326+ collision_handler=FilenameCollisionHandler(),
327+ )
328+ pipeline = Pipeline()
329+ pipeline.add_component(writer, "writer")
330+ """
331+
332+ def __init__ (
333+ self ,
334+ nodes_dest : ParquetOutputDestination ,
335+ relationships_dest : ParquetOutputDestination ,
336+ collision_handler : FilenameCollisionHandler ,
337+ prefix : str = "" ,
338+ ) -> None :
339+ self .nodes_dest = nodes_dest
340+ self .relationships_dest = relationships_dest
341+ self .collision_handler = collision_handler
342+ self .prefix = prefix
343+
344+ @validate_call
345+ async def run (
346+ self ,
347+ graph : Neo4jGraph ,
348+ lexical_graph_config : LexicalGraphConfig = LexicalGraphConfig (),
349+ schema : Optional [dict [str , Any ]] = None ,
350+ ) -> KGWriterModel :
351+ """Write the knowledge graph to Parquet files via Neo4jGraphParquetFormatter.
352+
353+ Args:
354+ graph (Neo4jGraph): The knowledge graph to write.
355+ lexical_graph_config (LexicalGraphConfig): Used by the formatter for
356+ lexical graph labels (e.g. __Entity__) and key properties.
357+ schema (Optional[dict[str, Any]]): Optional GraphSchema as a dictionary for
358+ uniqueness constraints and key properties. If not provided, ``__id__`` is used.
359+ """
360+ try :
361+ formatter = Neo4jGraphParquetFormatter (schema = schema )
362+ data , file_metadata , stats = formatter .format_graph (
363+ graph , lexical_graph_config , prefix = self .prefix
364+ )
365+
366+ meta_by_filename : dict [str , Any ] = {m .filename : m for m in file_metadata }
367+ files : list [dict [str , Any ]] = []
368+ node_label_to_source_name : dict [str , str ] = {}
369+
370+ base_nodes = self .nodes_dest .output_path .rstrip ("/" )
371+ for filename , content in data ["nodes" ].items ():
372+ meta = meta_by_filename [filename ]
373+ unique_filename = self .collision_handler .get_unique_filename (
374+ filename , self .nodes_dest .output_path
375+ )
376+ await self .nodes_dest .write (content , unique_filename )
377+ file_path = os .path .join (base_nodes , unique_filename )
378+
379+ resolved_stem = (
380+ unique_filename [:- 8 ]
381+ if unique_filename .endswith (".parquet" )
382+ else unique_filename
383+ )
384+ if meta .node_label is not None :
385+ node_label_to_source_name [meta .node_label ] = resolved_stem
386+
387+ columns = _build_columns_from_schema (
388+ meta .schema ,
389+ meta .key_properties or [],
390+ )
391+ name = meta .node_label or (
392+ meta .labels [0 ] if meta .labels else resolved_stem
393+ )
394+ files .append (
395+ {
396+ "name" : name ,
397+ "file_path" : file_path ,
398+ "columns" : columns ,
399+ "is_node" : True ,
400+ "labels" : meta .labels or [],
401+ }
402+ )
403+
404+ base_rel = self .relationships_dest .output_path .rstrip ("/" )
405+ for filename , content in data ["relationships" ].items ():
406+ meta = meta_by_filename [filename ]
407+ unique_filename = self .collision_handler .get_unique_filename (
408+ filename , self .relationships_dest .output_path
409+ )
410+ await self .relationships_dest .write (content , unique_filename )
411+ file_path = os .path .join (base_rel , unique_filename )
412+
413+ start_node_source = node_label_to_source_name .get (
414+ meta .relationship_head or "" , meta .relationship_head or ""
415+ )
416+ end_node_source = node_label_to_source_name .get (
417+ meta .relationship_tail or "" , meta .relationship_tail or ""
418+ )
419+ columns = _build_columns_from_schema (
420+ meta .schema ,
421+ ["from" , "to" ],
422+ )
423+ rel_name = (
424+ f"{ meta .relationship_head } _{ meta .relationship_type } _{ meta .relationship_tail } "
425+ if meta .relationship_head
426+ and meta .relationship_type
427+ and meta .relationship_tail
428+ else unique_filename [:- 8 ]
429+ if unique_filename .endswith (".parquet" )
430+ else unique_filename
431+ )
432+ files .append (
433+ {
434+ "name" : rel_name ,
435+ "file_path" : file_path ,
436+ "columns" : columns ,
437+ "is_node" : False ,
438+ "relationship_type" : meta .relationship_type ,
439+ "start_node_source" : start_node_source ,
440+ "start_node_primary_keys" : meta .head_node_key_properties
441+ or ["__id__" ],
442+ "end_node_source" : end_node_source ,
443+ "end_node_primary_keys" : meta .tail_node_key_properties
444+ or ["__id__" ],
445+ }
446+ )
447+
448+ logger .info (
449+ "Wrote %d node files and %d relationship files" ,
450+ len (data ["nodes" ]),
451+ len (data ["relationships" ]),
452+ )
453+ statistics = _graph_stats (
454+ graph ,
455+ nodes_per_label = stats ["nodes_per_label" ],
456+ rel_per_type = stats ["rel_per_type" ],
457+ input_files_count = 0 ,
458+ input_files_total_size_bytes = 0 ,
459+ )
460+ return KGWriterModel (
461+ status = "SUCCESS" ,
462+ metadata = {
463+ "statistics" : statistics ,
464+ "files" : files ,
465+ },
466+ )
467+ except Exception as e :
468+ logger .exception (e )
469+ return KGWriterModel (status = "FAILURE" , metadata = {"error" : str (e )})
0 commit comments