Skip to content

Commit 249216e

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add Graceful Plugin Shutdown to Runner
This change introduces a shutdown lifecycle hook for plugins. The `PluginManager` now has an `async def shutdown()` method that will call `await plugin.shutdown()` on any registered plugins that implement the method. This is called from `Runner.close()`, allowing plugins to perform cleanup tasks like flushing logs or closing connections when the runner instance is being closed. This improves the reliability of plugins that perform background operations. PiperOrigin-RevId: 831037737
1 parent 01bac62 commit 249216e

File tree

7 files changed

+152
-7
lines changed

7 files changed

+152
-7
lines changed

src/google/adk/plugins/base_plugin.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,14 @@ async def after_run_callback(
187187
"""
188188
pass
189189

190+
async def close(self) -> None:
191+
"""Method executed when the runner is closed.
192+
193+
This method is used for cleanup tasks such as closing network connections
194+
or releasing resources.
195+
"""
196+
pass
197+
190198
async def before_agent_callback(
191199
self, *, agent: BaseAgent, callback_context: CallbackContext
192200
) -> Optional[types.Content]:

src/google/adk/plugins/bigquery_agent_analytics_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ async def _log(self, data: dict):
455455
self._background_tasks.add(task)
456456
task.add_done_callback(self._background_tasks.discard)
457457

458-
async def shutdown(self):
458+
async def close(self):
459459
"""Flushes pending logs and closes client."""
460460
# 1. Wait for pending background logs (best effort, 2s timeout)
461461
if self._background_tasks:

src/google/adk/plugins/plugin_manager.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
import logging
19+
import sys
1820
from typing import Any
1921
from typing import List
2022
from typing import Literal
@@ -70,13 +72,19 @@ class PluginManager:
7072
tool calls, or model requests.
7173
"""
7274

73-
def __init__(self, plugins: Optional[List[BasePlugin]] = None):
75+
def __init__(
76+
self,
77+
plugins: Optional[List[BasePlugin]] = None,
78+
close_timeout: float = 5.0,
79+
):
7480
"""Initializes the plugin service.
7581
7682
Args:
7783
plugins: An optional list of plugins to register upon initialization.
84+
close_timeout: The timeout in seconds for each plugin's close method.
7885
"""
7986
self.plugins: List[BasePlugin] = []
87+
self._close_timeout = close_timeout
8088
if plugins:
8189
for plugin in plugins:
8290
self.register_plugin(plugin)
@@ -297,3 +305,43 @@ async def _run_callbacks(
297305
raise RuntimeError(error_message) from e
298306

299307
return None
308+
309+
async def close(self) -> None:
310+
"""Calls the close method on all registered plugins concurrently.
311+
312+
Raises:
313+
RuntimeError: If one or more plugins failed to close, containing
314+
details of all failures.
315+
"""
316+
exceptions = {}
317+
# We iterate sequentially to avoid creating new tasks which can cause issues
318+
# with some libraries (like anyio/mcp) that rely on task-local context.
319+
for plugin in self.plugins:
320+
try:
321+
if sys.version_info >= (3, 11):
322+
async with asyncio.timeout(self._close_timeout):
323+
await plugin.close()
324+
else:
325+
# For Python < 3.11, we use wait_for which creates a new task.
326+
# This might still cause issues with task-local contexts, but
327+
# asyncio.timeout is not available.
328+
await asyncio.wait_for(plugin.close(), timeout=self._close_timeout)
329+
except Exception as e:
330+
exceptions[plugin.name] = e
331+
if isinstance(e, (asyncio.TimeoutError, asyncio.CancelledError)):
332+
logger.warning(
333+
"Timeout/Cancelled while closing plugin: %s", plugin.name
334+
)
335+
else:
336+
logger.error(
337+
"Error during close of plugin %s: %s",
338+
plugin.name,
339+
e,
340+
exc_info=e,
341+
)
342+
343+
if exceptions:
344+
error_summary = ", ".join(
345+
f"'{name}': {type(exc).__name__}" for name, exc in exceptions.items()
346+
)
347+
raise RuntimeError(f"Failed to close plugins: {error_summary}")

src/google/adk/runners.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
session_service: BaseSessionService,
116116
memory_service: Optional[BaseMemoryService] = None,
117117
credential_service: Optional[BaseCredentialService] = None,
118+
plugin_close_timeout: float = 5.0,
118119
):
119120
"""Initializes the Runner.
120121
@@ -134,6 +135,7 @@ def __init__(
134135
session_service: The session service for the runner.
135136
memory_service: The memory service for the runner.
136137
credential_service: The credential service for the runner.
138+
plugin_close_timeout: The timeout in seconds for plugin close methods.
137139
138140
Raises:
139141
ValueError: If `app` is provided along with `app_name` or `plugins`, or
@@ -151,7 +153,9 @@ def __init__(
151153
self.session_service = session_service
152154
self.memory_service = memory_service
153155
self.credential_service = credential_service
154-
self.plugin_manager = PluginManager(plugins=plugins)
156+
self.plugin_manager = PluginManager(
157+
plugins=plugins, close_timeout=plugin_close_timeout
158+
)
155159
(
156160
self._agent_origin_app_name,
157161
self._agent_origin_dir,
@@ -1297,8 +1301,16 @@ async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]):
12971301

12981302
async def close(self):
12991303
"""Closes the runner."""
1304+
logger.info('Closing runner...')
1305+
# Close Toolsets
13001306
await self._cleanup_toolsets(self._collect_toolset(self.agent))
13011307

1308+
# Close Plugins
1309+
if self.plugin_manager:
1310+
await self.plugin_manager.close()
1311+
1312+
logger.info('Runner closed.')
1313+
13021314
async def __aenter__(self):
13031315
"""Async context manager entry."""
13041316
return self
@@ -1329,13 +1341,17 @@ def __init__(
13291341
app_name: Optional[str] = None,
13301342
plugins: Optional[list[BasePlugin]] = None,
13311343
app: Optional[App] = None,
1344+
plugin_close_timeout: float = 5.0,
13321345
):
13331346
"""Initializes the InMemoryRunner.
13341347
13351348
Args:
13361349
agent: The root agent to run.
13371350
app_name: The application name of the runner. Defaults to
13381351
'InMemoryRunner'.
1352+
plugins: Optional list of plugins for the runner.
1353+
app: Optional App instance.
1354+
plugin_close_timeout: The timeout in seconds for plugin close methods.
13391355
"""
13401356
if app is None and app_name is None:
13411357
app_name = 'InMemoryRunner'
@@ -1347,4 +1363,5 @@ def __init__(
13471363
app=app,
13481364
session_service=InMemorySessionService(),
13491365
memory_service=InMemoryMemoryService(),
1366+
plugin_close_timeout=plugin_close_timeout,
13501367
)

tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -513,10 +513,8 @@ async def fake_append_rows_with_error(requests, **kwargs):
513513
mock_write_client.append_rows.assert_called_once()
514514

515515
@pytest.mark.asyncio
516-
async def test_shutdown(
517-
self, bq_plugin_inst, mock_bq_client, mock_write_client
518-
):
519-
await bq_plugin_inst.shutdown()
516+
async def test_close(self, bq_plugin_inst, mock_bq_client, mock_write_client):
517+
await bq_plugin_inst.close()
520518
mock_write_client.transport.close.assert_called_once()
521519
mock_bq_client.close.assert_called_once()
522520

tests/unittests/plugins/test_plugin_manager.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from __future__ import annotations
1818

19+
import asyncio
20+
from unittest.mock import AsyncMock
1921
from unittest.mock import Mock
2022

2123
from google.adk.models.llm_response import LlmResponse
@@ -267,3 +269,51 @@ async def test_all_callbacks_are_supported(
267269
"on_model_error_callback",
268270
]
269271
assert set(plugin1.call_log) == set(expected_callbacks)
272+
273+
274+
@pytest.mark.asyncio
275+
async def test_close_calls_plugin_close(
276+
service: PluginManager, plugin1: TestPlugin
277+
):
278+
"""Tests that close calls the close method on registered plugins."""
279+
plugin1.close = AsyncMock()
280+
service.register_plugin(plugin1)
281+
282+
await service.close()
283+
284+
plugin1.close.assert_awaited_once()
285+
286+
287+
@pytest.mark.asyncio
288+
async def test_close_raises_runtime_error_on_plugin_exception(
289+
service: PluginManager, plugin1: TestPlugin
290+
):
291+
"""Tests that close raises a RuntimeError if a plugin's close fails."""
292+
plugin1.close = AsyncMock(side_effect=ValueError("Shutdown error"))
293+
service.register_plugin(plugin1)
294+
295+
with pytest.raises(
296+
RuntimeError, match="Failed to close plugins: 'plugin1': ValueError"
297+
):
298+
await service.close()
299+
300+
plugin1.close.assert_awaited_once()
301+
302+
303+
@pytest.mark.asyncio
304+
async def test_close_with_timeout(plugin1: TestPlugin):
305+
"""Tests that close respects the timeout and raises on failure."""
306+
service = PluginManager(close_timeout=0.1)
307+
308+
async def slow_close():
309+
await asyncio.sleep(0.2)
310+
311+
plugin1.close = slow_close
312+
service.register_plugin(plugin1)
313+
314+
with pytest.raises(RuntimeError) as excinfo:
315+
await service.close()
316+
317+
assert "Failed to close plugins: 'plugin1': TimeoutError" in str(
318+
excinfo.value
319+
)

tests/unittests/test_runners.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pathlib import Path
1616
import textwrap
1717
from typing import Optional
18+
from unittest.mock import AsyncMock
1819

1920
from google.adk.agents.base_agent import BaseAgent
2021
from google.adk.agents.context_cache_config import ContextCacheConfig
@@ -562,6 +563,29 @@ async def test_runner_modifies_event_after_execution(self):
562563

563564
assert modified_event_message == MockPlugin.ON_EVENT_CALLBACK_MSG
564565

566+
@pytest.mark.asyncio
567+
async def test_runner_close_calls_plugin_close(self):
568+
"""Test that runner.close() calls plugin manager close."""
569+
# Mock the plugin manager's close method
570+
self.runner.plugin_manager.close = AsyncMock()
571+
572+
await self.runner.close()
573+
574+
self.runner.plugin_manager.close.assert_awaited_once()
575+
576+
@pytest.mark.asyncio
577+
async def test_runner_passes_plugin_close_timeout(self):
578+
"""Test that runner passes plugin_close_timeout to PluginManager."""
579+
runner = Runner(
580+
app_name="test_app",
581+
agent=MockLlmAgent("test_agent"),
582+
session_service=self.session_service,
583+
artifact_service=self.artifact_service,
584+
plugins=[self.plugin],
585+
plugin_close_timeout=10.0,
586+
)
587+
assert runner.plugin_manager._close_timeout == 10.0
588+
565589
def test_runner_init_raises_error_with_app_and_app_name_and_agent(self):
566590
"""Test that ValueError is raised when app, app_name and agent are provided."""
567591
with pytest.raises(

0 commit comments

Comments
 (0)