2828import pyarrow as pa
2929import pyarrow .compute as pc
3030import pyarrow .dataset as pad
31- import pyarrow .parquet as pq
3231from iris .env_resources import TaskResources as _TaskResources
3332from rigging .filesystem import open_url , url_to_fs
3433from rigging .timing import log_time
3534
3635from zephyr .plan import deterministic_hash
36+ from zephyr .spill_writer import SpillWriter
3737from zephyr .writers import ensure_parent_dir
3838
3939logger = logging .getLogger (__name__ )
@@ -571,12 +571,7 @@ def _write_parquet_scatter(
571571 seg_idx = 0
572572 seg_paths : list [str ] = []
573573 schema : pa .Schema | None = None
574- writer : pq .ParquetWriter | None = None
575- seg_file = ""
576-
577- pending_chunk : pa .RecordBatch | None = None
578- pending_target : int = - 1
579- pending_cnt : int = 0
574+ spill_writer : SpillWriter | None = None
580575
581576 avg_item_bytes : float = 0.0
582577 _sampled_avg = False
@@ -586,57 +581,36 @@ def _get_buffer(target: int) -> _ShardBuffer:
586581 buffers [target ] = _ShardBuffer (shard_idx = target , pickled = pickled , has_sort = sort_fn is not None )
587582 return buffers [target ]
588583
589- def _flush_pending () -> None :
590- nonlocal n_chunks_flushed , pending_chunk
591- if pending_chunk is None :
592- return
593- writer .write_batch (pending_chunk )
594- seg_shard_counts [seg_idx ][pending_target ] = seg_shard_counts [seg_idx ].get (pending_target , 0 ) + 1
595- n_chunks_flushed += 1
596- pending_chunk = None
597- if n_chunks_flushed % 10 == 0 :
598- logger .info (
599- "[shard %d segment %d] Wrote %d parquet chunks so far (latest chunk size: %d items)" ,
600- source_shard ,
601- seg_idx ,
602- n_chunks_flushed ,
603- pending_cnt ,
604- )
605-
606584 def _ensure_writer (chunk_schema : pa .Schema ) -> pa .Schema :
607- nonlocal schema , writer , seg_file , seg_idx
585+ nonlocal schema , spill_writer , seg_idx
608586 if schema is None :
609587 schema = chunk_schema
610588 seg_file = _segment_path (parquet_path , seg_idx )
611589 seg_paths .append (seg_file )
612590 ensure_parent_dir (seg_file )
613- writer = pq . ParquetWriter (seg_file , schema , compression = "zstd" , compression_level = 1 )
591+ spill_writer = SpillWriter (seg_file , schema )
614592 elif chunk_schema != schema :
615- _flush_pending ()
616- writer .close ()
593+ spill_writer .close ()
617594 schema = pa .unify_schemas ([schema , chunk_schema ])
618595 seg_idx += 1
619596 for buf in buffers .values ():
620597 buf .chunk_idx = 0
621598 seg_file = _segment_path (parquet_path , seg_idx )
622599 seg_paths .append (seg_file )
623600 ensure_parent_dir (seg_file )
624- writer = pq . ParquetWriter (seg_file , schema , compression = "zstd" , compression_level = 1 )
601+ spill_writer = SpillWriter (seg_file , schema )
625602 logger .info (
626603 "[shard %d] Schema evolved after %d chunks; starting segment %d" ,
627604 source_shard ,
628605 n_chunks_flushed ,
629606 seg_idx ,
630607 )
631- else :
632- _flush_pending ()
633608 return schema
634609
635610 def _flush_buffer (buf : _ShardBuffer ) -> None :
636- nonlocal pending_chunk , pending_target , pending_cnt , avg_item_bytes , _sampled_avg
611+ nonlocal n_chunks_flushed , avg_item_bytes , _sampled_avg
637612
638613 if combiner_fn is not None :
639- # Combiner path: drain buffer to Python, apply combiner, re-sort in Arrow
640614 buf ._flush_micro ()
641615 if not buf .tables :
642616 return
@@ -664,9 +638,21 @@ def _flush_buffer(buf: _ShardBuffer) -> None:
664638 write_schema = _ensure_writer (batch .schema )
665639 if batch .schema != write_schema :
666640 batch = batch .cast (write_schema )
667- pending_chunk = batch
668- pending_target = buf .shard_idx
669- pending_cnt = len (batch )
641+
642+ # Each sorted chunk is its own row group (distinct shard/chunk metadata).
643+ batch_table = pa .Table .from_batches ([batch ])
644+ spill_writer .write_row_group (batch_table )
645+ seg_shard_counts [seg_idx ][buf .shard_idx ] = seg_shard_counts [seg_idx ].get (buf .shard_idx , 0 ) + 1
646+ n_chunks_flushed += 1
647+
648+ if n_chunks_flushed % 10 == 0 :
649+ logger .info (
650+ "[shard %d segment %d] Wrote %d parquet chunks so far (latest chunk size: %d items)" ,
651+ source_shard ,
652+ seg_idx ,
653+ n_chunks_flushed ,
654+ len (batch ),
655+ )
670656
671657 if not _sampled_avg and len (batch ) > 0 :
672658 avg_item_bytes = batch .nbytes / len (batch )
@@ -682,16 +668,14 @@ def _flush_buffer(buf: _ShardBuffer) -> None:
682668 _flush_buffer (buf )
683669
684670 with log_time (f"Flushing remaining buffers for { parquet_path } " ):
685- _flush_pending ()
686671 for target in sorted (buffers .keys ()):
687672 buf = buffers [target ]
688673 if buf .item_count == 0 :
689674 continue
690675 _flush_buffer (buf )
691- _flush_pending ()
692676
693- if writer is not None :
694- writer .close ()
677+ if spill_writer is not None :
678+ spill_writer .close ()
695679
696680 per_shard_max_rows : dict [int , int ] = {target : buf .max_rows for target , buf in buffers .items () if buf .max_rows > 0 }
697681
0 commit comments