Skip to content

Commit 6fb4676

Browse files
Add drain and resume mutation for build-id rollout state
Operators could inspect per-build-id rollout state via Client.list_task_queue_build_ids, but the async and sync clients had no way to cut traffic off a build before deleting its workers. Without that mutation, rollouts had to stop each worker process by hand, which is easy to mis-sequence when rolling back a bad build or cutting over from unversioned to versioned workers. Add two coroutine methods on the async Client plus their sync mirrors that wrap the new server endpoints: - drain_task_queue_build_id(task_queue, build_id) POSTs to /task-queues/<tq>/build-ids/drain with {"build_id": ...}. - resume_task_queue_build_id(task_queue, build_id) POSTs to /task-queues/<tq>/build-ids/resume with {"build_id": ...}. Passing None as build_id targets the unversioned cohort (pre-rollout default). Both calls are idempotent: repeated drains do not shift the recorded drained_at timestamp, and repeated resumes leave the cohort active. Both methods return a TaskQueueBuildIdRolloutState dataclass carrying the server's mutation receipt (namespace, task_queue, build_id, drain_intent, drained_at) so operators can verify the outcome programmatically. The raw server payload is preserved on the dataclass for callers that need a byte-identical mirror. Two new polyglot parity fixtures cover the cross-repo contract with the CLI under tests/fixtures/control-plane/: - task-queue-build-id-drain-parity.json - task-queue-build-id-resume-parity.json The matching CLI-side fixtures and parity test will land separately so the cross-repo drift gate stays green.
1 parent 810807c commit 6fb4676

7 files changed

Lines changed: 354 additions & 0 deletions

File tree

src/durable_workflow/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
TaskQueueAdmission,
3535
TaskQueueBuildIdCohort,
3636
TaskQueueBuildIdRollout,
37+
TaskQueueBuildIdRolloutState,
3738
TaskQueueDescription,
3839
TaskQueueList,
3940
TaskQueueQueryAdmission,
@@ -181,6 +182,7 @@
181182
"TaskQueueAdmission",
182183
"TaskQueueBuildIdCohort",
183184
"TaskQueueBuildIdRollout",
185+
"TaskQueueBuildIdRolloutState",
184186
"TaskQueueDescription",
185187
"TaskQueueList",
186188
"TaskQueueQueryAdmission",

src/durable_workflow/client.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,36 @@ def from_dict(cls, data: dict[str, Any]) -> TaskQueueBuildIdRollout:
534534
)
535535

536536

537+
@dataclass
538+
class TaskQueueBuildIdRolloutState:
539+
"""Operator-recorded drain intent for one ``(task_queue, build_id)`` cohort.
540+
541+
Returned by ``drain_task_queue_build_id`` and ``resume_task_queue_build_id``.
542+
``build_id`` is ``None`` for the unversioned cohort (workers registered
543+
without a build identifier). ``drain_intent`` is ``"active"`` or
544+
``"draining"``. ``drained_at`` is set only when ``drain_intent`` is
545+
``"draining"``; repeated drains do not shift the timestamp.
546+
"""
547+
548+
namespace: str | None
549+
task_queue: str
550+
build_id: str | None
551+
drain_intent: str
552+
drained_at: str | None
553+
raw: dict[str, Any] | None = None
554+
555+
@classmethod
556+
def from_dict(cls, data: dict[str, Any]) -> TaskQueueBuildIdRolloutState:
557+
return cls(
558+
namespace=data.get("namespace"),
559+
task_queue=str(data.get("task_queue") or ""),
560+
build_id=data.get("build_id") if isinstance(data.get("build_id"), str) else None,
561+
drain_intent=str(data.get("drain_intent") or ""),
562+
drained_at=data.get("drained_at") if isinstance(data.get("drained_at"), str) else None,
563+
raw=data,
564+
)
565+
566+
537567
@dataclass
538568
class WorkerDescription:
539569
"""Current server view of one registered worker."""
@@ -1392,6 +1422,67 @@ async def list_task_queue_build_ids(self, task_queue: str) -> TaskQueueBuildIdRo
13921422
)
13931423
return TaskQueueBuildIdRollout.from_dict(data)
13941424

