@@ -98,8 +98,17 @@ def get_iterators(self) -> Iterator[Iterator]:
9898# ScatterReader. Sidecars are small msgpack files (a few KB) and reads are
9999# GCS GET-bound, so a modest pool keeps latency low without thrashing.
100100_SIDECAR_READ_CONCURRENCY = 32
101- # Number of items sampled from the first flush to estimate avg_item_bytes.
101+ # Items sampled on the first flush to establish an avg_item_bytes baseline .
102102_SCATTER_SAMPLE_SIZE = 100
103+ # Items sampled on each subsequent flush to track item-size drift cheaply.
104+ _SCATTER_ONGOING_SAMPLE_SIZE = 10
105+ # How often (in items written) to re-sample one item's pickle size and update
106+ # the EMA estimate in write(). This is independent of flush-time sampling and
107+ # ensures the estimate tracks drift even when no flush has fired yet.
108+ _ESTIMATE_WRITE_SAMPLE_INTERVAL = 10
109+ # EMA weight given to each new observation. 0.3 converges to a 2x step-change
110+ # in item size within ~3 samples while staying stable under small fluctuations.
111+ _ESTIMATE_EMA_ALPHA = 0.3
103112# Fraction of total memory budgeted for read-side decompression buffers.
104113_SCATTER_READ_BUFFER_FRACTION = 0.25
105114
@@ -532,12 +541,21 @@ def _flush(self, target: int, buf: list) -> None:
532541 buf = _apply_combiner (buf , self ._key_fn , self ._combiner_fn )
533542 buf .sort (key = self ._sort_key )
534543
535- if not self ._sampled_avg and buf :
536- sample = buf [: min (len (buf ), _SCATTER_SAMPLE_SIZE )]
537- total_bytes = sum (len (pickle .dumps (item , protocol = pickle .HIGHEST_PROTOCOL )) for item in sample )
538- self ._avg_item_bytes = total_bytes / len (sample )
544+ if buf :
545+ # Sample a subset of the buffer to update the byte-size estimate.
546+ # First flush: larger sample for a good baseline. Subsequent flushes:
547+ # smaller sample to track drift cheaply via EMA. This prevents OOM
548+ # when early items are small but later items are large — the estimate
549+ # stays current rather than being frozen at the first-flush value.
550+ n = _SCATTER_SAMPLE_SIZE if not self ._sampled_avg else _SCATTER_ONGOING_SAMPLE_SIZE
551+ sample = buf [: min (len (buf ), n )]
552+ observed = sum (len (pickle .dumps (item , protocol = pickle .HIGHEST_PROTOCOL )) for item in sample ) / len (sample )
553+ if not self ._sampled_avg :
554+ self ._avg_item_bytes = observed
555+ self ._sampled_avg = True
556+ else :
557+ self ._avg_item_bytes = (1 - _ESTIMATE_EMA_ALPHA ) * self ._avg_item_bytes + _ESTIMATE_EMA_ALPHA * observed
539558 self ._item_bytes_estimate = self ._avg_item_bytes
540- self ._sampled_avg = True
541559
542560 frame = _write_chunk_frame (buf )
543561 offset = self ._out .tell ()
@@ -557,13 +575,21 @@ def _flush(self, target: int, buf: list) -> None:
557575
558576 def write (self , item : Any ) -> None :
559577 """Route a single item to its target shard buffer, flushing when over budget."""
560- if self ._total_buffer_rows == 0 :
561- # Calibrate from the first item before any batching occurs. A
562- # hardcoded default (e.g. 512 B) can be orders of magnitude off for
563- # large documents, allowing millions of rows to accumulate before the
564- # first flush fires. One real measurement is far safer.
565- self ._item_bytes_estimate = float (len (pickle .dumps (item , protocol = pickle .HIGHEST_PROTOCOL )))
566- self ._first_item_bytes = self ._item_bytes_estimate
578+ if self ._total_buffer_rows % _ESTIMATE_WRITE_SAMPLE_INTERVAL == 0 :
579+ # Periodically measure a single item's serialised size and apply EMA.
580+ # This runs in write() — not just in _flush() — so the estimate tracks
581+ # size drift even when no flush has fired yet (the flush EMA is a
582+ # closed loop: if the estimate is too low no flush fires, so it never
583+ # updates). Interval-based sampling amortises the pickle.dumps cost
584+ # to 1-in-10 items while still catching step-changes within a few rows.
585+ observed = float (len (pickle .dumps (item , protocol = pickle .HIGHEST_PROTOCOL )))
586+ if self ._total_buffer_rows == 0 :
587+ self ._item_bytes_estimate = observed
588+ self ._first_item_bytes = observed
589+ else :
590+ self ._item_bytes_estimate = (
591+ 1 - _ESTIMATE_EMA_ALPHA
592+ ) * self ._item_bytes_estimate + _ESTIMATE_EMA_ALPHA * observed
567593
568594 key = self ._key_fn (item )
569595 target = deterministic_hash (key ) % self ._num_output_shards
0 commit comments