Skip to content

Commit 801ed1e

Browse files
rapsealkclaude
andcommitted
fix(agent): Ensure shared aiodocker client is closed on shutdown and boot failure
Addresses review feedback on PR #11226: - Wrap shutdown() so self.docker.close() runs in a finally block, preventing an aiohttp.ClientSession leak if super().shutdown() raises. - Release the shared Docker client if __ainit__ raises after the session is constructed. - Add regression tests asserting the session is closed both on successful shutdown and when inner shutdown raises. Refs #11218 Refs #11226 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 27f9c36 commit 801ed1e

2 files changed

Lines changed: 166 additions & 92 deletions

File tree

src/ai/backend/agent/docker/agent.py

Lines changed: 101 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,106 +1481,115 @@ def __init__(
14811481

14821482
async def __ainit__(self) -> None:
14831483
self.docker = Docker()
1484-
docker = self.docker
1485-
docker_host = ""
1486-
match docker.connector:
1487-
case aiohttp.TCPConnector():
1488-
if docker.docker_host is None:
1489-
raise InvalidArgumentError("docker_host is not set for TCP connector")
1490-
docker_host = docker.docker_host
1491-
case aiohttp.NamedPipeConnector() | aiohttp.UnixConnector() as connector:
1492-
docker_host = connector.path
1493-
case _:
1494-
docker_host = "(unknown)"
1495-
log.info("accessing the local Docker daemon via {}", docker_host)
1496-
docker_version = await docker.version()
1497-
log.info(
1498-
"running with Docker {0} with API {1}",
1499-
docker_version["Version"],
1500-
docker_version["ApiVersion"],
1501-
)
1502-
kernel_version = docker_version["KernelVersion"]
1503-
if "linuxkit" in kernel_version:
1504-
self.local_config.agent.docker_mode = "linuxkit"
1505-
else:
1506-
self.local_config.agent.docker_mode = "native"
1507-
docker_info = await docker.system.info()
1508-
docker_info = dict(docker_info)
1509-
# Assume cgroup v1 if CgroupVersion key is absent
1510-
if "CgroupVersion" not in docker_info:
1511-
docker_info["CgroupVersion"] = "1"
1512-
log.info(
1513-
"Cgroup Driver: {0}, Cgroup Version: {1}",
1514-
docker_info["CgroupDriver"],
1515-
docker_info["CgroupVersion"],
1516-
)
1517-
self.docker_info = docker_info
1518-
await self._kernel_recovery_adapter.adapt_recovery_data()
1519-
await super().__ainit__()
15201484
try:
1521-
gwbridge = await docker.networks.get("docker_gwbridge")
1522-
gwbridge_info = await gwbridge.show()
1523-
self.gwbridge_subnet = gwbridge_info["IPAM"]["Config"][0]["Subnet"]
1524-
except (DockerError, KeyError, IndexError):
1525-
self.gwbridge_subnet = None
1526-
ipc_base_path = self.local_config.agent.ipc_base_path
1527-
(ipc_base_path / "container").mkdir(parents=True, exist_ok=True)
1528-
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.id}.sock"
1529-
# Workaround for Docker Desktop for Mac's UNIX socket mount failure with virtiofs
1530-
if sys.platform != "darwin":
1531-
socket_relay_name = f"backendai-socket-relay.{self.id}"
1532-
socket_relay_container = PersistentServiceContainer(
1533-
"backendai-socket-relay:latest",
1534-
{
1535-
"Cmd": [
1536-
f"UNIX-LISTEN:/ipc/{self.agent_sockpath.name},unlink-early,fork,mode=777",
1537-
f"TCP-CONNECT:127.0.0.1:{self.local_config.agent.agent_sock_port}",
1538-
],
1539-
"HostConfig": {
1540-
"Mounts": [
1541-
{
1542-
"Type": "bind",
1543-
"Source": str(ipc_base_path / "container"),
1544-
"Target": "/ipc",
1545-
},
1485+
docker = self.docker
1486+
docker_host = ""
1487+
match docker.connector:
1488+
case aiohttp.TCPConnector():
1489+
if docker.docker_host is None:
1490+
raise InvalidArgumentError("docker_host is not set for TCP connector")
1491+
docker_host = docker.docker_host
1492+
case aiohttp.NamedPipeConnector() | aiohttp.UnixConnector() as connector:
1493+
docker_host = connector.path
1494+
case _:
1495+
docker_host = "(unknown)"
1496+
log.info("accessing the local Docker daemon via {}", docker_host)
1497+
docker_version = await docker.version()
1498+
log.info(
1499+
"running with Docker {0} with API {1}",
1500+
docker_version["Version"],
1501+
docker_version["ApiVersion"],
1502+
)
1503+
kernel_version = docker_version["KernelVersion"]
1504+
if "linuxkit" in kernel_version:
1505+
self.local_config.agent.docker_mode = "linuxkit"
1506+
else:
1507+
self.local_config.agent.docker_mode = "native"
1508+
docker_info = await docker.system.info()
1509+
docker_info = dict(docker_info)
1510+
# Assume cgroup v1 if CgroupVersion key is absent
1511+
if "CgroupVersion" not in docker_info:
1512+
docker_info["CgroupVersion"] = "1"
1513+
log.info(
1514+
"Cgroup Driver: {0}, Cgroup Version: {1}",
1515+
docker_info["CgroupDriver"],
1516+
docker_info["CgroupVersion"],
1517+
)
1518+
self.docker_info = docker_info
1519+
await self._kernel_recovery_adapter.adapt_recovery_data()
1520+
await super().__ainit__()
1521+
try:
1522+
gwbridge = await docker.networks.get("docker_gwbridge")
1523+
gwbridge_info = await gwbridge.show()
1524+
self.gwbridge_subnet = gwbridge_info["IPAM"]["Config"][0]["Subnet"]
1525+
except (DockerError, KeyError, IndexError):
1526+
self.gwbridge_subnet = None
1527+
ipc_base_path = self.local_config.agent.ipc_base_path
1528+
(ipc_base_path / "container").mkdir(parents=True, exist_ok=True)
1529+
self.agent_sockpath = ipc_base_path / "container" / f"agent.{self.id}.sock"
1530+
# Workaround for Docker Desktop for Mac's UNIX socket mount failure with virtiofs
1531+
if sys.platform != "darwin":
1532+
socket_relay_name = f"backendai-socket-relay.{self.id}"
1533+
socket_relay_container = PersistentServiceContainer(
1534+
"backendai-socket-relay:latest",
1535+
{
1536+
"Cmd": [
1537+
f"UNIX-LISTEN:/ipc/{self.agent_sockpath.name},unlink-early,fork,mode=777",
1538+
f"TCP-CONNECT:127.0.0.1:{self.local_config.agent.agent_sock_port}",
15461539
],
1547-
"NetworkMode": "host",
1540+
"HostConfig": {
1541+
"Mounts": [
1542+
{
1543+
"Type": "bind",
1544+
"Source": str(ipc_base_path / "container"),
1545+
"Target": "/ipc",
1546+
},
1547+
],
1548+
"NetworkMode": "host",
1549+
},
15481550
},
1549-
},
1550-
name=socket_relay_name,
1551-
)
1552-
await socket_relay_container.ensure_running_latest()
1553-
self.agent_sock_task = asyncio.create_task(self.handle_agent_socket())
1554-
self.monitor_docker_task = asyncio.create_task(self.monitor_docker_events())
1555-
self.docker_ptask_group = aiotools.PersistentTaskGroup()
1551+
name=socket_relay_name,
1552+
)
1553+
await socket_relay_container.ensure_running_latest()
1554+
self.agent_sock_task = asyncio.create_task(self.handle_agent_socket())
1555+
self.monitor_docker_task = asyncio.create_task(self.monitor_docker_events())
1556+
self.docker_ptask_group = aiotools.PersistentTaskGroup()
15561557

1557-
self.network_plugin_ctx = NetworkPluginContext(
1558-
self.etcd, self.local_config.model_dump(by_alias=True)
1559-
)
1560-
await self.network_plugin_ctx.init(
1561-
context=self,
1562-
allowlist=self.local_config.agent.allow_network_plugins,
1563-
blocklist=self.local_config.agent.block_network_plugins,
1564-
)
1558+
self.network_plugin_ctx = NetworkPluginContext(
1559+
self.etcd, self.local_config.model_dump(by_alias=True)
1560+
)
1561+
await self.network_plugin_ctx.init(
1562+
context=self,
1563+
allowlist=self.local_config.agent.allow_network_plugins,
1564+
blocklist=self.local_config.agent.block_network_plugins,
1565+
)
1566+
except BaseException:
1567+
# Release the shared aiodocker client if boot fails after its construction
1568+
# so the underlying aiohttp.ClientSession does not leak.
1569+
await self.docker.close()
1570+
raise
15651571

15661572
async def shutdown(self, stop_signal: signal.Signals) -> None:
1567-
# Stop handling agent sock.
1568-
if self.agent_sock_task is not None:
1569-
self.agent_sock_task.cancel()
1570-
await self.agent_sock_task
1571-
if self.docker_ptask_group is not None:
1572-
await self.docker_ptask_group.shutdown()
1573-
15741573
try:
1575-
await super().shutdown(stop_signal)
1576-
finally:
1577-
# Stop docker event monitoring.
1578-
if self.monitor_docker_task is not None:
1579-
self.monitor_docker_task.cancel()
1580-
await self.monitor_docker_task
1574+
# Stop handling agent sock.
1575+
if self.agent_sock_task is not None:
1576+
self.agent_sock_task.cancel()
1577+
await self.agent_sock_task
1578+
if self.docker_ptask_group is not None:
1579+
await self.docker_ptask_group.shutdown()
15811580

1582-
if self.docker:
1583-
await self.docker.close()
1581+
try:
1582+
await super().shutdown(stop_signal)
1583+
finally:
1584+
# Stop docker event monitoring.
1585+
if self.monitor_docker_task is not None:
1586+
self.monitor_docker_task.cancel()
1587+
await self.monitor_docker_task
1588+
finally:
1589+
# Always release the shared aiodocker client so its aiohttp.ClientSession
1590+
# does not leak even when inner shutdown steps raise.
1591+
if self.docker is not None:
1592+
await self.docker.close()
15841593

15851594
@override
15861595
async def _load_kernel_registry_from_recovery(self) -> MutableMapping[KernelId, AbstractKernel]:

tests/component/agent/docker/test_agent.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,68 @@ async def test_save_last_registry_exception(agent: DockerAgent, mocker: Any) ->
183183
)
184184
await agent.save_last_registry()
185185
assert not registry_state_path.exists()
186+
187+
188+
@pytest.fixture
189+
async def unmanaged_agent(
190+
local_config: Any, test_id: str, mocker: Any, socket_relay_image: Any
191+
) -> Any:
192+
"""
193+
Like the ``agent`` fixture, but leaves shutdown entirely to the test so it
194+
can assert against the lifecycle of the shared aiodocker client.
195+
"""
196+
dummy_etcd = DummyEtcd()
197+
mocked_etcd_get_prefix = AsyncMock(return_value={})
198+
mocker.patch.object(dummy_etcd, "get_prefix", new=mocked_etcd_get_prefix)
199+
test_case_id = secrets.token_hex(8)
200+
kernel_registry = KernelRegistry()
201+
agent = await DockerAgent.new(
202+
dummy_etcd,
203+
local_config,
204+
stats_monitor=None,
205+
error_monitor=None,
206+
skip_initial_scan=True,
207+
agent_public_key=None,
208+
kernel_registry=kernel_registry,
209+
computers={},
210+
slots={},
211+
agent_class=AgentClass.PRIMARY,
212+
)
213+
agent.local_instance_id = test_case_id
214+
yield agent
215+
# Best-effort cleanup: close the shared client if the test did not already
216+
# trigger a full shutdown. ``Docker.close`` -> ``ClientSession.close`` is
217+
# safe to call on an already-closed session.
218+
if not agent.docker.session.closed:
219+
await agent.docker.close()
220+
221+
222+
async def test_shared_docker_client_open_after_ainit(unmanaged_agent: DockerAgent) -> None:
223+
assert unmanaged_agent.docker.session.closed is False
224+
225+
226+
async def test_shared_docker_client_closed_after_shutdown(
227+
unmanaged_agent: DockerAgent,
228+
) -> None:
229+
assert unmanaged_agent.docker.session.closed is False
230+
await unmanaged_agent.shutdown(signal.SIGTERM)
231+
assert unmanaged_agent.docker.session.closed is True
232+
233+
234+
async def test_shared_docker_client_closed_when_super_shutdown_raises(
235+
unmanaged_agent: DockerAgent, mocker: Any
236+
) -> None:
237+
# Simulate the base Agent.shutdown raising mid-shutdown — the shared
238+
# aiodocker client must still be closed so its aiohttp.ClientSession does
239+
# not leak.
240+
class _SimulatedShutdownError(Exception):
241+
pass
242+
243+
mocker.patch(
244+
"ai.backend.agent.agent.AbstractAgent.shutdown",
245+
new=AsyncMock(side_effect=_SimulatedShutdownError("simulated")),
246+
)
247+
assert unmanaged_agent.docker.session.closed is False
248+
with pytest.raises(_SimulatedShutdownError):
249+
await unmanaged_agent.shutdown(signal.SIGTERM)
250+
assert unmanaged_agent.docker.session.closed is True

0 commit comments

Comments
 (0)