Skip to content

Commit 7274657

Browse files
Add file search stores, register_files, chat example, and full integration test
- Add TemporalAsyncFileSearchStores: subclasses AsyncFileSearchStores, overriding upload_to_file_search_store() to dispatch through a Temporal activity. All other methods (create, get, delete, list, import_file, documents) work via async_request. - Add TemporalAsyncFiles.register_files(): dispatches through gemini_files_register activity. The auth param from the caller is ignored; the activity uses extra_credentials from GeminiPlugin init or falls back to the client's own credentials. - Add extra_credentials param to GeminiPlugin for operations that need explicit auth (e.g. GCS file registration). - Wire TemporalAsyncFileSearchStores into TemporalAsyncClient. - Add 12 new tests (31 total): file upload (str + bytes), download, file search store upload, multi-turn chat, TemporalAsyncClient wiring checks, low-level raise checks, and a full integration test that runs real activities with a mocked client covering generate_content, streaming, file upload/download, store upload, RAG query, and store deletion.
1 parent 0fb8300 commit 7274657

7 files changed

Lines changed: 724 additions & 20 deletions

File tree

temporalio/contrib/google_gemini_sdk/_gemini_activity.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from __future__ import annotations
1010

1111
from collections.abc import Sequence
12-
from typing import Any, Callable
12+
from typing import Any, Callable, Optional
1313

14+
import google.auth.credentials
1415
from google.genai import Client as GeminiClient, types
1516
from google.genai.types import HttpOptions
1617
from google.genai.types import HttpResponse as SdkHttpResponse
@@ -21,7 +22,9 @@
2122
_GeminiApiResponse,
2223
_GeminiApiStreamedResponse,
2324
_GeminiDownloadFileRequest,
25+
_GeminiRegisterFilesRequest,
2426
_GeminiUploadFileRequest,
27+
_GeminiUploadToFileSearchStoreRequest,
2528
)
2629

2730

