Skip to content

Commit b6facdc

Browse files
Merge branch 'main' into feat/parallelize-llm-judge-evaluation
2 parents 93effce + 1063fa5 commit b6facdc

File tree

11 files changed

+335
-77
lines changed

11 files changed

+335
-77
lines changed

src/google/adk/a2a/converters/event_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def _get_context_metadata(
131131
_get_adk_metadata_key("session_id"): invocation_context.session.id,
132132
_get_adk_metadata_key("invocation_id"): event.invocation_id,
133133
_get_adk_metadata_key("author"): event.author,
134+
_get_adk_metadata_key("event_id"): event.id,
134135
}
135136

136137
# Add optional metadata fields if present
@@ -479,7 +480,6 @@ def _create_status_update_event(
479480
task_id: Optional task ID to use for generated events.
480481
context_id: Optional Context ID to use for generated events.
481482
482-
483483
Returns:
484484
A TaskStatusUpdateEvent with RUNNING state.
485485
"""

src/google/adk/evaluation/custom_metric_evaluator.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from .eval_case import ConversationScenario
2525
from .eval_case import Invocation
2626
from .eval_metrics import EvalMetric
27-
from .eval_metrics import EvalStatus
2827
from .evaluator import EvaluationResult
2928
from .evaluator import Evaluator
3029

@@ -44,12 +43,6 @@ def _get_metric_function(
4443
) from e
4544

4645

47-
def _get_eval_status(score: Optional[float], threshold: float) -> EvalStatus:
48-
if score is None:
49-
return EvalStatus.NOT_EVALUATED
50-
return EvalStatus.PASSED if score >= threshold else EvalStatus.FAILED
51-
52-
5346
class _CustomMetricEvaluator(Evaluator):
5447
"""Evaluator for custom metrics."""
5548

@@ -64,16 +57,20 @@ async def evaluate_invocations(
6457
expected_invocations: Optional[list[Invocation]],
6558
conversation_scenario: Optional[ConversationScenario] = None,
6659
) -> EvaluationResult:
60+
eval_metric = self._eval_metric.model_copy(deep=True)
61+
eval_metric.threshold = None
6762
if inspect.iscoroutinefunction(self._metric_function):
6863
eval_result = await self._metric_function(
69-
actual_invocations, expected_invocations, conversation_scenario
64+
eval_metric,
65+
actual_invocations,
66+
expected_invocations,
67+
conversation_scenario,
7068
)
7169
else:
7270
eval_result = self._metric_function(
73-
actual_invocations, expected_invocations, conversation_scenario
71+
eval_metric,
72+
actual_invocations,
73+
expected_invocations,
74+
conversation_scenario,
7475
)
75-
76-
eval_result.overall_eval_status = _get_eval_status(
77-
eval_result.overall_score, self._eval_metric.threshold
78-
)
7976
return eval_result

src/google/adk/evaluation/eval_metrics.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,11 @@ class EvalMetric(EvalBaseModel):
258258
description="The name of the metric.",
259259
)
260260

261-
threshold: float = Field(
261+
threshold: Optional[float] = Field(
262+
default=None,
262263
description=(
263-
"A threshold value. Each metric decides how to interpret this"
264+
"This field will be deprecated soon. Please use `criterion` instead."
265+
" A threshold value. Each metric decides how to interpret this"
264266
" threshold."
265267
),
266268
)

src/google/adk/evaluation/vertex_ai_eval_facade.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,18 +159,25 @@ def _perform_eval(dataset, metrics):
159159
"""
160160
project_id = os.environ.get("GOOGLE_CLOUD_PROJECT", None)
161161
location = os.environ.get("GOOGLE_CLOUD_LOCATION", None)
162-
163-
if not project_id:
164-
raise ValueError("Missing project id." + _ERROR_MESSAGE_SUFFIX)
165-
if not location:
166-
raise ValueError("Missing location." + _ERROR_MESSAGE_SUFFIX)
167-
168-
from vertexai import Client
169-
from vertexai import types as vertexai_types
170-
171-
client = Client(project=project_id, location=location)
162+
api_key = os.environ.get("GOOGLE_API_KEY", None)
163+
164+
from ..dependencies.vertexai import vertexai
165+
166+
if api_key:
167+
client = vertexai.Client(api_key=api_key)
168+
elif project_id or location:
169+
if not project_id:
170+
raise ValueError("Missing project id." + _ERROR_MESSAGE_SUFFIX)
171+
if not location:
172+
raise ValueError("Missing location." + _ERROR_MESSAGE_SUFFIX)
173+
client = vertexai.Client(project=project_id, location=location)
174+
else:
175+
raise ValueError(
176+
"Either API Key or Google cloud Project id and location should be"
177+
" specified."
178+
)
172179

173180
return client.evals.evaluate(
174-
dataset=vertexai_types.EvaluationDataset(eval_dataset_df=dataset),
181+
dataset=vertexai.types.EvaluationDataset(eval_dataset_df=dataset),
175182
metrics=metrics,
176183
)

src/google/adk/sessions/database_session_service.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from sqlalchemy.ext.asyncio import AsyncEngine
3232
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
3333
from sqlalchemy.ext.asyncio import create_async_engine
34-
from sqlalchemy.inspection import inspect
3534
from sqlalchemy.pool import StaticPool
3635
from typing_extensions import override
3736
from tzlocal import get_localzone
@@ -275,22 +274,29 @@ async def create_session(
275274
storage_user_state.state = storage_user_state.state | user_state_delta
276275

277276
# Store the session
277+
now = datetime.now(timezone.utc)
278+
is_sqlite = self.db_engine.dialect.name == "sqlite"
279+
if is_sqlite:
280+
now = now.replace(tzinfo=None)
281+
278282
storage_session = schema.StorageSession(
279283
app_name=app_name,
280284
user_id=user_id,
281285
id=session_id,
282286
state=session_state,
287+
create_time=now,
288+
update_time=now,
283289
)
284290
sql_session.add(storage_session)
285291
await sql_session.commit()
286292

287-
await sql_session.refresh(storage_session)
288-
289293
# Merge states for response
290294
merged_state = _merge_state(
291295
storage_app_state.state, storage_user_state.state, session_state
292296
)
293-
session = storage_session.to_session(state=merged_state)
297+
session = storage_session.to_session(
298+
state=merged_state, is_sqlite=is_sqlite
299+
)
294300
return session
295301

296302
@override
@@ -350,7 +356,10 @@ async def get_session(
350356

351357
# Convert storage session to session
352358
events = [e.to_event() for e in reversed(storage_events)]
353-
session = storage_session.to_session(state=merged_state, events=events)
359+
is_sqlite = self.db_engine.dialect.name == "sqlite"
360+
session = storage_session.to_session(
361+
state=merged_state, events=events, is_sqlite=is_sqlite
362+
)
354363
return session
355364

356365
@override
@@ -393,11 +402,14 @@ async def list_sessions(
393402
user_states_map[storage_user_state.user_id] = storage_user_state.state
394403

395404
sessions = []
405+
is_sqlite = self.db_engine.dialect.name == "sqlite"
396406
for storage_session in results:
397407
session_state = storage_session.state
398408
user_state = user_states_map.get(storage_session.user_id, {})
399409
merged_state = _merge_state(app_state, user_state, session_state)
400-
sessions.append(storage_session.to_session(state=merged_state))
410+
sessions.append(
411+
storage_session.to_session(state=merged_state, is_sqlite=is_sqlite)
412+
)
401413
return ListSessionsResponse(sessions=sessions)
402414

403415
@override
@@ -433,15 +445,6 @@ async def append_event(self, session: Session, event: Event) -> Event:
433445
schema.StorageSession, (session.app_name, session.user_id, session.id)
434446
)
435447

436-
if storage_session.update_timestamp_tz > session.last_update_time:
437-
raise ValueError(
438-
"The last_update_time provided in the session object"
439-
f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is"
440-
" earlier than the update_time in the storage_session"
441-
f" {datetime.fromtimestamp(storage_session.update_timestamp_tz):'%Y-%m-%d %H:%M:%S'}."
442-
" Please check if it is a stale session."
443-
)
444-
445448
# Fetch states from storage
446449
storage_app_state = await sql_session.get(
447450
schema.StorageAppState, (session.app_name)
@@ -450,6 +453,29 @@ async def append_event(self, session: Session, event: Event) -> Event:
450453
schema.StorageUserState, (session.app_name, session.user_id)
451454
)
452455

456+
is_sqlite = self.db_engine.dialect.name == "sqlite"
457+
if (
458+
storage_session.get_update_timestamp(is_sqlite)
459+
> session.last_update_time
460+
):
461+
# Reload the session from storage if it has been updated since it was
462+
# loaded.
463+
app_state = storage_app_state.state if storage_app_state else {}
464+
user_state = storage_user_state.state if storage_user_state else {}
465+
session_state = storage_session.state
466+
session.state = _merge_state(app_state, user_state, session_state)
467+
468+
stmt = (
469+
select(schema.StorageEvent)
470+
.filter(schema.StorageEvent.app_name == session.app_name)
471+
.filter(schema.StorageEvent.session_id == session.id)
472+
.filter(schema.StorageEvent.user_id == session.user_id)
473+
.order_by(schema.StorageEvent.timestamp.asc())
474+
)
475+
result = await sql_session.stream_scalars(stmt)
476+
storage_events = [e async for e in result]
477+
session.events = [e.to_event() for e in storage_events]
478+
453479
# Extract state delta
454480
if event.actions and event.actions.state_delta:
455481
state_deltas = _session_util.extract_state_delta(
@@ -466,7 +492,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
466492
if session_state_delta:
467493
storage_session.state = storage_session.state | session_state_delta
468494

469-
if storage_session._dialect_name == "sqlite":
495+
if is_sqlite:
470496
update_time = datetime.fromtimestamp(
471497
event.timestamp, timezone.utc
472498
).replace(tzinfo=None)
@@ -476,10 +502,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
476502
sql_session.add(schema.StorageEvent.from_event(session, event))
477503

478504
await sql_session.commit()
479-
await sql_session.refresh(storage_session)
480505

481506
# Update timestamp with commit time
482-
session.last_update_time = storage_session.update_timestamp_tz
507+
session.last_update_time = storage_session.get_update_timestamp(is_sqlite)
483508

484509
# Also update the in-memory session
485510
await super().append_event(session=session, event=event)

src/google/adk/sessions/schemas/v0.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from sqlalchemy import Text
4242
from sqlalchemy.dialects import mysql
4343
from sqlalchemy.ext.mutable import MutableDict
44-
from sqlalchemy.inspection import inspect
4544
from sqlalchemy.orm import DeclarativeBase
4645
from sqlalchemy.orm import Mapped
4746
from sqlalchemy.orm import mapped_column
@@ -131,15 +130,9 @@ class StorageSession(Base):
131130
def __repr__(self):
132131
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
133132

134-
@property
135-
def _dialect_name(self) -> Optional[str]:
136-
session = inspect(self).session
137-
return session.bind.dialect.name if session else None
138-
139-
@property
140-
def update_timestamp_tz(self) -> datetime:
133+
def get_update_timestamp(self, is_sqlite: bool) -> float:
141134
"""Returns the time zone aware update timestamp."""
142-
if self._dialect_name == "sqlite":
135+
if is_sqlite:
143136
# SQLite does not support timezone. SQLAlchemy returns a naive datetime
144137
# object without timezone information. We need to convert it to UTC
145138
# manually.
@@ -150,6 +143,7 @@ def to_session(
150143
self,
151144
state: dict[str, Any] | None = None,
152145
events: list[Event] | None = None,
146+
is_sqlite: bool = False,
153147
) -> Session:
154148
"""Converts the storage session to a session object."""
155149
if state is None:
@@ -163,7 +157,7 @@ def to_session(
163157
id=self.id,
164158
state=state,
165159
events=events,
166-
last_update_time=self.update_timestamp_tz,
160+
last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite),
167161
)
168162

169163

src/google/adk/sessions/schemas/v1.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from sqlalchemy import ForeignKeyConstraint
3333
from sqlalchemy import func
3434
from sqlalchemy.ext.mutable import MutableDict
35-
from sqlalchemy.inspection import inspect
3635
from sqlalchemy.orm import DeclarativeBase
3736
from sqlalchemy.orm import Mapped
3837
from sqlalchemy.orm import mapped_column
@@ -106,15 +105,9 @@ class StorageSession(Base):
106105
def __repr__(self):
107106
return f"<StorageSession(id={self.id}, update_time={self.update_time})>"
108107

109-
@property
110-
def _dialect_name(self) -> Optional[str]:
111-
session = inspect(self).session
112-
return session.bind.dialect.name if session else None
113-
114-
@property
115-
def update_timestamp_tz(self) -> datetime:
108+
def get_update_timestamp(self, is_sqlite: bool) -> float:
116109
"""Returns the time zone aware update timestamp."""
117-
if self._dialect_name == "sqlite":
110+
if is_sqlite:
118111
# SQLite does not support timezone. SQLAlchemy returns a naive datetime
119112
# object without timezone information. We need to convert it to UTC
120113
# manually.
@@ -125,6 +118,7 @@ def to_session(
125118
self,
126119
state: dict[str, Any] | None = None,
127120
events: list[Event] | None = None,
121+
is_sqlite: bool = False,
128122
) -> Session:
129123
"""Converts the storage session to a session object."""
130124
if state is None:
@@ -138,7 +132,7 @@ def to_session(
138132
id=self.id,
139133
state=state,
140134
events=events,
141-
last_update_time=self.update_timestamp_tz,
135+
last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite),
142136
)
143137

144138

tests/unittests/a2a/converters/test_event_converter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def setup_method(self):
5656
self.mock_invocation_context.artifact_service = self.mock_artifact_service
5757

5858
self.mock_event = Mock(spec=Event)
59+
self.mock_event.id = None
5960
self.mock_event.invocation_id = "test-invocation-id"
6061
self.mock_event.author = "test-author"
6162
self.mock_event.branch = None
@@ -130,6 +131,7 @@ def test_get_context_metadata_success(self):
130131
f"{ADK_METADATA_KEY_PREFIX}session_id",
131132
f"{ADK_METADATA_KEY_PREFIX}invocation_id",
132133
f"{ADK_METADATA_KEY_PREFIX}author",
134+
f"{ADK_METADATA_KEY_PREFIX}event_id",
133135
]
134136

135137
for key in expected_keys:

0 commit comments

Comments
 (0)