-
Notifications
You must be signed in to change notification settings - Fork 110
Expand file tree
/
Copy pathtest_direct_controller.py
More file actions
363 lines (281 loc) · 12 KB
/
test_direct_controller.py
File metadata and controls
363 lines (281 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0
"""Tests for KubernetesProvider integration with controller and transitions."""
from iris.cluster.controller.transitions import (
DirectProviderBatch,
DirectProviderSyncResult,
TaskUpdate,
)
from iris.cluster.types import JobName
from iris.rpc import cluster_pb2, logging_pb2
from iris.time_utils import Timestamp
from .conftest import (
make_direct_job_request,
query_attempt,
query_task,
submit_direct_job,
)
class FakeDirectProvider:
"""Minimal KubernetesProvider-like implementation for testing."""
def __init__(self):
self.sync_calls: list[DirectProviderBatch] = []
self.sync_result = DirectProviderSyncResult()
self.closed = False
def sync(self, batch: DirectProviderBatch) -> DirectProviderSyncResult:
self.sync_calls.append(batch)
return self.sync_result
def fetch_live_logs(
self,
task_id: str,
attempt_id: int,
cursor: int,
max_lines: int,
) -> tuple[list[logging_pb2.LogEntry], int]:
return [], cursor
def close(self) -> None:
self.closed = True
# =============================================================================
# Transition-level tests: drain_for_direct_provider
# =============================================================================
def test_drain_pending_creates_attempt_rows(state):
"""Pending tasks are promoted to ASSIGNED with NULL worker_id and an attempt row is created."""
[task_id] = submit_direct_job(state, "drain-pending")
task_before = query_task(state, task_id)
assert task_before.state == cluster_pb2.TASK_STATE_PENDING
batch = state.drain_for_direct_provider()
assert len(batch.tasks_to_run) == 1
assert batch.tasks_to_run[0].task_id == task_id.to_wire()
assert batch.tasks_to_run[0].attempt_id == 0
task_after = query_task(state, task_id)
assert task_after.state == cluster_pb2.TASK_STATE_ASSIGNED
assert task_after.current_attempt_id == 0
attempt = query_attempt(state, task_id, 0)
assert attempt is not None
assert attempt.worker_id is None
def test_drain_skips_already_assigned(state):
"""Already ASSIGNED tasks appear in running_tasks, not tasks_to_run."""
[task_id] = submit_direct_job(state, "drain-skip")
# First drain promotes to ASSIGNED.
batch1 = state.drain_for_direct_provider()
assert len(batch1.tasks_to_run) == 1
assert len(batch1.running_tasks) == 0
# Second drain: task is already ASSIGNED, so appears only in running_tasks.
batch2 = state.drain_for_direct_provider()
assert len(batch2.tasks_to_run) == 0
assert len(batch2.running_tasks) == 1
assert batch2.running_tasks[0].task_id == task_id
def test_drain_kill_queue(state):
"""Kill requests buffered via buffer_direct_kill appear in tasks_to_kill."""
[task_id] = submit_direct_job(state, "drain-kill")
# Promote to ASSIGNED first.
state.drain_for_direct_provider()
state.buffer_direct_kill(task_id.to_wire())
batch = state.drain_for_direct_provider()
assert task_id.to_wire() in batch.tasks_to_kill
# =============================================================================
# Transition-level tests: apply_direct_provider_updates
# =============================================================================
def test_apply_running(state):
"""ASSIGNED -> RUNNING via direct provider update."""
[task_id] = submit_direct_job(state, "apply-running")
batch = state.drain_for_direct_provider()
attempt_id = batch.tasks_to_run[0].attempt_id
result = state.apply_direct_provider_updates(
[
TaskUpdate(task_id=task_id, attempt_id=attempt_id, new_state=cluster_pb2.TASK_STATE_RUNNING),
]
)
task = query_task(state, task_id)
assert task.state == cluster_pb2.TASK_STATE_RUNNING
assert not result.tasks_to_kill
def test_apply_succeeded(state):
"""RUNNING -> SUCCEEDED via direct provider update."""
[task_id] = submit_direct_job(state, "apply-succeeded")
batch = state.drain_for_direct_provider()
attempt_id = batch.tasks_to_run[0].attempt_id
# First move to RUNNING.
state.apply_direct_provider_updates(
[
TaskUpdate(task_id=task_id, attempt_id=attempt_id, new_state=cluster_pb2.TASK_STATE_RUNNING),
]
)
# Then to SUCCEEDED.
state.apply_direct_provider_updates(
[
TaskUpdate(task_id=task_id, attempt_id=attempt_id, new_state=cluster_pb2.TASK_STATE_SUCCEEDED),
]
)
task = query_task(state, task_id)
assert task.state == cluster_pb2.TASK_STATE_SUCCEEDED
assert task.exit_code == 0
def test_apply_failed_with_retry(state):
"""FAILED with retries remaining returns task to PENDING."""
jid = JobName.root("test-user", "retry-job")
req = make_direct_job_request("retry-job")
req.max_retries_failure = 2
state.submit_job(jid, req, Timestamp.now())
with state._db.snapshot() as q:
from iris.cluster.controller.db import TASKS
tasks = q.select(TASKS, where=TASKS.c.job_id == jid.to_wire())
task_id = tasks[0].task_id
batch = state.drain_for_direct_provider()
attempt_id = batch.tasks_to_run[0].attempt_id
state.apply_direct_provider_updates(
[
TaskUpdate(task_id=task_id, attempt_id=attempt_id, new_state=cluster_pb2.TASK_STATE_RUNNING),
]
)
state.apply_direct_provider_updates(
[
TaskUpdate(task_id=task_id, attempt_id=attempt_id, new_state=cluster_pb2.TASK_STATE_FAILED, error="boom"),
]
)
task = query_task(state, task_id)
# Should be back to PENDING because failure_count(1) <= max_retries_failure(2).
assert task.state == cluster_pb2.TASK_STATE_PENDING
assert task.failure_count == 1
def test_apply_failed_no_retry(state):
"""FAILED with no retries remaining stays terminal."""
jid = JobName.root("test-user", "no-retry-job")
req = make_direct_job_request("no-retry-job")
req.max_retries_failure = 0
state.submit_job(jid, req, Timestamp.now())
with state._db.snapshot() as q:
from iris.cluster.controller.db import TASKS
tasks = q.select(TASKS, where=TASKS.c.job_id == jid.to_wire())
task_id = tasks[0].task_id
batch = state.drain_for_direct_provider()
attempt_id = batch.tasks_to_run[0].attempt_id
state.apply_direct_provider_updates(
[
TaskUpdate(task_id=task_id, attempt_id=attempt_id, new_state=cluster_pb2.TASK_STATE_RUNNING),
]
)
state.apply_direct_provider_updates(
[
TaskUpdate(task_id=task_id, attempt_id=attempt_id, new_state=cluster_pb2.TASK_STATE_FAILED, error="fatal"),
]
)
task = query_task(state, task_id)
assert task.state == cluster_pb2.TASK_STATE_FAILED
assert task.failure_count == 1
def test_apply_failed_directly_from_assigned(state):
"""ASSIGNED -> FAILED without going through RUNNING (e.g. ConfigMap too large)."""
[task_id] = submit_direct_job(state, "fail-on-apply")
batch = state.drain_for_direct_provider()
attempt_id = batch.tasks_to_run[0].attempt_id
state.apply_direct_provider_updates(
[
TaskUpdate(
task_id=task_id,
attempt_id=attempt_id,
new_state=cluster_pb2.TASK_STATE_FAILED,
error="kubectl apply failed: RequestEntityTooLarge",
),
]
)
task = query_task(state, task_id)
assert task.state == cluster_pb2.TASK_STATE_FAILED
assert task.error == "kubectl apply failed: RequestEntityTooLarge"
def test_apply_worker_failed_from_running_retries(state):
"""WORKER_FAILED from RUNNING with retries remaining returns to PENDING."""
jid = JobName.root("test-user", "wf-retry")
req = make_direct_job_request("wf-retry")
req.max_retries_preemption = 5
state.submit_job(jid, req, Timestamp.now())
with state._db.snapshot() as q:
from iris.cluster.controller.db import TASKS
tasks = q.select(TASKS, where=TASKS.c.job_id == jid.to_wire())
task_id = tasks[0].task_id
batch = state.drain_for_direct_provider()
attempt_id = batch.tasks_to_run[0].attempt_id
state.apply_direct_provider_updates(
[
TaskUpdate(task_id=task_id, attempt_id=attempt_id, new_state=cluster_pb2.TASK_STATE_RUNNING),
]
)
state.apply_direct_provider_updates(
[
TaskUpdate(task_id=task_id, attempt_id=attempt_id, new_state=cluster_pb2.TASK_STATE_WORKER_FAILED),
]
)
task = query_task(state, task_id)
assert task.state == cluster_pb2.TASK_STATE_PENDING
assert task.preemption_count == 1
def test_apply_worker_failed_from_assigned(state):
"""WORKER_FAILED from ASSIGNED returns to PENDING without incrementing preemption_count."""
[task_id] = submit_direct_job(state, "wf-assigned")
batch = state.drain_for_direct_provider()
attempt_id = batch.tasks_to_run[0].attempt_id
# Task is ASSIGNED after drain (not yet RUNNING).
state.apply_direct_provider_updates(
[
TaskUpdate(task_id=task_id, attempt_id=attempt_id, new_state=cluster_pb2.TASK_STATE_WORKER_FAILED),
]
)
task = query_task(state, task_id)
assert task.state == cluster_pb2.TASK_STATE_PENDING
assert task.preemption_count == 0
def test_buffer_direct_kill(state):
"""buffer_direct_kill inserts a kill entry with NULL worker_id."""
state.buffer_direct_kill("some-task-id")
rows = state._db.fetchall(
"SELECT worker_id, kind, task_id FROM dispatch_queue WHERE worker_id IS NULL",
(),
)
assert len(rows) == 1
assert rows[0]["kind"] == "kill"
assert rows[0]["task_id"] == "some-task-id"
assert rows[0]["worker_id"] is None
# =============================================================================
# Controller-level tests
# =============================================================================
def test_drain_multiple_tasks(state):
"""Multiple pending tasks are all promoted in a single drain call."""
task_ids = submit_direct_job(state, "multi-task", replicas=3)
assert len(task_ids) == 3
batch = state.drain_for_direct_provider()
assert len(batch.tasks_to_run) == 3
promoted_ids = {req.task_id for req in batch.tasks_to_run}
expected_ids = {tid.to_wire() for tid in task_ids}
assert promoted_ids == expected_ids
def test_apply_ignores_stale_attempt(state):
"""Updates with a mismatched attempt_id are silently skipped."""
[task_id] = submit_direct_job(state, "stale-attempt")
batch = state.drain_for_direct_provider()
attempt_id = batch.tasks_to_run[0].attempt_id
# Apply with wrong attempt_id.
result = state.apply_direct_provider_updates(
[
TaskUpdate(task_id=task_id, attempt_id=attempt_id + 99, new_state=cluster_pb2.TASK_STATE_RUNNING),
]
)
task = query_task(state, task_id)
# Should still be ASSIGNED (the update was skipped).
assert task.state == cluster_pb2.TASK_STATE_ASSIGNED
assert not result.tasks_to_kill
def test_apply_ignores_finished_task(state):
"""Updates to already-finished tasks are silently skipped."""
[task_id] = submit_direct_job(state, "finished-task")
batch = state.drain_for_direct_provider()
attempt_id = batch.tasks_to_run[0].attempt_id
# Move to SUCCEEDED.
state.apply_direct_provider_updates(
[
TaskUpdate(task_id=task_id, attempt_id=attempt_id, new_state=cluster_pb2.TASK_STATE_RUNNING),
]
)
state.apply_direct_provider_updates(
[
TaskUpdate(task_id=task_id, attempt_id=attempt_id, new_state=cluster_pb2.TASK_STATE_SUCCEEDED),
]
)
# Try to move to FAILED after already succeeded.
result = state.apply_direct_provider_updates(
[
TaskUpdate(task_id=task_id, attempt_id=attempt_id, new_state=cluster_pb2.TASK_STATE_FAILED),
]
)
task = query_task(state, task_id)
assert task.state == cluster_pb2.TASK_STATE_SUCCEEDED
assert not result.tasks_to_kill