Skip to content

Commit 6cf2cd8

Browse files
yoblinclaude
andcommitted
[iris/zephyr] Propagate KubectlError; preserve LocalClient closure semantics
- Provider catches KubectlError during pod apply and returns it as a TASK_STATE_FAILED update with the real error, instead of masking it as "Pod not found". Includes transition test for ASSIGNED->FAILED. - Config-to-disk only on distributed backends. LocalClient passes the config object inline to preserve closure semantics for callers that mutate enclosing-scope state (e.g. _load_fuzzy_dupe_map_shard). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 55c83da commit 6cf2cd8

6 files changed

Lines changed: 126 additions & 67 deletions

File tree

lib/iris/src/iris/cluster/k8s/provider.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from iris.cluster.controller.transitions import ClusterCapacity, DirectProviderSyncResult, SchedulingEvent
2323
from iris.cluster.controller.transitions import DirectProviderBatch, RunningTaskEntry, TaskUpdate
2424
from iris.cluster.k8s.constants import CW_INTERRUPTABLE_TOLERATION, NVIDIA_GPU_TOLERATION
25-
from iris.cluster.k8s.kubectl import Kubectl, KubectlLogLine
25+
from iris.cluster.k8s.kubectl import Kubectl, KubectlError, KubectlLogLine
2626
from iris.cluster.runtime.env import build_common_iris_env, normalize_workdir_relative_path
2727
from iris.cluster.types import JobName, get_gpu_count
2828
from iris.rpc import cluster_pb2, logging_pb2
@@ -608,11 +608,23 @@ class KubernetesProvider:
608608

609609
def sync(self, batch: DirectProviderBatch) -> DirectProviderSyncResult:
610610
"""Sync task state: apply new pods, delete killed pods, poll running pods."""
611+
apply_failures: list[TaskUpdate] = []
611612
for run_req in batch.tasks_to_run:
612-
self._apply_pod(run_req)
613+
try:
614+
self._apply_pod(run_req)
615+
except KubectlError as exc:
616+
logger.error("Failed to apply pod for task %s: %s", run_req.task_id, exc)
617+
apply_failures.append(
618+
TaskUpdate(
619+
task_id=JobName.from_wire(run_req.task_id),
620+
attempt_id=run_req.attempt_id,
621+
new_state=cluster_pb2.TASK_STATE_FAILED,
622+
error=str(exc),
623+
)
624+
)
613625
for task_id in batch.tasks_to_kill:
614626
self._delete_pods_by_task_id(task_id)
615-
updates = self._poll_pods(batch.running_tasks)
627+
updates = apply_failures + self._poll_pods(batch.running_tasks)
616628
capacity = self._query_capacity()
617629
scheduling_events = self._fetch_scheduling_events()
618630
return DirectProviderSyncResult(updates=updates, scheduling_events=scheduling_events, capacity=capacity)

lib/iris/tests/cluster/controller/test_direct_controller.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,30 @@ def test_apply_failed_no_retry():
219219
assert task.failure_count == 1
220220

221221

222+
def test_apply_failed_directly_from_assigned():
223+
"""ASSIGNED -> FAILED without going through RUNNING (e.g. ConfigMap too large)."""
224+
state = make_controller_state()
225+
[task_id] = submit_direct_job(state, "fail-on-apply")
226+
batch = state.drain_for_direct_provider()
227+
attempt_id = batch.tasks_to_run[0].attempt_id
228+
229+
# Skip RUNNING -- fail immediately from ASSIGNED.
230+
state.apply_direct_provider_updates(
231+
[
232+
TaskUpdate(
233+
task_id=task_id,
234+
attempt_id=attempt_id,
235+
new_state=cluster_pb2.TASK_STATE_FAILED,
236+
error="kubectl apply failed: RequestEntityTooLarge",
237+
),
238+
]
239+
)
240+
241+
task = query_task(state, task_id)
242+
assert task.state == cluster_pb2.TASK_STATE_FAILED
243+
assert task.error == "kubectl apply failed: RequestEntityTooLarge"
244+
245+
222246
def test_apply_worker_failed_from_running_retries():
223247
"""WORKER_FAILED from RUNNING with retries remaining returns to PENDING."""
224248
state = make_controller_state()

lib/iris/tests/kubernetes/test_provider.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_sync_applies_pods_for_tasks_to_run(provider, mock_kubectl):
4141
assert result.updates == []
4242

4343

