Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
7 changes: 7 additions & 0 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ class CreateSessionRequest(common.BaseModel):
default=None,
description="A list of events to initialize the session with.",
)
display_name: Optional[str] = Field(
default=None,
description="The display name of the session.",
)


class AddSessionToEvalSetRequest(common.BaseModel):
Expand Down Expand Up @@ -594,13 +598,15 @@ async def _create_session(
user_id: str,
session_id: Optional[str] = None,
state: Optional[dict[str, Any]] = None,
display_name: Optional[str] = None,
) -> Session:
try:
session = await self.session_service.create_session(
app_name=app_name,
user_id=user_id,
state=state,
session_id=session_id,
display_name=display_name,
)
logger.info("New session created: %s", session.id)
return session
Expand Down Expand Up @@ -795,6 +801,7 @@ async def create_session(
user_id=user_id,
state=req.state,
session_id=req.session_id,
display_name=req.display_name,
)

if req.events:
Expand Down
2 changes: 2 additions & 0 deletions src/google/adk/sessions/base_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def create_session(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
) -> Session:
"""Creates a new session.

Expand All @@ -65,6 +66,7 @@ async def create_session(
state: the initial state of the session.
session_id: the client-provided id of the session. If not provided, a
generated ID will be used.
display_name: the display name of the session.

Returns:
session: The newly created session instance.
Expand Down
7 changes: 7 additions & 0 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ class StorageSession(Base):
PreciseTimestamp, default=func.now(), onupdate=func.now()
)

display_name: Mapped[Optional[str]] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)

storage_events: Mapped[list[StorageEvent]] = relationship(
"StorageEvent",
back_populates="storage_session",
Expand Down Expand Up @@ -215,6 +219,7 @@ def to_session(
state=state,
events=events,
last_update_time=self.update_timestamp_tz,
display_name=self.display_name,
)


Expand Down Expand Up @@ -477,6 +482,7 @@ async def create_session(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
) -> Session:
# 1. Populate states.
# 2. Build storage session object
Expand Down Expand Up @@ -526,6 +532,7 @@ async def create_session(
user_id=user_id,
id=session_id,
state=session_state,
display_name=display_name,
)
sql_session.add(storage_session)
await sql_session.commit()
Expand Down
6 changes: 6 additions & 0 deletions src/google/adk/sessions/in_memory_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,14 @@ async def create_session(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
) -> Session:
return self._create_session_impl(
app_name=app_name,
user_id=user_id,
state=state,
session_id=session_id,
display_name=display_name,
)

def create_session_sync(
Expand All @@ -73,13 +75,15 @@ def create_session_sync(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
) -> Session:
logger.warning('Deprecated. Please migrate to the async method.')
return self._create_session_impl(
app_name=app_name,
user_id=user_id,
state=state,
session_id=session_id,
display_name=display_name,
)

def _create_session_impl(
Expand All @@ -89,6 +93,7 @@ def _create_session_impl(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
) -> Session:
if session_id and self._get_session_impl(
app_name=app_name, user_id=user_id, session_id=session_id
Expand Down Expand Up @@ -116,6 +121,7 @@ def _create_session_impl(
id=session_id,
state=session_state or {},
last_update_time=time.time(),
display_name=display_name,
)

if app_name not in self.sessions:
Expand Down
3 changes: 3 additions & 0 deletions src/google/adk/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from typing import Any
from typing import Optional

from pydantic import alias_generators
from pydantic import BaseModel
Expand Down Expand Up @@ -48,3 +49,5 @@ class Session(BaseModel):
call/response, etc."""
last_update_time: float = 0.0
"""The last update time of the session."""
display_name: Optional[str] = None
"""The display name of the session."""
23 changes: 15 additions & 8 deletions src/google/adk/sessions/sqlite_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
state TEXT NOT NULL,
create_time REAL NOT NULL,
update_time REAL NOT NULL,
display_name TEXT,
PRIMARY KEY (app_name, user_id, id)
);
"""
Expand Down Expand Up @@ -121,6 +122,7 @@ async def create_session(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
) -> Session:
if session_id:
session_id = session_id.strip()
Expand Down Expand Up @@ -160,8 +162,8 @@ async def create_session(
# Store the session
await db.execute(
"""
INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time)
VALUES (?, ?, ?, ?, ?, ?)
INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time, display_name)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
app_name,
Expand All @@ -170,6 +172,7 @@ async def create_session(
json.dumps(session_state),
now,
now,
display_name,
),
)
await db.commit()
Expand All @@ -185,6 +188,7 @@ async def create_session(
state=merged_state,
events=[],
last_update_time=now,
display_name=display_name,
)

