diff --git a/experiments/dedup/bench_external_merge.py b/experiments/dedup/bench_external_merge.py new file mode 100644 index 0000000000..80d3067de5 --- /dev/null +++ b/experiments/dedup/bench_external_merge.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""A/B benchmark: run nemotron_1slice_fuzzy on main vs arrow-scatter-reduce. + +Creates a worktree for main, patches both branches with ZEPHYR_FORCE_EXTERNAL_MERGE, +deletes stale output buckets, and submits both jobs via iris. +""" + +import os +import shutil +import subprocess +import tempfile + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +WORKTREE_DIR = os.path.join(tempfile.gettempdir(), "marin-main-bench") +EXPERIMENT = "experiments/dedup/nemotron_1slice_fuzzy.py" + +# Output prefix names (used in marin_temp_bucket calls inside the experiment scripts) +MAIN_PREFIX = "arrow-scatter-bench-main" +BRANCH_PREFIX = "arrow-scatter-bench-fast" + +IRIS_BASE_CMD = [ + "uv", + "run", + "iris", + "--config=lib/iris/examples/marin.yaml", + "job", + "run", + "--no-wait", + "--memory=0.5g", + "--cpu=0", + "--region=europe-west4", +] + +# The env var that forces external merge in both old and new code paths +FORCE_ENV = ("ZEPHYR_FORCE_EXTERNAL_MERGE", "1") + + +def run(cmd: list[str], cwd: str | None = None, check: bool = True) -> subprocess.CompletedProcess: + print(f" $ {' '.join(cmd)}") + return subprocess.run(cmd, cwd=cwd, check=check, text=True, capture_output=True) + + +def run_live(cmd: list[str], cwd: str | None = None, check: bool = True) -> subprocess.CompletedProcess: + """Run with stdout/stderr going to terminal.""" + print(f" $ {' '.join(cmd)}") + return subprocess.run(cmd, cwd=cwd, check=check, text=True) + + +def delete_old_outputs() -> None: + """Delete previous benchmark outputs from GCS temp buckets.""" + print("\n=== Deleting previous benchmark outputs ===") + for prefix in (MAIN_PREFIX, BRANCH_PREFIX): + bucket_path = f"gs://marin-tmp-eu-west4/ttl=1d/{prefix}" + result = subprocess.run( + ["gcloud", "storage", "rm", "-r", bucket_path], + text=True, + capture_output=True, + ) + if result.returncode == 0: + print(f" Deleted {bucket_path}") + else: + print(f" Nothing to delete at {bucket_path} (or already gone)") + + +def setup_main_worktree() -> str: + """Create a git worktree for main, return path.""" + print("\n=== Setting up main worktree ===") + if os.path.exists(WORKTREE_DIR): + print(f" Removing existing worktree at {WORKTREE_DIR}") + run(["git", "worktree", "remove", "--force", WORKTREE_DIR], cwd=REPO_ROOT, check=False) + if os.path.exists(WORKTREE_DIR): + shutil.rmtree(WORKTREE_DIR) + + BENCH_BRANCH = "arrow-scatter-test" + run(["git", "branch", "-D", BENCH_BRANCH], cwd=REPO_ROOT, check=False) + run(["git", "worktree", "add", "-b", BENCH_BRANCH, WORKTREE_DIR, "main"], cwd=REPO_ROOT) + print(f" Worktree created at {WORKTREE_DIR} (branch {BENCH_BRANCH})") + return WORKTREE_DIR + + +def patch_main_worktree(worktree: str) -> None: + """Apply ZEPHYR_FORCE_EXTERNAL_MERGE env var check to main's plan.py.""" + print("\n=== Patching main worktree plan.py ===") + plan_py = os.path.join(worktree, "lib/zephyr/src/zephyr/plan.py") + + with open(plan_py) as f: + content = f.read() + + # Main's code has: + # use_external = ( + # external_sort_dir is not None + # and isinstance(shard, ScatterShard) + # and shard.needs_external_sort(_TaskResources.from_environment().memory_bytes) + # ) + old = ( + " use_external = (\n" + " external_sort_dir is not None\n" + " and isinstance(shard, ScatterShard)\n" + " and shard.needs_external_sort(_TaskResources.from_environment().memory_bytes)\n" + " )" + ) + new = ( + ' force_external = os.environ.get("ZEPHYR_FORCE_EXTERNAL_MERGE", "").lower() in ("1", "true", "yes")\n' + " use_external = (\n" + " external_sort_dir is not None\n" + " and isinstance(shard, ScatterShard)\n" + " and (force_external or shard.needs_external_sort(_TaskResources.from_environment().memory_bytes))\n" + " )" + ) + + if old not in content: + print(" WARNING: Could not find expected code pattern in main's plan.py") + print(" Searching for alternative patterns...") + # Check if already patched + if "ZEPHYR_FORCE_EXTERNAL_MERGE" in content: + print(" Already patched, skipping") + return + raise RuntimeError("Cannot patch main's plan.py — expected code not found") + + content = content.replace(old, new) + with open(plan_py, "w") as f: + f.write(content) + print(" Patched plan.py with ZEPHYR_FORCE_EXTERNAL_MERGE support") + + +def copy_experiment_to_worktree(worktree: str) -> None: + """Copy the experiment script to main worktree, adjusting the output prefix.""" + print("\n=== Copying experiment script to main worktree ===") + src = os.path.join(REPO_ROOT, EXPERIMENT) + dst_dir = os.path.join(worktree, os.path.dirname(EXPERIMENT)) + os.makedirs(dst_dir, exist_ok=True) + dst = os.path.join(worktree, EXPERIMENT) + + with open(src) as f: + content = f.read() + + # Replace the branch prefix with the main prefix + content = content.replace(BRANCH_PREFIX, MAIN_PREFIX) + with open(dst, "w") as f: + f.write(content) + print(f" Copied {EXPERIMENT} → {dst} (prefix={MAIN_PREFIX})") + + +def submit_job(cwd: str, label: str) -> str: + """Submit an iris job and return the job ID.""" + print(f"\n=== Submitting job: {label} ===") + cmd = [ + *IRIS_BASE_CMD, + "-e", + FORCE_ENV[0], + FORCE_ENV[1], + "--", + "python", + EXPERIMENT, + ] + run_live(cmd, cwd=cwd) + return label + + +def main() -> None: + print("=" * 60) + print("A/B Benchmark: main vs arrow-scatter-reduce") + print("=" * 60) + + # 1. Delete old outputs + delete_old_outputs() + + # 2. Set up main worktree + worktree = setup_main_worktree() + + # 3. Patch main worktree + patch_main_worktree(worktree) + + # 4. Copy experiment to main worktree + copy_experiment_to_worktree(worktree) + + # 5. Submit main job + submit_job(worktree, "main") + + # 6. Submit branch job (from repo root) + submit_job(REPO_ROOT, "arrow-scatter-reduce") + + print("\n" + "=" * 60) + print("Both jobs submitted. Monitor via:") + print(" uv run iris --config=lib/iris/examples/marin.yaml job list") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/experiments/dedup/bench_external_merge_exact.py b/experiments/dedup/bench_external_merge_exact.py new file mode 100644 index 0000000000..85de4cc9b3 --- /dev/null +++ b/experiments/dedup/bench_external_merge_exact.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""A/B benchmark: exact paragraph dedup on main vs arrow-scatter-reduce. + +Submits 4 jobs total: + - main @ 10% files (prefix: exact-bench-main-10pct) + - branch @ 10% files (prefix: exact-bench-fast-10pct) + - main @ full (prefix: exact-bench-main-full) + - branch @ full (prefix: exact-bench-fast-full) + +Creates a worktree for main, patches both branches with ZEPHYR_FORCE_EXTERNAL_MERGE, +deletes stale output buckets, and submits all jobs via iris. +""" + +import os +import shutil +import subprocess +import tempfile + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +WORKTREE_DIR = os.path.join(tempfile.gettempdir(), "marin-main-bench") +EXPERIMENT = "experiments/dedup/nemotron_1split_exact.py" + +# 10% of ~11108 files ≈ 1111 +TEN_PCT_FILES = 1111 + +VARIANTS = [ + {"label": "10pct", "max_files": str(TEN_PCT_FILES)}, +] + +MAIN_PREFIX_FMT = "exact-bench-main-{label}" +BRANCH_PREFIX_FMT = "exact-bench-fast-{label}" + +IRIS_BASE_CMD = [ + "uv", + "run", + "iris", + "--config=lib/iris/examples/marin-dev.yaml", + "job", + "run", + "--no-wait", + "--memory=4g", + "--cpu=0", + "--region=europe-west4", +] + +FORCE_ENV = ("ZEPHYR_FORCE_EXTERNAL_MERGE", "1") + + +def run(cmd: list[str], cwd: str | None = None, check: bool = True) -> subprocess.CompletedProcess: + print(f" $ {' '.join(cmd)}") + return subprocess.run(cmd, cwd=cwd, check=check, text=True, capture_output=True) + + +def run_live(cmd: list[str], cwd: str | None = None, check: bool = True) -> subprocess.CompletedProcess: + print(f" $ {' '.join(cmd)}") + return subprocess.run(cmd, cwd=cwd, check=check, text=True) + + +def delete_old_outputs() -> None: + print("\n=== Deleting previous benchmark outputs ===") + for variant in VARIANTS: + for fmt in (MAIN_PREFIX_FMT, BRANCH_PREFIX_FMT): + prefix = fmt.format(**variant) + bucket_path = f"gs://marin-tmp-eu-west4/ttl=1d/{prefix}" + result = subprocess.run( + ["gcloud", "storage", "rm", "-r", bucket_path], + text=True, + capture_output=True, + ) + if result.returncode == 0: + print(f" Deleted {bucket_path}") + else: + print(f" Nothing to delete at {bucket_path} (or already gone)") + + +def setup_main_worktree() -> str: + print("\n=== Setting up main worktree ===") + if os.path.exists(WORKTREE_DIR): + print(f" Removing existing worktree at {WORKTREE_DIR}") + run(["git", "worktree", "remove", "--force", WORKTREE_DIR], cwd=REPO_ROOT, check=False) + if os.path.exists(WORKTREE_DIR): + shutil.rmtree(WORKTREE_DIR) + + BENCH_BRANCH = "arrow-scatter-test" + run(["git", "branch", "-D", BENCH_BRANCH], cwd=REPO_ROOT, check=False) + run(["git", "worktree", "add", "-b", BENCH_BRANCH, WORKTREE_DIR, "main"], cwd=REPO_ROOT) + print(f" Worktree created at {WORKTREE_DIR} (branch {BENCH_BRANCH})") + return WORKTREE_DIR + + +def patch_main_worktree(worktree: str) -> None: + print("\n=== Patching main worktree plan.py ===") + plan_py = os.path.join(worktree, "lib/zephyr/src/zephyr/plan.py") + + with open(plan_py) as f: + content = f.read() + + old = ( + " use_external = (\n" + " external_sort_dir is not None\n" + " and isinstance(shard, ScatterShard)\n" + " and shard.needs_external_sort(_TaskResources.from_environment().memory_bytes)\n" + " )" + ) + new = ( + ' force_external = os.environ.get("ZEPHYR_FORCE_EXTERNAL_MERGE", "").lower() in ("1", "true", "yes")\n' + " use_external = (\n" + " external_sort_dir is not None\n" + " and isinstance(shard, ScatterShard)\n" + " and (force_external or shard.needs_external_sort(_TaskResources.from_environment().memory_bytes))\n" + " )" + ) + + if old not in content: + if "ZEPHYR_FORCE_EXTERNAL_MERGE" in content: + print(" Already patched, skipping") + return + raise RuntimeError("Cannot patch main's plan.py — expected code not found") + + content = content.replace(old, new) + with open(plan_py, "w") as f: + f.write(content) + print(" Patched plan.py with ZEPHYR_FORCE_EXTERNAL_MERGE support") + + +def copy_experiment_to_worktree(worktree: str) -> None: + print("\n=== Copying experiment script to main worktree ===") + src = os.path.join(REPO_ROOT, EXPERIMENT) + dst_dir = os.path.join(worktree, os.path.dirname(EXPERIMENT)) + os.makedirs(dst_dir, exist_ok=True) + dst = os.path.join(worktree, EXPERIMENT) + + with open(src) as f: + content = f.read() + + with open(dst, "w") as f: + f.write(content) + print(f" Copied {EXPERIMENT} → {dst}") + + +def submit_job(cwd: str, label: str, output_prefix: str, max_files: str) -> None: + print(f"\n=== Submitting job: {label} (prefix={output_prefix}, max_files={max_files}) ===") + cmd = [ + *IRIS_BASE_CMD, + "-e", + FORCE_ENV[0], + FORCE_ENV[1], + "-e", + "OUTPUT_PREFIX", + output_prefix, + "-e", + "MAX_FILES", + max_files, + "--", + "python", + EXPERIMENT, + ] + run_live(cmd, cwd=cwd) + + +def main() -> None: + print("=" * 70) + print("A/B Benchmark (exact dedup): main vs arrow-scatter-reduce") + print(" Variants: 10% files only") + print("=" * 70) + + delete_old_outputs() + worktree = setup_main_worktree() + patch_main_worktree(worktree) + copy_experiment_to_worktree(worktree) + + for variant in VARIANTS: + label = variant["label"] + max_files = variant["max_files"] + + main_prefix = MAIN_PREFIX_FMT.format(**variant) + branch_prefix = BRANCH_PREFIX_FMT.format(**variant) + + submit_job(worktree, f"main-{label}", main_prefix, max_files) + submit_job(REPO_ROOT, f"branch-{label}", branch_prefix, max_files) + + print("\n" + "=" * 70) + print("2 jobs submitted. Monitor via:") + print(" uv run iris --config=lib/iris/examples/marin-dev.yaml job list") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/experiments/dedup/bench_local_pipeline.py b/experiments/dedup/bench_local_pipeline.py new file mode 100644 index 0000000000..3ca7bd0af3 --- /dev/null +++ b/experiments/dedup/bench_local_pipeline.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""A/B benchmark: local dedup pipeline on main vs arrow-scatter-reduce. + +Generates shared input files once, then runs benchmark_dedup_pipeline.py +in both the current branch and a main worktree, comparing throughput. + +Usage: + uv run python experiments/dedup/bench_local_pipeline.py + uv run python experiments/dedup/bench_local_pipeline.py --num-docs 500000 --backends threadpool sync +""" + +import os +import shutil +import subprocess +import tempfile + +import click + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +WORKTREE_DIR = os.path.join(tempfile.gettempdir(), "marin-main-bench-local") +BENCHMARK_SCRIPT = "lib/zephyr/tests/benchmark_dedup_pipeline.py" + + +def run(cmd: list[str], cwd: str | None = None, check: bool = True) -> subprocess.CompletedProcess: + print(f" $ {' '.join(cmd)}") + return subprocess.run(cmd, cwd=cwd, check=check, text=True, capture_output=True) + + +def run_live(cmd: list[str], cwd: str | None = None, check: bool = True) -> subprocess.CompletedProcess: + print(f" $ {' '.join(cmd)}") + return subprocess.run(cmd, cwd=cwd, check=check, text=True) + + +def setup_main_worktree() -> str: + print("\n=== Setting up main worktree ===") + if os.path.exists(WORKTREE_DIR): + print(f" Removing existing worktree at {WORKTREE_DIR}") + run(["git", "worktree", "remove", "--force", WORKTREE_DIR], cwd=REPO_ROOT, check=False) + if os.path.exists(WORKTREE_DIR): + shutil.rmtree(WORKTREE_DIR) + + # Create a new branch from main so we don't pin main itself to the worktree + BENCH_BRANCH = "arrow-scatter-test" + run(["git", "branch", "-D", BENCH_BRANCH], cwd=REPO_ROOT, check=False) + run(["git", "worktree", "add", "-b", BENCH_BRANCH, WORKTREE_DIR, "main"], cwd=REPO_ROOT) + print(f" Worktree created at {WORKTREE_DIR} (branch {BENCH_BRANCH})") + return WORKTREE_DIR + + +def generate_input( + cwd: str, + input_dir: str, + num_docs: int, + words_per_doc: int, + num_input_files: int, +) -> None: + """Generate shared input files using the benchmark script's write-input command.""" + print(f"\n=== Generating input ({num_docs:,} docs) ===") + cmd = [ + "uv", + "run", + "python", + BENCHMARK_SCRIPT, + "write-input", + "--output-dir", + input_dir, + "--num-docs", + str(num_docs), + "--words-per-doc", + str(words_per_doc), + "--num-input-files", + str(num_input_files), + ] + run_live(cmd, cwd=cwd) + + +def run_benchmark( + cwd: str, + label: str, + input_dir: str, + num_docs: int, + words_per_doc: int, + num_input_files: int, + backends: list[str], +) -> None: + """Run the benchmark in the given directory.""" + print(f"\n{'=' * 60}") + print(f"Running benchmark: {label}") + print(f" cwd: {cwd}") + print(f"{'=' * 60}") + + cmd = [ + "uv", + "run", + "python", + BENCHMARK_SCRIPT, + "benchmark", + "--input-dir", + input_dir, + "--num-docs", + str(num_docs), + "--words-per-doc", + str(words_per_doc), + "--num-input-files", + str(num_input_files), + ] + for backend in backends: + cmd.extend(["--backends", backend]) + + run_live(cmd, cwd=cwd) + + +@click.command() +@click.option("--num-docs", type=int, default=1_000_000, help="Number of documents to generate") +@click.option("--words-per-doc", type=int, default=1000, help="Words per document") +@click.option("--num-input-files", type=int, default=10, help="Number of input parquet files") +@click.option( + "--backends", + multiple=True, + type=click.Choice(["sync", "threadpool", "ray"]), + default=["threadpool"], + help="Backends to benchmark", +) +def main( + num_docs: int, + words_per_doc: int, + num_input_files: int, + backends: tuple[str, ...], +) -> None: + """A/B benchmark: local dedup pipeline on main vs current branch.""" + backends_list = list(backends) + + print("=" * 60) + print("A/B Local Benchmark: main vs arrow-scatter-reduce") + print(f" docs={num_docs:,} words/doc={words_per_doc} files={num_input_files}") + print(f" backends: {', '.join(backends_list)}") + print("=" * 60) + + worktree = setup_main_worktree() + input_dir = tempfile.mkdtemp(prefix="zephyr_ab_input_") + + try: + # Generate input once from current branch + generate_input(REPO_ROOT, input_dir, num_docs, words_per_doc, num_input_files) + + # Run on main + run_benchmark(worktree, "main", input_dir, num_docs, words_per_doc, num_input_files, backends_list) + + # Run on current branch + run_benchmark( + REPO_ROOT, "arrow-scatter-reduce", input_dir, num_docs, words_per_doc, num_input_files, backends_list + ) + + print("\n" + "=" * 60) + print("A/B benchmark complete. Compare results above.") + print("=" * 60) + + finally: + print(f"\nCleaning up input directory {input_dir}...") + shutil.rmtree(input_dir, ignore_errors=True) + print(f"Cleaning up worktree {worktree}...") + run(["git", "worktree", "remove", "--force", worktree], cwd=REPO_ROOT, check=False) + if os.path.exists(worktree): + shutil.rmtree(worktree, ignore_errors=True) + + +if __name__ == "__main__": + main() diff --git a/experiments/dedup/monitor_bench.py b/experiments/dedup/monitor_bench.py new file mode 100644 index 0000000000..8326372704 --- /dev/null +++ b/experiments/dedup/monitor_bench.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Monitor the A/B exact-dedup benchmark jobs (4 variants). + +Usage: + uv run python experiments/dedup/monitor_bench.py # one-shot status + uv run python experiments/dedup/monitor_bench.py --loop # poll every 60s +""" + +import json +import re +import subprocess +import sys +import time +from datetime import datetime, timezone + +IRIS_CMD = ["uv", "run", "iris", "--config", "lib/iris/examples/marin-dev.yaml"] + +# Job IDs from bench_external_merge_exact.py run on 2026-04-01 +JOBS = { + "main-10pct": "/power/iris-run-nemotron_1split_exact-20260402-005308", + "branch-10pct": "/power/iris-run-nemotron_1split_exact-20260402-005322", + "main-full": "/power/iris-run-nemotron_1split_exact-20260402-003956", + "branch-full": "/power/iris-run-nemotron_1split_exact-20260402-004004", +} + + +def run_quiet(cmd: list[str], timeout: int = 60) -> str: + r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) + return "\n".join(l for l in r.stdout.splitlines() if not re.match(r"^I\d{8}", l)) + + +def get_job_state(job_id: str) -> str: + """Get the top-level job state from iris job list.""" + raw = run_quiet([*IRIS_CMD, "job", "list"], timeout=90) + for line in raw.splitlines(): + if job_id in line: + parts = line.split() + for p in parts: + if p in ("running", "succeeded", "failed", "killed", "pending", "queued"): + return p + return "unknown" + + +def discover_coord_actor(job_id: str) -> str | None: + """Find the coordinator actor endpoint from child job names.""" + raw = run_quiet([*IRIS_CMD, "job", "list"], timeout=90) + # Look for the coord child: .../zephyr-*-p0-a0/zephyr-*-p0-coord-0 + for line in raw.splitlines(): + if job_id in line and "coord-0" in line: + # First token is the job ID / task path + path = line.split()[0] + return path + # Try alternative: the coordinator is a child of the zephyr pipeline job + for line in raw.splitlines(): + if job_id in line and "-p0-a0" in line and "workers" not in line and "coord" not in line: + # This is the pipeline child — construct coord path + pipeline_id = line.split()[0] + # Coordinator actor is named like the pipeline but with -coord-0 appended + base = pipeline_id.rsplit("/", 1)[-1] # e.g. zephyr-exact-para-dedup-XXX-p0-a0 + coord_name = base.replace("-a0", "-coord-0") + return f"{pipeline_id}/{coord_name}" + return None + + +def actor_call(endpoint: str, method: str, timeout: int = 60) -> str: + return run_quiet([*IRIS_CMD, "actor", "call", endpoint, method], timeout=timeout) + + +def parse_status(raw: str) -> dict: + m = re.search( + r"stage='([^']*)'.*?completed=(\d+).*?total=(\d+).*?retries=(\d+)" + r".*?in_flight=(\d+).*?queue_depth=(\d+).*?done=(\w+).*?fatal_error=(\w+)", + raw, + ) + if not m: + return {"error": raw[:200]} + return { + "stage": m.group(1), + "completed": int(m.group(2)), + "total": int(m.group(3)), + "retries": int(m.group(4)), + "in_flight": int(m.group(5)), + "queued": int(m.group(6)), + "done": m.group(7), + "fatal": m.group(8), + "busy": raw.count("'state': 'busy'"), + "idle": raw.count("'state': 'idle'"), + "dead": raw.count("'state': 'dead'"), + "ready": raw.count("'state': 'ready'"), + } + + +def parse_counters(raw: str) -> dict: + try: + return json.loads(raw) + except json.JSONDecodeError: + return {"error": raw[:200]} + + +def fmt_count(n: int) -> str: + if n >= 1_000_000_000: + return f"{n / 1e9:.1f}B" + if n >= 1_000_000: + return f"{n / 1e6:.0f}M" + if n >= 1_000: + return f"{n / 1e3:.0f}K" + return str(n) + + +def print_status(): + now = datetime.now(timezone.utc).strftime("%H:%M:%S UTC") + print(f"\n{'=' * 72}") + print(f" Benchmark Monitor — {now}") + print(f" Jobs: {len(JOBS)} (main vs branch x 10%/full)") + print(f"{'=' * 72}\n") + + all_done = True + for label, job_id in JOBS.items(): + state = get_job_state(job_id) + print(f" [{label}] {job_id}") + print(f" State: {state}") + + if state not in ("running",): + if state not in ("succeeded",): + all_done = False + else: + print(" ✓ Completed") + print() + continue + + all_done = False + coord = discover_coord_actor(job_id) + if not coord: + print(" (coordinator not yet discoverable)") + print() + continue + + try: + status = parse_status(actor_call(coord, "get_status", timeout=90)) + except Exception as e: + print(f" (status query failed: {e})") + print() + continue + + if "error" in status: + print(f" Status: {status['error']}") + print() + continue + + pct = status["completed"] / status["total"] * 100 if status["total"] else 0 + print(f" Stage: {status['stage']}") + print(f" Progress: {status['completed']}/{status['total']} ({pct:.0f}%)") + print(f" In-flight: {status['in_flight']} Queued: {status['queued']} Retries: {status['retries']}") + print(f" Workers: {status['busy']} busy, {status['idle']} idle, {status['dead']} dead") + + try: + counters = parse_counters(actor_call(coord, "get_counters")) + if "error" not in counters: + counter_str = ", ".join(f"{k}={fmt_count(v)}" for k, v in counters.items()) + print(f" Counters: {counter_str}") + except Exception: + pass + + print() + + if all_done: + print(" *** ALL JOBS FINISHED ***\n") + return all_done + + +def main(): + loop = "--loop" in sys.argv + if loop: + while True: + done = print_status() + if done: + break + print(" (next check in 10m, Ctrl-C to stop)\n") + time.sleep(600) + else: + print_status() + + +if __name__ == "__main__": + main() diff --git a/experiments/dedup/nemotron_1slice_fuzzy.py b/experiments/dedup/nemotron_1slice_fuzzy.py new file mode 100644 index 0000000000..c12f0f18bf --- /dev/null +++ b/experiments/dedup/nemotron_1slice_fuzzy.py @@ -0,0 +1,63 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Fuzzy dedup on a single CC crawl slice of Nemotron-CC (quality=high). + +Runs the full fuzzy dedup pipeline (MinHash LSH -> connected components -> +dedup tagging) on CC-MAIN-2013-20 (~43 files, ~15 GB compressed). Use this +to validate the Arrow scatter/reduce optimization on real data. + +Usage: + # Submit as an Iris job (requires cluster connection): + uv run lib/marin/src/marin/run/ray_run.py -- python experiments/dedup/nemotron_1slice_fuzzy.py + + # Or run directly if gcloud auth is configured: + uv run python experiments/dedup/nemotron_1slice_fuzzy.py +""" + +import logging + +from fray.v2 import ResourceConfig +from rigging.log_setup import configure_logging +from rigging.filesystem import marin_temp_bucket + +from marin.execution.step_runner import StepRunner +from marin.execution.step_spec import StepSpec +from marin.processing.classification.deduplication.fuzzy import dedup_fuzzy_document +from marin.processing.classification.deduplication.dedup_commons import _collect_input_files, DEFAULT_FILETYPES + +logger = logging.getLogger(__name__) + +NEMOTRON_HIGH = "gs://marin-eu-west4/raw/nemotro-cc-eeb783/contrib/Nemotron/Nemotron-CC/data-jsonl/quality=high" + +# Single CC crawl slice prefix +SLICE_PREFIX = "CC-MAIN-2013-20" + + +def _collect_slice_files() -> list[str]: + """Collect only files matching SLICE_PREFIX from the quality=high directory.""" + all_files = _collect_input_files(input_paths=NEMOTRON_HIGH, filetypes=DEFAULT_FILETYPES) + slice_files = [f for f in all_files if SLICE_PREFIX in f] + logger.info("Selected %d files for slice %s (out of %d total)", len(slice_files), SLICE_PREFIX, len(all_files)) + return slice_files + + +def build_steps() -> list[StepSpec]: + slice_files = _collect_slice_files() + + dedup_step = StepSpec( + name="fuzzy_dedup_nemotron_1slice_rust_arrow", + output_path_prefix=marin_temp_bucket(ttl_days=1, prefix="arrow-scatter-bench-fast"), + fn=lambda op: dedup_fuzzy_document( + input_paths=slice_files, + output_path=op, + max_parallelism=32, + worker_resources=ResourceConfig(cpu=5, ram="32g", disk="5g"), + ), + ) + return [dedup_step] + + +if __name__ == "__main__": + configure_logging(logging.INFO) + StepRunner().run(build_steps()) diff --git a/experiments/dedup/nemotron_1split_exact.py b/experiments/dedup/nemotron_1split_exact.py new file mode 100644 index 0000000000..17c416f753 --- /dev/null +++ b/experiments/dedup/nemotron_1split_exact.py @@ -0,0 +1,61 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Exact paragraph dedup on quality=high of Nemotron-CC. + +Runs the full exact paragraph dedup pipeline on the quality=high split. +Faster wall-time than fuzzy dedup while still exercising shuffle at scale. + +Usage: + uv run iris --config=lib/iris/examples/marin.yaml job run -- python experiments/dedup/nemotron_1split_exact.py + MAX_FILES=1000 uv run iris ... -- python experiments/dedup/nemotron_1split_exact.py +""" + +import logging +import os + +from rigging.log_setup import configure_logging +from rigging.filesystem import marin_temp_bucket + +from marin.execution.step_runner import StepRunner +from marin.execution.step_spec import StepSpec +from marin.processing.classification.deduplication.exact import dedup_exact_paragraph + +logger = logging.getLogger(__name__) + +NEMOTRON_HIGH = "gs://marin-eu-west4/raw/nemotro-cc-eeb783/contrib/Nemotron/Nemotron-CC/data-jsonl/quality=high" + +OUTPUT_PREFIX = os.environ.get("OUTPUT_PREFIX", "arrow-scatter-exact-bench-fast") +MAX_FILES = int(os.environ.get("MAX_FILES", "0")) # 0 = all files + + +def _maybe_truncate_inputs(input_path: str, max_files: int) -> str | list[str]: + """If max_files > 0, glob the input path and return a truncated file list.""" + if max_files <= 0: + return input_path + from marin.utils import fsspec_glob + + files = sorted(fsspec_glob(f"{input_path.rstrip('/')}/**/*.{{jsonl.gz,jsonl,json.gz,json,parquet,vortex}}")) + truncated = files[:max_files] + logger.info("Truncated input to %d / %d files (max_files=%d)", len(truncated), len(files), max_files) + return truncated + + +def build_steps() -> list[StepSpec]: + input_paths = _maybe_truncate_inputs(NEMOTRON_HIGH, MAX_FILES) + + dedup_step = StepSpec( + name="exact_dedup_nemotron_high_arrow", + output_path_prefix=marin_temp_bucket(ttl_days=1, prefix=OUTPUT_PREFIX), + fn=lambda op: dedup_exact_paragraph( + input_paths=input_paths, + output_path=op, + max_parallelism=2048, + ), + ) + return [dedup_step] + + +if __name__ == "__main__": + configure_logging(logging.INFO) + StepRunner().run(build_steps()) diff --git a/lib/marin/src/marin/processing/classification/deduplication/exact.py b/lib/marin/src/marin/processing/classification/deduplication/exact.py index 8779a806f3..a18fef0a78 100644 --- a/lib/marin/src/marin/processing/classification/deduplication/exact.py +++ b/lib/marin/src/marin/processing/classification/deduplication/exact.py @@ -170,11 +170,13 @@ def _flat_map_paragraph_hashes(paths: list[str]) -> Iterator[dict]: # NOTE: selecting the canonical record is deterministic via this sort sort_by=lambda record: record["id"], reducer=annotate_dups, + max_hot_shard_splits=8, ) .group_by( lambda r: r["file_idx"], sort_by=lambda r: r["id"], reducer=aggregate_and_write_to_corresponding_files, + max_hot_shard_splits=8, ), verbose=True, ), @@ -256,8 +258,14 @@ def _flat_map_document_hashes(paths: list[str]) -> Iterator[dict]: # NOTE: selecting the canonical record is deterministic via this sort sort_by=lambda record: record["id"], reducer=annotate_dups, + max_hot_shard_splits=8, ) - .group_by(lambda r: r["file_idx"], sort_by=lambda r: r["id"], reducer=aggregate_and_write), + .group_by( + lambda r: r["file_idx"], + sort_by=lambda r: r["id"], + reducer=aggregate_and_write, + max_hot_shard_splits=8, + ), verbose=True, ), ) diff --git a/lib/marin/src/marin/processing/classification/deduplication/fuzzy.py b/lib/marin/src/marin/processing/classification/deduplication/fuzzy.py index a101880158..ebb7c7c5d9 100644 --- a/lib/marin/src/marin/processing/classification/deduplication/fuzzy.py +++ b/lib/marin/src/marin/processing/classification/deduplication/fuzzy.py @@ -95,7 +95,11 @@ def compute_minhash_lsh_batches(batch: pa.RecordBatch) -> Iterator[dict]: doc_id_val = doc_id.as_py() for b in doc_buckets.as_py(): counters.increment("minhash/buckets") - yield {"bucket": str(b), "id": doc_id_val} + # Reinterpret u64 as signed int64 so Arrow infers int64 instead of + # failing on values >= 2^63. The bucket is only a grouping key so + # the sign bit doesn't matter. + bucket = b if b < (1 << 63) else b - (1 << 64) + yield {"bucket": bucket, "id": doc_id_val} ctx = ZephyrContext( name="fuzzy-dedup", diff --git a/lib/zephyr/src/zephyr/dataset.py b/lib/zephyr/src/zephyr/dataset.py index a936a3f64f..11708f6072 100644 --- a/lib/zephyr/src/zephyr/dataset.py +++ b/lib/zephyr/src/zephyr/dataset.py @@ -233,6 +233,7 @@ class GroupByOp: num_output_shards: int | None = None # None = auto-detect from current shard count sort_fn: Callable | None = None # Optional secondary sort within each group combiner_fn: Callable | None = None # Optional local pre-aggregation during scatter + max_hot_shard_splits: int = 0 # 0 = disabled; >0 = max sub-tasks per hot shard def __repr__(self): return f"GroupByOp(key={_get_fn_name(self.key_fn)})" @@ -757,6 +758,7 @@ def group_by( sort_by: Callable[[T], Any] | None = None, num_output_shards: int | None = None, combiner: Callable[[K, Iterator[T]], Iterator[T]] | None = None, + max_hot_shard_splits: int = 0, ) -> Dataset[R]: ... @overload @@ -768,6 +770,7 @@ def group_by( sort_by: Callable[[T], Any] | None = None, num_output_shards: int | None = None, combiner: Callable[[K, Iterator[T]], Iterator[T]] | None = None, + max_hot_shard_splits: int = 0, ) -> Dataset[R]: ... def group_by( @@ -778,6 +781,7 @@ def group_by( sort_by: Callable[[T], Any] | None = None, num_output_shards: int | None = None, combiner: Callable[[K, Iterator[T]], Iterator[T]] | None = None, + max_hot_shard_splits: int = 0, ) -> Dataset[R]: """Group items by key and apply reducer function. @@ -824,7 +828,17 @@ def group_by( """ return Dataset( self.source, - [*self.operations, GroupByOp(key, reducer, num_output_shards, sort_fn=sort_by, combiner_fn=combiner)], + [ + *self.operations, + GroupByOp( + key, + reducer, + num_output_shards, + sort_fn=sort_by, + combiner_fn=combiner, + max_hot_shard_splits=max_hot_shard_splits, + ), + ], ) def deduplicate(self, key: Callable[[T], object], num_output_shards: int | None = None) -> Dataset[T]: diff --git a/lib/zephyr/src/zephyr/execution.py b/lib/zephyr/src/zephyr/execution.py index 0aba9e04f3..805ead7ceb 100644 --- a/lib/zephyr/src/zephyr/execution.py +++ b/lib/zephyr/src/zephyr/execution.py @@ -17,7 +17,6 @@ import itertools import logging import os -import pickle import re from datetime import datetime, timezone import threading @@ -44,6 +43,7 @@ Join, PhysicalOp, PhysicalPlan, + Reduce, Scatter, Shard, SourceItem, @@ -57,45 +57,65 @@ logger = logging.getLogger(__name__) +_PARQUET_CHUNK_VALUE_COL = "_zephyr_value" + + @dataclass(frozen=True) -class PickleDiskChunk: - """Reference to a pickle chunk stored on disk. +class ParquetDiskChunk: + """Reference to a Parquet chunk stored on disk. Each write goes to a UUID-unique path to avoid collisions when multiple workers race on the same shard. No coordinator-side rename is needed; the winning result's paths are used directly and the entire execution directory is cleaned up after the pipeline completes. + + Items that are dicts are stored as Arrow columns directly. Non-dict items + (scalars, frozensets, etc.) are wrapped in a ``_zephyr_value`` column via + cloudpickle so that arbitrary Python objects can round-trip through Parquet. """ path: str count: int + wrapped: bool = False def __iter__(self) -> Iterator: return iter(self.read()) @classmethod - def write(cls, path: str, data: list) -> PickleDiskChunk: - """Write *data* to a UUID-unique path derived from *path*. + def write(cls, path: str, data: list) -> ParquetDiskChunk: + """Write *data* as a Parquet file at a UUID-unique path derived from *path*.""" + import pyarrow.parquet as pq - The UUID suffix avoids collisions when multiple workers race on - the same shard. The resulting path is used directly for reads — - no rename step is required. - """ from zephyr.writers import unique_temp_path ensure_parent_dir(path) data = list(data) count = len(data) - unique_path = unique_temp_path(path) - with open_url(unique_path, "wb") as f: - pickle.dump(data, f) - return cls(path=unique_path, count=count) + + wrapped = False + if not data or not isinstance(data[0], dict): + wrapped = True + else: + try: + table = pa.Table.from_pylist(data) + except (pa.ArrowInvalid, pa.ArrowTypeError, pa.ArrowNotImplementedError): + wrapped = True + + if wrapped: + table = pa.table({_PARQUET_CHUNK_VALUE_COL: [cloudpickle.dumps(item) for item in data]}) + pq.write_table(table, unique_path, compression="zstd") + return cls(path=unique_path, count=count, wrapped=wrapped) def read(self) -> list: - """Load chunk data from disk.""" - with open_url(self.path, "rb") as f: - return pickle.load(f) + import pickle + + import pyarrow.parquet as pq + + table = pq.read_table(self.path) + if _PARQUET_CHUNK_VALUE_COL in table.column_names: + return [pickle.loads(b) for b in table.column(_PARQUET_CHUNK_VALUE_COL).to_pylist()] + return table.to_pylist() # --------------------------------------------------------------------------- @@ -103,12 +123,15 @@ def read(self) -> list: # --------------------------------------------------------------------------- from zephyr.shuffle import ( # noqa: E402 + KeyRange, # noqa: F401 — re-exported for callers ListShard, MemChunk, ScatterParquetIterator, # noqa: F401 — re-exported for external callers ScatterShard, # noqa: F401 — re-exported for plan.py and external callers _build_scatter_shard_from_manifest, # noqa: F401 — re-exported for plan.py - _make_envelope, + _compute_key_range_boundaries, + _read_scatter_manifest, + make_envelope_batch, _write_parquet_scatter, _write_scatter_manifest, _SCATTER_MANIFEST_NAME, @@ -141,7 +164,7 @@ class TaskResult: """Result of a single worker task. Always contains a ListShard. For non-scatter stages, refs are - PickleDiskChunks. For scatter stages, refs contain file paths + ParquetDiskChunks. For scatter stages, refs contain file paths (the actual metadata lives in ``.scatter_meta`` sidecar files read lazily by reducers). """ @@ -177,16 +200,12 @@ def _cleanup_execution(prefix: str, execution_id: str) -> None: logger.info(f"Cleaned up execution directory {exec_dir} in {elapsed:.1f}s") -def _write_pickle_chunks( +def _write_parquet_chunks( items: Iterator, source_shard: int, chunk_path_fn: Callable[[int], str], ) -> ListShard: - """Batch a plain item stream into pickle chunk files. - - Returns a ListShard containing PickleDiskChunk references. - """ - # TODO: make chunk_size configurable per writer + """Batch a plain item stream into Parquet chunk files.""" chunk_size = 100_000 chunks: list[Iterable] = [] batch: list = [] @@ -195,20 +214,20 @@ def _write_pickle_chunks( for item in items: batch.append(item) if chunk_size > 0 and len(batch) >= chunk_size: - chunk_ref = PickleDiskChunk.write(chunk_path_fn(pidx), batch) + chunk_ref = ParquetDiskChunk.write(chunk_path_fn(pidx), batch) chunks.append(chunk_ref) pidx += 1 batch = [] if pidx % 10 == 0: logger.info( - "[shard %d] Wrote %d pickle chunks so far (latest: %d items)", + "[shard %d] Wrote %d parquet chunks so far (latest: %d items)", source_shard, pidx, chunk_ref.count, ) if batch: - chunks.append(PickleDiskChunk.write(chunk_path_fn(pidx), batch)) + chunks.append(ParquetDiskChunk.write(chunk_path_fn(pidx), batch)) return ListShard(refs=chunks) @@ -227,7 +246,7 @@ def _write_stage_output( wrapping and ``.scatter_meta`` sidecars. Returns TaskResult with compact scatter metadata. - For non-scatter stages, batches items into pickle chunk files. Returns + For non-scatter stages, batches items into Parquet chunk files. Returns TaskResult with a ListShard. """ if scatter_op is not None: @@ -240,8 +259,8 @@ def _write_stage_output( use_pickle_envelope = False try: - test_envelope = _make_envelope([first_item], 0, 0) - pa.RecordBatch.from_pylist(test_envelope) + test_key = scatter_op.key_fn(first_item) + make_envelope_batch([first_item], 0, 0, key_values=[test_key], sort_values=None, pickled=False) logger.info("Using Parquet for scatter serialization for shard %d", source_shard) except Exception: use_pickle_envelope = True @@ -265,9 +284,9 @@ def _write_stage_output( return TaskResult(shard=shard) def chunk_path_fn(idx: int) -> str: - return f"{stage_dir}/shard-{shard_idx:04d}/chunk-{idx:04d}.pkl" + return f"{stage_dir}/shard-{shard_idx:04d}/chunk-{idx:04d}.parquet" - return TaskResult(shard=_write_pickle_chunks(stage_gen, source_shard, chunk_path_fn)) + return TaskResult(shard=_write_parquet_chunks(stage_gen, source_shard, chunk_path_fn)) class WorkerState(enum.Enum): @@ -288,6 +307,8 @@ class ShardTask: operations: list[PhysicalOp] stage_name: str = "output" aux_shards: dict[int, Shard] | None = None + logical_shard_idx: int | None = None + key_range: Any | None = None class ZephyrWorkerError(RuntimeError): @@ -791,6 +812,8 @@ def run_pipeline( default=-1, ) + last_scatter_manifest: str | None = None + for stage_idx, stage in enumerate(plan.stages): stage_label = f"stage{stage_idx}-{stage.stage_name(max_length=40)}" @@ -801,8 +824,24 @@ def run_pipeline( # Compute aux data for joins aux_per_shard = self._compute_join_aux(stage.operations, shards, stage_idx) - # Build and submit tasks - tasks = _compute_tasks_from_shards(shards, stage, aux_per_shard, stage_name=stage_label) + # Check if this reduce stage should use hot shard splitting + reduce_op = next((op for op in stage.operations if isinstance(op, Reduce)), None) + split_map: dict[int, list[int]] = {} + task_to_logical: dict[int, int] = {} + + if reduce_op is not None and reduce_op.max_hot_shard_splits > 0 and last_scatter_manifest is not None: + tasks, split_map = _expand_hot_shard_tasks( + shards, + stage, + stage_label, + last_scatter_manifest, + max_splits=reduce_op.max_hot_shard_splits, + aux_per_shard=aux_per_shard, + ) + task_to_logical = {t.shard_idx: t.logical_shard_idx for t in tasks} + else: + tasks = _compute_tasks_from_shards(shards, stage, aux_per_shard, stage_name=stage_label) + output_stage_name = tasks[0].stage_name if tasks else stage_label logger.info("[%s] Starting stage %s with %d tasks", self._execution_id, stage_label, len(tasks)) self._start_stage(stage_label, tasks, is_last_stage=(stage_idx == last_worker_stage_idx)) @@ -812,15 +851,27 @@ def run_pipeline( # Collect and regroup results for next stage result_refs = self._collect_results() + + if split_map: + result_refs = _merge_split_results(result_refs, split_map, task_to_logical) + stage_is_scatter = any(isinstance(op, Scatter) for op in stage.operations) + scatter_manifest_dir = f"{self._chunk_prefix}/{self._execution_id}/{output_stage_name}" shards = _regroup_result_refs( result_refs, len(shards), output_shard_count=stage.output_shards, is_scatter=stage_is_scatter, - scatter_manifest_dir=f"{self._chunk_prefix}/{self._execution_id}/{output_stage_name}", + scatter_manifest_dir=scatter_manifest_dir, ) + # Track manifest path for hot shard detection on the next reduce stage + if stage_is_scatter: + manifest_name = _SCATTER_MANIFEST_NAME + last_scatter_manifest = f"{scatter_manifest_dir}/{manifest_name}" + else: + last_scatter_manifest = None + # Flatten final results flat_result = [] for shard in shards: @@ -971,7 +1022,7 @@ def __init__(self, coordinator_handle: ActorHandle): self._host_shutdown_event = actor_ctx.shutdown_event self._worker_id = f"{actor_ctx.group_name}-{actor_ctx.index}" - # Register with coordinator - wait is not stricly necessary, but it reduces the complexity + # Register with coordinator - wait is not strictly necessary, but it reduces the complexity self._coordinator.register_worker.remote(self._worker_id, actor_ctx.handle).result(timeout=60.0) # Start polling in a background thread @@ -1133,7 +1184,7 @@ def _poll_loop(self, coordinator: ActorHandle) -> None: logger.info("[%s] Executing task for shard %d (attempt %d)", self._worker_id, task.shard_idx, attempt) try: t_0 = time.monotonic() - result = self._execute_shard(task, config) + result = self._execute_shard(task, config, attempt) logger.info( "[%s] Task for shard %d completed in %.2f seconds", self._worker_id, @@ -1162,7 +1213,7 @@ def _poll_loop(self, coordinator: ActorHandle) -> None: "".join(traceback.format_exc()), ).result() - def _execute_shard(self, task: ShardTask, config: dict) -> TaskResult: + def _execute_shard(self, task: ShardTask, config: dict, attempt: int = 0) -> TaskResult: """Execute a stage's operations on a single shard. Returns list[TaskResult]. @@ -1190,10 +1241,12 @@ def _execute_shard(self, task: ShardTask, config: dict) -> TaskResult: shard_idx=task.shard_idx, total_shards=task.total_shards, aux_shards=task.aux_shards, + logical_shard_idx=task.logical_shard_idx, + key_range=task.key_range, ) stage_dir = f"{self._chunk_prefix}/{self._execution_id}/{task.stage_name}" - external_sort_dir = f"{stage_dir}-external-sort/shard-{task.shard_idx:04d}" + external_sort_dir = f"{stage_dir}-external-sort/shard-{task.shard_idx:04d}/attempt-{attempt}" scatter_op = next((op for op in task.operations if isinstance(op, Scatter)), None) result = _write_stage_output( @@ -1252,6 +1305,165 @@ def _regroup_result_refs( return [result_refs[idx].shard if idx in result_refs else ListShard(refs=[]) for idx in range(num_output)] +# --------------------------------------------------------------------------- +# Hot shard splitting +# --------------------------------------------------------------------------- + + +def _detect_hot_shards( + manifest_path: str, + num_output_shards: int, + max_splits: int, + split_threshold: float = 3.0, +) -> dict[int, int]: + """Detect hot shards from scatter manifest. + + Returns {shard_idx: num_splits} for shards whose total chunk count exceeds + split_threshold * median. Only returns entries where splitting is worthwhile + (num_splits >= 2). + """ + entries = _read_scatter_manifest(manifest_path) + + chunk_totals: dict[int, int] = defaultdict(int) + for entry in entries: + for shard_str, count in entry["chunk_counts"].items(): + chunk_totals[int(shard_str)] += count + + if not chunk_totals: + return {} + + counts = sorted(chunk_totals.values()) + n = len(counts) + median = counts[n // 2] if n % 2 == 1 else (counts[n // 2 - 1] + counts[n // 2]) / 2 + if median <= 0: + median = 1 + + threshold = split_threshold * median + hot: dict[int, int] = {} + for shard_idx, total in chunk_totals.items(): + if shard_idx >= num_output_shards: + continue + if total > threshold: + splits = min(max_splits, max(2, int(total / median))) + hot[shard_idx] = splits + + if hot: + logger.info( + "Hot shard detection: median_chunks=%.0f, threshold=%.0f, hot_shards=%s", + median, + threshold, + {k: f"{chunk_totals[k]} chunks -> {v} splits" for k, v in hot.items()}, + ) + return hot + + +def _expand_hot_shard_tasks( + shards: list, + stage, + stage_label: str, + manifest_path: str, + max_splits: int, + aux_per_shard: list | None = None, +) -> tuple[list, dict[int, list[int]]]: + """Detect hot shards and expand into sub-tasks with key ranges. + + Returns: + tasks: expanded ShardTask list + split_map: {logical_shard_idx: [task_idx, ...]} for split shards + """ + hot = _detect_hot_shards(manifest_path, len(shards), max_splits) + + if not hot: + tasks = _compute_tasks_from_shards(shards, stage, aux_per_shard, stage_name=stage_label) + return tasks, {} + + tasks: list[ShardTask] = [] + split_map: dict[int, list[int]] = {} + task_idx = 0 + raw_name = stage_label or stage.stage_name(max_length=60) + output_stage_name = re.sub(r"[^a-zA-Z0-9_.-]+", "-", raw_name).strip("-") + + for shard_idx, shard in enumerate(shards): + aux_shards = None + if aux_per_shard and aux_per_shard[shard_idx]: + aux_shards = aux_per_shard[shard_idx] + + if shard_idx not in hot: + tasks.append( + ShardTask( + shard_idx=task_idx, + total_shards=0, # patched below + shard=shard, + operations=stage.operations, + stage_name=output_stage_name, + aux_shards=aux_shards, + logical_shard_idx=shard_idx, + ) + ) + task_idx += 1 + else: + num_splits = hot[shard_idx] + key_ranges = _compute_key_range_boundaries(manifest_path, shard_idx, num_splits) + sub_task_ids = [] + for kr in key_ranges: + tasks.append( + ShardTask( + shard_idx=task_idx, + total_shards=0, # patched below + shard=shard, + operations=stage.operations, + stage_name=output_stage_name, + aux_shards=aux_shards, + logical_shard_idx=shard_idx, + key_range=kr, + ) + ) + sub_task_ids.append(task_idx) + task_idx += 1 + split_map[shard_idx] = sub_task_ids + + total = len(tasks) + for t in tasks: + t.total_shards = total + + logger.info( + "Hot shard expansion: %d logical shards -> %d tasks (%d shards split)", + len(shards), + total, + len(split_map), + ) + return tasks, split_map + + +def _merge_split_results( + result_refs: dict[int, TaskResult], + split_map: dict[int, list[int]], + task_to_logical: dict[int, int], +) -> dict[int, TaskResult]: + """Merge sub-task results back into logical shard results. + + For each split shard, concatenates the ListShard refs from all sub-tasks. + Non-split tasks are re-keyed from their task_idx to their logical shard. + """ + merged: dict[int, TaskResult] = {} + + # Handle split shards + for logical_idx, sub_task_ids in split_map.items(): + combined_refs = [] + for tid in sub_task_ids: + if tid in result_refs: + combined_refs.extend(result_refs[tid].shard.refs) + merged[logical_idx] = TaskResult(shard=ListShard(refs=combined_refs)) + + # Handle non-split tasks + for tid, result in result_refs.items(): + logical = task_to_logical.get(tid) + if logical is not None and logical not in split_map: + merged[logical] = result + + return merged + + # --------------------------------------------------------------------------- # Coordinator-as-Job infrastructure # --------------------------------------------------------------------------- diff --git a/lib/zephyr/src/zephyr/external_sort.py b/lib/zephyr/src/zephyr/external_sort.py index 9e9a30efa0..4fc515607e 100644 --- a/lib/zephyr/src/zephyr/external_sort.py +++ b/lib/zephyr/src/zephyr/external_sort.py @@ -1,163 +1,374 @@ # Copyright The Marin Authors # SPDX-License-Identifier: Apache-2.0 -"""Two-pass external merge sort for large k-way merges. +"""Two-pass external merge sort using Parquet spill files. -Used by the reduce stage when the number of sorted chunk iterators exceeds -``EXTERNAL_SORT_FAN_IN``, to avoid opening O(k) scanners simultaneously and -exhausting worker memory. +Used by the reduce stage when scatter chunks don't fit in memory. -Pass 1: batch the k iterators into groups of EXTERNAL_SORT_FAN_IN, merge each -group with heapq.merge, and spill items in batches of ``_WRITE_BATCH_SIZE`` to -a zstd-compressed pickle run file under -``{external_sort_dir}/run-{i:04d}.pkl.zst``. Items are streamed to disk -rather than accumulated in a list, so peak memory per batch is bounded by the -number of open iterators rather than their total item count. +Each scatter chunk is already sorted by (sort_key, sort_secondary). +Pass 1 streams batches of pre-sorted chunk tables through a k-way merge +(no re-sort needed), writing merged runs as Parquet spill files. -Pass 2: heapq.merge over the (much smaller) set of run file iterators. Each -iterator reads one batch at a time and yields items one-by-one; the read batch -size is computed from the cgroup memory limit so that all concurrent batches -together stay within ``_READ_MEMORY_FRACTION`` of available memory. +Pass 2: streaming k-way merge over the sorted run files, reading one row +group at a time per run. -Run files are deleted after the final merge completes. +All buffer sizes are derived from the worker's memory budget, probed from +actual data sizes. """ import heapq +import itertools import logging -import pickle -from collections.abc import Callable, Iterator +from collections.abc import Iterator +from dataclasses import dataclass, field from itertools import islice import fsspec -import zstandard as zstd -from iris.env_resources import TaskResources -from rigging.filesystem import url_to_fs +import pyarrow as pa +import pyarrow.compute as pc +import pyarrow.parquet as pq +from iris.env_resources import TaskResources as _TaskResources + +from zephyr.spill_writer import SpillWriter, TableAccumulator logger = logging.getLogger(__name__) -# Maximum simultaneous chunk iterators per pass-1 batch. -EXTERNAL_SORT_FAN_IN = 500 +# Fraction of worker memory available for merge (pass 1 and pass 2 are +# sequential, so both use the same budget). +_MERGE_MEMORY_FRACTION = 0.5 + +# Target spill row group size in bytes. Each run holds one row group in +# memory during merge, so this controls per-run memory footprint. +_SPILL_ROW_GROUP_TARGET_BYTES = 8 * 1024 * 1024 # 8 MB + -# Items per pickle.dump in pass-1. Larger batches compress better (zstd -# dictionary spans the whole batch) and reduce per-call overhead. -_WRITE_BATCH_SIZE = 10_000 +@dataclass +class _MergeBudget: + """Derived budget parameters, computed once from probed data sizes.""" -# Fraction of container memory budgeted for pass-2 read buffers. -_READ_MEMORY_FRACTION = 0.25 + merge_budget_bytes: int + fan_in: int -def _safe_read_batch_size(n_runs: int, sample_run_path: str) -> int: - """Compute a pass-2 read batch size that fits within the memory budget. +def _compute_budget(chunk_bytes: int) -> _MergeBudget: + """Compute merge budget from probed chunk size and worker memory. - Probes the first batch from ``sample_run_path`` to estimate in-memory - bytes per item, then divides the memory budget by ``n_runs * item_bytes`` - so that all concurrent run-file buffers together stay within - ``_READ_MEMORY_FRACTION`` of available container memory. + fan_in = how many chunk tables we can hold in memory simultaneously + during the k-way merge. """ - dctx = zstd.ZstdDecompressor() - try: - with fsspec.open(sample_run_path, "rb") as raw_f: - with dctx.stream_reader(raw_f) as f: - sample_batch: list = pickle.load(f) - except Exception: - return _WRITE_BATCH_SIZE - - sample = sample_batch[:100] - if not sample: - return _WRITE_BATCH_SIZE - # pickle size x 3 approximates Python object overhead (dicts are ~3x larger - # in memory than their serialised form). - item_bytes = max(64, len(pickle.dumps(sample)) // len(sample) * 3) - - available = TaskResources.from_environment().memory_bytes - budget = int(available * _READ_MEMORY_FRACTION) - size = budget // max(1, n_runs * item_bytes) - result = max(100, min(size, _WRITE_BATCH_SIZE)) + memory_bytes = _TaskResources.from_environment().memory_bytes + if memory_bytes <= 0: + memory_bytes = 16 * 1024**3 + merge_budget = int(memory_bytes * _MERGE_MEMORY_FRACTION) + + fan_in = max(1, min(1000, merge_budget // max(chunk_bytes, 1))) + + budget = _MergeBudget(merge_budget_bytes=merge_budget, fan_in=fan_in) logger.info( - "External sort pass-2: %d runs x ~%d bytes/item, budget=%.1f GB -> read_batch_size=%d", - n_runs, - item_bytes, - budget / 1e9, - result, + "External sort budget: memory=%dMB, merge_budget=%dMB, fan_in=%d, chunk_bytes=%dMB", + memory_bytes // (1024 * 1024), + merge_budget // (1024 * 1024), + fan_in, + chunk_bytes // (1024 * 1024), ) - return result + return budget + + +def _write_spill_file(table: pa.Table, path: str) -> None: + """Write a sorted table as a Parquet file with byte-budgeted row groups.""" + with SpillWriter(path, table.schema, row_group_bytes=_SPILL_ROW_GROUP_TARGET_BYTES) as w: + w.write_table(table) + + +def _promote_to_large_string(table: pa.Table) -> pa.Table: + """Cast string/binary columns to large_string/large_binary to avoid 2GB offset overflow on concat.""" + new_fields = [] + needs_cast = False + for f in table.schema: + if f.type == pa.string(): + new_fields.append(f.with_type(pa.large_string())) + needs_cast = True + elif f.type == pa.binary(): + new_fields.append(f.with_type(pa.large_binary())) + needs_cast = True + else: + new_fields.append(f) + if not needs_cast: + return table + return table.cast(pa.schema(new_fields)) + + +# --------------------------------------------------------------------------- +# Merge sources +# --------------------------------------------------------------------------- + + +@dataclass(order=True) +class _MergeEntry: + """Heap entry keyed by the sort value at the current cursor position.""" + + sort_value: tuple + source_idx: int = field(compare=True) + source: "_MergeSource" = field(compare=False, repr=False) + + +class _MergeSource: + """Abstract source for the k-way merge. Provides a cursor over sorted rows.""" + + sort_key_columns: list[str] + table: pa.Table | None + cursor: int + + def advance(self) -> bool: + raise NotImplementedError + + def current_sort_value(self) -> tuple: + return tuple(self.table.column(c)[self.cursor].as_py() for c in self.sort_key_columns) + + def remaining(self) -> int: + return len(self.table) - self.cursor + + def take(self, count: int) -> pa.Table: + sliced = self.table.slice(self.cursor, count) + self.cursor += count + if self.cursor >= len(self.table): + self.advance() + return sliced + + def rows_le(self, threshold: tuple) -> int: + """Count rows from cursor whose sort key <= threshold. + + Exploits sorted data — uses vectorized comparison to find the first + row exceeding the threshold. + """ + remaining_rows = self.remaining() + primary_col = self.table.column(self.sort_key_columns[0]).slice(self.cursor, remaining_rows) + + if len(self.sort_key_columns) == 1: + gt_mask = pc.greater(primary_col, pa.scalar(threshold[0], type=primary_col.type)) + gt_count = pc.sum(gt_mask).as_py() or 0 + return max(1, remaining_rows - gt_count) + + secondary_col = self.table.column(self.sort_key_columns[1]).slice(self.cursor, remaining_rows) + primary_gt = pc.greater(primary_col, pa.scalar(threshold[0], type=primary_col.type)) + primary_eq = pc.equal(primary_col, pa.scalar(threshold[0], type=primary_col.type)) + secondary_gt = pc.greater(secondary_col, pa.scalar(threshold[1], type=secondary_col.type)) + exceeds = pc.or_(primary_gt, pc.and_(primary_eq, secondary_gt)) + exceed_count = pc.sum(exceeds).as_py() or 0 + return max(1, remaining_rows - exceed_count) + + @property + def has_data(self) -> bool: + return self.table is not None + + +class _RunSource(_MergeSource): + """Read position within a single sorted Parquet run file.""" + + def __init__(self, idx: int, pf: pq.ParquetFile, sort_key_columns: list[str]): + self.idx = idx + self.pf = pf + self.sort_key_columns = sort_key_columns + self._rg_idx = 0 + self.table = None + self.cursor = 0 + + def advance(self) -> bool: + while self._rg_idx < self.pf.metadata.num_row_groups: + self.table = self.pf.read_row_group(self._rg_idx) + self._rg_idx += 1 + self.cursor = 0 + if len(self.table) > 0: + return True + self.table = None + return False + + +class _TableSource(_MergeSource): + """Merge source backed by a single in-memory Arrow table (one scatter chunk).""" + + def __init__(self, idx: int, table: pa.Table, sort_key_columns: list[str]): + self.idx = idx + self.sort_key_columns = sort_key_columns + self._loaded = table + self.table = None + self.cursor = 0 + + def advance(self) -> bool: + if self._loaded is not None: + self.table = self._loaded + self._loaded = None + self.cursor = 0 + return len(self.table) > 0 + self.table = None + return False + + +# --------------------------------------------------------------------------- +# K-way merge +# --------------------------------------------------------------------------- + + +def _streaming_k_way_merge( + sources: list[_MergeSource], + sort_keys: list[tuple[str, str]], + output_batch_bytes: int = _SPILL_ROW_GROUP_TARGET_BYTES, +) -> Iterator[pa.Table]: + """Streaming k-way merge over pre-sorted sources. + + Uses a min-heap to pick the source with the smallest current key, + and yields batches of sorted rows. + """ + if not sources: + return + + heap: list[_MergeEntry] = [] + for i, src in enumerate(sources): + heapq.heappush(heap, _MergeEntry(src.current_sort_value(), i, src)) + + accumulator = TableAccumulator(output_batch_bytes) + + while heap: + entry = heapq.heappop(heap) + winner = entry.source + + if heap: + next_key = heap[0].sort_value + take_count = winner.rows_le(next_key) + else: + take_count = winner.remaining() + + chunk = winner.take(take_count) + + if winner.has_data: + heapq.heappush(heap, _MergeEntry(winner.current_sort_value(), winner.idx, winner)) + + merged = accumulator.add(chunk) + if merged is not None: + yield merged + + remaining = accumulator.flush() + if remaining is not None: + yield remaining + + +# --------------------------------------------------------------------------- +# External sort entry point +# --------------------------------------------------------------------------- def external_sort_merge( - chunk_iterators_gen: Iterator[Iterator], # lazy — consumed in batches - merge_key: Callable, + chunk_tables_gen: Iterator[pa.Table], + sort_keys: list[tuple[str, str]], external_sort_dir: str, -) -> Iterator: - """Merge ``chunk_iterators_gen`` via a two-pass external sort. - - Args: - chunk_iterators_gen: Lazy iterator of sorted iterators (one per scatter chunk). - Consumed in batches of EXTERNAL_SORT_FAN_IN to avoid opening all file - handles simultaneously. - merge_key: Key function passed to heapq.merge. - external_sort_dir: GCS prefix for spill files, e.g. - ``gs://bucket/.../stage1-external-sort/shard-0042``. - - Yields: - Items in merged sort order. +) -> Iterator[pa.Table]: + """Two-pass external merge yielding sorted Arrow tables. + + Input chunk tables are assumed to be pre-sorted by the scatter writer. + + Pass 1: batch pre-sorted chunk tables into groups of fan_in, stream + through a k-way merge, and write merged runs as Parquet spill files. + No re-sort is needed — the merge is O(n) per batch. + + Pass 2: streaming k-way merge over the (much smaller) set of run files. """ - cctx = zstd.ZstdCompressor(level=3) + from zephyr.writers import ensure_parent_dir + + first = next(chunk_tables_gen, None) + if first is None: + return + budget = _compute_budget(first.nbytes) + + chunk_tables_gen = itertools.chain([first], chunk_tables_gen) + + sort_key_columns = [col for col, _ in sort_keys] run_paths: list[str] = [] batch_idx = 0 while True: - batch = list(islice(chunk_iterators_gen, EXTERNAL_SORT_FAN_IN)) - if not batch: + batch_tables = list(islice(chunk_tables_gen, budget.fan_in)) + if not batch_tables: break - run_path = f"{external_sort_dir}/run-{batch_idx:04d}.pkl.zst" - item_count = 0 - pending: list = [] - with fsspec.open(run_path, "wb") as raw_f: - with cctx.stream_writer(raw_f, closefd=False) as f: - for item in heapq.merge(*batch, key=merge_key): - pending.append(item) - if len(pending) >= _WRITE_BATCH_SIZE: - pickle.dump(pending, f, protocol=pickle.HIGHEST_PROTOCOL) - item_count += len(pending) - pending = [] - if pending: - pickle.dump(pending, f, protocol=pickle.HIGHEST_PROTOCOL) - item_count += len(pending) + + # Build merge sources from pre-sorted chunk tables + sources: list[_MergeSource] = [] + for i, t in enumerate(batch_tables): + t = _promote_to_large_string(t) + src = _TableSource(idx=i, table=t, sort_key_columns=sort_key_columns) + if src.advance(): + sources.append(src) + + if not sources: + continue + + run_path = f"{external_sort_dir}/run-{batch_idx:04d}.parquet" + ensure_parent_dir(run_path) + + # Stream k-way merge directly to Parquet spill file + merged_iter = _streaming_k_way_merge(sources, sort_keys) + first_merged = next(merged_iter, None) + if first_merged is None: + continue + + merged_rows = len(first_merged) + writer = SpillWriter(run_path, first_merged.schema, row_group_bytes=_SPILL_ROW_GROUP_TARGET_BYTES) + try: + writer.write_table(first_merged) + for merged_table in merged_iter: + writer.write_table(merged_table) + merged_rows += len(merged_table) + finally: + writer.close() + + # Free the batch tables now that they've been merged to disk + del batch_tables, sources + run_paths.append(run_path) logger.info( - "External sort: wrote run %d (%d items) to %s", + "External sort: wrote run %d (%d rows) to %s", batch_idx + 1, - item_count, + merged_rows, run_path, ) batch_idx += 1 - read_batch_size = _safe_read_batch_size(len(run_paths), run_paths[0]) if run_paths else _WRITE_BATCH_SIZE - - def _read_run(path: str) -> Iterator: - with fsspec.open(path, "rb") as raw_f: - with zstd.ZstdDecompressor().stream_reader(raw_f) as f: - while True: - try: - items: list = pickle.load(f) - # Yield in read_batch_size chunks and delete consumed - # items in-place so memory is released progressively - # even while the generator is suspended in heapq.merge. - while items: - chunk = items[:read_batch_size] - del items[:read_batch_size] - yield from chunk - except EOFError: - break - - run_iters = [_read_run(p) for p in run_paths] + if not run_paths: + return + + # Pass 2: verify merge memory fits in budget using actual Parquet metadata. + num_runs = len(run_paths) + max_rg_bytes = 0 + for rp in run_paths: + meta = pq.read_metadata(rp) + for i in range(meta.num_row_groups): + max_rg_bytes = max(max_rg_bytes, meta.row_group(i).total_byte_size) + merge_estimate = num_runs * max_rg_bytes + if merge_estimate > budget.merge_budget_bytes: + logger.warning( + "External sort merge may exceed budget: %d runs x %.0fMB/rg = %.0fMB > %dMB budget", + num_runs, + max_rg_bytes / (1024 * 1024), + merge_estimate / (1024 * 1024), + budget.merge_budget_bytes // (1024 * 1024), + ) + try: - yield from heapq.merge(*run_iters, key=merge_key) + if len(run_paths) == 1: + pf = pq.ParquetFile(run_paths[0]) + for i in range(pf.metadata.num_row_groups): + yield pf.read_row_group(i) + else: + run_sources: list[_MergeSource] = [] + for i, path in enumerate(run_paths): + src = _RunSource(idx=i, pf=pq.ParquetFile(path), sort_key_columns=sort_key_columns) + if src.advance(): + run_sources.append(src) + yield from _streaming_k_way_merge(run_sources, sort_keys) finally: fs, _ = fsspec.core.url_to_fs(external_sort_dir) for path in run_paths: try: - _, fs_path = url_to_fs(path) + _, fs_path = fsspec.core.url_to_fs(path) fs.rm(fs_path) except Exception: pass diff --git a/lib/zephyr/src/zephyr/plan.py b/lib/zephyr/src/zephyr/plan.py index c6b2842926..62c78ae84f 100644 --- a/lib/zephyr/src/zephyr/plan.py +++ b/lib/zephyr/src/zephyr/plan.py @@ -22,10 +22,12 @@ from typing import Any, Protocol import msgspec +import pyarrow as pa +import pyarrow.compute as pc from iris.env_resources import TaskResources as _TaskResources from rigging.filesystem import url_to_fs -from zephyr.external_sort import EXTERNAL_SORT_FAN_IN, external_sort_merge +from zephyr.external_sort import _promote_to_large_string, external_sort_merge from zephyr.dataset import ( Dataset, @@ -123,9 +125,8 @@ class Scatter: class Reduce: """Merge sorted chunks and reduce per key.""" - key_fn: Callable[[Any], Any] reducer_fn: Callable[[Any, Iterator], Any] - sort_fn: Callable[[Any], Any] | None = None # Must match Scatter's sort_fn + max_hot_shard_splits: int = 0 @dataclass @@ -176,19 +177,142 @@ def _flatmap_gen(stream: Iterator, fn: Callable) -> Iterator: yield from fn(item) -def _reduce_gen( +def _find_group_boundaries(key_col: pa.ChunkedArray) -> Iterator[tuple[int, int, Any]]: + """Yield (start, end, key_value) for each contiguous group in a sorted key column. + + Uses Arrow compute to find boundaries vectorized instead of per-element + Python scalar extraction, which matters for high-cardinality keys. + """ + arr = key_col.combine_chunks() + n = len(arr) + if n == 0: + return + if n == 1: + yield (0, 1, arr[0].as_py()) + return + + # Vectorized boundary detection: compare adjacent elements + ne_mask = pc.not_equal(arr[:-1], arr[1:]) + boundary_indices = pc.add(pc.indices_nonzero(ne_mask), 1).to_pylist() + + prev = 0 + for idx in boundary_indices: + yield (prev, idx, arr[prev].as_py()) + prev = idx + yield (prev, n, arr[prev].as_py()) + + +def _arrow_merge_sorted_chunks(shard: Any) -> pa.Table: + """Concatenate all chunks and sort in Arrow. Returns sorted table.""" + from zephyr.shuffle import _ZEPHYR_SORT_KEY, _ZEPHYR_SORT_SECONDARY + + all_tables: list[pa.Table] = [] + for it in shard.iterators: + for table in it.get_chunk_tables(): + all_tables.append(table) + if not all_tables: + return pa.table({}) + combined = pa.concat_tables([_promote_to_large_string(t) for t in all_tables], promote_options="default") + sort_keys: list[tuple[str, str]] = [(_ZEPHYR_SORT_KEY, "ascending")] + if _ZEPHYR_SORT_SECONDARY in combined.column_names: + sort_keys.append((_ZEPHYR_SORT_SECONDARY, "ascending")) + indices = pc.sort_indices(combined, sort_keys=sort_keys) + return combined.take(indices) + + +def _arrow_reduce_gen( shard: Any, - key_fn: Callable, reducer_fn: Callable, - sort_fn: Callable | None = None, external_sort_dir: str | None = None, ) -> Iterator: + """Arrow-native reduce: sort in Arrow, group by sort key, unwrap items for reducer_fn. + + Handles both flat (dict) and pickled envelope formats via unwrap_items / is_zephyr_column. + """ + from zephyr.shuffle import ( + ScatterShard, + _ZEPHYR_PAYLOAD, + _ZEPHYR_SORT_KEY, + _ZEPHYR_SORT_SECONDARY, + unwrap_items, + ) + + force_external = os.environ.get("ZEPHYR_FORCE_EXTERNAL_MERGE", "").lower() in ("1", "true", "yes") + use_external = ( + external_sort_dir is not None + and isinstance(shard, ScatterShard) + and (force_external or shard.needs_external_sort(_TaskResources.from_environment().memory_bytes)) + ) + + if use_external: + sort_keys: list[tuple[str, str]] = [(_ZEPHYR_SORT_KEY, "ascending")] + first_tables = list(islice((t for it in shard.iterators for t in it.get_chunk_tables()), 1)) + if first_tables and _ZEPHYR_SORT_SECONDARY in first_tables[0].column_names: + sort_keys.append((_ZEPHYR_SORT_SECONDARY, "ascending")) + + logger.info( + "Arrow external sort triggered for shard with %d iterators, spilling to %s", + sum(it.chunk_count for it in shard.iterators), + external_sort_dir, + ) + + def _chunk_tables() -> Iterator[pa.Table]: + for it in shard.iterators: + yield from it.get_chunk_tables() + + # Stream through the merge, grouping by sort key across batch boundaries. + # Only one batch + one group's accumulated rows are in memory at a time. + is_gen = inspect.isgeneratorfunction(reducer_fn) + current_key = None + current_group_tables: list[pa.Table] = [] + + for batch_table in external_sort_merge(_chunk_tables(), sort_keys, external_sort_dir): + pickled = _ZEPHYR_PAYLOAD in batch_table.column_names + key_col = batch_table.column(_ZEPHYR_SORT_KEY) + + for start, end, key_value in _find_group_boundaries(key_col): + group_slice = batch_table.slice(start, end - start) + + if current_key is None: + current_key = key_value + current_group_tables = [group_slice] + elif key_value == current_key: + current_group_tables.append(group_slice) + else: + group_table = pa.concat_tables(current_group_tables, promote_options="default") + group_items = unwrap_items(group_table, pickled) + if is_gen: + yield from reducer_fn(current_key, iter(group_items)) + else: + yield reducer_fn(current_key, iter(group_items)) + current_key = key_value + current_group_tables = [group_slice] + + if current_group_tables: + group_table = pa.concat_tables(current_group_tables, promote_options="default") + pickled = _ZEPHYR_PAYLOAD in group_table.column_names + group_items = unwrap_items(group_table, pickled) + if is_gen: + yield from reducer_fn(current_key, iter(group_items)) + else: + yield reducer_fn(current_key, iter(group_items)) + return + + sorted_table = _arrow_merge_sorted_chunks(shard) + if len(sorted_table) == 0: + return + + key_col = sorted_table.column(_ZEPHYR_SORT_KEY) + pickled = _ZEPHYR_PAYLOAD in sorted_table.column_names + is_gen = inspect.isgeneratorfunction(reducer_fn) - for key, items_iter in _merge_sorted_chunks(shard, key_fn, sort_fn, external_sort_dir=external_sort_dir): + for start, end, key_value in _find_group_boundaries(key_col): + group_table = sorted_table.slice(start, end - start) + group_items = unwrap_items(group_table, pickled) if is_gen: - yield from reducer_fn(key, items_iter) + yield from reducer_fn(key_value, iter(group_items)) else: - yield reducer_fn(key, items_iter) + yield reducer_fn(key_value, iter(group_items)) def _select_gen(stream: Iterator, columns: tuple[str, ...]) -> Iterator: @@ -437,7 +561,7 @@ def _fuse_operations(operations: list) -> list[PhysicalStage]: output_shards=num_shards if num_shards > 0 else None, ) state.end_stage() - state.add_op(Reduce(key_fn=op.key_fn, reducer_fn=op.reducer_fn, sort_fn=op.sort_fn)) + state.add_op(Reduce(reducer_fn=op.reducer_fn, max_hot_shard_splits=op.max_hot_shard_splits)) elif isinstance(op, ReduceOp): state.add_op(Fold(fn=op.local_reducer)) @@ -584,64 +708,6 @@ def make_windows( yield window -def _merge_sorted_chunks( - shard: Shard, key_fn: Callable, sort_fn: Callable | None = None, external_sort_dir: str | None = None -) -> Iterator[tuple[object, Iterator]]: - """Merge sorted chunks using k-way merge, yielding (key, items_iterator) groups. - - Each chunk is assumed to be sorted by key (and optionally by sort_fn within key). - This function performs a k-way merge across all chunks and groups consecutive - items with the same key. - - Args: - shard: Shard containing sorted chunks (iterable of chunk lists) - key_fn: Function to extract grouping key from item - sort_fn: Optional secondary sort key. When provided, the merge uses - (key_fn, sort_fn) for ordering but still groups by key_fn alone. - - Yields: - Tuples of (key, iterator_of_items) for each unique key - """ - # Merge by composite key when sort_fn is provided, but group by key_fn only. - # Rebind to captured_sort_fn so pyrefly narrows the type inside the closure. - if sort_fn is not None: - captured_sort_fn = sort_fn - - def merge_key(item): - return (key_fn(item), captured_sort_fn(item)) - - else: - merge_key = key_fn - - # Check if external sort is needed BEFORE materializing all iterators. - # ScatterShard can decide using manifest stats (no file opens needed). - from zephyr.shuffle import ScatterShard - - use_external = ( - external_sort_dir is not None - and isinstance(shard, ScatterShard) - and shard.needs_external_sort(_TaskResources.from_environment().memory_bytes) - ) - - if use_external: - logger.info( - "External sort triggered for shard with %d iterators, spilling to %s", - sum(it.chunk_count for it in shard.iterators), - external_sort_dir, - ) - # Pass lazy generator — external_sort_merge consumes in batches without opening all files - merged_stream = external_sort_merge(shard.get_iterators(), merge_key, external_sort_dir) - else: - chunk_iterators = list(shard.get_iterators()) - logger.info(f"Merging {len(chunk_iterators):,} sorted chunk iterators") - if external_sort_dir is not None and len(chunk_iterators) > EXTERNAL_SORT_FAN_IN: - # Fallback: stats unavailable, use fan_in threshold - merged_stream = external_sort_merge(iter(chunk_iterators), merge_key, external_sort_dir) - else: - merged_stream = heapq.merge(*chunk_iterators, key=merge_key) - yield from groupby(merged_stream, key=key_fn) - - def _sorted_merge_join( left_stream: Iterable, right_stream: Iterable, @@ -703,12 +769,16 @@ class StageContext: shard_idx: Index of this shard total_shards: Total number of shards aux_shards: Auxiliary shards for joins, keyed by op index + logical_shard_idx: Original shard index before hot-shard splitting (None if unsplit) + key_range: Key range filter for hot-shard sub-tasks (None if unsplit) """ shard: Iterable[Any] shard_idx: int total_shards: int aux_shards: dict[int, Iterable[Any]] = field(default_factory=dict) + logical_shard_idx: int | None = None + key_range: Any | None = None def get_right_shard(self, op_index: int) -> Iterable[Any]: """Get right shard for join at given op index. @@ -812,10 +882,9 @@ def run_stage( # Shard contains a single manifest path — read it to build ScatterShard paths = list(shard) assert len(paths) == 1, f"Expected single scatter manifest path, got {len(paths)}" - shard = _build_scatter_shard_from_manifest(paths[0], ctx.shard_idx) - stream = _reduce_gen( - shard, op.key_fn, op.reducer_fn, sort_fn=op.sort_fn, external_sort_dir=external_sort_dir - ) + target = ctx.logical_shard_idx if ctx.logical_shard_idx is not None else ctx.shard_idx + shard = _build_scatter_shard_from_manifest(paths[0], target, key_range=ctx.key_range) + stream = _arrow_reduce_gen(shard, op.reducer_fn, external_sort_dir=external_sort_dir) op_index += 1 elif isinstance(op, Fold): diff --git a/lib/zephyr/src/zephyr/shuffle.py b/lib/zephyr/src/zephyr/shuffle.py index 868ada0e9a..a8287e99c0 100644 --- a/lib/zephyr/src/zephyr/shuffle.py +++ b/lib/zephyr/src/zephyr/shuffle.py @@ -20,7 +20,7 @@ import pickle from collections import defaultdict from collections.abc import Callable, Iterable, Iterator -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any import cloudpickle @@ -28,12 +28,12 @@ import pyarrow as pa import pyarrow.compute as pc import pyarrow.dataset as pad -import pyarrow.parquet as pq from iris.env_resources import TaskResources as _TaskResources from rigging.filesystem import open_url, url_to_fs from rigging.timing import log_time from zephyr.plan import deterministic_hash +from zephyr.spill_writer import SpillWriter from zephyr.writers import ensure_parent_dir logger = logging.getLogger(__name__) @@ -44,6 +44,18 @@ # --------------------------------------------------------------------------- +@dataclass(frozen=True) +class KeyRange: + """Half-open key range [lo, hi) for splitting a shard's key space. + + lo=None means unbounded below; hi=None means unbounded above. + Values are compared using the natural ordering of _zephyr_sort_key. + """ + + lo: Any | None = None + hi: Any | None = None + + @dataclass class MemChunk: """In-memory chunk.""" @@ -56,7 +68,7 @@ def __iter__(self) -> Iterator: @dataclass class ListShard: - """Shard backed by a list of iterable references (PickleDiskChunk, MemChunk, etc.).""" + """Shard backed by a list of iterable references (MemChunk, etc.).""" refs: list[Iterable] @@ -73,20 +85,76 @@ def get_iterators(self) -> Iterator[Iterator]: # Column names and constants # --------------------------------------------------------------------------- -_ZEPHYR_SHUFFLE_SHARD_IDX_COL = "shard_idx" -_ZEPHYR_SHUFFLE_CHUNK_IDX_COL = "chunk_idx" -_ZEPHYR_SHUFFLE_ITEM_COL = "item" -_ZEPHYR_SHUFFLE_PICKLED_COL = "pickled" +_ZEPHYR_SHARD_IDX = "_zephyr_shard_idx" +_ZEPHYR_CHUNK_IDX = "_zephyr_chunk_idx" +_ZEPHYR_PAYLOAD = "_zephyr_payload" +_ZEPHYR_SORT_KEY = "_zephyr_sort_key" +_ZEPHYR_SORT_SECONDARY = "_zephyr_sort_secondary" + _SCATTER_META_SUFFIX = ".scatter_meta" _SCATTER_MANIFEST_NAME = "scatter_metadata" _SCATTER_META_READ_CONCURRENCY = 256 -# Number of items sampled from the first flush to estimate avg_item_bytes at scatter-write time -_SCATTER_SAMPLE_SIZE = 100 -# Conservative item-bytes fallback when avg_item_bytes is not in the manifest -_ITEM_BYTES_FALLBACK = 500.0 -# Fraction of total memory limit to budget for scatter read buffers _SCATTER_READ_BUFFER_FRACTION = 0.25 +_SCATTER_MICRO_BATCH_SIZE = 64 +_SCATTER_ROW_GROUP_BYTES = 8 * 1024 * 1024 # 8 MB + + +def is_zephyr_column(name: str) -> bool: + """Return True if the column name is a zephyr metadata column.""" + return name.startswith("_zephyr_") + + +# --------------------------------------------------------------------------- +# Envelope helpers +# --------------------------------------------------------------------------- + + +def make_envelope_batch( + items: list, + shard_idx: int, + chunk_idx: int, + key_values: list, + sort_values: list | None, + pickled: bool, +) -> pa.RecordBatch: + """Build an Arrow batch wrapping items with zephyr metadata columns. + + Flat mode (pickled=False): each item dict's fields become columns alongside _zephyr_* metadata. + Pickle mode (pickled=True): items become _zephyr_payload bytes alongside _zephyr_* metadata. + """ + n = len(items) + rows = [] + for i in range(n): + row: dict[str, Any] = { + _ZEPHYR_SHARD_IDX: shard_idx, + _ZEPHYR_CHUNK_IDX: chunk_idx, + _ZEPHYR_SORT_KEY: key_values[i], + } + if sort_values is not None: + row[_ZEPHYR_SORT_SECONDARY] = sort_values[i] + if pickled: + row[_ZEPHYR_PAYLOAD] = cloudpickle.dumps(items[i]) + else: + row.update(items[i]) + rows.append(row) + return pa.RecordBatch.from_pylist(rows) + + +def unwrap_items(table_or_batch: pa.Table | pa.RecordBatch, pickled: bool) -> list: + """Extract user items from an envelope table/batch. + + Flat mode: drop _zephyr_* columns, return remaining fields as dicts. + Pickle mode: deserialize _zephyr_payload column. + """ + if pickled: + payload_col = table_or_batch.column(_ZEPHYR_PAYLOAD) + return [pickle.loads(b) for b in payload_col.to_pylist()] + + if isinstance(table_or_batch, pa.RecordBatch): + table_or_batch = pa.Table.from_batches([table_or_batch]) + user_cols = [name for name in table_or_batch.column_names if not is_zephyr_column(name)] + return table_or_batch.select(user_cols).to_pylist() # --------------------------------------------------------------------------- @@ -112,14 +180,14 @@ def _get_scatter_read_fs(num_files: int, sample_path: str, memory_fraction: floa budget = int(total_mem * memory_fraction) per_file = max(budget // num_files, 64 * 1024) # floor at 64 KB - # Only override when we would meaningfully reduce the default (~5 MB). if per_file >= 5 * 1024 * 1024: return default_fs - if not hasattr(fs, "blocksize"): + # blocksize is a gcsfs/s3fs attribute; not all fsspec implementations have it. + # We check via the storage_options dict rather than hasattr to stay explicit. + if "block_size" not in getattr(fs, "storage_options", {}): return default_fs - # Recreate the filesystem with the budgeted block_size. fsspec_fs = type(fs)(block_size=per_file, **{k: v for k, v in fs.storage_options.items() if k != "block_size"}) logger.info( "Scatter read: %d files, per-file block_size=%d KB (total budget=%.1f GB)", @@ -148,40 +216,53 @@ class ScatterParquetIterator: chunk_count: int is_pickled: bool filesystem: pa.fs.FileSystem + key_range: KeyRange | None = None def __iter__(self) -> Iterator: - for chunk_iter in self.get_chunk_iterators(): - yield from chunk_iter + for table in self.get_chunk_tables(): + yield from unwrap_items(table, self.is_pickled) def get_chunk_iterators(self, batch_size: int = 1024) -> Iterator[Iterator]: - """Yield one lazy iterator per sorted chunk. + """Yield one lazy iterator per sorted chunk, each backed by get_chunk_tables + unwrap_items.""" + for table in self.get_chunk_tables(batch_size=batch_size): + yield iter(unwrap_items(table, self.is_pickled)) + + def get_chunk_tables(self, batch_size: int = 1024) -> Iterator[pa.Table]: + """Yield Arrow tables per sorted chunk (no Python materialization). - Opens the file once via ``pyarrow.dataset`` and creates a Scanner - per chunk with predicate pushdown on ``(shard_idx, chunk_idx)``. + Always selects all columns; the caller (unwrap_items or the Arrow reduce + path) handles column filtering. """ _, fs_path = url_to_fs(self.path) dataset: pad.FileSystemDataset = pad.dataset(fs_path, format="parquet", filesystem=self.filesystem) - col = _ZEPHYR_SHUFFLE_PICKLED_COL if self.is_pickled else _ZEPHYR_SHUFFLE_ITEM_COL + + # Select item columns + sort columns; filter on shard/chunk metadata + if self.is_pickled: + item_cols = [_ZEPHYR_PAYLOAD] + else: + item_cols = [name for name in dataset.schema.names if not is_zephyr_column(name)] + + columns = [*item_cols, _ZEPHYR_SORT_KEY] + if _ZEPHYR_SORT_SECONDARY in dataset.schema.names: + columns.append(_ZEPHYR_SORT_SECONDARY) for chunk_idx in range(self.chunk_count): + chunk_filter = (pc.field(_ZEPHYR_SHARD_IDX) == self.shard_idx) & (pc.field(_ZEPHYR_CHUNK_IDX) == chunk_idx) + if self.key_range is not None: + if self.key_range.lo is not None: + chunk_filter = chunk_filter & (pc.field(_ZEPHYR_SORT_KEY) >= self.key_range.lo) + if self.key_range.hi is not None: + chunk_filter = chunk_filter & (pc.field(_ZEPHYR_SORT_KEY) < self.key_range.hi) + scanner = dataset.scanner( - columns=[col], - filter=( - (pc.field(_ZEPHYR_SHUFFLE_SHARD_IDX_COL) == self.shard_idx) - & (pc.field(_ZEPHYR_SHUFFLE_CHUNK_IDX_COL) == chunk_idx) - ), + columns=columns, + filter=chunk_filter, batch_size=batch_size, use_threads=False, ) - yield self._iter_scanner(scanner, col) - - def _iter_scanner(self, scanner: pad.Scanner, col: str) -> Iterator: - for batch in scanner.to_batches(): - items = batch.column(col).to_pylist() - if self.is_pickled: - yield from (pickle.loads(b) for b in items) - else: - yield from items + batches = list(scanner.to_batches()) + if batches: + yield pa.Table.from_batches(batches) # --------------------------------------------------------------------------- @@ -214,10 +295,9 @@ def get_iterators(self) -> Iterator[Iterator]: def needs_external_sort(self, memory_limit: int, memory_fraction: float = 0.5) -> bool: """Return True if opening all chunk iterators simultaneously would exceed memory_fraction of memory_limit.""" total_chunks = sum(it.chunk_count for it in self.iterators) - if total_chunks == 0: + if total_chunks == 0 or memory_limit <= 0 or self.avg_item_bytes <= 0: return False - item_bytes = self.avg_item_bytes if self.avg_item_bytes > 0 else _ITEM_BYTES_FALLBACK - estimated = total_chunks * self.max_row_group_rows * item_bytes + estimated = total_chunks * self.max_row_group_rows * self.avg_item_bytes return estimated > memory_limit * memory_fraction def _compute_batch_size(self) -> int: @@ -229,9 +309,9 @@ def _compute_batch_size(self) -> int: We cap this at _SCATTER_READ_BUFFER_FRACTION of the worker's memory limit. """ total_chunks = sum(it.chunk_count for it in self.iterators) - if total_chunks == 0: + if total_chunks == 0 or self.avg_item_bytes <= 0: return 1024 - bytes_per_item = self.avg_item_bytes if self.avg_item_bytes > 0 else _ITEM_BYTES_FALLBACK + bytes_per_item = self.avg_item_bytes memory_limit = _TaskResources.from_environment().memory_bytes buffer_budget = int(memory_limit * _SCATTER_READ_BUFFER_FRACTION) safe = max(1, int(buffer_budget // (total_chunks * bytes_per_item))) @@ -255,7 +335,7 @@ def _compute_batch_size(self) -> int: def _scatter_meta_path(parquet_path: str) -> str: """Return the sidecar metadata path for a scatter Parquet file. - Replaces the ``.parquet`` extension: ``shard-0000-seg0000.parquet`` → + Replaces the ``.parquet`` extension: ``shard-0000-seg0000.parquet`` -> ``shard-0000-seg0000.scatter_meta``. """ stem, _ = os.path.splitext(parquet_path) @@ -286,7 +366,6 @@ def _write_scatter_meta( f.write(payload) -# Per-worker cache for scatter sidecar metadata (populated on first read, shared across tasks) _scatter_meta_cache: dict[str, dict] = {} @@ -337,7 +416,6 @@ def _read_entry(path: str) -> tuple[str, dict]: f.write(payload) -# Per-worker cache for scatter manifests (populated on first read, shared across tasks) _scatter_manifest_cache: dict[str, list[dict]] = {} @@ -349,12 +427,13 @@ def _read_scatter_manifest(manifest_path: str) -> list[dict]: return _scatter_manifest_cache[manifest_path] -def _build_scatter_shard_from_manifest(manifest_path: str, target_shard: int) -> ScatterShard: +def _build_scatter_shard_from_manifest( + manifest_path: str, target_shard: int, key_range: KeyRange | None = None +) -> ScatterShard: """Build a ScatterShard for one target shard from a consolidated scatter manifest.""" entries = _read_scatter_manifest(manifest_path) iterators: list[ScatterParquetIterator] = [] with log_time(f"Building ScatterShard for target shard {target_shard} from manifest ({len(entries)} files)"): - # First pass: count files that have data for this shard file_entries = [] for entry in entries: count = entry["chunk_counts"].get(str(target_shard), 0) @@ -373,24 +452,17 @@ def _build_scatter_shard_from_manifest(manifest_path: str, target_shard: int) -> chunk_count=count, is_pickled=entry.get("is_pickled", False), filesystem=filesystem, + key_range=key_range, ) ) - # Aggregate stats from manifest entries for this shard. - # max_chunk_rows is a per-shard dict so we only look at target_shard's value. - # Fall back to the old scalar max_row_group_rows for pre-migration manifests. max_rg_rows = 0 for entry in file_entries: per_shard = entry.get("max_chunk_rows", {}) - if per_shard: - max_rg_rows = max(max_rg_rows, per_shard.get(str(target_shard), 0)) - else: - # old manifest: scalar max across all shards — use as conservative fallback - max_rg_rows = max(max_rg_rows, entry.get("max_row_group_rows", 0)) + max_rg_rows = max(max_rg_rows, per_shard.get(str(target_shard), 0)) if max_rg_rows == 0: - max_rg_rows = 100_000 # fallback for old manifests without stats + max_rg_rows = 100_000 - # Weighted avg item bytes (weight by chunk_count for this shard) total_chunks_for_avg = 0 weighted_bytes = 0.0 for entry in file_entries: @@ -404,43 +476,168 @@ def _build_scatter_shard_from_manifest(manifest_path: str, target_shard: int) -> return ScatterShard(iterators=iterators, max_row_group_rows=max_rg_rows, avg_item_bytes=avg_item_bytes) -# --------------------------------------------------------------------------- -# Envelope helpers -# --------------------------------------------------------------------------- +def _compute_key_range_boundaries( + manifest_path: str, + target_shard: int, + num_splits: int, + max_sample_keys: int = 50_000, +) -> list[KeyRange]: + """Sample sort keys from scatter files and compute equi-spaced split boundaries. + Reads a sample of _zephyr_sort_key values from the scatter files for the + target shard, sorts them, and picks split points to divide the key space + into num_splits ranges. -def _make_envelope(items: list, target_shard: int, chunk_idx: int) -> list[dict]: - return [ - { - _ZEPHYR_SHUFFLE_SHARD_IDX_COL: target_shard, - _ZEPHYR_SHUFFLE_CHUNK_IDX_COL: chunk_idx, - _ZEPHYR_SHUFFLE_ITEM_COL: item, - } - for item in items - ] + Returns a list of num_splits KeyRange objects covering the full key space. + Falls back to fewer ranges if fewer distinct keys are found than num_splits. + """ + entries = _read_scatter_manifest(manifest_path) + file_entries = [e for e in entries if e["chunk_counts"].get(str(target_shard), 0) > 0] + if not file_entries: + return [KeyRange()] + + total_chunks = sum(e["chunk_counts"].get(str(target_shard), 0) for e in file_entries) + # Sample roughly max_sample_keys keys spread across files proportionally + keys_per_chunk = max(1, max_sample_keys // max(total_chunks, 1)) + + sample_path = file_entries[0]["path"] + filesystem = _get_scatter_read_fs(len(file_entries), sample_path) + + sampled_keys: list = [] + for entry in file_entries: + _, fs_path = url_to_fs(entry["path"]) + dataset = pad.dataset(fs_path, format="parquet", filesystem=filesystem) + shard_filter = pc.field(_ZEPHYR_SHARD_IDX) == target_shard + scanner = dataset.scanner( + columns=[_ZEPHYR_SORT_KEY], + filter=shard_filter, + use_threads=False, + ) + for batch in scanner.to_batches(): + col = batch.column(_ZEPHYR_SORT_KEY) + if len(col) <= keys_per_chunk: + sampled_keys.extend(col.to_pylist()) + else: + step = max(1, len(col) // keys_per_chunk) + sampled_keys.extend(col.to_pylist()[::step]) + if len(sampled_keys) >= max_sample_keys * 2: + break + if len(sampled_keys) >= max_sample_keys * 2: + break -def _make_pickle_envelope(items: list, target_shard: int, chunk_idx: int) -> list[dict]: - """Wrap items as pickle-serialized bytes for Arrow-incompatible types.""" - return [ - { - _ZEPHYR_SHUFFLE_SHARD_IDX_COL: target_shard, - _ZEPHYR_SHUFFLE_CHUNK_IDX_COL: chunk_idx, - _ZEPHYR_SHUFFLE_PICKLED_COL: cloudpickle.dumps(item), - } - for item in items - ] + if not sampled_keys: + return [KeyRange()] + sampled_keys.sort() -def _segment_path(base_path: str, seg_idx: int) -> str: - """Return the file path for a given segment index. + # Deduplicate to find distinct keys for split points + distinct = [] + prev = object() + for k in sampled_keys: + if k != prev: + distinct.append(k) + prev = k - ``shard-0000.parquet`` → ``shard-0000-seg0000.parquet`` - """ + effective_splits = min(num_splits, len(distinct)) + if effective_splits <= 1: + return [KeyRange()] + + # Pick equi-spaced split points from the sorted distinct keys + split_points = [] + for i in range(1, effective_splits): + idx = int(len(distinct) * i / effective_splits) + split_points.append(distinct[idx]) + + ranges: list[KeyRange] = [] + for i in range(effective_splits): + lo = split_points[i - 1] if i > 0 else None + hi = split_points[i] if i < len(split_points) else None + ranges.append(KeyRange(lo=lo, hi=hi)) + + return ranges + + +# --------------------------------------------------------------------------- +# Internal write machinery +# --------------------------------------------------------------------------- + + +def _segment_path(base_path: str, seg_idx: int) -> str: + """``shard-0000.parquet`` -> ``shard-0000-seg0000.parquet``""" stem, ext = os.path.splitext(base_path) return f"{stem}-seg{seg_idx:04d}{ext}" +@dataclass +class _ShardBuffer: + """Per-shard buffer that accumulates Arrow micro-batches and flushes when a byte threshold is reached. + + Items are appended one at a time with their sort key (and optional secondary sort value). + Every _SCATTER_MICRO_BATCH_SIZE items, they are converted to an Arrow RecordBatch via + make_envelope_batch. When total buffered bytes exceed _SCATTER_ROW_GROUP_BYTES, + take_sorted_batch() drains the buffer, sorts in Arrow, and returns a single RecordBatch. + """ + + shard_idx: int + pickled: bool = False + has_sort: bool = False + pending: list[tuple[Any, Any, Any | None]] = field(default_factory=list) + tables: list[pa.RecordBatch] = field(default_factory=list) + nbytes: int = 0 + chunk_idx: int = 0 + schema: pa.Schema | None = None + max_rows: int = 0 + + def append(self, item: Any, key_value: Any, sort_value: Any | None = None) -> None: + self.pending.append((item, key_value, sort_value)) + if len(self.pending) >= _SCATTER_MICRO_BATCH_SIZE: + self._flush_micro() + + def _flush_micro(self) -> None: + if not self.pending: + return + items, keys, sorts = zip(*self.pending, strict=True) + batch = make_envelope_batch( + list(items), + self.shard_idx, + self.chunk_idx, + list(keys), + list(sorts) if self.has_sort else None, + pickled=self.pickled, + ) + if self.schema is None: + self.schema = batch.schema + self.tables.append(batch) + self.nbytes += batch.nbytes + self.pending = [] + + def should_flush(self) -> bool: + return self.nbytes >= _SCATTER_ROW_GROUP_BYTES + + def take_sorted_batch(self) -> pa.RecordBatch | None: + """Drain buffer, sort by _zephyr_sort_key in Arrow, return single batch.""" + self._flush_micro() + if not self.tables: + return None + table = pa.concat_tables([pa.Table.from_batches([b]) for b in self.tables], promote_options="default") + sort_cols: list[tuple[str, str]] = [(_ZEPHYR_SORT_KEY, "ascending")] + if _ZEPHYR_SORT_SECONDARY in table.column_names: + sort_cols.append((_ZEPHYR_SORT_SECONDARY, "ascending")) + indices = pc.sort_indices(table, sort_keys=sort_cols) + sorted_table = table.take(indices) + num_rows = len(sorted_table) + self.max_rows = max(self.max_rows, num_rows) + self.chunk_idx += 1 + self.tables = [] + self.nbytes = 0 + return sorted_table.to_batches()[0] + + @property + def item_count(self) -> int: + return len(self.pending) + sum(len(t) for t in self.tables) + + def _apply_combiner(buffer: list, key_fn: Callable, combiner_fn: Callable) -> list: """Apply combiner to a buffer, grouping by key and reducing locally.""" by_key: dict[object, list] = defaultdict(list) @@ -465,155 +662,132 @@ def _write_parquet_scatter( ) -> ListShard: """Route items to target shards, buffer, sort, and write as Parquet row groups. - Handles the full scatter pipeline: hash-routing each item to a target shard, - buffering per-shard, applying an optional combiner, sorting each buffer, and - writing sorted chunks as Parquet row groups with envelope wrapping. + Items are accumulated incrementally in per-shard _ShardBuffer instances, + converted to Arrow micro-batches every _SCATTER_MICRO_BATCH_SIZE items. + When a buffer exceeds _SCATTER_ROW_GROUP_BYTES, it is drained, sorted in + Arrow by _zephyr_sort_key (and optionally _zephyr_sort_secondary), and + written as a Parquet row group. Writes ``.scatter_meta`` sidecar files alongside each Parquet segment. - - Returns: - A ListShard containing the segment file paths. """ - if sort_fn is not None: - captured_sort_fn = sort_fn - - def _sort_key(item): - return (key_fn(item), captured_sort_fn(item)) - - else: - _sort_key = key_fn - - # TODO: make chunk_size configurable per writer - chunk_size = 100_000 - - # Per-segment per-shard chunk counts seg_shard_counts: dict[int, dict[int, int]] = defaultdict(lambda: defaultdict(int)) - per_shard_chunk_cnt: dict[int, int] = defaultdict(int) - buffers: dict[int, list] = defaultdict(list) + buffers: dict[int, _ShardBuffer] = {} n_chunks_flushed = 0 seg_idx = 0 seg_paths: list[str] = [] schema: pa.Schema | None = None - writer: pq.ParquetWriter | None = None - seg_file = "" + spill_writer: SpillWriter | None = None - pending_chunk: pa.RecordBatch | None = None - pending_target: int = -1 - pending_cnt: int = 0 - - per_shard_max_rows: dict[int, int] = defaultdict(int) avg_item_bytes: float = 0.0 _sampled_avg = False - def _flush_pending(): - nonlocal n_chunks_flushed, pending_chunk - if pending_chunk is None: - return - writer.write_batch(pending_chunk) - seg_shard_counts[seg_idx][pending_target] = seg_shard_counts[seg_idx].get(pending_target, 0) + 1 - n_chunks_flushed += 1 - pending_chunk = None - if n_chunks_flushed % 10 == 0: - logger.info( - "[shard %d segment %d] Wrote %d parquet chunks so far (latest chunk size: %d items)", - source_shard, - seg_idx, - n_chunks_flushed, - pending_cnt, - ) - - def _prepare_batch(target_shard: int, buf: list) -> list[dict]: - """Apply combiner, sort, envelope a buffer. Returns enveloped rows.""" - if combiner_fn is not None: - buf = _apply_combiner(buf, key_fn, combiner_fn) - buf.sort(key=_sort_key) - shard_chunk_idx = per_shard_chunk_cnt[target_shard] - per_shard_chunk_cnt[target_shard] += 1 - envelope_fn = _make_pickle_envelope if pickled else _make_envelope - return envelope_fn(buf, target_shard, shard_chunk_idx) + def _get_buffer(target: int) -> _ShardBuffer: + if target not in buffers: + buffers[target] = _ShardBuffer(shard_idx=target, pickled=pickled, has_sort=sort_fn is not None) + return buffers[target] def _ensure_writer(chunk_schema: pa.Schema) -> pa.Schema: - """Ensure Parquet writer is open and compatible. Returns the active write schema.""" - nonlocal schema, writer, seg_file, seg_idx, per_shard_chunk_cnt + nonlocal schema, spill_writer, seg_idx if schema is None: schema = chunk_schema seg_file = _segment_path(parquet_path, seg_idx) seg_paths.append(seg_file) ensure_parent_dir(seg_file) - writer = pq.ParquetWriter(seg_file, schema) + spill_writer = SpillWriter(seg_file, schema) elif chunk_schema != schema: - _flush_pending() - writer.close() + spill_writer.close() schema = pa.unify_schemas([schema, chunk_schema]) seg_idx += 1 - per_shard_chunk_cnt = defaultdict(int) # chunk_idx restarts at 0 in new segment + for buf in buffers.values(): + buf.chunk_idx = 0 seg_file = _segment_path(parquet_path, seg_idx) seg_paths.append(seg_file) ensure_parent_dir(seg_file) - writer = pq.ParquetWriter(seg_file, schema) + spill_writer = SpillWriter(seg_file, schema) logger.info( "[shard %d] Schema evolved after %d chunks; starting segment %d", source_shard, n_chunks_flushed, seg_idx, ) - else: - _flush_pending() return schema - def _write_buffer(target_shard: int, buf: list) -> None: - """Sort a buffer and write it as a Parquet row group.""" - nonlocal pending_chunk, pending_target, pending_cnt, avg_item_bytes, _sampled_avg - enveloped = _prepare_batch(target_shard, buf) - chunk_arrow = pa.RecordBatch.from_pylist(enveloped) - write_schema = _ensure_writer(chunk_arrow.schema) - if chunk_arrow.schema != write_schema: - chunk_arrow = chunk_arrow.cast(write_schema) - pending_chunk = chunk_arrow - pending_target = target_shard - pending_cnt = len(buf) - per_shard_max_rows[target_shard] = max(per_shard_max_rows[target_shard], len(buf)) - - # Sample avg_item_bytes once on first flush - if not _sampled_avg and len(enveloped) > 0: - sample_size = min(len(enveloped), _SCATTER_SAMPLE_SIZE) - sample_rows = enveloped[:sample_size] - if pickled: - total_bytes = sum(len(row[_ZEPHYR_SHUFFLE_PICKLED_COL]) for row in sample_rows) - else: - total_bytes = sum(len(pickle.dumps(row[_ZEPHYR_SHUFFLE_ITEM_COL])) for row in sample_rows) - avg_item_bytes = total_bytes / len(sample_rows) + def _flush_buffer(buf: _ShardBuffer) -> None: + nonlocal n_chunks_flushed, avg_item_bytes, _sampled_avg + + if combiner_fn is not None: + buf._flush_micro() + if not buf.tables: + return + table = pa.concat_tables([pa.Table.from_batches([b]) for b in buf.tables], promote_options="default") + py_items = unwrap_items(table, pickled) + combined = _apply_combiner(py_items, key_fn, combiner_fn) + combined_buf = _ShardBuffer(shard_idx=buf.shard_idx, pickled=pickled, has_sort=sort_fn is not None) + combined_buf.chunk_idx = buf.chunk_idx + for item in combined: + k = key_fn(item) + sv = sort_fn(item) if sort_fn else None + combined_buf.append(item, k, sv) + batch = combined_buf.take_sorted_batch() + buf.chunk_idx = combined_buf.chunk_idx + buf.tables = [] + buf.nbytes = 0 + buf.pending = [] + buf.max_rows = max(buf.max_rows, combined_buf.max_rows) + else: + batch = buf.take_sorted_batch() + + if batch is None: + return + + write_schema = _ensure_writer(batch.schema) + if batch.schema != write_schema: + batch = batch.cast(write_schema) + + # Each sorted chunk is its own row group (distinct shard/chunk metadata). + batch_table = pa.Table.from_batches([batch]) + spill_writer.write_row_group(batch_table) + seg_shard_counts[seg_idx][buf.shard_idx] = seg_shard_counts[seg_idx].get(buf.shard_idx, 0) + 1 + n_chunks_flushed += 1 + + if n_chunks_flushed % 10 == 0: + logger.info( + "[shard %d segment %d] Wrote %d parquet chunks so far (latest chunk size: %d items)", + source_shard, + seg_idx, + n_chunks_flushed, + len(batch), + ) + + if not _sampled_avg and len(batch) > 0: + avg_item_bytes = batch.nbytes / len(batch) _sampled_avg = True - # Route items to target shards, flush buffers at chunk_size for item in items: key = key_fn(item) target = deterministic_hash(key) % num_output_shards - buffers[target].append(item) - if chunk_size > 0 and len(buffers[target]) >= chunk_size: - _write_buffer(target, buffers[target]) - buffers[target] = [] + sort_val = sort_fn(item) if sort_fn else None + buf = _get_buffer(target) + buf.append(item, key, sort_val) + if buf.should_flush(): + _flush_buffer(buf) - # Flush remaining buffers — write each shard as its own row group so PyArrow - # can use min/max statistics on shard_idx to skip non-matching row groups on read. with log_time(f"Flushing remaining buffers for {parquet_path}"): - _flush_pending() - for target, buf in sorted(buffers.items()): - if not buf: + for target in sorted(buffers.keys()): + buf = buffers[target] + if buf.item_count == 0: continue - _write_buffer(target, buf) - _flush_pending() + _flush_buffer(buf) + + if spill_writer is not None: + spill_writer.close() - if writer is not None: - writer.close() + per_shard_max_rows: dict[int, int] = {target: buf.max_rows for target, buf in buffers.items() if buf.max_rows > 0} - # Write sidecar metadata for each segment. - # chunk_offsets track where each segment's chunks start in the global - # chunk_idx space (cumulative across segments from this source shard). with log_time(f"Writing scatter meta for {parquet_path}"): for i, path in enumerate(seg_paths): counts = dict(seg_shard_counts.get(i, {})) - seg_max_rows = {shard: per_shard_max_rows[shard] for shard in counts if per_shard_max_rows[shard] > 0} + seg_max_rows = {shard: per_shard_max_rows[shard] for shard in counts if per_shard_max_rows.get(shard, 0) > 0} _write_scatter_meta(path, counts, pickled, seg_max_rows, avg_item_bytes) return ListShard(refs=[MemChunk(items=seg_paths)]) diff --git a/lib/zephyr/src/zephyr/spill_writer.py b/lib/zephyr/src/zephyr/spill_writer.py new file mode 100644 index 0000000000..2572786180 --- /dev/null +++ b/lib/zephyr/src/zephyr/spill_writer.py @@ -0,0 +1,158 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Byte-budgeted Parquet writer with background GCS uploads. + +SpillWriter wraps pq.ParquetWriter and accumulates Arrow tables, flushing +them as row groups when the accumulated bytes exceed a configurable threshold. +Writes happen in a background thread so the caller can overlap production +with I/O (pq.ParquetWriter.write_table releases the GIL). + +TableAccumulator is a standalone helper that accumulates Arrow tables until +a byte threshold is reached, then yields the concatenated result. +""" + +import logging +import queue +import threading + +import pyarrow as pa +import pyarrow.parquet as pq + +logger = logging.getLogger(__name__) + +_SENTINEL = object() + + +class TableAccumulator: + """Accumulates Arrow tables and yields merged results when a byte threshold is reached. + + Unlike row-count batching, byte-budgeted batching produces uniformly-sized + output regardless of row width, which matters for write performance and + memory predictability. + """ + + def __init__(self, byte_threshold: int) -> None: + self._byte_threshold = byte_threshold + self._tables: list[pa.Table] = [] + self._nbytes: int = 0 + + def add(self, table: pa.Table) -> pa.Table | None: + """Accumulate a table. Returns a merged table when the threshold is exceeded, else None.""" + self._tables.append(table) + self._nbytes += table.nbytes + if self._nbytes >= self._byte_threshold: + return self._take() + return None + + def flush(self) -> pa.Table | None: + """Return any remaining accumulated data, or None if empty.""" + if not self._tables: + return None + return self._take() + + def _take(self) -> pa.Table: + result = pa.concat_tables(self._tables, promote_options="default") + self._tables.clear() + self._nbytes = 0 + return result + + def pending_bytes(self) -> int: + return self._nbytes + + def __len__(self) -> int: + return sum(len(t) for t in self._tables) + + +def _background_writer_loop( + write_queue: "queue.Queue[pa.Table | object]", + writer: pq.ParquetWriter, + error_box: list[BaseException], +) -> None: + """Drain write_queue, writing each table as a row group. Stops on _SENTINEL.""" + while True: + item = write_queue.get() + if item is _SENTINEL: + return + try: + writer.write_table(item) + except BaseException as exc: + error_box.append(exc) + return + + +class SpillWriter: + """Byte-budgeted ParquetWriter with background I/O. + + Row groups are accumulated via an internal TableAccumulator and flushed + to a pq.ParquetWriter in a background thread, overlapping one write + with the next produce cycle. + + Two write modes: + - write_table(table): accumulates rows, flushes a row group when + accumulated bytes exceed row_group_bytes. + - write_row_group(table): writes the table as its own row group immediately + (no accumulation). Used by the scatter path where each sorted chunk must + be a separate row group. + """ + + def __init__( + self, + path: str, + schema: pa.Schema, + *, + row_group_bytes: int = 8 * 1024 * 1024, + compression: str = "zstd", + compression_level: int = 1, + ) -> None: + self._writer = pq.ParquetWriter(path, schema, compression=compression, compression_level=compression_level) + self._accumulator = TableAccumulator(row_group_bytes) + # Single-slot queue: allows one write to be in-flight while the caller + # produces the next batch. Backpressure is automatic — put() blocks when + # the slot is occupied. + self._queue: queue.Queue[pa.Table | object] = queue.Queue(maxsize=1) + self._error_box: list[BaseException] = [] + self._thread = threading.Thread( + target=_background_writer_loop, + args=(self._queue, self._writer, self._error_box), + daemon=True, + ) + self._thread.start() + self._closed = False + + def _check_error(self) -> None: + if self._error_box: + raise self._error_box[0] + + def write_table(self, table: pa.Table) -> None: + """Accumulate rows; flush a row group to the background writer when threshold is exceeded.""" + self._check_error() + merged = self._accumulator.add(table) + if merged is not None: + self._queue.put(merged) + self._check_error() + + def write_row_group(self, table: pa.Table) -> None: + """Write the table as its own row group immediately (no accumulation).""" + self._check_error() + self._queue.put(table) + self._check_error() + + def close(self) -> None: + """Flush remaining accumulated data and wait for the background thread to finish.""" + if self._closed: + return + self._closed = True + remaining = self._accumulator.flush() + if remaining is not None: + self._queue.put(remaining) + self._queue.put(_SENTINEL) + self._thread.join() + self._writer.close() + self._check_error() + + def __enter__(self) -> "SpillWriter": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() diff --git a/lib/zephyr/tests/benchmark_scatter_reduce.py b/lib/zephyr/tests/benchmark_scatter_reduce.py new file mode 100644 index 0000000000..ce5db6eda9 --- /dev/null +++ b/lib/zephyr/tests/benchmark_scatter_reduce.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Benchmark scatter/reduce in isolation: serialization, sort, and Arrow merge. + +Directly exercises _write_parquet_scatter and _arrow_reduce_gen with synthetic +data to measure the performance of the scatter/reduce code paths without the +overhead of file loading, mapping, or writing final output. + +Usage: + cd lib/zephyr + uv run python tests/benchmark_scatter_reduce.py + uv run python tests/benchmark_scatter_reduce.py --num-items 1000000 --num-shards 128 +""" + +import logging +import os +import random +import resource +import shutil +import sys +import tempfile +import time +from collections.abc import Iterator +from dataclasses import dataclass + +import click + +from zephyr.plan import _arrow_reduce_gen +from zephyr.shuffle import ( + _build_scatter_shard_from_manifest, + _write_parquet_scatter, + _write_scatter_manifest, +) + +WORDS = """ +the be to of and a in that have I it for not on with he as you do at this but his by from +they we say her she or an will my one all would there their what so up out if about who get +which go me when make can like time no just him know take people into year your good some +could them see other than then now look only come its over think also back after use two how +our work first well way even new want because any these give day most us data system process +compute memory network storage algorithm function variable method class object interface +""".split() + + +def generate_items(n: int, num_keys: int = 1000) -> list[dict]: + """Generate n items resembling real dedup records (~150 bytes each). + + Realistic shape: hash key, document ID, file index, and a short field. + Matches the typical item size in exact/fuzzy dedup pipelines. + """ + return [ + { + "key": random.randint(0, num_keys - 1), + "id": f"doc-{i:08d}", + "file_idx": i % 100, + "score": random.random(), + } + for i in range(n) + ] + + +def peak_rss_mb() -> float: + """Return peak RSS in MB (macOS returns bytes, Linux returns KB).""" + usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + if sys.platform == "darwin": + return usage / (1024 * 1024) + return usage / 1024 + + +@dataclass +class BenchmarkResult: + name: str + scatter_time_s: float + reduce_time_s: float + peak_rss_mb: float + total_items: int + num_shards: int + num_keys: int + unique_keys_found: int + scatter_file_bytes: int + + @property + def scatter_items_per_sec(self) -> float: + return self.total_items / self.scatter_time_s if self.scatter_time_s > 0 else 0 + + @property + def reduce_items_per_sec(self) -> float: + return self.total_items / self.reduce_time_s if self.reduce_time_s > 0 else 0 + + +def _key_fn(item: dict) -> int: + return item["key"] + + +def run_scatter(items: list[dict], tmp_dir: str, num_shards: int) -> tuple[str, float, int]: + """Scatter items, return (manifest_path, elapsed_s, file_bytes).""" + t0 = time.monotonic() + parquet_path = f"{tmp_dir}/shard-0000.parquet" + list_shard = _write_parquet_scatter( + iter(items), + source_shard=0, + parquet_path=parquet_path, + key_fn=_key_fn, + num_output_shards=num_shards, + ) + seg_paths = list(list_shard) + manifest_path = f"{tmp_dir}/scatter_metadata" + _write_scatter_manifest(seg_paths, manifest_path) + elapsed = time.monotonic() - t0 + + file_bytes = sum(os.path.getsize(p) for p in seg_paths if os.path.exists(p)) + return manifest_path, elapsed, file_bytes + + +def _keep_first(_key: int, items: Iterator) -> dict: + return next(items) + + +def run_reduce(manifest_path: str, num_shards: int) -> tuple[int, float]: + """Reduce all shards (keep-first per key) using Arrow merge, return (unique_keys, elapsed_s).""" + t0 = time.monotonic() + count = 0 + for shard_idx in range(num_shards): + shard = _build_scatter_shard_from_manifest(manifest_path, shard_idx) + for _item in _arrow_reduce_gen(shard, _keep_first): + count += 1 + elapsed = time.monotonic() - t0 + return count, elapsed + + +@click.command() +@click.option("--num-items", default=500_000, help="Total items to scatter") +@click.option("--num-shards", default=64, help="Number of output shards") +@click.option("--num-keys", default=1000, help="Number of unique keys") +@click.option("--seed", default=42, help="Random seed for reproducibility") +def benchmark(num_items: int, num_shards: int, num_keys: int, seed: int) -> None: + """Benchmark scatter/reduce performance in isolation.""" + random.seed(seed) + + print(f"\nGenerating {num_items:,} items ({num_keys} unique keys)...") + gen_start = time.monotonic() + items = generate_items(num_items, num_keys) + gen_time = time.monotonic() - gen_start + print(f"Generated in {gen_time:.2f}s") + + tmp_dir = tempfile.mkdtemp(prefix="zephyr_bench_scatter_") + try: + print(f"\nScattering to {num_shards} shards...") + manifest_path, scatter_time, file_bytes = run_scatter(items, tmp_dir, num_shards) + + print("Reducing with Arrow merge (keep-first per key)...") + arrow_keys, arrow_reduce_time = run_reduce(manifest_path, num_shards) + + print(f"\n{'=' * 60}") + print("Scatter/Reduce Benchmark Results") + print(f"{'=' * 60}") + print(f" Items: {num_items:>12,}") + print(f" Shards: {num_shards:>12,}") + print(f" Unique keys: {num_keys:>12,}") + print(f" Keys found: {arrow_keys:>12,}") + print(f" Scatter file size: {file_bytes / (1024*1024):>12.1f} MB") + print(f"{'─' * 60}") + print(f" Scatter time: {scatter_time:>12.2f} s") + print(f" Scatter throughput: {num_items / scatter_time:>12,.0f} items/s") + print(f"{'─' * 60}") + print(f" Reduce (Arrow): {arrow_reduce_time:>12.2f} s ({num_items / arrow_reduce_time:>10,.0f} items/s)") + print(f"{'=' * 60}\n") + + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.WARNING, format="%(asctime)s %(levelname)s %(message)s") + benchmark() diff --git a/lib/zephyr/tests/test_execution.py b/lib/zephyr/tests/test_execution.py index 58715d0be7..8fa95b52e9 100644 --- a/lib/zephyr/tests/test_execution.py +++ b/lib/zephyr/tests/test_execution.py @@ -257,11 +257,11 @@ def test_no_duplicate_results_on_heartbeat_timeout(actor_context, tmp_path): def test_disk_chunk_write_uses_unique_paths(tmp_path): - """Each PickleDiskChunk.write() writes to a unique location, avoiding collisions.""" - from zephyr.execution import PickleDiskChunk + """Each ParquetDiskChunk.write() writes to a unique location, avoiding collisions.""" + from zephyr.execution import ParquetDiskChunk - base_path = str(tmp_path / "chunk.pkl") - refs = [PickleDiskChunk.write(base_path, [i]) for i in range(3)] + base_path = str(tmp_path / "chunk.parquet") + refs = [ParquetDiskChunk.write(base_path, [i]) for i in range(3)] # Each written to a distinct UUID path (no rename needed) paths = [r.path for r in refs] @@ -280,7 +280,7 @@ def test_coordinator_accepts_winner_ignores_stale(actor_context, tmp_path): Stale chunk files are left for context-dir cleanup (no per-chunk deletion). """ - from zephyr.execution import ListShard, PickleDiskChunk, ShardTask, TaskResult, ZephyrCoordinator + from zephyr.execution import ListShard, ParquetDiskChunk, ShardTask, TaskResult, ZephyrCoordinator coord = ZephyrCoordinator() coord.set_chunk_config(str(tmp_path / "chunks"), "test-exec") @@ -299,7 +299,7 @@ def test_coordinator_accepts_winner_ignores_stale(actor_context, tmp_path): _task_a, attempt_a, _config = pulled_a # Worker A writes a chunk (simulating slow completion) - stale_ref = PickleDiskChunk.write(str(tmp_path / "stale-chunk.pkl"), [1, 2, 3]) + stale_ref = ParquetDiskChunk.write(str(tmp_path / "stale-chunk.parquet"), [1, 2, 3]) assert Path(stale_ref.path).exists() # Heartbeat timeout re-queues the task @@ -310,7 +310,7 @@ def test_coordinator_accepts_winner_ignores_stale(actor_context, tmp_path): pulled_b = coord.pull_task("worker-B") _task_b, attempt_b, _config = pulled_b - winner_ref = PickleDiskChunk.write(str(tmp_path / "winner-chunk.pkl"), [4, 5, 6]) + winner_ref = ParquetDiskChunk.write(str(tmp_path / "winner-chunk.parquet"), [4, 5, 6]) coord.report_result( "worker-B", @@ -343,13 +343,13 @@ def test_shard_streaming_low_memory(tmp_path): Verifies get_iterators yields data lazily and flat iteration works. """ - from zephyr.execution import ListShard, PickleDiskChunk + from zephyr.execution import ListShard, ParquetDiskChunk # Write 3 refs to disk (directly readable, no finalize needed) refs = [] for i in range(3): - path = str(tmp_path / f"chunk-{i}.pkl") - chunk = PickleDiskChunk.write(path, [i * 10 + j for j in range(5)]) + path = str(tmp_path / f"chunk-{i}.parquet") + chunk = ParquetDiskChunk.write(path, [i * 10 + j for j in range(5)]) refs.append(chunk) shard = ListShard(refs=refs) diff --git a/lib/zephyr/tests/test_groupby.py b/lib/zephyr/tests/test_groupby.py index 221820aac4..b0a668e403 100644 --- a/lib/zephyr/tests/test_groupby.py +++ b/lib/zephyr/tests/test_groupby.py @@ -365,11 +365,17 @@ def test_scatter_parquet_iterator_pickle_roundtrip(tmp_path): """ScatterParquetIterator with is_pickled=True round-trips non-Arrow-serializable items.""" import pyarrow.parquet as pq - from zephyr.shuffle import ScatterParquetIterator, _make_pickle_envelope + from zephyr.shuffle import ScatterParquetIterator, make_envelope_batch items = [frozenset([1, 2]), frozenset([3, 4, 5])] - envelope = _make_pickle_envelope(items, target_shard=0, chunk_idx=0) - batch = pa.RecordBatch.from_pylist(envelope) + batch = make_envelope_batch( + items, + shard_idx=0, + chunk_idx=0, + key_values=[0, 1], + sort_values=None, + pickled=True, + ) path = str(tmp_path / "test.parquet") pq.write_table(pa.Table.from_batches([batch]), path) @@ -535,3 +541,77 @@ def dedup_reducer(key, items): {"key": "a", "ids": [1, 2]}, {"key": "b", "ids": [3, 4]}, ] + + +# --------------------------------------------------------------------------- +# Hot shard splitting tests +# --------------------------------------------------------------------------- + + +def test_hot_shard_split_skewed_data(zephyr_ctx): + """Hot shard splitting produces correct results with skewed key distribution. + + 90% of items share the same key ("hot"), creating a hot shard. + max_hot_shard_splits=4 should split it while producing identical results. + """ + hot_items = [{"key": "hot", "val": i} for i in range(90)] + cold_items = [{"key": f"cold_{i}", "val": i} for i in range(10)] + data = hot_items + cold_items + + def count_reducer(key, items): + vals = sorted(item["val"] for item in items) + return {"key": key, "count": len(vals), "vals": vals} + + ds_split = Dataset.from_list(data).group_by( + key=lambda x: x["key"], + reducer=count_reducer, + max_hot_shard_splits=4, + ) + ds_normal = Dataset.from_list(data).group_by( + key=lambda x: x["key"], + reducer=count_reducer, + ) + + results_split = sorted(zephyr_ctx.execute(ds_split), key=lambda x: x["key"]) + results_normal = sorted(zephyr_ctx.execute(ds_normal), key=lambda x: x["key"]) + + assert results_split == results_normal + + +def test_hot_shard_split_balanced_data(zephyr_ctx): + """When data is balanced, hot shard splitting does not activate.""" + data = [{"key": f"k{i % 5}", "val": i} for i in range(50)] + + ds = Dataset.from_list(data).group_by( + key=lambda x: x["key"], + reducer=lambda key, items: {"key": key, "count": sum(1 for _ in items)}, + max_hot_shard_splits=4, + ) + + results = sorted(zephyr_ctx.execute(ds), key=lambda x: x["key"]) + assert len(results) == 5 + for r in results: + assert r["count"] == 10 + + +def test_hot_shard_split_generator_reducer(zephyr_ctx): + """Hot shard splitting works with generator reducers (yielding multiple items).""" + data = [{"key": "hot", "val": i} for i in range(50)] + [{"key": "cold", "val": i} for i in range(5)] + + def gen_reducer(key, items): + for item in items: + yield {"key": key, "val": item["val"], "tagged": True} + + ds = Dataset.from_list(data).group_by( + key=lambda x: x["key"], + reducer=gen_reducer, + max_hot_shard_splits=4, + ) + + results = zephyr_ctx.execute(ds) + assert len(results) == 55 + assert all(r["tagged"] for r in results) + hot_results = [r for r in results if r["key"] == "hot"] + cold_results = [r for r in results if r["key"] == "cold"] + assert len(hot_results) == 50 + assert len(cold_results) == 5 diff --git a/lib/zephyr/tests/test_shuffle.py b/lib/zephyr/tests/test_shuffle.py index 820fc7734a..7d83ea66ed 100644 --- a/lib/zephyr/tests/test_shuffle.py +++ b/lib/zephyr/tests/test_shuffle.py @@ -1,24 +1,27 @@ # Copyright The Marin Authors # SPDX-License-Identifier: Apache-2.0 -"""Unit tests for zephyr/shuffle.py. +"""Unit tests for zephyr/shuffle.py and zephyr/spill_writer.py. Tests the scatter write/read roundtrip, per-shard stats, needs_external_sort, -and multi-segment schema evolution — all without spinning up a full coordinator. +multi-segment schema evolution, SpillWriter, and TableAccumulator — all +without spinning up a full coordinator. """ import pyarrow as pa import pyarrow.parquet as pq -from zephyr.plan import deterministic_hash +from zephyr.plan import deterministic_hash, _arrow_reduce_gen from zephyr.shuffle import ( ScatterParquetIterator, ScatterShard, + _ZEPHYR_SORT_KEY, _build_scatter_shard_from_manifest, - _make_pickle_envelope, + make_envelope_batch, _write_parquet_scatter, _write_scatter_manifest, ) +from zephyr.spill_writer import SpillWriter, TableAccumulator # --------------------------------------------------------------------------- # Helpers @@ -225,8 +228,14 @@ def test_avg_item_bytes_written(tmp_path): def test_scatter_parquet_iterator_pickle_roundtrip(tmp_path): """ScatterParquetIterator with is_pickled=True round-trips non-Arrow-serializable items.""" items = [frozenset([1, 2]), frozenset([3, 4, 5])] - envelope = _make_pickle_envelope(items, target_shard=0, chunk_idx=0) - batch = pa.RecordBatch.from_pylist(envelope) + batch = make_envelope_batch( + items, + shard_idx=0, + chunk_idx=0, + key_values=[0, 1], + sort_values=None, + pickled=True, + ) path = str(tmp_path / "test.parquet") pq.write_table(pa.Table.from_batches([batch]), path) @@ -243,36 +252,365 @@ def test_scatter_parquet_iterator_pickle_roundtrip(tmp_path): assert chunks[0] == items +# --------------------------------------------------------------------------- +# Sort key column +# --------------------------------------------------------------------------- + + +def test_scatter_writes_sort_key_column(tmp_path): + """Parquet files contain a _zephyr_sort_key column matching key_fn values.""" + num_shards = 2 + items = [{"k": i % 3, "v": i} for i in range(20)] + parquet_path = str(tmp_path / "shard-0000.parquet") + list_shard = _write_parquet_scatter( + iter(items), + source_shard=0, + parquet_path=parquet_path, + key_fn=_key, + num_output_shards=num_shards, + ) + seg_paths = list(list_shard) + for seg_path in seg_paths: + table = pq.read_table(seg_path) + assert _ZEPHYR_SORT_KEY in table.column_names + for row in table.to_pylist(): + # In flat mode, user fields are top-level columns alongside _zephyr_* metadata + assert row[_ZEPHYR_SORT_KEY] == row["k"] + + +def test_get_chunk_tables_returns_arrow(tmp_path): + """get_chunk_tables yields pa.Table instances with sort key columns.""" + items = [{"k": i % 2, "v": i} for i in range(20)] + manifest_path, _ = _build_shard(tmp_path, items, num_output_shards=1) + shard = _build_scatter_shard_from_manifest(manifest_path, 0) + tables = [] + for it in shard.iterators: + tables.extend(list(it.get_chunk_tables())) + assert len(tables) > 0 + for t in tables: + assert isinstance(t, pa.Table) + assert _ZEPHYR_SORT_KEY in t.column_names + + +def test_scatter_with_combiner(tmp_path): + """Scatter with combiner_fn produces correct combined output.""" + # Duplicate keys — combiner keeps only one per key + items = [{"k": i % 3, "v": i} for i in range(30)] + + def combiner(key, items_iter): + return [next(items_iter)] + + parquet_path = str(tmp_path / "shard-0000.parquet") + list_shard = _write_parquet_scatter( + iter(items), + source_shard=0, + parquet_path=parquet_path, + key_fn=_key, + num_output_shards=1, + combiner_fn=combiner, + ) + seg_paths = list(list_shard) + manifest_path = str(tmp_path / "scatter_metadata") + _write_scatter_manifest(seg_paths, manifest_path) + + shard = _build_scatter_shard_from_manifest(manifest_path, 0) + recovered = list(shard) + # With combiner keeping first, we get at most 3 unique keys (0,1,2) + keys = {item["k"] for item in recovered} + assert keys == {0, 1, 2} + assert len(recovered) == 3 + + # --------------------------------------------------------------------------- # external_sort_merge # --------------------------------------------------------------------------- -def test_external_sort_merge_streaming(tmp_path): - """external_sort_merge streams items to disk; output is fully sorted.""" +def test_arrow_merge_produces_same_results(tmp_path): + """Arrow merge-sort produces correct grouped results.""" + num_shards = 4 + items = [{"k": i % 7, "v": i} for i in range(200)] + manifest_path, _ = _build_shard(tmp_path, items, num_output_shards=num_shards) + + def keep_first(key, items_iter): + return next(items_iter) + + for shard_idx in range(num_shards): + shard = _build_scatter_shard_from_manifest(manifest_path, shard_idx) + arrow_results = sorted(list(_arrow_reduce_gen(shard, keep_first)), key=lambda x: x["v"]) + + # Verify: each result should be a valid item from the original set + for result in arrow_results: + assert result in items, f"unexpected item {result} in shard {shard_idx}" + + # Verify: each key appears at most once (keep_first deduplication) + keys_seen = [result["k"] for result in arrow_results] + assert len(keys_seen) == len(set(keys_seen)), f"duplicate keys in shard {shard_idx}" + + +def test_arrow_external_sort_more_chunks_than_fan_in(tmp_path): + """Reproduces production crash: >1000 small chunks forces 2 run files. + + With small chunks, fan_in=1000. When there are >1000 chunks, the sort + produces 2 runs. Verifies both runs are written and merged correctly. + """ + from zephyr.external_sort import external_sort_merge + + # 1020 chunks with 1 row each — forces 2 batches (1000 + 20) + chunks = [pa.table({"val": [i], _ZEPHYR_SORT_KEY: [i]}) for i in range(1020)] + sort_dir = str(tmp_path / "ext_sort") + + result_tables = list( + external_sort_merge( + iter(chunks), + sort_keys=[(_ZEPHYR_SORT_KEY, "ascending")], + external_sort_dir=sort_dir, + ) + ) + combined = pa.concat_tables(result_tables) + assert combined.column("val").to_pylist() == list(range(1020)) + + +def test_arrow_external_sort_run_files_exist_after_write(tmp_path): + """Verify run files are created during pass 1 and the merge is correct.""" + from zephyr.external_sort import external_sort_merge - # Build 3 sorted iterators, more than would fit in one batch if fan-in were 2 - iters = [iter([1, 4, 7]), iter([2, 5, 8]), iter([3, 6, 9])] + chunks = [pa.table({"val": [i], _ZEPHYR_SORT_KEY: [i]}) for i in range(1020)] + sort_dir = str(tmp_path / "ext_sort") + + result_tables = list( + external_sort_merge( + iter(chunks), + sort_keys=[(_ZEPHYR_SORT_KEY, "ascending")], + external_sort_dir=sort_dir, + ) + ) + + # Run files are cleaned up after merge, but the result must be correct + combined = pa.concat_tables(result_tables) + assert combined.column("val").to_pylist() == list(range(1020)) - result = list(external_sort_merge(iter(iters), merge_key=lambda x: x, external_sort_dir=str(tmp_path))) - assert result == list(range(1, 10)) +def test_arrow_external_sort_gcs_roundtrip(): + """Reproduce the production bug: write 2 run files to GCS, read back. -def test_external_sort_merge_single_batch(tmp_path): - """Works correctly when all iterators fit in a single pass-1 batch.""" + Uses pyarrow ParquetWriter (same as _write_spill_file) and fsspec + read_metadata (same as the verification loop). + """ + import uuid from zephyr.external_sort import external_sort_merge - iters = [iter([i]) for i in range(10)] - result = list(external_sort_merge(iter(iters), merge_key=lambda x: x, external_sort_dir=str(tmp_path))) - assert result == list(range(10)) + gcs_dir = f"gs://marin-tmp-eu-west4/ttl=1d/test-external-sort/{uuid.uuid4().hex[:8]}" + + # 1020 chunks → 2 runs (fan_in=1000) + chunks = [pa.table({"val": [i], _ZEPHYR_SORT_KEY: [i]}) for i in range(1020)] + + result_tables = list( + external_sort_merge( + iter(chunks), + sort_keys=[(_ZEPHYR_SORT_KEY, "ascending")], + external_sort_dir=gcs_dir, + ) + ) + combined = pa.concat_tables(result_tables) + assert combined.column("val").to_pylist() == list(range(1020)) + + +def test_arrow_external_sort_roundtrip(tmp_path): + """Arrow external sort produces correctly sorted output.""" + from zephyr.external_sort import external_sort_merge + + tables = [ + pa.table({"a": [1, 4, 7], _ZEPHYR_SORT_KEY: [1, 4, 7]}), + pa.table({"a": [2, 5, 8], _ZEPHYR_SORT_KEY: [2, 5, 8]}), + pa.table({"a": [3, 6, 9], _ZEPHYR_SORT_KEY: [3, 6, 9]}), + ] + sort_dir = str(tmp_path / "ext_sort") + + result_tables = list( + external_sort_merge( + iter(tables), + sort_keys=[(_ZEPHYR_SORT_KEY, "ascending")], + external_sort_dir=sort_dir, + ) + ) + combined = pa.concat_tables(result_tables) + a_values = combined.column("a").to_pylist() + assert a_values == list(range(1, 10)) + + +def test_arrow_external_sort_cleans_up(tmp_path): + """Arrow external sort run files are deleted after merge.""" + from zephyr.external_sort import external_sort_merge + + tables = [pa.table({"val": [i], _ZEPHYR_SORT_KEY: [i]}) for i in range(10)] + sort_dir = str(tmp_path / "ext_sort") + + list( + external_sort_merge( + iter(tables), + sort_keys=[(_ZEPHYR_SORT_KEY, "ascending")], + external_sort_dir=sort_dir, + ) + ) + import os + + if os.path.exists(sort_dir): + remaining = os.listdir(sort_dir) + assert remaining == [], f"run files should be deleted, found: {remaining}" + + +def test_needs_external_sort_zero_memory(): + """needs_external_sort returns False when memory_limit is 0 (unknown).""" + shard = ScatterShard( + iterators=[ + ScatterParquetIterator( + path="gs://fake/path.parquet", + shard_idx=0, + chunk_count=1000, + is_pickled=False, + filesystem=pa.fs.LocalFileSystem(), + ) + ], + max_row_group_rows=1000, + avg_item_bytes=1000.0, + ) + assert not shard.needs_external_sort(memory_limit=0) + + +# --------------------------------------------------------------------------- +# TableAccumulator +# --------------------------------------------------------------------------- + + +def test_table_accumulator_yields_on_threshold(): + """add() returns a merged table once accumulated bytes exceed the threshold.""" + threshold = 100 + acc = TableAccumulator(threshold) + # Small table well under threshold + small = pa.table({"x": [1]}) + assert acc.add(small) is None + + # Large table that pushes over threshold + big = pa.table({"x": list(range(1000))}) + result = acc.add(big) + assert result is not None + assert len(result) == 1 + 1000 + + +def test_table_accumulator_flush_returns_remaining(): + """flush() returns accumulated data that hasn't hit the threshold yet.""" + acc = TableAccumulator(10**9) + t = pa.table({"x": [1, 2, 3]}) + assert acc.add(t) is None + flushed = acc.flush() + assert flushed is not None + assert flushed.column("x").to_pylist() == [1, 2, 3] + + +def test_table_accumulator_flush_empty(): + """flush() returns None when nothing has been accumulated.""" + acc = TableAccumulator(100) + assert acc.flush() is None + + +def test_table_accumulator_resets_after_yield(): + """After add() yields a result, the accumulator is empty.""" + acc = TableAccumulator(1) # threshold of 1 byte = flush every add + t = pa.table({"x": [1]}) + result = acc.add(t) + assert result is not None + assert acc.flush() is None + + +# --------------------------------------------------------------------------- +# SpillWriter +# --------------------------------------------------------------------------- + + +def test_spill_writer_creates_valid_parquet(tmp_path): + """SpillWriter produces a readable Parquet file with correct data.""" + path = str(tmp_path / "out.parquet") + schema = pa.schema([("val", pa.int64())]) + tables = [pa.table({"val": list(range(i * 10, (i + 1) * 10))}) for i in range(5)] + + with SpillWriter(path, schema, row_group_bytes=1) as w: + for t in tables: + w.write_table(t) + + result = pq.read_table(path) + assert result.column("val").to_pylist() == list(range(50)) + + +def test_spill_writer_row_group_byte_budget(tmp_path): + """Row groups are split based on the byte budget, not row count.""" + path = str(tmp_path / "out.parquet") + schema = pa.schema([("val", pa.int64())]) + + # Each table is 8 bytes per int64 * 100 = 800 bytes. + # With a 1000-byte budget, each table triggers its own row group. + tables = [pa.table({"val": list(range(100))}) for _ in range(5)] + + with SpillWriter(path, schema, row_group_bytes=1000) as w: + for t in tables: + w.write_table(t) + + meta = pq.read_metadata(path) + assert meta.num_row_groups >= 2, f"expected multiple row groups, got {meta.num_row_groups}" + + +def test_spill_writer_write_row_group_no_accumulation(tmp_path): + """write_row_group writes each call as a separate row group.""" + path = str(tmp_path / "out.parquet") + schema = pa.schema([("val", pa.int64())]) + + with SpillWriter(path, schema, row_group_bytes=10**9) as w: + for i in range(3): + w.write_row_group(pa.table({"val": [i]})) + + meta = pq.read_metadata(path) + assert meta.num_row_groups == 3 + + +def test_spill_writer_close_flushes_remaining(tmp_path): + """close() flushes accumulated data that hasn't hit the threshold.""" + path = str(tmp_path / "out.parquet") + schema = pa.schema([("val", pa.int64())]) + + with SpillWriter(path, schema, row_group_bytes=10**9) as w: + w.write_table(pa.table({"val": [1, 2, 3]})) + + result = pq.read_table(path) + assert result.column("val").to_pylist() == [1, 2, 3] + + +def test_spill_writer_uses_zstd_compression(tmp_path): + """Default compression is zstd.""" + path = str(tmp_path / "out.parquet") + schema = pa.schema([("val", pa.int64())]) + + with SpillWriter(path, schema) as w: + w.write_table(pa.table({"val": list(range(100))})) + + meta = pq.read_metadata(path) + # Check the compression of the first column in the first row group + col_meta = meta.row_group(0).column(0) + assert "ZSTD" in col_meta.compression.upper() + +def test_spill_writer_context_manager_on_error(tmp_path): + """SpillWriter closes cleanly even when an exception occurs in the with block.""" + path = str(tmp_path / "out.parquet") + schema = pa.schema([("val", pa.int64())]) -def test_external_sort_merge_cleans_up(tmp_path): - """Run files are deleted after the merge completes.""" - from zephyr.external_sort import external_sort_merge, EXTERNAL_SORT_FAN_IN + try: + with SpillWriter(path, schema) as w: + w.write_table(pa.table({"val": [1]})) + raise ValueError("simulated error") + except ValueError: + pass - # Force multiple batches by making more iterators than EXTERNAL_SORT_FAN_IN - iters = [iter([i]) for i in range(EXTERNAL_SORT_FAN_IN + 1)] - list(external_sort_merge(iter(iters), merge_key=lambda x: x, external_sort_dir=str(tmp_path))) - assert list(tmp_path.iterdir()) == [], "run files should be deleted after merge" + # File should still be readable with the data written before the error + result = pq.read_table(path) + assert result.column("val").to_pylist() == [1] diff --git a/rust/dupekit/src/minhash_ops.rs b/rust/dupekit/src/minhash_ops.rs index 59976c8dae..9e67a4b721 100644 --- a/rust/dupekit/src/minhash_ops.rs +++ b/rust/dupekit/src/minhash_ops.rs @@ -4,9 +4,15 @@ use pyo3::prelude::*; use rand::{Rng, SeedableRng}; use rand_pcg::Pcg64; use regex::Regex; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use xxhash_rust::xxh3; +static WHITESPACE_RE: OnceLock = OnceLock::new(); + +fn whitespace_regex() -> &'static Regex { + WHITESPACE_RE.get_or_init(|| Regex::new(r"\s+").unwrap()) +} + /// Clean text using the SlimPajama text cleaning process. /// 1. Lowercase /// 2. Remove punctuation @@ -14,7 +20,7 @@ use xxhash_rust::xxh3; /// 4. Trim pub fn clean_text(arr: &StringArray) -> PyResult> { let mut builder = StringBuilder::with_capacity(arr.len(), arr.len() * 50); - let whitespace_re = Regex::new(r"\s+").map_err(|e| PyValueError::new_err(e.to_string()))?; + let whitespace_re = whitespace_regex(); let punctuation: &[char] = &[ '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', @@ -73,9 +79,15 @@ pub fn compute_minhash( let hash = xxh3::xxh3_64(text.as_bytes()) as u128; update_signature(&mut signature, hash, &coeffs); } else { + // Reusable buffer for encoding char windows to bytes, avoiding + // a String allocation per ngram. + let mut ngram_buf = Vec::with_capacity(ngram_size * 4); for window in chars.windows(ngram_size) { - let s: String = window.iter().collect(); - let hash = xxh3::xxh3_64(s.as_bytes()) as u128; + ngram_buf.clear(); + for &ch in window { + ngram_buf.extend_from_slice(ch.encode_utf8(&mut [0; 4]).as_bytes()); + } + let hash = xxh3::xxh3_64(&ngram_buf) as u128; update_signature(&mut signature, hash, &coeffs); } }