1425+
async def drain_task_queue_build_id(
1426+
self,
1427+
task_queue: str,
1428+
build_id: str | None,
1429+
) -> TaskQueueBuildIdRolloutState:
1430+
"""Mark a build-id cohort as draining so it stops claiming new tasks.
1431+
1432+
Workers registered under ``build_id`` keep running their in-flight
1433+
work but are blocked from claiming fresh tasks, and future workers
1434+
that heartbeat under the same ``build_id`` land as draining too.
1435+
Pass ``None`` to drain the unversioned cohort (workers registered
1436+
without a build identifier). Idempotent: repeated drains do not
1437+
shift the recorded ``drained_at`` timestamp.
1438+
"""
1439+
return await self._mutate_task_queue_build_id_rollout(
1440+
task_queue,
1441+
build_id,
1442+
action="drain",
1443+
)
1444+
1445+
async def resume_task_queue_build_id(
1446+
self,
1447+
task_queue: str,
1448+
build_id: str | None,
1449+
) -> TaskQueueBuildIdRolloutState:
1450+
"""Revert a previous drain so a build-id cohort can claim work again.
1451+
1452+
Resuming clears both ``drain_intent`` and ``drained_at`` for the
1453+
cohort and flips any still-running workers back to ``active``.
1454+
Pass ``None`` to resume the unversioned cohort. Idempotent:
1455+
resuming an already-active cohort is a no-op.
1456+
"""
1457+
return await self._mutate_task_queue_build_id_rollout(
1458+
task_queue,
1459+
build_id,
1460+
action="resume",
1461+
)
1462+
1463+
async def _mutate_task_queue_build_id_rollout(
1464+
self,
1465+
task_queue: str,
1466+
build_id: str | None,
1467+
*,
1468+
action: str,
1469+
) -> TaskQueueBuildIdRolloutState:
1470+
data = await self._request(
1471+
"POST",
1472+
f"/task-queues/{quote(task_queue, safe='')}/build-ids/{action}",
1473+
json={"build_id": build_id},
1474+
context=task_queue,
1475+
)
1476+
if not isinstance(data, dict):
1477+
raise ServerError(
1478+
200,
1479+
{
1480+
"reason": f"invalid_task_queue_build_id_{action}_response",
1481+
"message": f"expected JSON object, got {type(data).__name__}",
1482+
},
1483+
)
1484+
return TaskQueueBuildIdRolloutState.from_dict(data)
1485+
13951486
# ── Search attributes ─────────────────────────────────────────────
13961487
async def list_search_attributes(self) -> SearchAttributeList:
13971488
"""List system and custom search attribute definitions for this namespace."""

src/durable_workflow/sync.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ScheduleTriggerResult,
1919
StorageTestResult,
2020
TaskQueueBuildIdRollout,
21+
TaskQueueBuildIdRolloutState,
2122
TaskQueueDescription,
2223
TaskQueueList,
2324
WorkflowCommandResult,
@@ -291,6 +292,26 @@ def list_task_queue_build_ids(self, task_queue: str) -> TaskQueueBuildIdRollout:
291292
)
292293
return result
293294

