diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f3751206a8..2710c3894c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index d857da9635..e31db15788 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index b97932d042..4b51e1dba4 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -224,6 +224,10 @@ class CreateSessionRequest(common.BaseModel): default=None, description="A list of events to initialize the session with.", ) + display_name: Optional[str] = Field( + default=None, + description="The display name of the session.", + ) class SaveArtifactRequest(common.BaseModel): @@ -649,6 +653,7 @@ async def _create_session( user_id: str, session_id: Optional[str] = None, state: Optional[dict[str, Any]] = None, + display_name: Optional[str] = None, ) -> Session: try: session = await self.session_service.create_session( @@ -656,6 +661,7 @@ async def _create_session( user_id=user_id, state=state, session_id=session_id, + display_name=display_name, ) logger.info("New session created: %s", session.id) return session @@ -861,6 +867,7 @@ async def create_session( user_id=user_id, state=req.state, session_id=req.session_id, + display_name=req.display_name, ) if req.events: diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index dddc2c83e0..c0c7075929 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -56,6 +56,7 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, ) -> Session: """Creates a new session. @@ -65,6 +66,7 @@ async def create_session( state: the initial state of the session. session_id: the client-provided id of the session. If not provided, a generated ID will be used. + display_name: the display name of the session. Returns: session: The newly created session instance. diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index afc09b3e9d..06baec28ff 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -228,6 +228,7 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, ) -> Session: # 1. Populate states. # 2. Build storage session object @@ -284,6 +285,7 @@ async def create_session( user_id=user_id, id=session_id, state=session_state, + display_name=display_name, create_time=now, update_time=now, ) diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index d782072d6a..19d45cd967 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -58,12 +58,14 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, ) -> Session: return self._create_session_impl( app_name=app_name, user_id=user_id, state=state, session_id=session_id, + display_name=display_name, ) def create_session_sync( @@ -73,6 +75,7 @@ def create_session_sync( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, ) -> Session: logger.warning('Deprecated. Please migrate to the async method.') return self._create_session_impl( @@ -80,6 +83,7 @@ def create_session_sync( user_id=user_id, state=state, session_id=session_id, + display_name=display_name, ) def _create_session_impl( @@ -89,6 +93,7 @@ def _create_session_impl( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, ) -> Session: if session_id and self._get_session_impl( app_name=app_name, user_id=user_id, session_id=session_id @@ -116,6 +121,7 @@ def _create_session_impl( id=session_id, state=session_state or {}, last_update_time=time.time(), + display_name=display_name, ) if app_name not in self.sessions: diff --git a/src/google/adk/sessions/schemas/v0.py b/src/google/adk/sessions/schemas/v0.py index 8603ed8197..88007aeeb4 100644 --- a/src/google/adk/sessions/schemas/v0.py +++ b/src/google/adk/sessions/schemas/v0.py @@ -110,6 +110,9 @@ class StorageSession(Base): primary_key=True, default=lambda: str(uuid.uuid4()), ) + display_name: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) state: Mapped[MutableDict[str, Any]] = mapped_column( MutableDict.as_mutable(DynamicJSON), default={} @@ -157,6 +160,7 @@ def to_session( id=self.id, state=state, events=events, + display_name=self.display_name, last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite), ) diff --git a/src/google/adk/sessions/schemas/v1.py b/src/google/adk/sessions/schemas/v1.py index a7c429b242..ad6e87cffb 100644 --- a/src/google/adk/sessions/schemas/v1.py +++ b/src/google/adk/sessions/schemas/v1.py @@ -83,6 +83,9 @@ class StorageSession(Base): primary_key=True, default=lambda: str(uuid.uuid4()), ) + display_name: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) state: Mapped[MutableDict[str, Any]] = mapped_column( MutableDict.as_mutable(DynamicJSON), default={} @@ -132,6 +135,7 @@ def to_session( id=self.id, state=state, events=events, + display_name=self.display_name, last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite), ) diff --git a/src/google/adk/sessions/session.py b/src/google/adk/sessions/session.py index 89af1d4167..6473ae9956 100644 --- a/src/google/adk/sessions/session.py +++ b/src/google/adk/sessions/session.py @@ -15,6 +15,7 @@ from __future__ import annotations from typing import Any +from typing import Optional from pydantic import alias_generators from pydantic import BaseModel @@ -48,3 +49,5 @@ class Session(BaseModel): call/response, etc.""" last_update_time: float = 0.0 """The last update time of the session.""" + display_name: Optional[str] = None + """The display name of the session.""" diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index d23c8278cf..a006049c4e 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -68,6 +68,7 @@ state TEXT NOT NULL, create_time REAL NOT NULL, update_time REAL NOT NULL, + display_name TEXT, PRIMARY KEY (app_name, user_id, id) ); """ @@ -161,6 +162,7 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, ) -> Session: if session_id: session_id = session_id.strip() @@ -200,8 +202,8 @@ async def create_session( # Store the session await db.execute( """ - INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time) - VALUES (?, ?, ?, ?, ?, ?) + INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time, display_name) + VALUES (?, ?, ?, ?, ?, ?, ?) """, ( app_name, @@ -210,6 +212,7 @@ async def create_session( json.dumps(session_state), now, now, + display_name, ), ) await db.commit() @@ -225,6 +228,7 @@ async def create_session( state=merged_state, events=[], last_update_time=now, + display_name=display_name, ) @override @@ -238,8 +242,8 @@ async def get_session( ) -> Optional[Session]: async with self._get_db_connection() as db: async with db.execute( - "SELECT state, update_time FROM sessions WHERE app_name=? AND" - " user_id=? AND id=?", + "SELECT state, update_time, display_name FROM sessions WHERE" + " app_name=? AND user_id=? AND id=?", (app_name, user_id, session_id), ) as cursor: session_row = await cursor.fetchone() @@ -247,6 +251,7 @@ async def get_session( return None session_state = json.loads(session_row["state"]) last_update_time = session_row["update_time"] + display_name = session_row["display_name"] # Build events query query_parts = [ @@ -288,6 +293,7 @@ async def get_session( state=merged_state, events=events, last_update_time=last_update_time, + display_name=display_name, ) @override @@ -299,14 +305,14 @@ async def list_sessions( # Fetch sessions if user_id: session_rows = await db.execute_fetchall( - "SELECT id, user_id, state, update_time FROM sessions WHERE" - " app_name=? AND user_id=?", + "SELECT id, user_id, state, update_time, display_name FROM sessions" + " WHERE app_name=? AND user_id=?", (app_name, user_id), ) else: session_rows = await db.execute_fetchall( - "SELECT id, user_id, state, update_time FROM sessions WHERE" - " app_name=?", + "SELECT id, user_id, state, update_time, display_name FROM sessions" + " WHERE app_name=?", (app_name,), ) @@ -341,6 +347,7 @@ async def list_sessions( state=merged_state, events=[], last_update_time=row["update_time"], + display_name=row["display_name"], ) ) return ListSessionsResponse(sessions=sessions_list) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index f5cafee467..949a58dc96 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -84,6 +84,7 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, **kwargs: Any, ) -> Session: """Creates a new session. @@ -93,6 +94,7 @@ async def create_session( user_id: The ID of the user. state: The initial state of the session. session_id: The ID of the session. + display_name: An optional display name for the session. **kwargs: Additional arguments to pass to the session creation. E.g. set expire_time='2025-10-01T00:00:00Z' to set the session expiration time. See https://cloud.google.com/vertex-ai/generative-ai/docs/reference/rest/v1beta1/projects.locations.reasoningEngines.sessions @@ -110,6 +112,8 @@ async def create_session( reasoning_engine_id = self._get_reasoning_engine_id(app_name) config = {'session_state': state} if state else {} + if display_name is not None: + config['display_name'] = display_name config.update(kwargs) async with self._get_api_client() as api_client: api_response = await api_client.agent_engines.sessions.create( @@ -127,6 +131,7 @@ async def create_session( id=session_id, state=getattr(get_session_response, 'session_state', None) or {}, last_update_time=get_session_response.update_time.timestamp(), + display_name=getattr(get_session_response, 'display_name', None), ) return session @@ -185,10 +190,6 @@ async def get_session( state=getattr(get_session_response, 'session_state', None) or {}, last_update_time=update_timestamp, ) - # Preserve the entire event stream that Vertex returns rather than trying - # to discard events written milliseconds after the session resource was - # updated. Clock skew between those writes can otherwise drop tool_result - # events and permanently break the replayed conversation. async for event in events_iterator: session.events.append(_from_api_event(event)) @@ -223,6 +224,7 @@ async def list_sessions( id=api_session.name.split('/')[-1], state=getattr(api_session, 'session_state', None) or {}, last_update_time=api_session.update_time.timestamp(), + display_name=getattr(api_session, 'display_name', None), ) ) diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 30eb15678b..b51af700ac 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -121,6 +121,7 @@ async def test_create_get_session(session_service): assert session.user_id == user_id assert session.id assert session.state == state + assert session.display_name is None assert ( session.last_update_time <= datetime.now().astimezone(timezone.utc).timestamp() @@ -642,3 +643,41 @@ async def test_partial_events_are_not_persisted(session_service): app_name=app_name, user_id=user_id, session_id=session.id ) assert len(session_got.events) == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'service_type', + [ + SessionServiceType.IN_MEMORY, + SessionServiceType.DATABASE, + SessionServiceType.SQLITE, + ], +) +async def test_create_session_with_display_name(service_type, tmp_path): + """Test that display_name is properly stored and retrieved.""" + session_service = get_session_service(service_type, tmp_path) + app_name = 'my_app' + user_id = 'test_user' + display_name = 'My Test Session' + + # Create a session with a display_name + session = await session_service.create_session( + app_name=app_name, + user_id=user_id, + display_name=display_name, + ) + assert session.display_name == display_name + + # Verify display_name is persisted when fetching the session + got_session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert got_session.display_name == display_name + + # Verify display_name appears in list_sessions + list_response = await session_service.list_sessions( + app_name=app_name, user_id=user_id + ) + assert len(list_response.sessions) == 1 + assert list_response.sessions[0].display_name == display_name diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 12a11a93a2..380d8547b8 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -781,6 +781,22 @@ async def test_create_session_with_custom_config(mock_api_client_instance): ) +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_create_session_with_display_name(mock_api_client_instance): + """Test that display_name parameter is passed through to the API.""" + session_service = mock_vertex_ai_session_service() + + display_name = 'Display Name' + await session_service.create_session( + app_name='123', user_id='user', display_name=display_name + ) + assert ( + mock_api_client_instance.last_create_session_config['display_name'] + == display_name + ) + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_append_event():