Skip to content

Commit 11403ea

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add Tool compatibility with RagRetrievalConfig in both Vertex AI SDK and GenAI SDK for use with generate_content.
PiperOrigin-RevId: 890232571
1 parent 8a0483a commit 11403ea

File tree

3 files changed

+231
-1
lines changed

3 files changed

+231
-1
lines changed

google/genai/models.py

Lines changed: 160 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4137,7 +4137,13 @@ def _Tool_to_vertex(
41374137
) -> dict[str, Any]:
41384138
to_object: dict[str, Any] = {}
41394139
if getv(from_object, ['retrieval']) is not None:
4140-
setv(to_object, ['retrieval'], getv(from_object, ['retrieval']))
4140+
setv(
4141+
to_object,
4142+
['retrieval'],
4143+
_Retrieval_to_vertex(
4144+
getv(from_object, ['retrieval']), to_object, root_object
4145+
),
4146+
)
41414147

41424148
if getv(from_object, ['computer_use']) is not None:
41434149
setv(to_object, ['computerUse'], getv(from_object, ['computer_use']))
@@ -4194,6 +4200,159 @@ def _Tool_to_vertex(
41944200
return to_object
41954201

41964202

4203+
def _Retrieval_to_vertex(
4204+
from_object: Union[dict[str, Any], object],
4205+
parent_object: Optional[dict[str, Any]] = None,
4206+
root_object: Optional[Union[dict[str, Any], object]] = None,
4207+
) -> dict[str, Any]:
4208+
to_object: dict[str, Any] = {}
4209+
if getv(from_object, ['disable_attribution']) is not None:
4210+
setv(
4211+
to_object, ['disableAttribution'], getv(from_object, ['disable_attribution'])
4212+
)
4213+
4214+
if getv(from_object, ['vertex_ai_search']) is not None:
4215+
setv(
4216+
to_object,
4217+
['vertexAiSearch'],
4218+
_VertexAISearch_to_vertex(
4219+
getv(from_object, ['vertex_ai_search']), to_object, root_object
4220+
),
4221+
)
4222+
4223+
if getv(from_object, ['vertex_rag_store']) is not None:
4224+
setv(
4225+
to_object,
4226+
['vertexRagStore'],
4227+
_VertexRagStore_to_vertex(
4228+
getv(from_object, ['vertex_rag_store']), to_object, root_object
4229+
),
4230+
)
4231+
4232+
return to_object
4233+
4234+
4235+
def _VertexAISearch_to_vertex(
4236+
from_object: Union[dict[str, Any], object],
4237+
parent_object: Optional[dict[str, Any]] = None,
4238+
root_object: Optional[Union[dict[str, Any], object]] = None,
4239+
) -> dict[str, Any]:
4240+
to_object: dict[str, Any] = {}
4241+
if getv(from_object, ['data_store_specs']) is not None:
4242+
setv(to_object, ['dataStoreSpecs'], getv(from_object, ['data_store_specs']))
4243+
if getv(from_object, ['datastore']) is not None:
4244+
setv(to_object, ['datastore'], getv(from_object, ['datastore']))
4245+
if getv(from_object, ['engine']) is not None:
4246+
setv(to_object, ['engine'], getv(from_object, ['engine']))
4247+
if getv(from_object, ['filter']) is not None:
4248+
setv(to_object, ['filter'], getv(from_object, ['filter']))
4249+
if getv(from_object, ['max_results']) is not None:
4250+
setv(to_object, ['maxResults'], getv(from_object, ['max_results']))
4251+
return to_object
4252+
4253+
4254+
def _VertexRagStore_to_vertex(
4255+
from_object: Union[dict[str, Any], object],
4256+
parent_object: Optional[dict[str, Any]] = None,
4257+
root_object: Optional[Union[dict[str, Any], object]] = None,
4258+
) -> dict[str, Any]:
4259+
to_object: dict[str, Any] = {}
4260+
if getv(from_object, ['rag_resources']) is not None:
4261+
setv(
4262+
to_object,
4263+
['ragResources'],
4264+
[
4265+
_VertexRagStoreRagResource_to_vertex(item, to_object, root_object)
4266+
for item in getv(from_object, ['rag_resources'])
4267+
],
4268+
)
4269+
4270+
if getv(from_object, ['rag_retrieval_config']) is not None:
4271+
setv(
4272+
to_object,
4273+
['ragRetrievalConfig'],
4274+
_RagRetrievalConfig_to_vertex(
4275+
getv(from_object, ['rag_retrieval_config']), to_object, root_object
4276+
),
4277+
)
4278+
4279+
if getv(from_object, ['similarity_top_k']) is not None:
4280+
setv(to_object, ['similarityTopK'], getv(from_object, ['similarity_top_k']))
4281+
4282+
if getv(from_object, ['vector_distance_threshold']) is not None:
4283+
setv(
4284+
to_object,
4285+
['vectorDistanceThreshold'],
4286+
getv(from_object, ['vector_distance_threshold']),
4287+
)
4288+
4289+
return to_object
4290+
4291+
4292+
def _VertexRagStoreRagResource_to_vertex(
4293+
from_object: Union[dict[str, Any], object],
4294+
parent_object: Optional[dict[str, Any]] = None,
4295+
root_object: Optional[Union[dict[str, Any], object]] = None,
4296+
) -> dict[str, Any]:
4297+
to_object: dict[str, Any] = {}
4298+
if getv(from_object, ['rag_corpus_name']) is not None:
4299+
setv(to_object, ['ragCorpus'], getv(from_object, ['rag_corpus_name']))
4300+
elif getv(from_object, ['rag_corpus']) is not None:
4301+
setv(to_object, ['ragCorpus'], getv(from_object, ['rag_corpus']))
4302+
4303+
if getv(from_object, ['rag_file_ids']) is not None:
4304+
setv(to_object, ['ragFileIds'], getv(from_object, ['rag_file_ids']))
4305+
4306+
return to_object
4307+
4308+
4309+
def _RagRetrievalConfig_to_vertex(
4310+
from_object: Union[dict[str, Any], object],
4311+
parent_object: Optional[dict[str, Any]] = None,
4312+
root_object: Optional[Union[dict[str, Any], object]] = None,
4313+
) -> dict[str, Any]:
4314+
to_object: dict[str, Any] = {}
4315+
if getv(from_object, ['top_k']) is not None:
4316+
setv(to_object, ['topK'], getv(from_object, ['top_k']))
4317+
4318+
if getv(from_object, ['filter']) is not None:
4319+
setv(
4320+
to_object,
4321+
['filter'],
4322+
_RagRetrievalConfigFilter_to_vertex(
4323+
getv(from_object, ['filter']), to_object, root_object
4324+
),
4325+
)
4326+
4327+
return to_object
4328+
4329+
4330+
def _RagRetrievalConfigFilter_to_vertex(
4331+
from_object: Union[dict[str, Any], object],
4332+
parent_object: Optional[dict[str, Any]] = None,
4333+
root_object: Optional[Union[dict[str, Any], object]] = None,
4334+
) -> dict[str, Any]:
4335+
to_object: dict[str, Any] = {}
4336+
if getv(from_object, ['vector_distance_threshold']) is not None:
4337+
setv(
4338+
to_object,
4339+
['vectorDistanceThreshold'],
4340+
getv(from_object, ['vector_distance_threshold']),
4341+
)
4342+
4343+
if getv(from_object, ['vector_similarity_threshold']) is not None:
4344+
setv(
4345+
to_object,
4346+
['vectorSimilarityThreshold'],
4347+
getv(from_object, ['vector_similarity_threshold']),
4348+
)
4349+
4350+
if getv(from_object, ['metadata_filter']) is not None:
4351+
setv(to_object, ['metadataFilter'], getv(from_object, ['metadata_filter']))
4352+
4353+
return to_object
4354+
4355+
41974356
def _TunedModelInfo_from_vertex(
41984357
from_object: Union[dict[str, Any], object],
41994358
parent_object: Optional[dict[str, Any]] = None,

google/genai/tests/models/test_generate_content_tools.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,34 @@ def divide_floats(a: float, b: float) -> float:
290290
),
291291
exception_if_mldev='retrieval',
292292
),
293+
pytest_helper.TestTableItem(
294+
name='test_rag_with_metadata_filter',
295+
parameters=types._GenerateContentParameters(
296+
model='gemini-2.5-flash',
297+
contents=t.t_contents('What is the capital of France?'),
298+
config={
299+
'tools': [
300+
types.Tool(
301+
retrieval=types.Retrieval(
302+
vertex_rag_store=types.VertexRagStore(
303+
rag_resources=[
304+
types.VertexRagStoreRagResource(
305+
rag_corpus='projects/test-project/locations/us-central1/ragCorpora/test-corpus'
306+
)
307+
],
308+
rag_retrieval_config=types.RagRetrievalConfig(
309+
filter=types.RagRetrievalConfigFilter(
310+
metadata_filter='color = "red"',
311+
),
312+
),
313+
)
314+
),
315+
),
316+
]
317+
},
318+
),
319+
exception_if_mldev='retrieval',
320+
),
293321
pytest_helper.TestTableItem(
294322
name='test_file_search',
295323
parameters=types._GenerateContentParameters(
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import sys
2+
import os
3+
4+
# Set up path so we can import and run scripts as if they were in the workspace.
5+
sys.path.append(os.getcwd())
6+
7+
from third_party.py.google.genai import types
8+
from third_party.py.google.genai import models
9+
from third_party.py.google.genai import _common
10+
import json
11+
12+
class MockApiClient:
13+
def __init__(self):
14+
self.vertexai = True
15+
self.location = 'us-central1'
16+
self.project = 'test-project'
17+
18+
def test_transformation():
19+
tool = types.Tool(
20+
retrieval=types.Retrieval(
21+
vertex_rag_store=types.VertexRagStore(
22+
rag_resources=[
23+
types.VertexRagStoreRagResource(
24+
rag_corpus='projects/test-project/locations/us-central1/ragCorpora/test-corpus'
25+
)
26+
],
27+
rag_retrieval_config=types.RagRetrievalConfig(
28+
filter=types.RagRetrievalConfigFilter(
29+
metadata_filter='color = "red"',
30+
),
31+
),
32+
)
33+
),
34+
)
35+
36+
api_client = MockApiClient()
37+
# Mocking internal transform
38+
transformed = models._Tool_to_vertex(api_client, tool)
39+
final_dict = _common.convert_to_dict(transformed)
40+
print(json.dumps(final_dict, indent=2))
41+
42+
if __name__ == '__main__':
43+
test_transformation()

0 commit comments

Comments
 (0)