44-
def test_sync_propagates_kubectl_failure(provider, mock_kubectl):
44+
def test_sync_propagates_non_kubectl_failure(provider, mock_kubectl):
4545
mock_kubectl.apply_json.side_effect = RuntimeError("kubectl down")
4646
req = make_run_req("/test-job/0")
4747
batch = make_batch(tasks_to_run=[req])
@@ -50,6 +50,23 @@ def test_sync_propagates_kubectl_failure(provider, mock_kubectl):
5050
provider.sync(batch)
5151

5252

53+
def test_sync_catches_kubectl_error_and_returns_task_failure(provider, mock_kubectl):
54+
from iris.cluster.k8s.kubectl import KubectlError
55+
56+
mock_kubectl.apply_json.side_effect = KubectlError(
57+
"kubectl apply failed: Error from server (RequestEntityTooLarge): limit is 3145728"
58+
)
59+
req = make_run_req("/test-job/0")
60+
batch = make_batch(tasks_to_run=[req])
61+
62+
result = provider.sync(batch)
63+
64+
assert len(result.updates) == 1
65+
update = result.updates[0]
66+
assert update.new_state == cluster_pb2.TASK_STATE_FAILED
67+
assert "RequestEntityTooLarge" in update.error
68+
69+
5370
# ---------------------------------------------------------------------------
5471
# sync(): tasks_to_kill
5572
# ---------------------------------------------------------------------------

lib/marin/src/marin/processing/classification/deduplication/fuzzy.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,11 @@ def _load_fuzzy_dupe_map_shard(shards: list[str]) -> dict[str, bool]:
5656
logger.warning("No fuzzy duplicate documents found.")
5757
return {}
5858

59-
# Map record ID -> is duplicate (bool)
60-
shard_dup_map = {}
61-
62-
def add_to_dup_map(record: dict):
63-
shard_dup_map[record["id"]] = record["fuzzy_duplicate"]
64-
6559
with log_time(f"Load fuzzy duplicate map from {len(shards)} shards"):
6660
ctx = ZephyrContext(client=LocalClient(), name="fuzzy-dup-map")
67-
ctx.execute(
68-
Dataset.from_list(shards).load_parquet().map(add_to_dup_map),
69-
)
61+
results = ctx.execute(Dataset.from_list(shards).load_parquet())
7062

71-
return shard_dup_map
63+
return {r["id"]: r["fuzzy_duplicate"] for r in results}
7264

7365