@override
Expand All @@ -198,15 +202,16 @@ async def get_session(
) -> Optional[Session]:
async with self._get_db_connection() as db:
async with db.execute(
"SELECT state, update_time FROM sessions WHERE app_name=? AND"
" user_id=? AND id=?",
"SELECT state, update_time, display_name FROM sessions WHERE"
" app_name=? AND user_id=? AND id=?",
(app_name, user_id, session_id),
) as cursor:
session_row = await cursor.fetchone()
if session_row is None:
return None
session_state = json.loads(session_row["state"])
last_update_time = session_row["update_time"]
display_name = session_row["display_name"]

# Build events query
query_parts = [
Expand Down Expand Up @@ -248,6 +253,7 @@ async def get_session(
state=merged_state,
events=events,
last_update_time=last_update_time,
display_name=display_name,
)

@override
Expand All @@ -259,14 +265,14 @@ async def list_sessions(
# Fetch sessions
if user_id:
session_rows = await db.execute_fetchall(
"SELECT id, user_id, state, update_time FROM sessions WHERE"
" app_name=? AND user_id=?",
"SELECT id, user_id, state, update_time, display_name FROM sessions"
" WHERE app_name=? AND user_id=?",
(app_name, user_id),
)
else:
session_rows = await db.execute_fetchall(
"SELECT id, user_id, state, update_time FROM sessions WHERE"
" app_name=?",
"SELECT id, user_id, state, update_time, display_name FROM sessions"
" WHERE app_name=?",
(app_name,),
)

Expand Down Expand Up @@ -301,6 +307,7 @@ async def list_sessions(
state=merged_state,
events=[],
last_update_time=row["update_time"],
display_name=row["display_name"],
)
)
return ListSessionsResponse(sessions=sessions_list)
Expand Down
3 changes: 3 additions & 0 deletions src/google/adk/sessions/vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ async def create_session(
id=session_id,
state=getattr(get_session_response, 'session_state', None) or {},
last_update_time=get_session_response.update_time.timestamp(),
display_name=getattr(get_session_response, 'display_name', None) or None,
)
return session

Expand Down Expand Up @@ -175,6 +176,7 @@ async def get_session(
id=session_id,
state=getattr(get_session_response, 'session_state', None) or {},
last_update_time=update_timestamp,
display_name=getattr(get_session_response, 'display_name', None) or None,
)
session.events += [
_from_api_event(event)
Expand Down Expand Up @@ -213,6 +215,7 @@ async def list_sessions(
id=api_session.name.split('/')[-1],
state=getattr(api_session, 'session_state', None) or {},
last_update_time=api_session.update_time.timestamp(),
display_name=getattr(api_session, 'display_name', None),
)
)

Expand Down
39 changes: 39 additions & 0 deletions tests/unittests/sessions/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async def test_create_get_session(service_type, tmp_path):
assert session.user_id == user_id
assert session.id
assert session.state == state
assert session.display_name is None
assert (
session.last_update_time
<= datetime.now().astimezone(timezone.utc).timestamp()
Expand Down Expand Up @@ -618,3 +619,41 @@ async def test_partial_events_are_not_persisted(service_type, tmp_path):
app_name=app_name, user_id=user_id, session_id=session.id
)
assert len(session_got.events) == 0


@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_create_session_with_display_name(service_type, tmp_path):
"""Test that display_name is properly stored and retrieved."""
session_service = get_session_service(service_type, tmp_path)
app_name = 'my_app'
user_id = 'test_user'
display_name = 'My Test Session'

# Create a session with a display_name
session = await session_service.create_session(
app_name=app_name,
user_id=user_id,
display_name=display_name,
)
assert session.display_name == display_name

# Verify display_name is persisted when fetching the session
got_session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert got_session.display_name == display_name

# Verify display_name appears in list_sessions
list_response = await session_service.list_sessions(
app_name=app_name, user_id=user_id
)
assert len(list_response.sessions) == 1
assert list_response.sessions[0].display_name == display_name
16 changes: 16 additions & 0 deletions tests/unittests/sessions/test_vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,22 @@ async def test_create_session_with_custom_config(mock_api_client_instance):
)


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_create_session_with_display_name(mock_api_client_instance):
"""Test that display_name parameter is passed through to the API."""
session_service = mock_vertex_ai_session_service()

display_name = 'Display Name'
await session_service.create_session(
app_name='123', user_id='user', display_name=display_name
)
assert (
mock_api_client_instance.last_create_session_config['display_name']
== display_name
)


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_append_event():
Expand Down