Skip to content

Commit 1063fa5

Browse files
DeanChensjcopybara-github
authored andcommitted
fix: Reload the stale session in database_session_service when the update_time of storage is later than the current in memory session object
This changes the behavior of how we handle stale session, instead of reject the transaction entirely, we aggresively refresh the session. Removed synchronous inspect(self).session calls within StorageSession.update_timestamp_tz and _dialect_name. This introspection was causing deadlocks/hangs when used with sqlalchemy.ext.asyncio in Python 3.13. Closes issue: #1733 Co-authored-by: Shangjie Chen <deanchen@google.com> PiperOrigin-RevId: 861353058
1 parent 43d6075 commit 1063fa5

File tree

4 files changed

+102
-50
lines changed

4 files changed

+102
-50
lines changed

src/google/adk/sessions/database_session_service.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from sqlalchemy.ext.asyncio import AsyncEngine
3232
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
3333
from sqlalchemy.ext.asyncio import create_async_engine
34-
from sqlalchemy.inspection import inspect
3534
from sqlalchemy.pool import StaticPool
3635
from typing_extensions import override
3736
from tzlocal import get_localzone
@@ -275,22 +274,29 @@ async def create_session(
275274
storage_user_state.state = storage_user_state.state | user_state_delta
276275

277276
# Store the session
277+
now = datetime.now(timezone.utc)
278+
is_sqlite = self.db_engine.dialect.name == "sqlite"
279+
if is_sqlite:
280+
now = now.replace(tzinfo=None)
281+
278282
storage_session = schema.StorageSession(
279283
app_name=app_name,
280284
user_id=user_id,
281285
id=session_id,
282286
state=session_state,
287+
create_time=now,
288+
update_time=now,
283289
)
284290
sql_session.add(storage_session)
285291
await sql_session.commit()
286292

287-
await sql_session.refresh(storage_session)
288-
289293
# Merge states for response
290294
merged_state = _merge_state(
291295
storage_app_state.state, storage_user_state.state, session_state
292296
)
293-
session = storage_session.to_session(state=merged_state)
297+
session = storage_session.to_session(
298+
state=merged_state, is_sqlite=is_sqlite
299+
)
294300
return session
295301

296302
@override
@@ -350,7 +356,10 @@ async def get_session(
350356

351357
# Convert storage session to session
352358
events = [e.to_event() for e in reversed(storage_events)]
353-
session = storage_session.to_session(state=merged_state, events=events)
359+
is_sqlite = self.db_engine.dialect.name == "sqlite"
360+
session = storage_session.to_session(
361+
state=merged_state, events=events, is_sqlite=is_sqlite
362+
)
354363
return session
355364

356365
@override
@@ -393,11 +402,14 @@ async def list_sessions(
393402
user_states_map[storage_user_state.user_id] = storage_user_state.state
394403

395404
sessions = []
405+
is_sqlite = self.db_engine.dialect.name == "sqlite"
396406
for storage_session in results:
397407
session_state = storage_session.state
398408
user_state = user_states_map.get(storage_session.user_id, {})
399409
merged_state = _merge_state(app_state, user_state, session_state)
400-
sessions.append(storage_session.to_session(state=merged_state))
410+
sessions.append(
411+
storage_session.to_session(state=merged_state, is_sqlite=is_sqlite)
412+
)
401413
return ListSessionsResponse(sessions=sessions)
402414

403415
@override
@@ -433,15 +445,6 @@ async def append_event(self, session: Session, event: Event) -> Event:
433445
schema.StorageSession, (session.app_name, session.user_id, session.id)
434446
)
435447

436-
if storage_session.update_timestamp_tz > session.last_update_time:
437-
raise ValueError(
438-
"The last_update_time provided in the session object"
439-
f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is"
440-
" earlier than the update_time in the storage_session"
441-
f" {datetime.fromtimestamp(storage_session.update_timestamp_tz):'%Y-%m-%d %H:%M:%S'}."
442-
" Please check if it is a stale session."
443-
)
444-
445448
# Fetch states from storage
446449
storage_app_state = await sql_session.get(
447450
schema.StorageAppState, (session.app_name)
@@ -450,6 +453,29 @@ async def append_event(self, session: Session, event: Event) -> Event:
450453
schema.StorageUserState, (session.app_name, session.user_id)
451454
)
452455

456+
is_sqlite = self.db_engine.dialect.name == "sqlite"
457+
if (
458+
storage_session.get_update_timestamp(is_sqlite)
459+
> session.last_update_time
460+
):
461+
# Reload the session from storage if it has been updated since it was
462+
# loaded.
463+
app_state = storage_app_state.state if storage_app_state else {}
464+
user_state = storage_user_state.state if storage_user_state else {}
465+
session_state = storage_session.state
466+
session.state = _merge_state(app_state, user_state, session_state)
467+
468+
stmt = (
469+
select(schema.StorageEvent)
470+
.filter(schema.StorageEvent.app_name == session.app_name)
471+
.filter(schema.StorageEvent.session_id == session.id)
472+
.filter(schema.StorageEvent.user_id == session.user_id)
473+
.order_by(schema.StorageEvent.timestamp.asc())
474+
)
475+
result = await sql_session.stream_scalars(stmt)
476+
storage_events = [e async for e in result]
477+
session.events = [e.to_event() for e in storage_events]
478+
453479
# Extract state delta
454480
if event.actions and event.actions.state_delta:
455481
state_deltas = _session_util.extract_state_delta(
@@ -466,7 +492,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
466492
if session_state_delta:
467493
storage_session.state = storage_session.state | session_state_delta
468494

469-
if storage_session._dialect_name == "sqlite":
495+
if is_sqlite:
470496
update_time = datetime.fromtimestamp(
471497
event.timestamp, timezone.utc
472498
).replace(tzinfo=None)
@@ -476,10 +502,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
476502
sql_session.add(schema.StorageEvent.from_event(session, event))
477503

478504
await sql_session.commit()
479-
await sql_session.refresh(storage_session)
480505

481506
# Update timestamp with commit time
482-
session.last_update_time = storage_session.update_timestamp_tz
507+
session.last_update_time = storage_session.get_update_timestamp(is_sqlite)
483508

484509
# Also update the in-memory session
485510
await super().append_event(session=session, event=event)

src/google/adk/sessions/schemas/v0.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from sqlalchemy import Text
4242
from sqlalchemy.dialects import mysql
4343
from sqlalchemy.ext.mutable import MutableDict
44-
from sqlalchemy.inspection import inspect
4544
from sqlalchemy.orm import DeclarativeBase
4645
from sqlalchemy.orm import Mapped
4746
from sqlalchemy.orm import mapped_column
@@ -131,15 +130,9 @@ class StorageSession(Base):
131130
def __repr__(self):
132131
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
133132

134-
@property
135-
def _dialect_name(self) -> Optional[str]:
136-
session = inspect(self).session
137-
return session.bind.dialect.name if session else None
138-
139-
@property
140-
def update_timestamp_tz(self) -> datetime:
133+
def get_update_timestamp(self, is_sqlite: bool) -> float:
141134
"""Returns the time zone aware update timestamp."""
142-
if self._dialect_name == "sqlite":
135+
if is_sqlite:
143136
# SQLite does not support timezone. SQLAlchemy returns a naive datetime
144137
# object without timezone information. We need to convert it to UTC
145138
# manually.
@@ -150,6 +143,7 @@ def to_session(
150143
self,
151144
state: dict[str, Any] | None = None,
152145
events: list[Event] | None = None,
146+
is_sqlite: bool = False,
153147
) -> Session:
154148
"""Converts the storage session to a session object."""
155149
if state is None:
@@ -163,7 +157,7 @@ def to_session(
163157
id=self.id,
164158
state=state,
165159
events=events,
166-
last_update_time=self.update_timestamp_tz,
160+
last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite),
167161
)
168162

169163

src/google/adk/sessions/schemas/v1.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from sqlalchemy import ForeignKeyConstraint
3333
from sqlalchemy import func
3434
from sqlalchemy.ext.mutable import MutableDict
35-
from sqlalchemy.inspection import inspect
3635
from sqlalchemy.orm import DeclarativeBase
3736
from sqlalchemy.orm import Mapped
3837
from sqlalchemy.orm import mapped_column
@@ -106,15 +105,9 @@ class StorageSession(Base):
106105
def __repr__(self):
107106
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
108107

109-
@property
110-
def _dialect_name(self) -> Optional[str]:
111-
session = inspect(self).session
112-
return session.bind.dialect.name if session else None
113-
114-
@property
115-
def update_timestamp_tz(self) -> datetime:
108+
def get_update_timestamp(self, is_sqlite: bool) -> float:
116109
"""Returns the time zone aware update timestamp."""
117-
if self._dialect_name == "sqlite":
110+
if is_sqlite:
118111
# SQLite does not support timezone. SQLAlchemy returns a naive datetime
119112
# object without timezone information. We need to convert it to UTC
120113
# manually.
@@ -125,6 +118,7 @@ def to_session(
125118
self,
126119
state: dict[str, Any] | None = None,
127120
events: list[Event] | None = None,
121+
is_sqlite: bool = False,
128122
) -> Session:
129123
"""Converts the storage session to a session object."""
130124
if state is None:
@@ -138,7 +132,7 @@ def to_session(
138132
id=self.id,
139133
state=state,
140134
events=events,
141-
last_update_time=self.update_timestamp_tz,
135+
last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite),
142136
)
143137

144138

tests/unittests/sessions/test_session_service.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -504,26 +504,65 @@ async def test_session_last_update_time_updates_on_event(session_service):
504504

505505

506506
@pytest.mark.asyncio
507-
async def test_get_session_with_config(session_service):
507+
async def test_append_event_to_stale_session():
508+
session_service = get_session_service(
509+
service_type=SessionServiceType.DATABASE
510+
)
511+
508512
app_name = 'my_app'
509513
user_id = 'user'
514+
current_time = datetime.now().astimezone(timezone.utc).timestamp()
510515

511-
session = await session_service.create_session(
516+
original_session = await session_service.create_session(
512517
app_name=app_name, user_id=user_id
513518
)
514-
original_update_time = session.last_update_time
519+
event1 = Event(
520+
invocation_id='inv1',
521+
author='user',
522+
timestamp=current_time + 1,
523+
actions=EventActions(state_delta={'sk1': 'v1'}),
524+
)
525+
await session_service.append_event(original_session, event1)
515526

516-
event = Event(invocation_id='invocation', author='user')
517-
await session_service.append_event(session=session, event=event)
527+
updated_session = await session_service.get_session(
528+
app_name=app_name, user_id=user_id, session_id=original_session.id
529+
)
530+
event2 = Event(
531+
invocation_id='inv2',
532+
author='user',
533+
timestamp=current_time + 2,
534+
actions=EventActions(state_delta={'sk2': 'v2'}),
535+
)
536+
await session_service.append_event(updated_session, event2)
518537

519-
assert session.last_update_time >= event.timestamp
538+
# original_session is now stale
539+
assert original_session.last_update_time < updated_session.last_update_time
540+
assert len(original_session.events) == 1
541+
assert 'sk2' not in original_session.state
520542

521-
refreshed_session = await session_service.get_session(
522-
app_name=app_name, user_id=user_id, session_id=session.id
523-
)
524-
assert refreshed_session is not None
525-
assert refreshed_session.last_update_time >= event.timestamp
526-
assert refreshed_session.last_update_time > original_update_time
543+
# Appending another event to stale original_session
544+
event3 = Event(
545+
invocation_id='inv3',
546+
author='user',
547+
timestamp=current_time + 3,
548+
actions=EventActions(state_delta={'sk3': 'v3'}),
549+
)
550+
await session_service.append_event(original_session, event3)
551+
552+
# If we fetch session from DB, it should contain all 3 events and all state
553+
# changes.
554+
session_final = await session_service.get_session(
555+
app_name=app_name, user_id=user_id, session_id=original_session.id
556+
)
557+
assert len(session_final.events) == 3
558+
assert session_final.state.get('sk1') == 'v1'
559+
assert session_final.state.get('sk2') == 'v2'
560+
assert session_final.state.get('sk3') == 'v3'
561+
assert [e.invocation_id for e in session_final.events] == [
562+
'inv1',
563+
'inv2',
564+
'inv3',
565+
]
527566

528567

529568
@pytest.mark.asyncio

0 commit comments

Comments
 (0)