Skip to content

Commit 64c82ea

Browse files
committed
Merge branch 'main' into mypy-gh-actions
2 parents 4793188 + 1063fa5 commit 64c82ea

File tree

6 files changed

+208
-61
lines changed

6 files changed

+208
-61
lines changed

src/google/adk/evaluation/vertex_ai_eval_facade.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,18 +159,25 @@ def _perform_eval(dataset, metrics):
159159
"""
160160
project_id = os.environ.get("GOOGLE_CLOUD_PROJECT", None)
161161
location = os.environ.get("GOOGLE_CLOUD_LOCATION", None)
162-
163-
if not project_id:
164-
raise ValueError("Missing project id." + _ERROR_MESSAGE_SUFFIX)
165-
if not location:
166-
raise ValueError("Missing location." + _ERROR_MESSAGE_SUFFIX)
167-
168-
from vertexai import Client
169-
from vertexai import types as vertexai_types
170-
171-
client = Client(project=project_id, location=location)
162+
api_key = os.environ.get("GOOGLE_API_KEY", None)
163+
164+
from ..dependencies.vertexai import vertexai
165+
166+
if api_key:
167+
client = vertexai.Client(api_key=api_key)
168+
elif project_id or location:
169+
if not project_id:
170+
raise ValueError("Missing project id." + _ERROR_MESSAGE_SUFFIX)
171+
if not location:
172+
raise ValueError("Missing location." + _ERROR_MESSAGE_SUFFIX)
173+
client = vertexai.Client(project=project_id, location=location)
174+
else:
175+
raise ValueError(
176+
"Either API Key or Google cloud Project id and location should be"
177+
" specified."
178+
)
172179

173180
return client.evals.evaluate(
174-
dataset=vertexai_types.EvaluationDataset(eval_dataset_df=dataset),
181+
dataset=vertexai.types.EvaluationDataset(eval_dataset_df=dataset),
175182
metrics=metrics,
176183
)

src/google/adk/sessions/database_session_service.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from sqlalchemy.ext.asyncio import AsyncEngine
3232
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
3333
from sqlalchemy.ext.asyncio import create_async_engine
34-
from sqlalchemy.inspection import inspect
3534
from sqlalchemy.pool import StaticPool
3635
from typing_extensions import override
3736
from tzlocal import get_localzone
@@ -275,22 +274,29 @@ async def create_session(
275274
storage_user_state.state = storage_user_state.state | user_state_delta
276275

277276
# Store the session
277+
now = datetime.now(timezone.utc)
278+
is_sqlite = self.db_engine.dialect.name == "sqlite"
279+
if is_sqlite:
280+
now = now.replace(tzinfo=None)
281+
278282
storage_session = schema.StorageSession(
279283
app_name=app_name,
280284
user_id=user_id,
281285
id=session_id,
282286
state=session_state,
287+
create_time=now,
288+
update_time=now,
283289
)
284290
sql_session.add(storage_session)
285291
await sql_session.commit()
286292

287-
await sql_session.refresh(storage_session)
288-
289293
# Merge states for response
290294
merged_state = _merge_state(
291295
storage_app_state.state, storage_user_state.state, session_state
292296
)
293-
session = storage_session.to_session(state=merged_state)
297+
session = storage_session.to_session(
298+
state=merged_state, is_sqlite=is_sqlite
299+
)
294300
return session
295301

296302
@override
@@ -350,7 +356,10 @@ async def get_session(
350356

351357
# Convert storage session to session
352358
events = [e.to_event() for e in reversed(storage_events)]
353-
session = storage_session.to_session(state=merged_state, events=events)
359+
is_sqlite = self.db_engine.dialect.name == "sqlite"
360+
session = storage_session.to_session(
361+
state=merged_state, events=events, is_sqlite=is_sqlite
362+
)
354363
return session
355364

356365
@override
@@ -393,11 +402,14 @@ async def list_sessions(
393402
user_states_map[storage_user_state.user_id] = storage_user_state.state
394403

395404
sessions = []
405+
is_sqlite = self.db_engine.dialect.name == "sqlite"
396406
for storage_session in results:
397407
session_state = storage_session.state
398408
user_state = user_states_map.get(storage_session.user_id, {})
399409
merged_state = _merge_state(app_state, user_state, session_state)
400-
sessions.append(storage_session.to_session(state=merged_state))
410+
sessions.append(
411+
storage_session.to_session(state=merged_state, is_sqlite=is_sqlite)
412+
)
401413
return ListSessionsResponse(sessions=sessions)
402414

403415
@override
@@ -433,15 +445,6 @@ async def append_event(self, session: Session, event: Event) -> Event:
433445
schema.StorageSession, (session.app_name, session.user_id, session.id)
434446
)
435447

436-
if storage_session.update_timestamp_tz > session.last_update_time:
437-
raise ValueError(
438-
"The last_update_time provided in the session object"
439-
f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is"
440-
" earlier than the update_time in the storage_session"
441-
f" {datetime.fromtimestamp(storage_session.update_timestamp_tz):'%Y-%m-%d %H:%M:%S'}."
442-
" Please check if it is a stale session."
443-
)
444-
445448
# Fetch states from storage
446449
storage_app_state = await sql_session.get(
447450
schema.StorageAppState, (session.app_name)
@@ -450,6 +453,29 @@ async def append_event(self, session: Session, event: Event) -> Event:
450453
schema.StorageUserState, (session.app_name, session.user_id)
451454
)
452455

456+
is_sqlite = self.db_engine.dialect.name == "sqlite"
457+
if (
458+
storage_session.get_update_timestamp(is_sqlite)
459+
> session.last_update_time
460+
):
461+
# Reload the session from storage if it has been updated since it was
462+
# loaded.
463+
app_state = storage_app_state.state if storage_app_state else {}
464+
user_state = storage_user_state.state if storage_user_state else {}
465+
session_state = storage_session.state
466+
session.state = _merge_state(app_state, user_state, session_state)
467+
468+
stmt = (
469+
select(schema.StorageEvent)
470+
.filter(schema.StorageEvent.app_name == session.app_name)
471+
.filter(schema.StorageEvent.session_id == session.id)
472+
.filter(schema.StorageEvent.user_id == session.user_id)
473+
.order_by(schema.StorageEvent.timestamp.asc())
474+
)
475+
result = await sql_session.stream_scalars(stmt)
476+
storage_events = [e async for e in result]
477+
session.events = [e.to_event() for e in storage_events]
478+
453479
# Extract state delta
454480
if event.actions and event.actions.state_delta:
455481
state_deltas = _session_util.extract_state_delta(
@@ -466,7 +492,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
466492
if session_state_delta:
467493
storage_session.state = storage_session.state | session_state_delta
468494

469-
if storage_session._dialect_name == "sqlite":
495+
if is_sqlite:
470496
update_time = datetime.fromtimestamp(
471497
event.timestamp, timezone.utc
472498
).replace(tzinfo=None)
@@ -476,10 +502,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
476502
sql_session.add(schema.StorageEvent.from_event(session, event))
477503

478504
await sql_session.commit()
479-
await sql_session.refresh(storage_session)
480505

481506
# Update timestamp with commit time
482-
session.last_update_time = storage_session.update_timestamp_tz
507+
session.last_update_time = storage_session.get_update_timestamp(is_sqlite)
483508

484509
# Also update the in-memory session
485510
await super().append_event(session=session, event=event)

src/google/adk/sessions/schemas/v0.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from sqlalchemy import Text
4242
from sqlalchemy.dialects import mysql
4343
from sqlalchemy.ext.mutable import MutableDict
44-
from sqlalchemy.inspection import inspect
4544
from sqlalchemy.orm import DeclarativeBase
4645
from sqlalchemy.orm import Mapped
4746
from sqlalchemy.orm import mapped_column
@@ -131,15 +130,9 @@ class StorageSession(Base):
131130
def __repr__(self):
132131
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
133132

134-
@property
135-
def _dialect_name(self) -> Optional[str]:
136-
session = inspect(self).session
137-
return session.bind.dialect.name if session else None
138-
139-
@property
140-
def update_timestamp_tz(self) -> datetime:
133+
def get_update_timestamp(self, is_sqlite: bool) -> float:
141134
"""Returns the time zone aware update timestamp."""
142-
if self._dialect_name == "sqlite":
135+
if is_sqlite:
143136
# SQLite does not support timezone. SQLAlchemy returns a naive datetime
144137
# object without timezone information. We need to convert it to UTC
145138
# manually.
@@ -150,6 +143,7 @@ def to_session(
150143
self,
151144
state: dict[str, Any] | None = None,
152145
events: list[Event] | None = None,
146+
is_sqlite: bool = False,
153147
) -> Session:
154148
"""Converts the storage session to a session object."""
155149
if state is None:
@@ -163,7 +157,7 @@ def to_session(
163157
id=self.id,
164158
state=state,
165159
events=events,
166-
last_update_time=self.update_timestamp_tz,
160+
last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite),
167161
)
168162

169163

src/google/adk/sessions/schemas/v1.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from sqlalchemy import ForeignKeyConstraint
3333
from sqlalchemy import func
3434
from sqlalchemy.ext.mutable import MutableDict
35-
from sqlalchemy.inspection import inspect
3635
from sqlalchemy.orm import DeclarativeBase
3736
from sqlalchemy.orm import Mapped
3837
from sqlalchemy.orm import mapped_column
@@ -106,15 +105,9 @@ class StorageSession(Base):
106105
def __repr__(self):
107106
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
108107

109-
@property
110-
def _dialect_name(self) -> Optional[str]:
111-
session = inspect(self).session
112-
return session.bind.dialect.name if session else None
113-
114-
@property
115-
def update_timestamp_tz(self) -> datetime:
108+
def get_update_timestamp(self, is_sqlite: bool) -> float:
116109
"""Returns the time zone aware update timestamp."""
117-
if self._dialect_name == "sqlite":
110+
if is_sqlite:
118111
# SQLite does not support timezone. SQLAlchemy returns a naive datetime
119112
# object without timezone information. We need to convert it to UTC
120113
# manually.
@@ -125,6 +118,7 @@ def to_session(
125118
self,
126119
state: dict[str, Any] | None = None,
127120
events: list[Event] | None = None,
121+
is_sqlite: bool = False,
128122
) -> Session:
129123
"""Converts the storage session to a session object."""
130124
if state is None:
@@ -138,7 +132,7 @@ def to_session(
138132
id=self.id,
139133
state=state,
140134
events=events,
141-
last_update_time=self.update_timestamp_tz,
135+
last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite),
142136
)
143137

144138

tests/unittests/evaluation/test_vertex_ai_eval_facade.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616

1717
"""Tests for the Response Evaluator."""
1818
import math
19+
import os
1920
import random
2021

2122
from google.adk.dependencies.vertexai import vertexai
2223
from google.adk.evaluation.eval_case import Invocation
2324
from google.adk.evaluation.evaluator import EvalStatus
2425
from google.adk.evaluation.vertex_ai_eval_facade import _VertexAiEvalFacade
2526
from google.genai import types as genai_types
27+
import pandas as pd
2628
import pytest
2729

2830
vertexai_types = vertexai.types
@@ -246,3 +248,89 @@ def test_evaluate_invocations_metric_multiple_invocations(self, mocker):
246248
)
247249
assert evaluation_result.overall_eval_status == EvalStatus.FAILED
248250
assert mock_perform_eval.call_count == num_invocations
251+
252+
def test_perform_eval_with_api_key(self, mocker):
253+
mocker.patch.dict(
254+
os.environ, {"GOOGLE_API_KEY": "test_api_key"}, clear=True
255+
)
256+
mock_client_cls = mocker.patch(
257+
"google.adk.dependencies.vertexai.vertexai.Client"
258+
)
259+
mock_client_instance = mock_client_cls.return_value
260+
dummy_dataset = pd.DataFrame(
261+
[{"prompt": "p", "reference": "r", "response": "r"}]
262+
)
263+
dummy_metrics = [vertexai_types.PrebuiltMetric.COHERENCE]
264+
265+
_VertexAiEvalFacade._perform_eval(dummy_dataset, dummy_metrics)
266+
267+
mock_client_cls.assert_called_once_with(api_key="test_api_key")
268+
mock_client_instance.evals.evaluate.assert_called_once()
269+
270+
def test_perform_eval_with_project_and_location(self, mocker):
271+
mocker.patch.dict(
272+
os.environ,
273+
{
274+
"GOOGLE_CLOUD_PROJECT": "test_project",
275+
"GOOGLE_CLOUD_LOCATION": "test_location",
276+
},
277+
clear=True,
278+
)
279+
mock_client_cls = mocker.patch(
280+
"google.adk.dependencies.vertexai.vertexai.Client"
281+
)
282+
mock_client_instance = mock_client_cls.return_value
283+
dummy_dataset = pd.DataFrame(
284+
[{"prompt": "p", "reference": "r", "response": "r"}]
285+
)
286+
dummy_metrics = [vertexai_types.PrebuiltMetric.COHERENCE]
287+
288+
_VertexAiEvalFacade._perform_eval(dummy_dataset, dummy_metrics)
289+
290+
mock_client_cls.assert_called_once_with(
291+
project="test_project", location="test_location"
292+
)
293+
mock_client_instance.evals.evaluate.assert_called_once()
294+
295+
def test_perform_eval_with_project_only_raises_error(self, mocker):
296+
mocker.patch.dict(
297+
os.environ, {"GOOGLE_CLOUD_PROJECT": "test_project"}, clear=True
298+
)
299+
mocker.patch("google.adk.dependencies.vertexai.vertexai.Client")
300+
dummy_dataset = pd.DataFrame(
301+
[{"prompt": "p", "reference": "r", "response": "r"}]
302+
)
303+
dummy_metrics = [vertexai_types.PrebuiltMetric.COHERENCE]
304+
305+
with pytest.raises(ValueError, match="Missing location."):
306+
_VertexAiEvalFacade._perform_eval(dummy_dataset, dummy_metrics)
307+
308+
def test_perform_eval_with_location_only_raises_error(self, mocker):
309+
mocker.patch.dict(
310+
os.environ, {"GOOGLE_CLOUD_LOCATION": "test_location"}, clear=True
311+
)
312+
mocker.patch("google.adk.dependencies.vertexai.vertexai.Client")
313+
dummy_dataset = pd.DataFrame(
314+
[{"prompt": "p", "reference": "r", "response": "r"}]
315+
)
316+
dummy_metrics = [vertexai_types.PrebuiltMetric.COHERENCE]
317+
318+
with pytest.raises(ValueError, match="Missing project id."):
319+
_VertexAiEvalFacade._perform_eval(dummy_dataset, dummy_metrics)
320+
321+
def test_perform_eval_with_no_env_vars_raises_error(self, mocker):
322+
mocker.patch.dict(os.environ, {}, clear=True)
323+
mocker.patch("google.adk.dependencies.vertexai.vertexai.Client")
324+
dummy_dataset = pd.DataFrame(
325+
[{"prompt": "p", "reference": "r", "response": "r"}]
326+
)
327+
dummy_metrics = [vertexai_types.PrebuiltMetric.COHERENCE]
328+
329+
with pytest.raises(
330+
ValueError,
331+
match=(
332+
"Either API Key or Google cloud Project id and location should be"
333+
" specified."
334+
),
335+
):
336+
_VertexAiEvalFacade._perform_eval(dummy_dataset, dummy_metrics)

0 commit comments

Comments
 (0)