-
Notifications
You must be signed in to change notification settings - Fork 108
Expand file tree
/
Copy pathtest_transitions.py
More file actions
3608 lines (2807 loc) · 138 KB
/
test_transitions.py
File metadata and controls
3608 lines (2807 loc) · 138 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0
"""Tests for controller state management.
These tests exercise end-to-end observable behavior through the event-driven API (handle_event).
They focus on:
- Full workflows (submit job -> dispatch tasks -> complete/fail)
- Important edge cases (retry exhaustion, worker failure cascades, failure domains)
- Final state verification rather than intermediate steps
"""
import threading
from iris.cluster.constraints import DeviceType, WellKnownAttribute, constraints_from_resources
from iris.cluster.controller.autoscaler import DemandEntry
from iris.cluster.controller.controller import compute_demand_entries
from iris.cluster.controller.db import (
ControllerDB,
EndpointQuery,
attempt_is_terminal,
endpoint_query_sql,
)
from iris.cluster.controller.schema import (
ATTEMPT_PROJECTION,
ENDPOINT_PROJECTION,
JOB_DETAIL_PROJECTION,
TASK_DETAIL_PROJECTION,
WORKER_DETAIL_PROJECTION,
EndpointRow,
)
from iris.cluster.controller.scheduler import JobRequirements, Scheduler
from iris.cluster.controller.transitions import (
Assignment,
ControllerTransitions,
HEARTBEAT_FAILURE_THRESHOLD,
HeartbeatAction,
HeartbeatApplyRequest,
MAX_REPLICAS_PER_JOB,
PruneResult,
TaskUpdate,
)
from iris.cluster.types import JobName, WorkerId
from iris.rpc import job_pb2
from iris.rpc import controller_pb2
from iris.rpc import logging_pb2
from rigging.timing import Duration, Timestamp
from .conftest import (
building_counts as _building_counts,
check_task_can_be_scheduled,
check_task_is_finished,
dispatch_task,
fail_worker,
healthy_active_workers,
make_job_request,
make_test_entrypoint as _make_test_entrypoint,
make_worker_metadata,
query_job as _query_job,
query_task as _query_task,
query_tasks_for_job as _query_tasks_for_job,
query_worker as _query_worker,
register_worker,
schedulable_tasks as _schedulable_tasks,
submit_job,
transition_task,
worker_running_tasks,
)
# =============================================================================
# Test Helpers
# =============================================================================
def _queued_dispatch(
state: ControllerTransitions, worker_id: WorkerId
) -> tuple[list[job_pb2.RunTaskRequest], list[str]]:
rows = state._db.fetchall(
"SELECT kind, payload_proto, task_id FROM dispatch_queue WHERE worker_id = ? ORDER BY id ASC",
(str(worker_id),),
)
tasks_to_run: list[job_pb2.RunTaskRequest] = []
tasks_to_kill: list[str] = []
for row in rows:
if str(row["kind"]) == "run" and row["payload_proto"] is not None:
req = job_pb2.RunTaskRequest()
req.ParseFromString(bytes(row["payload_proto"]))
tasks_to_run.append(req)
elif row["task_id"] is not None:
tasks_to_kill.append(str(row["task_id"]))
return tasks_to_run, tasks_to_kill
def _endpoints(state: ControllerTransitions, query: EndpointQuery = EndpointQuery()) -> list[EndpointRow]:
sql, params = endpoint_query_sql(query)
# Add ORDER BY to match original behavior
sql += " ORDER BY registered_at_ms DESC, endpoint_id ASC"
with state._db.snapshot() as q:
return ENDPOINT_PROJECTION.decode(q.fetchall(sql, tuple(params)))
def _build_scheduling_context(scheduler: Scheduler, state: ControllerTransitions):
pending = _schedulable_tasks(state)
workers = healthy_active_workers(state)
task_ids = [t.task_id for t in pending]
jobs: dict[JobName, JobRequirements] = {}
for t in pending:
job_id = t.task_id.parent
if job_id and job_id not in jobs:
job = _query_job(state, job_id)
if job:
jobs[job_id] = JobRequirements(
resources=job.request.resources,
constraints=list(job.request.constraints),
is_coscheduled=job.request.HasField("coscheduling"),
coscheduling_group_by=(
job.request.coscheduling.group_by if job.request.HasField("coscheduling") else None
),
)
return scheduler.create_scheduling_context(
workers,
building_counts=_building_counts(state),
pending_tasks=task_ids,
jobs=jobs,
)
def test_db_snapshot_select_returns_typed_rows(state) -> None:
request = make_job_request("typed-rows")
tasks = submit_job(state, "typed-rows", request)
job_wire = JobName.root("test-user", "typed-rows").to_wire()
with state._db.snapshot() as q:
jobs = JOB_DETAIL_PROJECTION.decode(q.fetchall("SELECT * FROM jobs WHERE job_id = ?", (job_wire,)))
task_count = q.fetchone("SELECT COUNT(*) FROM tasks WHERE job_id = ?", (job_wire,))[0]
assert len(jobs) == 1
assert jobs[0].submitted_at is not None
assert jobs[0].job_id == JobName.root("test-user", "typed-rows")
assert task_count == len(tasks)
def test_db_snapshot_projection_inferrs_typed_values(state) -> None:
wid = register_worker(state, "proj-worker", "addr", make_worker_metadata())
request = controller_pb2.Controller.LaunchJobRequest(
name=JobName.root("test-user", "projection").to_wire(),
entrypoint=_make_test_entrypoint(),
resources=job_pb2.ResourceSpecProto(cpu_millicores=1000, memory_bytes=1024**3),
environment=job_pb2.EnvironmentConfig(),
replicas=1,
)
[task] = submit_job(state, "projection", request)
state.queue_assignments([Assignment(task_id=task.task_id, worker_id=wid)])
running = worker_running_tasks(state, wid)
assert len(running) == 1
assert task.task_id in running
def test_db_snapshot_exists_for_workers(state) -> None:
register_worker(state, "exists-worker", "addr", make_worker_metadata())
with state._db.snapshot() as q:
assert q.fetchone("SELECT 1 FROM workers WHERE worker_id = ?", ("exists-worker",)) is not None
# =============================================================================
# Job/Task Lifecycle Integration Tests
# =============================================================================
def test_job_lifecycle_success(harness):
"""E2E: Submit job -> dispatch task -> succeed -> verify final state."""
worker_id = harness.add_worker("w1")
tasks = harness.submit("j1", replicas=2)
assert len(tasks) == 2
assert harness.query_job(JobName.root("test-user", "j1")).state == job_pb2.JOB_STATE_PENDING
for task in tasks:
harness.dispatch(task, worker_id)
harness.transition(task.task_id, job_pb2.TASK_STATE_SUCCEEDED)
assert harness.query_job(JobName.root("test-user", "j1")).state == job_pb2.JOB_STATE_SUCCEEDED
for task in tasks:
assert harness.query_task(task.task_id).state == job_pb2.TASK_STATE_SUCCEEDED
assert len(_schedulable_tasks(harness.state)) == 0
def test_job_lifecycle_failure_exhausted_retries(harness):
"""E2E: Task failure with no retries -> job fails."""
worker_id = harness.add_worker("w1")
[task] = harness.submit("j1")
job_id = JobName.root("test-user", "j1")
harness.dispatch(task, worker_id)
harness.transition(task.task_id, job_pb2.TASK_STATE_FAILED, error="Task failed")
assert harness.query_task(task.task_id).state == job_pb2.TASK_STATE_FAILED
assert check_task_is_finished(harness.query_task(task.task_id))
assert harness.query_job(job_id).state == job_pb2.JOB_STATE_FAILED
def test_task_failure_with_retry_requeues(harness):
"""E2E: Task failure with retries -> task requeued, job stays running."""
worker_id = harness.add_worker("w1")
req = make_job_request("job1")
req.max_task_failures = 1
req.max_retries_failure = 1
tasks = submit_job(harness.state, "j1", req)
task = tasks[0]
job_id = JobName.root("test-user", "j1")
harness.dispatch(task, worker_id)
harness.transition(task.task_id, job_pb2.TASK_STATE_FAILED)
assert harness.query_task(task.task_id).state == job_pb2.TASK_STATE_PENDING
assert check_task_can_be_scheduled(harness.query_task(task.task_id))
assert harness.query_job(job_id).state == job_pb2.JOB_STATE_RUNNING
pending = _schedulable_tasks(harness.state)
assert len(pending) == 1
assert pending[0].task_id == task.task_id
def test_unschedulable_task_finalizes_job_with_timeout_error(harness):
"""E2E: Task UNSCHEDULABLE propagates timeout-style error to final job state."""
worker_id = harness.add_worker("w1")
[task] = harness.submit("j1", scheduling_timeout_seconds=300)
job_id = JobName.root("test-user", "j1")
harness.dispatch(task, worker_id)
harness.transition(task.task_id, job_pb2.TASK_STATE_UNSCHEDULABLE)
assert harness.query_task(task.task_id).state == job_pb2.TASK_STATE_UNSCHEDULABLE
assert harness.query_task(task.task_id).error == "Scheduling timeout exceeded"
assert harness.query_job(job_id).state == job_pb2.JOB_STATE_UNSCHEDULABLE
assert harness.query_job(job_id).error == "Scheduling timeout exceeded"
def test_job_cancellation_kills_all_tasks(harness):
"""E2E: Job cancellation -> all tasks killed."""
worker_id = harness.add_worker("w1")
tasks = harness.submit("j1", replicas=3)
job_id = JobName.root("test-user", "j1")
harness.dispatch(tasks[0], worker_id)
harness.dispatch(tasks[1], worker_id)
harness.state.cancel_job(job_id, reason="User cancelled")
assert harness.query_job(job_id).state == job_pb2.JOB_STATE_KILLED
for task in tasks:
assert harness.query_task(task.task_id).state == job_pb2.TASK_STATE_KILLED
def test_cancel_job_releases_committed_worker_resources(harness):
"""cancel_job must decommit resources on workers that had active tasks.
Regression: cancel_job marked tasks KILLED without calling _decommit_worker_resources.
apply_heartbeat then skipped the update (task already finished), so committed resources
were never released, permanently blocking scheduling on those workers.
"""
w1 = harness.add_worker("w1")
w2 = harness.add_worker("w2")
tasks = harness.submit("j1", replicas=3)
harness.dispatch(tasks[0], w1)
harness.dispatch(tasks[1], w2)
assert _query_worker(harness.state, w1).committed_cpu_millicores == 1000
assert _query_worker(harness.state, w1).committed_mem == 1024**3
assert _query_worker(harness.state, w2).committed_cpu_millicores == 1000
harness.state.cancel_job(JobName.root("test-user", "j1"), reason="User cancelled")
assert _query_worker(harness.state, w1).committed_cpu_millicores == 0, "w1 leaked committed_cpu_millicores"
assert _query_worker(harness.state, w1).committed_mem == 0, "w1 leaked committed_mem"
assert _query_worker(harness.state, w2).committed_cpu_millicores == 0, "w2 leaked committed_cpu_millicores"
assert _query_worker(harness.state, w2).committed_mem == 0, "w2 leaked committed_mem"
assert len(worker_running_tasks(harness.state, w1)) == 0
assert len(worker_running_tasks(harness.state, w2)) == 0
def test_cancel_job_preserves_kill_worker_mapping_after_clearing_tasks(harness):
"""cancel_job returns worker routing for kill RPCs before current_worker_id is cleared."""
w1 = harness.add_worker("w1")
w2 = harness.add_worker("w2")
tasks = harness.submit("j1", replicas=2)
harness.dispatch(tasks[0], w1)
harness.dispatch(tasks[1], w2)
result = harness.state.cancel_job(JobName.root("test-user", "j1"), reason="User cancelled")
assert result.tasks_to_kill == {tasks[0].task_id, tasks[1].task_id}
assert result.task_kill_workers == {
tasks[0].task_id: w1,
tasks[1].task_id: w2,
}
assert harness.query_task(tasks[0].task_id).current_worker_id is None
assert harness.query_task(tasks[1].task_id).current_worker_id is None
def test_cancel_job_removes_endpoints_for_job_tree(state):
parent_worker = register_worker(state, "w1", "host1:8080", make_worker_metadata())
child_worker = register_worker(state, "w2", "host2:8080", make_worker_metadata())
parent_tasks = submit_job(state, "parent", make_job_request("parent"))
child_req = make_job_request("child")
child_req.name = JobName.from_string("/test-user/parent/child").to_wire()
child_tasks = submit_job(state, "/test-user/parent/child", child_req)
dispatch_task(state, parent_tasks[0], parent_worker)
dispatch_task(state, child_tasks[0], child_worker)
state.add_endpoint(
EndpointRow(
endpoint_id="parent-ep",
name="parent/actor",
address="host1:9000",
job_id=JobName.root("test-user", "parent"),
metadata={},
registered_at=Timestamp.now(),
),
task_id=parent_tasks[0].task_id,
)
state.add_endpoint(
EndpointRow(
endpoint_id="child-ep",
name="parent/child/actor",
address="host2:9000",
job_id=JobName.from_string("/test-user/parent/child"),
metadata={},
registered_at=Timestamp.now(),
),
task_id=child_tasks[0].task_id,
)
assert len(_endpoints(state, EndpointQuery())) == 2
state.cancel_job(JobName.root("test-user", "parent"), reason="User cancelled")
assert _endpoints(state, EndpointQuery()) == []
def test_cancelled_job_tasks_excluded_from_demand(harness):
"""Regression test for issue #2777: Killed tasks with no attempts should not appear in demand entries."""
worker_id = harness.add_worker("w1")
tasks = harness.submit("j1", replicas=3)
job_id = JobName.root("test-user", "j1")
harness.dispatch(tasks[0], worker_id)
harness.state.cancel_job(job_id, reason="User cancelled")
assert harness.query_job(job_id).state == job_pb2.JOB_STATE_KILLED
for task in tasks:
assert harness.query_task(task.task_id).state == job_pb2.TASK_STATE_KILLED
assert not check_task_can_be_scheduled(harness.query_task(task.task_id))
assert len(_schedulable_tasks(harness.state)) == 0
assert len(compute_demand_entries(harness.state._db)) == 0
# =============================================================================
# Worker Failure Cascade Tests
# =============================================================================
def test_worker_failure_cascades_to_running_tasks(harness):
"""E2E: Worker failure -> running tasks transition to WORKER_FAILED and requeue."""
worker_id = harness.add_worker("w1")
req = make_job_request("job1")
req.max_retries_preemption = 1
tasks = submit_job(harness.state, "j1", req)
task = tasks[0]
harness.dispatch(task, worker_id)
fail_worker(harness.state, worker_id, "Connection lost")
assert _query_worker(harness.state, worker_id) is None
assert harness.query_task(task.task_id).state == job_pb2.TASK_STATE_PENDING
assert check_task_can_be_scheduled(harness.query_task(task.task_id))
assert len(_schedulable_tasks(harness.state)) == 1
def test_failed_worker_is_pruned_from_state(state):
"""E2E: Worker failure removes worker from state, preventing dead worker accumulation."""
w1 = register_worker(state, "w1", "host1:8080", make_worker_metadata())
w2 = register_worker(state, "w2", "host2:8080", make_worker_metadata())
req = make_job_request("job1")
req.max_retries_preemption = 1
tasks = submit_job(state, "j1", req)
dispatch_task(state, tasks[0], w1)
# Worker w1 fails
fail_worker(state, w1, "Connection lost")
# w1 is gone from state entirely
assert _query_worker(state, w1) is None
# w2 is still present
assert _query_worker(state, w2) is not None
# list_all_workers only returns w2
with state._db.snapshot() as q:
all_workers = WORKER_DETAIL_PROJECTION.decode(q.fetchall("SELECT * FROM workers"))
assert len(all_workers) == 1
assert all_workers[0].worker_id == w2
# Task was requeued despite worker removal
assert tasks[0].state == job_pb2.TASK_STATE_PENDING
assert check_task_can_be_scheduled(tasks[0])
# A re-registering worker creates a fresh entry
w1_again = register_worker(state, "w1", "host1:8080", make_worker_metadata())
assert _query_worker(state, w1_again) is not None
assert _query_worker(state, w1_again).healthy is True
with state._db.snapshot() as q:
assert len(WORKER_DETAIL_PROJECTION.decode(q.fetchall("SELECT * FROM workers"))) == 2
def test_dispatch_failure_marks_worker_failed_and_requeues_task(state):
"""E2E: Dispatch RPC failure (task in PENDING) -> worker failed event cascades to task."""
worker_id = register_worker(state, "w1", "host:8080", make_worker_metadata())
req = make_job_request("job1")
req.max_retries_preemption = 1
tasks = submit_job(state, "j1", req)
task = tasks[0]
# Task gets assigned (creates attempt, puts in ASSIGNED state)
state.queue_assignments([Assignment(task_id=task.task_id, worker_id=worker_id)])
assert _query_task(state, task.task_id).state == job_pb2.TASK_STATE_ASSIGNED
assert _query_task(state, task.task_id).current_attempt_id == 0
# Dispatch RPC fails -> WORKER_FAILED event
fail_worker(state, worker_id, "Dispatch RPC failed: Connection refused")
# Verify cascade:
# 1. Worker marked unhealthy
assert _query_worker(state, worker_id) is None
# 2. Task requeued (back to PENDING for retry).
# Since the task was still ASSIGNED (never confirmed BUILDING/RUNNING),
# this is a delivery failure — no budget consumed at all.
assert _query_task(state, task.task_id).state == job_pb2.TASK_STATE_PENDING
assert _query_task(state, task.task_id).preemption_count == 0
assert _query_task(state, task.task_id).failure_count == 0
assert check_task_can_be_scheduled(_query_task(state, task.task_id))
# 3. Task should be requeued for retry
pending = _schedulable_tasks(state)
assert len(pending) == 1
assert pending[0].task_id == task.task_id
# 4. Worker no longer has task assigned
assert _query_worker(state, worker_id) is None
def test_task_assigned_to_missing_worker_is_ignored(state):
"""Stale assignments to pruned workers are skipped without crashing."""
worker_id = register_worker(state, "w1", "host:8080", make_worker_metadata())
tasks = submit_job(state, "j1", make_job_request("job1"))
task = tasks[0]
# Worker disappears between scheduling and assignment commit.
state.remove_worker(worker_id)
state.queue_assignments([Assignment(task_id=task.task_id, worker_id=worker_id)])
# Task remains schedulable and no attempt/resources are committed.
assert _query_task(state, task.task_id).state == job_pb2.TASK_STATE_PENDING
assert _query_task(state, task.task_id).current_attempt_id == -1
assert check_task_can_be_scheduled(_query_task(state, task.task_id))
assert task.task_id in {t.task_id for t in _schedulable_tasks(state)}
# =============================================================================
# Failure Domain Tests (max_task_failures)
# =============================================================================
def test_failure_domain_kills_remaining_tasks(state):
"""E2E: One task fails beyond retries -> remaining tasks killed, job fails."""
worker_id = register_worker(state, "w1", "host:8080", make_worker_metadata())
req = controller_pb2.Controller.LaunchJobRequest(
name="multi-task-job",
entrypoint=_make_test_entrypoint(),
resources=job_pb2.ResourceSpecProto(cpu_millicores=1000, memory_bytes=1024**3),
environment=job_pb2.EnvironmentConfig(),
max_task_failures=0,
replicas=3,
)
tasks = submit_job(state, "j1", req)
job = _query_job(state, JobName.root("test-user", "j1"))
# Dispatch 2 tasks, leave 1 pending
dispatch_task(state, tasks[0], worker_id)
dispatch_task(state, tasks[1], worker_id)
# Task-0 fails
transition_task(state, tasks[0].task_id, job_pb2.TASK_STATE_FAILED, error="Task failed")
# Verify final state
assert _query_job(state, job.job_id).state == job_pb2.JOB_STATE_FAILED
assert _query_task(state, tasks[0].task_id).state == job_pb2.TASK_STATE_FAILED
assert _query_task(state, tasks[1].task_id).state == job_pb2.TASK_STATE_KILLED
assert _query_task(state, tasks[2].task_id).state == job_pb2.TASK_STATE_KILLED
def test_max_task_failures_tolerance(state):
"""E2E: Job tolerates max_task_failures, then fails on next failure."""
worker_id = register_worker(state, "w1", "host:8080", make_worker_metadata())
req = controller_pb2.Controller.LaunchJobRequest(
name="tolerant-job",
entrypoint=_make_test_entrypoint(),
resources=job_pb2.ResourceSpecProto(cpu_millicores=1000, memory_bytes=1024**3),
replicas=3,
environment=job_pb2.EnvironmentConfig(),
max_task_failures=1,
)
tasks = submit_job(state, "j1", req)
job = _query_job(state, JobName.root("test-user", "j1"))
for task in tasks:
dispatch_task(state, task, worker_id)
# First failure - job should keep running
transition_task(state, tasks[0].task_id, job_pb2.TASK_STATE_FAILED, error="First")
assert _query_job(state, job.job_id).state == job_pb2.JOB_STATE_RUNNING
# Second task succeeds
transition_task(state, tasks[1].task_id, job_pb2.TASK_STATE_SUCCEEDED)
assert _query_job(state, job.job_id).state == job_pb2.JOB_STATE_RUNNING
# Third task fails - exceeds threshold, job fails
transition_task(state, tasks[2].task_id, job_pb2.TASK_STATE_FAILED, error="Second")
assert _query_job(state, job.job_id).state == job_pb2.JOB_STATE_FAILED
def test_preemption_does_not_count_toward_max_task_failures(state):
"""E2E: Worker failures (preemptions) don't count toward max_task_failures."""
worker_id = register_worker(state, "w1", "host:8080", make_worker_metadata())
req = controller_pb2.Controller.LaunchJobRequest(
name="preemption-job",
entrypoint=_make_test_entrypoint(),
resources=job_pb2.ResourceSpecProto(cpu_millicores=1000, memory_bytes=1024**3),
replicas=2,
environment=job_pb2.EnvironmentConfig(),
max_task_failures=0,
max_retries_preemption=1,
)
tasks = submit_job(state, "j1", req)
job = _query_job(state, JobName.root("test-user", "j1"))
dispatch_task(state, tasks[0], worker_id)
transition_task(state, tasks[0].task_id, job_pb2.TASK_STATE_WORKER_FAILED, error="Worker died")
# Preemption doesn't count toward failure threshold; task requeued to PENDING
assert tasks[0].state == job_pb2.TASK_STATE_PENDING
assert check_task_can_be_scheduled(tasks[0])
assert _query_job(state, job.job_id).state == job_pb2.JOB_STATE_RUNNING
# =============================================================================
# Endpoint Cleanup Tests
# =============================================================================
def test_terminal_states_clean_up_endpoints(state):
"""E2E: Task reaching terminal state removes associated endpoints."""
worker_id = register_worker(state, "w1", "host:8080", make_worker_metadata())
req = make_job_request("job1")
tasks = submit_job(state, "j1", req)
task = tasks[0]
dispatch_task(state, task, worker_id)
ep = EndpointRow(
endpoint_id="ep1",
name="j1/actor",
address="a:1",
job_id=JobName.root("test-user", "j1"),
metadata={},
registered_at=Timestamp.now(),
)
state.add_endpoint(ep, task.task_id)
# Verify endpoint visible while running
assert len(_endpoints(state, EndpointQuery(exact_name="j1/actor"))) == 1
# Task succeeds
transition_task(state, task.task_id, job_pb2.TASK_STATE_SUCCEEDED)
# Endpoint removed
assert _endpoints(state, EndpointQuery(exact_name="j1/actor")) == []
def test_endpoint_visibility_by_job_state(state):
"""Endpoints associated with a task are deleted when the task reaches a terminal state."""
worker_id = register_worker(state, "w1", "host:8080", make_worker_metadata())
req = make_job_request("test")
tasks = submit_job(state, "ns-1", req)
job = _query_job(state, JobName.root("test-user", "ns-1"))
task = tasks[0]
ep = EndpointRow(
endpoint_id="ep-1",
name="ns-1/actor",
address="10.0.0.1:8080",
job_id=JobName.root("test-user", "ns-1"),
metadata={},
registered_at=Timestamp.now(),
)
state.add_endpoint(ep, task_id=task.task_id)
# Visible while pending
assert len(_endpoints(state, EndpointQuery(exact_name="ns-1/actor"))) == 1
# Still visible after transition to running
dispatch_task(state, task, worker_id)
assert _query_job(state, job.job_id).state == job_pb2.JOB_STATE_RUNNING
assert len(_endpoints(state, EndpointQuery(exact_name="ns-1/actor"))) == 1
# Deleted when task reaches terminal state
transition_task(state, task.task_id, job_pb2.TASK_STATE_SUCCEEDED)
assert _query_job(state, job.job_id).state == job_pb2.JOB_STATE_SUCCEEDED
assert len(_endpoints(state, EndpointQuery(exact_name="ns-1/actor"))) == 0
def test_endpoint_deleted_on_task_failure_with_retry(state):
"""Endpoints are cleaned up when a task fails even if it retries back to PENDING."""
worker_id = register_worker(state, "w1", "host:8080", make_worker_metadata())
req = make_job_request("test")
req.max_retries_failure = 1
tasks = submit_job(state, "ns-1", req)
task = tasks[0]
dispatch_task(state, task, worker_id)
ep = EndpointRow(
endpoint_id="ep-1",
name="ns-1/actor",
address="10.0.0.1:8080",
job_id=JobName.root("test-user", "ns-1"),
metadata={},
registered_at=Timestamp.now(),
)
state.add_endpoint(ep, task_id=task.task_id)
assert len(_endpoints(state, EndpointQuery(exact_name="ns-1/actor"))) == 1
# Task fails but retries (goes back to PENDING)
transition_task(state, task.task_id, job_pb2.TASK_STATE_FAILED, error="crash")
assert _query_task(state, task.task_id).state == job_pb2.TASK_STATE_PENDING
# Stale endpoints should be deleted even though the task retried
assert len(_endpoints(state, EndpointQuery(exact_name="ns-1/actor"))) == 0
def test_endpoint_deleted_on_worker_failure(state):
"""Endpoints are cleaned up when the worker dies, even if the task retries."""
worker_id = register_worker(state, "w1", "host:8080", make_worker_metadata())
req = make_job_request("test")
req.max_retries_preemption = 1
tasks = submit_job(state, "ns-1", req)
task = tasks[0]
dispatch_task(state, task, worker_id)
ep = EndpointRow(
endpoint_id="ep-1",
name="ns-1/actor",
address="10.0.0.1:8080",
job_id=JobName.root("test-user", "ns-1"),
metadata={},
registered_at=Timestamp.now(),
)
state.add_endpoint(ep, task_id=task.task_id)
assert len(_endpoints(state, EndpointQuery(exact_name="ns-1/actor"))) == 1
# Worker fails -> task retries to PENDING
fail_worker(state, worker_id, "Connection lost")
assert _query_task(state, task.task_id).state == job_pb2.TASK_STATE_PENDING
# Endpoints should be cleaned up because the worker is dead
assert len(_endpoints(state, EndpointQuery(exact_name="ns-1/actor"))) == 0
def test_endpoint_survives_building_state(state):
"""Endpoints registered during BUILDING are not deleted by subsequent BUILDING updates."""
worker_id = register_worker(state, "w1", "host:8080", make_worker_metadata())
req = make_job_request("test")
tasks = submit_job(state, "ns-1", req)
task = tasks[0]
# Assign task and transition to BUILDING
state.queue_assignments([Assignment(task_id=task.task_id, worker_id=worker_id)])
task = _query_task(state, task.task_id)
state.apply_task_updates(
HeartbeatApplyRequest(
worker_id=worker_id,
worker_resource_snapshot=None,
updates=[
TaskUpdate(
task_id=task.task_id,
attempt_id=task.current_attempt_id,
new_state=job_pb2.TASK_STATE_BUILDING,
)
],
)
)
# Register endpoint during BUILDING (e.g. jax_init.py pre-registration)
ep = EndpointRow(
endpoint_id="ep-1",
name="ns-1/actor",
address="10.0.0.1:8080",
job_id=JobName.root("test-user", "ns-1"),
metadata={},
registered_at=Timestamp.now(),
)
state.add_endpoint(ep, task_id=task.task_id)
assert len(_endpoints(state, EndpointQuery(exact_name="ns-1/actor"))) == 1
# Transition to RUNNING — endpoint should survive
state.apply_task_updates(
HeartbeatApplyRequest(
worker_id=worker_id,
worker_resource_snapshot=None,
updates=[
TaskUpdate(
task_id=task.task_id,
attempt_id=_query_task(state, task.task_id).current_attempt_id,
new_state=job_pb2.TASK_STATE_RUNNING,
)
],
)
)
assert len(_endpoints(state, EndpointQuery(exact_name="ns-1/actor"))) == 1
def test_namespace_isolation(state):
"""E2E: Endpoints are isolated by namespace prefix."""
worker_id = register_worker(state, "w1", "host:8080", make_worker_metadata())
req1 = make_job_request("test1")
req2 = make_job_request("test2")
tasks1 = submit_job(state, "ns-1", req1)
tasks2 = submit_job(state, "ns-2", req2)
# Dispatch tasks to transition jobs to RUNNING state
dispatch_task(state, tasks1[0], worker_id)
dispatch_task(state, tasks2[0], worker_id)
state.add_endpoint(
EndpointRow(
endpoint_id="ep-1",
name="ns-1/actor",
address="10.0.0.1:8080",
job_id=JobName.root("test-user", "ns-1"),
metadata={},
registered_at=Timestamp.now(),
)
)
state.add_endpoint(
EndpointRow(
endpoint_id="ep-2",
name="ns-2/actor",
address="10.0.0.2:8080",
job_id=JobName.root("test-user", "ns-2"),
metadata={},
registered_at=Timestamp.now(),
)
)
# Each namespace only sees its own endpoint
results_ns1 = _endpoints(state, EndpointQuery(exact_name="ns-1/actor"))
assert len(results_ns1) == 1
assert results_ns1[0].address == "10.0.0.1:8080"
results_ns2 = _endpoints(state, EndpointQuery(exact_name="ns-2/actor"))
assert len(results_ns2) == 1
assert results_ns2[0].address == "10.0.0.2:8080"
# =============================================================================
# Queue and Worker State Tests
# =============================================================================
def test_task_queue_fifo_order(state):
"""Tasks are returned in FIFO order."""
req1 = make_job_request("job1")
req2 = make_job_request("job2")
submit_job(state, "j1", req1)
submit_job(state, "j2", req2)
pending = _schedulable_tasks(state)
assert len(pending) == 2
assert pending[0].job_id == JobName.root("test-user", "j1")
assert pending[1].job_id == JobName.root("test-user", "j2")
def test_hierarchical_job_tracking(state):
"""Parent-child job relationships are tracked correctly."""
parent_req = make_job_request("parent")
submit_job(state, "parent", parent_req)
child1_req = make_job_request("child1")
submit_job(state, "/test-user/parent/child1", child1_req)
child2_req = make_job_request("child2")
submit_job(state, "/test-user/parent/child2", child2_req)
grandchild_req = make_job_request("grandchild")
submit_job(state, "/test-user/parent/child1/grandchild", grandchild_req)
# get_children only returns direct children
parent_wire = JobName.root("test-user", "parent").to_wire()
with state._db.snapshot() as q:
children = JOB_DETAIL_PROJECTION.decode(q.fetchall("SELECT * FROM jobs WHERE parent_job_id = ?", (parent_wire,)))
assert len(children) == 2
assert {c.job_id for c in children} == {
JobName.from_string("/test-user/parent/child1"),
JobName.from_string("/test-user/parent/child2"),
}
# No children for leaf nodes
grandchild_wire = JobName.from_string("/test-user/parent/child1/grandchild").to_wire()
with state._db.snapshot() as q:
leaf_children = JOB_DETAIL_PROJECTION.decode(
q.fetchall("SELECT * FROM jobs WHERE parent_job_id = ?", (grandchild_wire,)),
)
assert leaf_children == []
def test_thread_safety(state):
"""Concurrent access doesn't corrupt state."""
num_threads = 10
jobs_per_thread = 50
barrier = threading.Barrier(num_threads)
errors = []
def add_jobs(thread_id: int):
try:
barrier.wait()
for i in range(jobs_per_thread):
job_id = f"t{thread_id}_j{i}"
req = make_job_request(f"job-{job_id}")
submit_job(state, job_id, req)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=add_jobs, args=(i,)) for i in range(num_threads)]
for t in threads:
t.start()
for t in threads:
t.join()
assert not errors
expected_count = num_threads * jobs_per_thread
pending = _schedulable_tasks(state)
assert len(pending) == expected_count
# =============================================================================
# Validation Tests
# =============================================================================
def test_excessive_replicas_fails_job(state):
"""E2E: Job with replicas exceeding MAX_REPLICAS_PER_JOB fails immediately."""
req = make_job_request("too-many-replicas")
req.replicas = MAX_REPLICAS_PER_JOB + 1
tasks = submit_job(state, "j1", req)
job = _query_job(state, JobName.root("test-user", "j1"))
assert job is not None
assert _query_job(state, job.job_id).state == job_pb2.JOB_STATE_FAILED
assert f"exceeds max {MAX_REPLICAS_PER_JOB}" in _query_job(state, job.job_id).error
assert len(tasks) == 0
assert len(_schedulable_tasks(state)) == 0
# =============================================================================
# Worker Resource Commitment Tests
# =============================================================================
def test_worker_cannot_accept_task_when_resources_committed(state):
"""E2E: A worker with committed resources cannot accept tasks that exceed remaining capacity."""
# Worker with 4 CPUs
worker_id = register_worker(state, "w1", "host:8080", make_worker_metadata(cpu=4))
# First job uses 3 CPUs
tasks1 = submit_job(state, "j1", make_job_request(cpu=3))
dispatch_task(state, tasks1[0], worker_id)
# Second job needs 2 CPUs - should not fit (only 1 CPU remaining)
submit_job(state, "j2", make_job_request(cpu=2))
# Scheduler should not assign the second task to this worker
pending = _schedulable_tasks(state)
assert len(pending) == 1 # j2's task is still pending
scheduler = Scheduler()
context = _build_scheduling_context(scheduler, state)
result = scheduler.find_assignments(context)
# The task cannot be scheduled - no worker has sufficient capacity
assert len(result.assignments) == 0
assert pending[0].job_id == JobName.root("test-user", "j2")
def test_worker_can_accept_new_task_after_previous_completes(state):
"""E2E: After a task completes, its resources are freed and new tasks can be scheduled.
This verifies that task completion releases committed resources back to the worker.
"""
# Worker with 4 CPUs
worker_id = register_worker(state, "w1", "host:8080", make_worker_metadata(cpu=4))
# First job uses 3 CPUs
tasks1 = submit_job(state, "j1", make_job_request(cpu=3))
dispatch_task(state, tasks1[0], worker_id)
# Second job needs 3 CPUs - cannot fit while first is running
submit_job(state, "j2", make_job_request(cpu=3))
scheduler = Scheduler()
# Verify second task cannot be scheduled yet
context = _build_scheduling_context(scheduler, state)
result = scheduler.find_assignments(context)
assert len(result.assignments) == 0
# Complete the first task
transition_task(state, tasks1[0].task_id, job_pb2.TASK_STATE_SUCCEEDED)
# Now the second task can be scheduled
context = _build_scheduling_context(scheduler, state)
result = scheduler.find_assignments(context)
assert len(result.assignments) == 1
assert result.assignments[0][0].parent == JobName.root("test-user", "j2")
def test_multiple_small_tasks_fill_worker_capacity(state):
"""E2E: Multiple small tasks can fill a worker's capacity, blocking further assignments.
This verifies that the scheduler correctly tracks cumulative resource usage across
multiple running tasks. With round-robin scheduling, each worker gets at most one
task per cycle, so we run multiple cycles to fill capacity.
"""
# Worker with 4 CPUs
register_worker(state, "w1", "host:8080", make_worker_metadata(cpu=4))
# Submit 3 jobs, each using 2 CPUs
for i in range(3):
submit_job(state, f"j{i}", make_job_request(cpu=2))
scheduler = Scheduler()
# First scheduling cycle: 1 task assigned (round-robin: 1 per worker per cycle)
context = _build_scheduling_context(scheduler, state)
result = scheduler.find_assignments(context)
assert len(result.assignments) == 1
for task_id, worker_id in result.assignments:
task = _query_task(state, task_id)
dispatch_task(state, task, worker_id)