Skip to content

Commit 23ae265

Browse files
ravwojdylaclaude
andauthored
tokenize: size window + levanter batch from parquet row groups (#5158)
* size zephyr window and levanter cache `batch_size` from parquet row-group metadata so each unit of work aligns with ~half a row group end-to-end * probe first parquet file's footer via `_avg_parquet_row_group_rows`, then set `window = min(avg_rows_per_rg // 2, 64)` and `batch_size = avg_rows_per_rg // 2` * halving gives zephyr headroom to pipeline two windows per row group and caps per-worker peak memory * non-parquet inputs keep the old defaults (`window=64`, `batch_size` from `config.levanter_batch_size`) * caller-supplied `config.levanter_batch_size` still wins over the row-group-derived default * extract `_MAX_WINDOW_SIZE = 64` constant [^1] CC: @rjpower [^1]: rationale for the 64 cap lives in #2829 (comment) Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 30f6b6c commit 23ae265

1 file changed

Lines changed: 42 additions & 4 deletions

File tree

lib/marin/src/marin/processing/tokenize/tokenize.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import braceexpand
2020
import draccus
2121
import fsspec
22+
import pyarrow.parquet as pq
2223
from rigging.filesystem import open_url, url_to_fs
2324
from datasets import load_dataset_builder
2425
from fray import ResourceConfig
@@ -44,6 +45,22 @@
4445
logger = logging.getLogger(__name__)
4546

4647
MIN_GROUP_BYTES = 100_000_000 # 100 MB floor to avoid degenerate tiny shards
48+
# Empirical upper bound on the zephyr window size (see
49+
# https://github.com/marin-community/marin/issues/2829#issuecomment-3963661943).
50+
_MAX_WINDOW_SIZE = 64
51+
52+
53+
def _avg_parquet_row_group_rows(path: str) -> int | None:
54+
"""Return the mean rows-per-row-group from ``path``.
55+
56+
Returns ``None`` if the file has no row groups (empty parquet footer).
57+
"""
58+
fs, resolved = url_to_fs(path)
59+
with fs.open(resolved, "rb") as f:
60+
meta = pq.ParquetFile(f).metadata
61+
if meta.num_row_groups == 0:
62+
return None
63+
return max(1, meta.num_rows // meta.num_row_groups)
4764

4865

4966
def _compute_target_group_bytes(total_input_bytes: int, max_workers: int) -> int:
@@ -396,22 +413,43 @@ def run_pipeline(ctx: ZephyrContext, file_groups: list[list[str]], split_name: s
396413
prefix = os.path.join(config.cache_path, split_name)
397414
pipeline_start = time.monotonic()
398415

416+
# For parquet sources, align zephyr's window and levanter's cache batch
417+
# with the parquet row-group size so each unit of work is exactly one
418+
# row group end-to-end. Non-parquet inputs fall through to the defaults.
419+
sample_path = next(
420+
(p for group in file_groups for p in group if p.endswith(".parquet")),
421+
None,
422+
)
423+
window_size = _MAX_WINDOW_SIZE
424+
batch_size = config.levanter_batch_size
425+
if sample_path is not None:
426+
avg_rg_rows = _avg_parquet_row_group_rows(sample_path)
427+
if avg_rg_rows is not None:
428+
half_rg = max(avg_rg_rows // 2, 1)
429+
window_size = min(half_rg, _MAX_WINDOW_SIZE)
430+
batch_size = half_rg if config.levanter_batch_size is None else config.levanter_batch_size
431+
logger.info(
432+
"Parquet source: avg rows/row-group=%d (from %s) → window=%d, levanter batch_size=%d",
433+
avg_rg_rows,
434+
sample_path,
435+
window_size,
436+
batch_size,
437+
)
438+
399439
ds = Dataset.from_list(file_groups).flat_map(lambda file_list: file_list).flat_map(load_file)
400440

401441
if config.sample_count is not None:
402442
logger.info(f"Sampling {config.sample_count} examples from {split_name} set for tokenization")
403443
ds = ds.take_per_shard(config.sample_count)
404444

405445
temp_shards = (
406-
# NOTE: https://github.com/marin-community/marin/issues/2829#issuecomment-3963661943
407-
# Window set to 64 ^
408-
ds.window(64)
446+
ds.window(window_size)
409447
.map_shard(lambda batches, _: _tokenize_batches(config=config, batches=batches))
410448
.write_levanter_cache(
411449
f"{prefix}/part-{{shard:05d}}-of-{{total:05d}}",
412450
metadata={},
413451
skip_existing=True,
414-
batch_size=config.levanter_batch_size,
452+
batch_size=batch_size,
415453
)
416454
)
417455

0 commit comments

Comments
 (0)