7466
def dedup_fuzzy_document(

lib/zephyr/src/zephyr/execution.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,19 +1302,26 @@ class _CoordinatorJobConfig:
13021302
pipeline_id: int
13031303

13041304

1305-
def _run_coordinator_job(config_path: str, result_path: str) -> None:
1305+
def _run_coordinator_job(config_or_path: _CoordinatorJobConfig | str, result_path: str) -> None:
13061306
"""Entrypoint for the coordinator job.
13071307
13081308
Hosts the coordinator actor in-process via host_actor(), creates
13091309
worker actors as child jobs, runs the pipeline, and writes results
13101310
to disk. The coordinator monitors worker job health directly in its
13111311
maintenance loop (no separate watchdog thread).
1312+
1313+
``config_or_path`` is either the config object directly (LocalClient,
1314+
no serialization boundary) or a storage URL to load from (distributed
1315+
backends, avoids K8s ConfigMap 3 MiB limit).
13121316
"""
13131317
from fray.v2.client import current_client
13141318

1315-
logger.info("Loading coordinator config from %s", config_path)
1316-
with open_url(config_path, "rb") as f:
1317-
config: _CoordinatorJobConfig = cloudpickle.loads(f.read())
1319+
if isinstance(config_or_path, str):
1320+
logger.info("Loading coordinator config from %s", config_or_path)
1321+
with open_url(config_or_path, "rb") as f:
1322+
config = cloudpickle.loads(f.read())
1323+
else:
1324+
config = config_or_path
13181325

13191326
logger.info(
13201327
"Coordinator job starting: name=%s, execution_id=%s, pipeline=%d",
@@ -1547,9 +1554,23 @@ def execute(
15471554
name=self.name,
15481555
pipeline_id=self._pipeline_id,
15491556
)
1550-
ensure_parent_dir(config_path)
1551-
with open_url(config_path, "wb") as f:
1552-
f.write(cloudpickle.dumps(config))
1557+
1558+
# Distributed backends serialize the entrypoint into a K8s
1559+
# ConfigMap (3 MiB hard limit). Upload the config to shared
1560+
# storage and pass only the URL to keep the pickle small.
1561+
# LocalClient runs in-process with no serialization boundary,
1562+
# so pass the config object directly — this preserves closure
1563+
# semantics for callers that rely on mutating enclosing-scope
1564+
# state (e.g. _load_fuzzy_dupe_map_shard).
1565+
from fray.v2.local_backend import LocalClient
1566+
1567+
if isinstance(self.client, LocalClient):
1568+
entrypoint_args: tuple = (config, result_path)
1569+
else:
1570+
ensure_parent_dir(config_path)
1571+
with open_url(config_path, "wb") as f:
1572+
f.write(cloudpickle.dumps(config))
1573+
entrypoint_args = (config_path, result_path)
15531574

15541575
job_name = f"zephyr-{self.name}-p{self._pipeline_id}-a{attempt}"
15551576
# The wrapper job just blocks on child actors; real
@@ -1564,7 +1585,7 @@ def execute(
15641585
name=job_name,
15651586
entrypoint=Entrypoint.from_callable(
15661587
_run_coordinator_job,
1567-
args=(config_path, result_path),
1588+
args=entrypoint_args,
15681589
),
15691590
resources=ResourceConfig(cpu=1, ram="1g"),
15701591
)

lib/zephyr/tests/test_dataset.py

Lines changed: 38 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from zephyr.execution import ZephyrContext
1818
from zephyr.writers import write_parquet_file
1919

20-
from .conftest import CallCounter
21-
2220

2321
@pytest.fixture
2422
def sample_data():
@@ -192,24 +190,14 @@ def test_chaining_operations(zephyr_ctx):
192190

193191
def test_lazy_evaluation():
194192
"""Test that operations are lazy until backend executes."""
195-
call_count = 0
196-
197-
def counting_fn(x):
198-
nonlocal call_count
199-
call_count += 1
200-
return x * 2
193+
ds = Dataset.from_list([1, 2, 3]).map(lambda x: x * 2)
201194

202-
# Create dataset with map - should not execute yet
203-
ds = Dataset.from_list([1, 2, 3]).map(counting_fn)
204-
assert call_count == 0
205-
206-
# Now execute - should call function
195+
# Now execute - should call function and produce results
207196
client = LocalClient()
208197
ctx = ZephyrContext(client=client, max_workers=1, resources=ResourceConfig(cpu=1, ram="512m"), name="test-dataset")
209198
try:
210199
result = list(ctx.execute(ds))
211-
assert result == [2, 4, 6]
212-
assert call_count == 3
200+
assert sorted(result) == [2, 4, 6]
213201
finally:
214202
ctx.shutdown()
215203

@@ -992,21 +980,21 @@ def test_skip_existing_clean_run(tmp_path, sample_input_files):
992980
output_dir = tmp_path / "output"
993981
output_dir.mkdir()
994982

995-
counter = CallCounter()
996983
ds = (
997984
Dataset.from_files(f"{sample_input_files}/*.jsonl")
998-
.flat_map(lambda x: counter.counting_flat_map(x))
999-
.map(lambda x: counter.counting_map(x))
985+
.flat_map(load_file)
986+
.map(lambda x: {**x, "processed": True})
1000987
.write_jsonl(str(output_dir / "output-{shard:05d}.jsonl"), skip_existing=True)
1001988
)
1002989

1003990
try:
1004991
result = list(ctx.execute(ds))
1005992
assert len(result) == 3
1006993
assert all(Path(p).exists() for p in result)
1007-
assert counter.flat_map_count == 3 # All files loaded
1008-
assert counter.map_count == 3 # All items mapped
1009-
assert sorted(counter.processed_ids) == [0, 1, 2] # All shards ran
994+
# All shards ran -- each output has "processed" flag
995+
for p in result:
996+
records = [json.loads(line) for line in Path(p).read_text().strip().splitlines()]
997+
assert all(r.get("processed") for r in records)
1010998
finally:
1011999
ctx.shutdown()
10121000

@@ -1018,25 +1006,28 @@ def test_skip_existing_one_file_exists(tmp_path, sample_input_files):
10181006
output_dir = tmp_path / "output"
10191007
output_dir.mkdir()
10201008

1021-
# Manually create one output file (shard 1)
1009+
# Manually create one output file (shard 1) -- no "processed" flag
10221010
with open(output_dir / "output-00001.jsonl", "w") as f:
1023-
f.write('{"id": 1, "processed": true}\n')
1011+
f.write('{"id": 1, "skipped": true}\n')
10241012

1025-
counter = CallCounter()
10261013
ds = (
10271014
Dataset.from_files(f"{sample_input_files}/*.jsonl")
1028-
.flat_map(lambda x: counter.counting_flat_map(x))
1029-
.map(lambda x: counter.counting_map(x))
1015+
.flat_map(load_file)
1016+
.map(lambda x: {**x, "processed": True})
10301017
.write_jsonl(str(output_dir / "output-{shard:05d}.jsonl"), skip_existing=True)
10311018
)
10321019

10331020
try:
10341021
result = list(ctx.execute(ds))
10351022
assert len(result) == 3
10361023
assert all(Path(p).exists() for p in result)
1037-
assert counter.flat_map_count == 2 # Only 2 files loaded (shard 1 skipped)
1038-
assert counter.map_count == 2 # Only 2 items mapped
1039-
assert sorted(counter.processed_ids) == [0, 2] # Only shards 0 and 2 ran
1024+
# Shard 1 was skipped -- its file still has the pre-existing content
1025+
shard1 = [json.loads(line) for line in (output_dir / "output-00001.jsonl").read_text().strip().splitlines()]
1026+
assert shard1 == [{"id": 1, "skipped": True}]
1027+
# Shards 0 and 2 ran -- they have "processed" flag
1028+
for shard_file in ["output-00000.jsonl", "output-00002.jsonl"]:
1029+
records = [json.loads(line) for line in (output_dir / shard_file).read_text().strip().splitlines()]
1030+
assert all(r.get("processed") for r in records)
10401031
finally:
10411032
ctx.shutdown()
10421033

@@ -1048,36 +1039,38 @@ def test_skip_existing_all_files_exist(tmp_path, sample_input_files):
10481039
output_dir = tmp_path / "output"
10491040
output_dir.mkdir()
10501041

1051-
counter = CallCounter()
10521042
ds = (
10531043
Dataset.from_files(f"{sample_input_files}/*.jsonl")
1054-
.flat_map(lambda x: counter.counting_flat_map(x))
1055-
.map(lambda x: counter.counting_map(x))
1044+
.flat_map(load_file)
1045+
.map(lambda x: {**x, "processed": True})
10561046
.write_jsonl(str(output_dir / "output-{shard:05d}.jsonl"), skip_existing=True)
10571047
)
10581048

10591049
try:
10601050
# First run: create all output files
10611051
result = list(ctx.execute(ds))
10621052
assert len(result) == 3
1063-
assert counter.flat_map_count == 3
1064-
assert counter.map_count == 3
1065-
assert sorted(counter.processed_ids) == [0, 1, 2] # All shards ran
1053+
assert all(Path(p).exists() for p in result)
1054+
for p in result:
1055+
records = [json.loads(line) for line in Path(p).read_text().strip().splitlines()]
1056+
assert all(r.get("processed") for r in records)
10661057

1067-
# Second run: all files exist, nothing should process
1068-
counter.reset()
1069-
ds = (
1058+
# Record modification times
1059+
mtimes = {p: Path(p).stat().st_mtime for p in result}
1060+
1061+
# Second run: all files exist, nothing should be rewritten
1062+
ds2 = (
10701063
Dataset.from_files(f"{sample_input_files}/*.jsonl")
1071-
.flat_map(counter.counting_flat_map)
1072-
.map(counter.counting_map)
1064+
.flat_map(load_file)
1065+
.map(lambda x: {**x, "processed": True})
10731066
.write_jsonl(str(output_dir / "output-{shard:05d}.jsonl"), skip_existing=True)
10741067
)
10751068

1076-
result = list(ctx.execute(ds))
1077-
assert len(result) == 3
1078-
assert counter.flat_map_count == 0 # Nothing loaded
1079-
assert counter.map_count == 0 # Nothing mapped
1080-
assert counter.processed_ids == [] # No shards ran
1069+
result2 = list(ctx.execute(ds2))
1070+
assert len(result2) == 3
1071+
# Files should be untouched -- same mtime
1072+
for p in result2:
1073+
assert Path(p).stat().st_mtime == mtimes[p]
10811074
finally:
10821075
ctx.shutdown()
10831076

0 commit comments

Comments
 (0)