Skip to content

Commit 3dcb209

Browse files
authored
Merge branch 'main' into fix/enforce-arg-types-4612
2 parents ba3c4ea + b004da5 commit 3dcb209

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+4802
-581
lines changed

contributing/samples/authn-adk-all-in-one/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
google-adk==1.12
2-
Flask==3.1.1
2+
Flask==3.1.3
33
flask-cors==6.0.1
44
python-dotenv==1.1.1
55
PyJWT[crypto]==2.10.1

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ dependencies = [
5656
"opentelemetry-resourcedetector-gcp>=1.9.0a0, <2.0.0",
5757
"opentelemetry-sdk>=1.36.0, <1.39.0",
5858
"pyarrow>=14.0.0",
59-
"pydantic>=2.7.0, <3.0.0", # For data validation/models
59+
"pydantic>=2.12.0, <3.0.0", # For data validation/models
6060
"python-dateutil>=2.9.0.post0, <3.0.0", # For Vertext AI Session Service
6161
"python-dotenv>=1.0.0, <2.0.0", # To manage environment variables
6262
"requests>=2.32.4, <3.0.0",
@@ -109,6 +109,7 @@ community = [
109109
eval = [
110110
# go/keep-sorted start
111111
"Jinja2>=3.1.4,<4.0.0", # For eval template rendering
112+
"gepa>=0.1.0",
112113
"google-cloud-aiplatform[evaluation]>=1.100.0",
113114
"pandas>=2.2.3",
114115
"rouge-score>=0.1.2",
@@ -155,6 +156,7 @@ extensions = [
155156
"crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool; chromadb/pypika fail on 3.12+
156157
"docker>=7.0.0", # For ContainerCodeExecutor
157158
"kubernetes>=29.0.0", # For GkeCodeExecutor
159+
"k8s-agent-sandbox>=0.1.1.post2", # For GkeCodeExecutor sandbox mode
158160
"langgraph>=0.2.60, <0.4.8", # For LangGraphAgent
159161
"litellm>=1.75.5, <2.0.0", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it
160162
"llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex.

src/google/adk/a2a/converters/event_converter.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def convert_a2a_message_to_event(
370370
@a2a_experimental
371371
def convert_event_to_a2a_message(
372372
event: Event,
373-
invocation_context: InvocationContext,
373+
invocation_context: InvocationContext | None = None,
374374
role: Role = Role.agent,
375375
part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part,
376376
) -> Optional[Message]:
@@ -390,8 +390,6 @@ def convert_event_to_a2a_message(
390390
"""
391391
if not event:
392392
raise ValueError("Event cannot be None")
393-
if not invocation_context:
394-
raise ValueError("Invocation context cannot be None")
395393

396394
if not event.content or not event.content.parts:
397395
return None

src/google/adk/a2a/converters/part_converter.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ def convert_a2a_part_to_genai_part(
7070
if isinstance(part.file, a2a_types.FileWithUri):
7171
return genai_types.Part(
7272
file_data=genai_types.FileData(
73-
file_uri=part.file.uri, mime_type=part.file.mime_type
73+
file_uri=part.file.uri,
74+
mime_type=part.file.mime_type,
75+
display_name=part.file.name,
7476
)
7577
)
7678

@@ -79,6 +81,7 @@ def convert_a2a_part_to_genai_part(
7981
inline_data=genai_types.Blob(
8082
data=base64.b64decode(part.file.bytes),
8183
mime_type=part.file.mime_type,
84+
display_name=part.file.name,
8285
)
8386
)
8487
else:
@@ -104,10 +107,25 @@ def convert_a2a_part_to_genai_part(
104107
part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)]
105108
== A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
106109
):
110+
# Restore thought_signature if present
111+
thought_signature = None
112+
thought_sig_key = _get_adk_metadata_key('thought_signature')
113+
if thought_sig_key in part.metadata:
114+
sig_value = part.metadata[thought_sig_key]
115+
if isinstance(sig_value, bytes):
116+
thought_signature = sig_value
117+
elif isinstance(sig_value, str):
118+
try:
119+
thought_signature = base64.b64decode(sig_value)
120+
except Exception:
121+
logger.warning(
122+
'Failed to decode thought_signature: %s', sig_value
123+
)
107124
return genai_types.Part(
108125
function_call=genai_types.FunctionCall.model_validate(
109126
part.data, by_alias=True
110-
)
127+
),
128+
thought_signature=thought_signature,
111129
)
112130
if (
113131
part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)]
@@ -173,6 +191,7 @@ def convert_genai_part_to_a2a_part(
173191
file=a2a_types.FileWithUri(
174192
uri=part.file_data.file_uri,
175193
mime_type=part.file_data.mime_type,
194+
name=part.file_data.display_name,
176195
)
177196
)
178197
)
@@ -196,6 +215,7 @@ def convert_genai_part_to_a2a_part(
196215
file=a2a_types.FileWithBytes(
197216
bytes=base64.b64encode(part.inline_data.data).decode('utf-8'),
198217
mime_type=part.inline_data.mime_type,
218+
name=part.inline_data.display_name,
199219
)
200220
)
201221

@@ -214,16 +234,22 @@ def convert_genai_part_to_a2a_part(
214234
# TODO once A2A defined how to service such information, migrate below
215235
# logic accordingly
216236
if part.function_call:
237+
fc_metadata = {
238+
_get_adk_metadata_key(
239+
A2A_DATA_PART_METADATA_TYPE_KEY
240+
): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
241+
}
242+
# Preserve thought_signature if present
243+
if part.thought_signature is not None:
244+
fc_metadata[_get_adk_metadata_key('thought_signature')] = (
245+
base64.b64encode(part.thought_signature).decode('utf-8')
246+
)
217247
return a2a_types.Part(
218248
root=a2a_types.DataPart(
219249
data=part.function_call.model_dump(
220250
by_alias=True, exclude_none=True
221251
),
222-
metadata={
223-
_get_adk_metadata_key(
224-
A2A_DATA_PART_METADATA_TYPE_KEY
225-
): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
226-
},
252+
metadata=fc_metadata,
227253
)
228254
)
229255

src/google/adk/artifacts/base_artifact_service.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616
from abc import ABC
1717
from abc import abstractmethod
1818
from datetime import datetime
19+
import logging
1920
from typing import Any
2021
from typing import Optional
22+
from typing import Union
2123

2224
from google.genai import types
2325
from pydantic import alias_generators
2426
from pydantic import BaseModel
2527
from pydantic import ConfigDict
2628
from pydantic import Field
2729

30+
logger = logging.getLogger("google_adk." + __name__)
31+
2832

2933
class ArtifactVersion(BaseModel):
3034
"""Metadata describing a specific version of an artifact."""
@@ -60,6 +64,26 @@ class ArtifactVersion(BaseModel):
6064
)
6165

6266

67+
def ensure_part(artifact: Union[types.Part, dict[str, Any]]) -> types.Part:
68+
"""Normalizes an artifact to a ``types.Part`` instance.
69+
70+
External callers may provide artifacts as
71+
plain dictionaries with camelCase keys (``inlineData``) instead of properly
72+
deserialized ``types.Part`` objects. ``model_validate`` handles both
73+
camelCase and snake_case dictionaries transparently via Pydantic aliases.
74+
75+
Args:
76+
artifact: A ``types.Part`` instance or a dictionary representation.
77+
78+
Returns:
79+
A validated ``types.Part`` instance.
80+
"""
81+
if isinstance(artifact, dict):
82+
logger.debug("Normalizing artifact dict to types.Part: %s", list(artifact))
83+
return types.Part.model_validate(artifact)
84+
return artifact
85+
86+
6387
class BaseArtifactService(ABC):
6488
"""Abstract base class for artifact services."""
6589

@@ -70,7 +94,7 @@ async def save_artifact(
7094
app_name: str,
7195
user_id: str,
7296
filename: str,
73-
artifact: types.Part,
97+
artifact: Union[types.Part, dict[str, Any]],
7498
session_id: Optional[str] = None,
7599
custom_metadata: Optional[dict[str, Any]] = None,
76100
) -> int:
@@ -84,10 +108,12 @@ async def save_artifact(
84108
app_name: The app name.
85109
user_id: The user ID.
86110
filename: The filename of the artifact.
87-
artifact: The artifact to save. If the artifact consists of `file_data`,
88-
the artifact service assumes its content has been uploaded separately,
89-
and this method will associate the `file_data` with the artifact if
90-
necessary.
111+
artifact: The artifact to save. Accepts a ``types.Part`` instance or a
112+
plain dictionary (camelCase or snake_case keys) which will be
113+
normalized via ``ensure_part``. If the artifact consists of
114+
``file_data``, the artifact service assumes its content has been
115+
uploaded separately, and this method will associate the ``file_data``
116+
with the artifact if necessary.
91117
session_id: The session ID. If `None`, the artifact is user-scoped.
92118
custom_metadata: custom metadata to associate with the artifact.
93119

src/google/adk/artifacts/file_artifact_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import shutil
2323
from typing import Any
2424
from typing import Optional
25+
from typing import Union
2526
from urllib.parse import unquote
2627
from urllib.parse import urlparse
2728

@@ -35,6 +36,7 @@
3536
from ..errors.input_validation_error import InputValidationError
3637
from .base_artifact_service import ArtifactVersion
3738
from .base_artifact_service import BaseArtifactService
39+
from .base_artifact_service import ensure_part
3840

3941
logger = logging.getLogger("google_adk." + __name__)
4042

@@ -314,7 +316,7 @@ async def save_artifact(
314316
app_name: str,
315317
user_id: str,
316318
filename: str,
317-
artifact: types.Part,
319+
artifact: Union[types.Part, dict[str, Any]],
318320
session_id: Optional[str] = None,
319321
custom_metadata: Optional[dict[str, Any]] = None,
320322
) -> int:
@@ -339,11 +341,12 @@ def _save_artifact_sync(
339341
self,
340342
user_id: str,
341343
filename: str,
342-
artifact: types.Part,
344+
artifact: Union[types.Part, dict[str, Any]],
343345
session_id: Optional[str],
344346
custom_metadata: Optional[dict[str, Any]],
345347
) -> int:
346348
"""Saves an artifact to disk and returns its version."""
349+
artifact = ensure_part(artifact)
347350
artifact_dir = self._artifact_dir(
348351
user_id=user_id,
349352
session_id=session_id,

src/google/adk/artifacts/gcs_artifact_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@
2727
import logging
2828
from typing import Any
2929
from typing import Optional
30+
from typing import Union
3031

3132
from google.genai import types
3233
from typing_extensions import override
3334

3435
from ..errors.input_validation_error import InputValidationError
3536
from .base_artifact_service import ArtifactVersion
3637
from .base_artifact_service import BaseArtifactService
38+
from .base_artifact_service import ensure_part
3739

3840
logger = logging.getLogger("google_adk." + __name__)
3941

@@ -61,7 +63,7 @@ async def save_artifact(
6163
app_name: str,
6264
user_id: str,
6365
filename: str,
64-
artifact: types.Part,
66+
artifact: Union[types.Part, dict[str, Any]],
6567
session_id: Optional[str] = None,
6668
custom_metadata: Optional[dict[str, Any]] = None,
6769
) -> int:
@@ -198,9 +200,10 @@ def _save_artifact(
198200
user_id: str,
199201
session_id: Optional[str],
200202
filename: str,
201-
artifact: types.Part,
203+
artifact: Union[types.Part, dict[str, Any]],
202204
custom_metadata: Optional[dict[str, Any]] = None,
203205
) -> int:
206+
artifact = ensure_part(artifact)
204207
versions = self._list_versions(
205208
app_name=app_name,
206209
user_id=user_id,

src/google/adk/artifacts/in_memory_artifact_service.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818
from typing import Any
1919
from typing import Optional
20+
from typing import Union
2021

2122
from google.genai import types
2223
from pydantic import BaseModel
@@ -27,6 +28,7 @@
2728
from ..errors.input_validation_error import InputValidationError
2829
from .base_artifact_service import ArtifactVersion
2930
from .base_artifact_service import BaseArtifactService
31+
from .base_artifact_service import ensure_part
3032

3133
logger = logging.getLogger("google_adk." + __name__)
3234

@@ -99,10 +101,11 @@ async def save_artifact(
99101
app_name: str,
100102
user_id: str,
101103
filename: str,
102-
artifact: types.Part,
104+
artifact: Union[types.Part, dict[str, Any]],
103105
session_id: Optional[str] = None,
104106
custom_metadata: Optional[dict[str, Any]] = None,
105107
) -> int:
108+
artifact = ensure_part(artifact)
106109
path = self._artifact_path(app_name, user_id, filename, session_id)
107110
if path not in self.artifacts:
108111
self.artifacts[path] = []

0 commit comments

Comments
 (0)