Skip to content

Commit 1af1f9e

Browse files
committed
Fix ZephyrContext not propagating explicit client to coordinator job
ZephyrContext.execute() submits a coordinator job that calls current_client() to auto-detect the backend. When a ZephyrContext was created with an explicit client (e.g. LocalClient()), the coordinator ignored it and auto-detected Ray, spawning workers as separate processes. This broke side-effect-based patterns like _load_fuzzy_dupe_map_shard where a closure modifies a shared dict -- each Ray actor got a serialized copy, leaving the original empty (zero fuzzy duplicates). Set the context var via set_current_client before submitting the coordinator job so it inherits the ZephyrContext's client. This also reverts the band-aid set_current_client wrappers in fuzzy.py since ZephyrContext now handles propagation. Regression from b1d7828 (Refactor Zephyr coordinator to job).
1 parent 0234e37 commit 1af1f9e

2 files changed

Lines changed: 35 additions & 35 deletions

File tree

lib/marin/src/marin/processing/classification/deduplication/fuzzy.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import logging
88
from marin.utils import rebase_file_path
99
import pyarrow as pa
10-
from fray.v2.client import set_current_client
1110
from fray.v2.local_backend import LocalClient
1211
from marin.processing.classification.deduplication.dedup_commons import (
1312
DEFAULT_FILETYPES,
@@ -32,25 +31,23 @@
3231

3332
def _compute_fuzzy_dedup_stats(shards: list[str] | Sequence[str], method: str, level: str) -> DupCounters:
3433
with log_time(f"Compute fuzzy deduplication stats from {len(shards)} shards"):
35-
client = LocalClient()
36-
with set_current_client(client):
37-
ctx = ZephyrContext(client=client, name="fuzzy-dup-counts")
38-
result: DupCounters = ctx.execute( # type: ignore[bad-assignment]
39-
Dataset.from_list(shards)
40-
.load_parquet(columns=["component_id"])
41-
# Compute the per-component statistics and then roll them up into a single counter group
42-
.group_by(
43-
key=lambda r: r["component_id"],
44-
reducer=lambda _, items: DupCounters(
45-
method=method,
46-
level=level,
47-
total=(total := sum(1 for _ in items)),
48-
dups=total if total > 1 else 0,
49-
unique=1,
50-
),
51-
)
52-
.reduce(partial(sum, start=DupCounters(method=method, level=level))),
53-
)[0]
34+
ctx = ZephyrContext(client=LocalClient(), name="fuzzy-dup-counts")
35+
result: DupCounters = ctx.execute( # type: ignore[bad-assignment]
36+
Dataset.from_list(shards)
37+
.load_parquet(columns=["component_id"])
38+
# Compute the per-component statistics and then roll them up into a single counter group
39+
.group_by(
40+
key=lambda r: r["component_id"],
41+
reducer=lambda _, items: DupCounters(
42+
method=method,
43+
level=level,
44+
total=(total := sum(1 for _ in items)),
45+
dups=total if total > 1 else 0,
46+
unique=1,
47+
),
48+
)
49+
.reduce(partial(sum, start=DupCounters(method=method, level=level))),
50+
)[0]
5451
return result
5552

5653

@@ -66,12 +63,10 @@ def add_to_dup_map(record: dict):
6663
shard_dup_map[record["id"]] = record["fuzzy_duplicate"]
6764

6865
with log_time(f"Load fuzzy duplicate map from {len(shards)} shards"):
69-
client = LocalClient()
70-
with set_current_client(client):
71-
ctx = ZephyrContext(client=client, name="fuzzy-dup-map")
72-
ctx.execute(
73-
Dataset.from_list(shards).load_parquet().map(add_to_dup_map),
74-
)
66+
ctx = ZephyrContext(client=LocalClient(), name="fuzzy-dup-map")
67+
ctx.execute(
68+
Dataset.from_list(shards).load_parquet().map(add_to_dup_map),
69+
)
7570

7671
return shard_dup_map
7772

lib/zephyr/src/zephyr/execution.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,16 +1546,21 @@ def execute(
15461546
job_name = f"zephyr-{self.name}-p{self._pipeline_id}-a{attempt}"
15471547
# The wrapper job just blocks on child actors; real
15481548
# resources are requested by the coordinator/worker children.
1549-
self._coordinator_job = self.client.submit(
1550-
JobRequest(
1551-
name=job_name,
1552-
entrypoint=Entrypoint.from_callable(
1553-
_run_coordinator_job,
1554-
args=(config, result_path),
1555-
),
1556-
resources=ResourceConfig(cpu=1, ram="1g"),
1549+
# Set the context var so the coordinator job inherits self.client
1550+
# instead of auto-detecting (which may pick a different backend).
1551+
from fray.v2.client import set_current_client
1552+
1553+
with set_current_client(self.client):
1554+
self._coordinator_job = self.client.submit(
1555+
JobRequest(
1556+
name=job_name,
1557+
entrypoint=Entrypoint.from_callable(
1558+
_run_coordinator_job,
1559+
args=(config, result_path),
1560+
),
1561+
resources=ResourceConfig(cpu=1, ram="1g"),
1562+
)
15571563
)
1558-
)
15591564

15601565
backoff.reset()
15611566
logger.info("Coordinator job submitted: %s (job_id=%s)", job_name, self._coordinator_job.job_id)

0 commit comments

Comments
 (0)