295+
def drain_task_queue_build_id(
296+
self,
297+
task_queue: str,
298+
build_id: str | None,
299+
) -> TaskQueueBuildIdRolloutState:
300+
result: TaskQueueBuildIdRolloutState = _run(
301+
self._async.drain_task_queue_build_id(task_queue, build_id)
302+
)
303+
return result
304+
305+
def resume_task_queue_build_id(
306+
self,
307+
task_queue: str,
308+
build_id: str | None,
309+
) -> TaskQueueBuildIdRolloutState:
310+
result: TaskQueueBuildIdRolloutState = _run(
311+
self._async.resume_task_queue_build_id(task_queue, build_id)
312+
)
313+
return result
314+
294315
def list_namespaces(self) -> NamespaceList:
295316
result: NamespaceList = _run(self._async.list_namespaces())
296317
return result
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
{
2+
"schema": "durable-workflow.polyglot.control-plane-request-fixture",
3+
"version": 1,
4+
"operation": "task_queue.build_id.drain",
5+
"request": {
6+
"method": "POST",
7+
"path": "/task-queues/orders-critical/build-ids/drain",
8+
"body": {
9+
"build_id": "build-2026.04.21-z9"
10+
}
11+
},
12+
"semantic_body": {
13+
"namespace": "orders-prod",
14+
"task_queue": "orders-critical",
15+
"build_id": "build-2026.04.21-z9",
16+
"drain_intent": "draining",
17+
"drained_at": "2026-04-22T09:45:00Z"
18+
},
19+
"response_body": {
20+
"namespace": "orders-prod",
21+
"task_queue": "orders-critical",
22+
"build_id": "build-2026.04.21-z9",
23+
"drain_intent": "draining",
24+
"drained_at": "2026-04-22T09:45:00Z"
25+
},
26+
"cli": {
27+
"argv": {
28+
"task-queue": "orders-critical",
29+
"--build-id": "build-2026.04.21-z9",
30+
"--json": true
31+
}
32+
},
33+
"sdk_python": {
34+
"method": "drain_task_queue_build_id",
35+
"args": {
36+
"task_queue": "orders-critical",
37+
"build_id": "build-2026.04.21-z9"
38+
}
39+
}
40+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
{
2+
"schema": "durable-workflow.polyglot.control-plane-request-fixture",
3+
"version": 1,
4+
"operation": "task_queue.build_id.resume",
5+
"request": {
6+
"method": "POST",
7+
"path": "/task-queues/orders-critical/build-ids/resume",
8+
"body": {
9+
"build_id": "build-2026.04.21-z9"
10+
}
11+
},
12+
"semantic_body": {
13+
"namespace": "orders-prod",
14+
"task_queue": "orders-critical",
15+
"build_id": "build-2026.04.21-z9",
16+
"drain_intent": "active",
17+
"drained_at": null
18+
},
19+
"response_body": {
20+
"namespace": "orders-prod",
21+
"task_queue": "orders-critical",
22+
"build_id": "build-2026.04.21-z9",
23+
"drain_intent": "active",
24+
"drained_at": null
25+
},
26+
"cli": {
27+
"argv": {
28+
"task-queue": "orders-critical",
29+
"--build-id": "build-2026.04.21-z9",
30+
"--json": true
31+
}
32+
},
33+
"sdk_python": {
34+
"method": "resume_task_queue_build_id",
35+
"args": {
36+
"task_queue": "orders-critical",
37+
"build_id": "build-2026.04.21-z9"
38+
}
39+
}
40+
}

tests/test_client.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,116 @@ async def test_list_task_queue_build_ids_surfaces_cohort_worker_counts(
11111111
assert unversioned.last_heartbeat_at is None
11121112
assert unversioned.first_seen_at is None
11131113

1114+
@pytest.mark.asyncio
1115+
async def test_drain_task_queue_build_id_matches_polyglot_fixture(self, client: Client) -> None:
1116+
fixture_path = (
1117+
Path(__file__).parent
1118+
/ "fixtures"
1119+
/ "control-plane"
1120+
/ "task-queue-build-id-drain-parity.json"
1121+
)
1122+
fixture = json.loads(fixture_path.read_text())
1123+
assert fixture["operation"] == "task_queue.build_id.drain"
1124+
sdk = fixture["sdk_python"]
1125+
resp = _mock_response(200, fixture["response_body"])
1126+
1127+
with patch.object(
1128+
client._http, "request", new_callable=AsyncMock, return_value=resp
1129+
) as mock:
1130+
result = await client.drain_task_queue_build_id(**sdk["args"])
1131+
1132+
assert mock.call_args.args[0] == fixture["request"]["method"]
1133+
assert mock.call_args.args[1] == f"/api{fixture['request']['path']}"
1134+
body = mock.call_args.kwargs.get("json")
1135+
assert body == fixture["request"]["body"]
1136+
1137+
semantic = fixture["semantic_body"]
1138+
assert result.namespace == semantic["namespace"]
1139+
assert result.task_queue == semantic["task_queue"]
1140+
assert result.build_id == semantic["build_id"]
1141+
assert result.drain_intent == semantic["drain_intent"]
1142+
assert result.drained_at == semantic["drained_at"]
1143+
1144+
@pytest.mark.asyncio
1145+
async def test_resume_task_queue_build_id_matches_polyglot_fixture(self, client: Client) -> None:
1146+
fixture_path = (
1147+
Path(__file__).parent
1148+
/ "fixtures"
1149+
/ "control-plane"
1150+
/ "task-queue-build-id-resume-parity.json"
1151+
)
1152+
fixture = json.loads(fixture_path.read_text())
1153+
assert fixture["operation"] == "task_queue.build_id.resume"
1154+
sdk = fixture["sdk_python"]
1155+
resp = _mock_response(200, fixture["response_body"])
1156+
1157+
with patch.object(
1158+
client._http, "request", new_callable=AsyncMock, return_value=resp
1159+
) as mock:
1160+
result = await client.resume_task_queue_build_id(**sdk["args"])
1161+
1162+
assert mock.call_args.args[0] == fixture["request"]["method"]
1163+
assert mock.call_args.args[1] == f"/api{fixture['request']['path']}"
1164+
body = mock.call_args.kwargs.get("json")
1165+
assert body == fixture["request"]["body"]
1166+
1167+
semantic = fixture["semantic_body"]
1168+
assert result.namespace == semantic["namespace"]
1169+
assert result.task_queue == semantic["task_queue"]
1170+
assert result.build_id == semantic["build_id"]
1171+
assert result.drain_intent == semantic["drain_intent"]
1172+
assert result.drained_at is None
1173+
1174+
@pytest.mark.asyncio
1175+
async def test_drain_task_queue_build_id_targets_unversioned_cohort_with_null_body(
1176+
self, client: Client
1177+
) -> None:
1178+
resp = _mock_response(
1179+
200,
1180+
{
1181+
"namespace": "default",
1182+
"task_queue": "orders",
1183+
"build_id": None,
1184+
"drain_intent": "draining",
1185+
"drained_at": "2026-04-22T09:50:00Z",
1186+
},
1187+
)
1188+
1189+
with patch.object(
1190+
client._http, "request", new_callable=AsyncMock, return_value=resp
1191+
) as mock:
1192+
result = await client.drain_task_queue_build_id("orders", None)
1193+
1194+
assert mock.call_args.args[0] == "POST"
1195+
assert mock.call_args.args[1] == "/api/task-queues/orders/build-ids/drain"
1196+
assert mock.call_args.kwargs.get("json") == {"build_id": None}
1197+
assert result.build_id is None
1198+
assert result.drain_intent == "draining"
1199+
assert result.drained_at == "2026-04-22T09:50:00Z"
1200+
1201+
@pytest.mark.asyncio
1202+
async def test_resume_task_queue_build_id_clears_drained_at(self, client: Client) -> None:
1203+
resp = _mock_response(
1204+
200,
1205+
{
1206+
"namespace": "default",
1207+
"task_queue": "orders",
1208+
"build_id": "build-alpha",
1209+
"drain_intent": "active",
1210+
"drained_at": None,
1211+
},
1212+
)
1213+
1214+
with patch.object(
1215+
client._http, "request", new_callable=AsyncMock, return_value=resp
1216+
) as mock:
1217+
result = await client.resume_task_queue_build_id("orders", "build-alpha")
1218+
1219+
assert mock.call_args.args[1] == "/api/task-queues/orders/build-ids/resume"
1220+
assert mock.call_args.kwargs.get("json") == {"build_id": "build-alpha"}
1221+
assert result.drain_intent == "active"
1222+
assert result.drained_at is None
1223+
11141224
@pytest.mark.asyncio
11151225
async def test_list_task_queues_parses_admission(self, client: Client) -> None:
11161226
resp = _mock_response(200, {

0 commit comments

Comments
 (0)