@@ -35,15 +38,20 @@ def _resolve_http_options(
3538

3639

3740
class GeminiApiCaller:
38-
"""Wraps a ``genai.Client`` and exposes a Temporal activity for API calls.
41+
"""Wraps a ``genai.Client`` and exposes Temporal activities for SDK calls.
3942
4043
The caller owns a reference to the user-provided ``genai.Client``.
4144
All credential management, HTTP client configuration, etc. is the
4245
responsibility of whoever constructs the client.
4346
"""
4447

45-
def __init__(self, client: GeminiClient) -> None:
48+
def __init__(
49+
self,
50+
client: GeminiClient,
51+
credentials: Optional[google.auth.credentials.Credentials] = None,
52+
) -> None:
4653
self._client = client
54+
self._credentials = credentials
4755

4856
def activities(self) -> Sequence[Callable]:
4957
"""Return activities that route SDK calls through this client."""
@@ -112,9 +120,46 @@ async def gemini_files_download(
112120
file=req.file, config=req.config
113121
)
114122

123+
@activity.defn(name="gemini_files_register")
124+
async def gemini_files_register(
125+
req: _GeminiRegisterFilesRequest,
126+
) -> types.RegisterFilesResponse:
127+
"""Register GCS files using the real genai.Client on the worker.
128+
129+
Uses ``credentials`` if provided at plugin init,
130+
otherwise falls back to the client's own credentials.
131+
Token refresh happens here on the worker side, so no auth
132+
material enters the workflow event history.
133+
"""
134+
return await self._client.aio.files.register_files(
135+
auth=self._credentials or self._client._api_client._credentials,
136+
uris=req.uris,
137+
config=req.config,
138+
)
139+
140+
@activity.defn(name="gemini_file_search_stores_upload")
141+
async def gemini_file_search_stores_upload(
142+
req: _GeminiUploadToFileSearchStoreRequest,
143+
) -> types.UploadToFileSearchStoreOperation:
144+
"""Upload a file to a file search store on the worker."""
145+
if req.file_bytes is not None:
146+
import io
147+
148+
file_arg: Any = io.BytesIO(req.file_bytes)
149+
else:
150+
file_arg = req.file_path
151+
152+
return await self._client.aio.file_search_stores.upload_to_file_search_store(
153+
file_search_store_name=req.file_search_store_name,
154+
file=file_arg,
155+
config=req.config,
156+
)
157+
115158
return [
116159
gemini_api_client_async_request,
117160
gemini_api_client_async_request_streamed,
118161
gemini_files_upload,
119162
gemini_files_download,
163+
gemini_files_register,
164+
gemini_file_search_stores_upload,
120165
]

temporalio/contrib/google_gemini_sdk/_gemini_plugin.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from __future__ import annotations
44

55
import dataclasses
6+
from typing import Optional
67

8+
import google.auth.credentials
79
from google.genai import Client as GeminiClient
810

911
from temporalio.contrib.google_gemini_sdk._gemini_activity import GeminiApiCaller
@@ -51,17 +53,33 @@ class GeminiPlugin(SimplePlugin):
5153
vertexai=True, project="my-project", location="us-central1",
5254
)
5355
plugin = GeminiPlugin(client)
56+
57+
Example (with separate GCS credentials for file registration)::
58+
59+
client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"])
60+
gcs_creds, _ = google.auth.default()
61+
plugin = GeminiPlugin(client, extra_credentials=gcs_creds)
5462
"""
5563

56-
def __init__(self, client: GeminiClient) -> None:
64+
def __init__(
65+
self,
66+
client: GeminiClient,
67+
extra_credentials: Optional[google.auth.credentials.Credentials] = None,
68+
) -> None:
5769
"""Initialize the Gemini plugin.
5870
5971
Args:
6072
client: A fully configured ``genai.Client`` instance.
6173
All credential management, HTTP client configuration, etc.
6274
is the responsibility of the caller.
75+
extra_credentials: Optional Google Cloud credentials used for
76+
operations that require explicit auth (e.g.
77+
``files.register_files()``). If not provided, the
78+
client's own credentials are used.
6379
"""
64-
self._api_caller = GeminiApiCaller(client)
80+
self._api_caller = GeminiApiCaller(
81+
client, credentials=extra_credentials
82+
)
6583

6684
def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner:
6785
if not runner:

temporalio/contrib/google_gemini_sdk/_temporal_api_client.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,22 @@ class _GeminiDownloadFileRequest(BaseModel):
122122
config: types.DownloadFileConfig | None = None
123123

124124

125+
class _GeminiRegisterFilesRequest(BaseModel):
126+
"""Serializable activity input for registering GCS files."""
127+
128+
uris: list[str]
129+
config: types.RegisterFilesConfig | None = None
130+
131+
132+
class _GeminiUploadToFileSearchStoreRequest(BaseModel):
133+
"""Serializable activity input for uploading a file to a file search store."""
134+
135+
file_search_store_name: str
136+
file_bytes: bytes | None = None
137+
file_path: str | None = None
138+
config: types.UploadToFileSearchStoreConfig | None = None
139+
140+
125141
class TemporalApiClient(BaseApiClient):
126142
"""A ``BaseApiClient`` that routes all API calls through Temporal activities.
127143

temporalio/contrib/google_gemini_sdk/_temporal_async_client.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Temporal-aware AsyncClient shim.
22
33
``TemporalAsyncClient`` is an ``AsyncClient`` subclass that wires up
4-
``TemporalAsyncFiles`` in place of the default ``AsyncFiles``.
4+
Temporal-aware replacements for modules that need special handling
5+
(files, file search stores).
56
"""
67

78
from __future__ import annotations
@@ -13,6 +14,9 @@
1314
from temporalio.contrib.google_gemini_sdk._temporal_api_client import (
1415
TemporalApiClient,
1516
)
17+
from temporalio.contrib.google_gemini_sdk._temporal_file_search_stores import (
18+
TemporalAsyncFileSearchStores,
19+
)
1620
from temporalio.contrib.google_gemini_sdk._temporal_files import (
1721
TemporalAsyncFiles,
1822
)
@@ -21,10 +25,14 @@
2125
class TemporalAsyncClient(AsyncClient):
2226
"""``AsyncClient`` subclass that uses Temporal-aware modules.
2327
24-
Replaces ``AsyncFiles`` with ``TemporalAsyncFiles`` so that file
25-
upload/download operations run entirely inside Temporal activities.
26-
Other modules (models, caches, etc.) are inherited unchanged and
27-
work through ``TemporalApiClient``'s activity-backed HTTP methods.
28+
Replaces ``AsyncFiles`` with ``TemporalAsyncFiles`` and
29+
``AsyncFileSearchStores`` with ``TemporalAsyncFileSearchStores``
30+
so that file upload/download operations and file search store uploads
31+
run entirely inside Temporal activities.
32+
33+
Other modules (models, tunings, caches, batches, live, tokens,
34+
operations) are inherited unchanged and work through
35+
``TemporalApiClient``'s activity-backed HTTP methods.
2836
"""
2937

3038
def __init__(
@@ -34,3 +42,6 @@ def __init__(
3442
) -> None:
3543
super().__init__(api_client)
3644
self._files = TemporalAsyncFiles(api_client, activity_config)
45+
self._file_search_stores = TemporalAsyncFileSearchStores(
46+
api_client, activity_config
47+
)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""Temporal-aware AsyncFileSearchStores shim.
2+
3+
``TemporalAsyncFileSearchStores`` is an ``AsyncFileSearchStores`` subclass
4+
whose ``upload_to_file_search_store`` method dispatches through a Temporal
5+
activity so the entire upload (including filesystem access and resumable
6+
upload negotiation) runs on the activity worker.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
import io
12+
from datetime import timedelta
13+
from typing import Optional, Union
14+
15+
from google.genai import types
16+
from google.genai.file_search_stores import AsyncFileSearchStores
17+
from google.genai.types import HttpOptions
18+
19+
from temporalio import workflow as temporal_workflow
20+
from temporalio.workflow import ActivityConfig
21+
22+
from temporalio.contrib.google_gemini_sdk._temporal_api_client import (
23+
TemporalApiClient,
24+
_GeminiUploadToFileSearchStoreRequest,
25+
_validate_http_options,
26+
)
27+
28+
29+
class TemporalAsyncFileSearchStores(AsyncFileSearchStores):
30+
"""``AsyncFileSearchStores`` subclass that routes ``upload_to_file_search_store`` through an activity.
31+
32+
The entire upload operation — including filesystem access, resumable
33+
upload negotiation, and chunked transfer — runs inside a Temporal
34+
activity on the worker. All other methods (``create``, ``get``,
35+
``delete``, ``list``, ``import_file``, ``documents``) are inherited
36+
and already work through the ``TemporalApiClient``'s ``async_request``
37+
activity.
38+
"""
39+
40+
def __init__(
41+
self,
42+
api_client: TemporalApiClient,
43+
activity_config: ActivityConfig | None = None,
44+
) -> None:
45+
super().__init__(api_client)
46+
self._activity_config = activity_config or ActivityConfig(
47+
start_to_close_timeout=timedelta(seconds=60),
48+
)
49+
50+
async def upload_to_file_search_store(
51+
self,
52+
*,
53+
file_search_store_name: str,
54+
file: Union[str, "os.PathLike[str]", io.IOBase],
55+
config: Optional[types.UploadToFileSearchStoreConfigOrDict] = None,
56+
) -> types.UploadToFileSearchStoreOperation:
57+
"""Upload a file to a file search store via a Temporal activity.
58+
59+
Accepts a file path (resolved on the worker), ``os.PathLike``, or
60+
an ``io.IOBase`` (bytes sent across the activity boundary).
61+
"""
62+
act_config: ActivityConfig = {**self._activity_config}
63+
if "summary" not in act_config:
64+
act_config["summary"] = "file_search_stores.upload"
65+
66+
upload_config = None
67+
if config is not None:
68+
if isinstance(config, dict):
69+
upload_config = types.UploadToFileSearchStoreConfig.model_validate(
70+
config
71+
)
72+
else:
73+
upload_config = config
74+
_validate_http_options(upload_config.http_options)
75+
76+
if isinstance(file, io.IOBase):
77+
req = _GeminiUploadToFileSearchStoreRequest(
78+
file_search_store_name=file_search_store_name,
79+
file_bytes=file.read(),
80+
config=upload_config,
81+
)
82+
elif isinstance(file, str):
83+
req = _GeminiUploadToFileSearchStoreRequest(
84+
file_search_store_name=file_search_store_name,
85+
file_path=file,
86+
config=upload_config,
87+
)
88+
else:
89+
req = _GeminiUploadToFileSearchStoreRequest(
90+
file_search_store_name=file_search_store_name,
91+
file_path=file.__fspath__(),
92+
config=upload_config,
93+
)
94+
95+
return await temporal_workflow.execute_activity(
96+
"gemini_file_search_stores_upload",
97+
req,
98+
result_type=types.UploadToFileSearchStoreOperation,
99+
**act_config,
100+
)

temporalio/contrib/google_gemini_sdk/_temporal_files.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@
1919
from temporalio import workflow as temporal_workflow
2020
from temporalio.workflow import ActivityConfig
2121

22+
import google.auth.credentials
23+
2224
from temporalio.contrib.google_gemini_sdk._temporal_api_client import (
2325
TemporalApiClient,
2426
_GeminiDownloadFileRequest,
27+
_GeminiRegisterFilesRequest,
2528
_GeminiUploadFileRequest,
2629
_validate_http_options,
2730
)
@@ -123,3 +126,38 @@ async def download(
123126
result_type=bytes,
124127
**act_config,
125128
)
129+
130+
async def register_files(
131+
self,
132+
*,
133+
auth: google.auth.credentials.Credentials,
134+
uris: list[str],
135+
config: Optional[types.RegisterFilesConfigOrDict] = None,
136+
) -> types.RegisterFilesResponse:
137+
"""Register GCS files via a Temporal activity.
138+
139+
.. note::
140+
The ``auth`` parameter is **ignored**. The activity uses
141+
``credentials`` if provided to ``GeminiPlugin``,
142+
otherwise falls back to the ``genai.Client``'s own credentials.
143+
Either way, those credentials must have access to the GCS URIs
144+
being registered.
145+
"""
146+
act_config: ActivityConfig = {**self._activity_config}
147+
if "summary" not in act_config:
148+
act_config["summary"] = "files.register_files"
149+
150+
register_config = None
151+
if config is not None:
152+
if isinstance(config, dict):
153+
register_config = types.RegisterFilesConfig.model_validate(config)
154+
else:
155+
register_config = config
156+
_validate_http_options(register_config.http_options)
157+
158+
return await temporal_workflow.execute_activity(
159+
"gemini_files_register",
160+
_GeminiRegisterFilesRequest(uris=uris, config=register_config),
161+
result_type=types.RegisterFilesResponse,
162+
**act_config,
163+
)

0 commit comments

Comments
 (0)