Skip to content

Commit c582618

Browse files
committed
Add Parallel Tool mode for Vertex AI
1 parent 58eef74 commit c582618

File tree

4 files changed

+101
-15
lines changed

4 files changed

+101
-15
lines changed

instructor/client_vertexai.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import Any, Type, Union, get_origin
44

55
from vertexai.preview.generative_models import ToolConfig # type: ignore
66
import vertexai.generative_models as gm # type: ignore
77
from pydantic import BaseModel
88
import instructor
9+
from instructor.dsl.parallel import get_types_array
910
import jsonref
1011

1112

1213
def _create_gemini_json_schema(model: BaseModel):
14+
# Add type check to ensure we have a concrete model class
15+
if get_origin(model) is not None:
16+
raise TypeError(f"Expected concrete model class, got type hint {model}")
17+
1318
schema = model.model_json_schema()
1419
schema_without_refs: dict[str, Any] = jsonref.replace_refs(schema) # type: ignore
1520
gemini_schema: dict[Any, Any] = {
@@ -22,16 +27,28 @@ def _create_gemini_json_schema(model: BaseModel):
2227
return gemini_schema
2328

2429

25-
def _create_vertexai_tool(model: BaseModel) -> gm.Tool:
26-
parameters = _create_gemini_json_schema(model)
27-
28-
declaration = gm.FunctionDeclaration(
29-
name=model.__name__, description=model.__doc__, parameters=parameters
30-
)
31-
32-
tool = gm.Tool(function_declarations=[declaration])
30+
def _create_vertexai_tool(models: Union[BaseModel, list[BaseModel], Type]) -> gm.Tool:
31+
"""Creates a tool with function declarations for single model or list of models"""
32+
# Handle Iterable case first
33+
if get_origin(models) is not None:
34+
model_list = list(get_types_array(models))
35+
else:
36+
# Handle both single model and list of models
37+
model_list = models if isinstance(models, list) else [models]
38+
39+
print(f"Debug - Model list: {[model.__name__ for model in model_list]}")
40+
41+
declarations = []
42+
for model in model_list:
43+
parameters = _create_gemini_json_schema(model)
44+
declaration = gm.FunctionDeclaration(
45+
name=model.__name__,
46+
description=model.__doc__,
47+
parameters=parameters
48+
)
49+
declarations.append(declaration)
3350

34-
return tool
51+
return gm.Tool(function_declarations=declarations)
3552

3653

3754
def vertexai_message_parser(
@@ -84,11 +101,11 @@ def vertexai_function_response_parser(
84101
)
85102

86103

87-
def vertexai_process_response(_kwargs: dict[str, Any], model: BaseModel):
104+
def vertexai_process_response(_kwargs: dict[str, Any], model: Union[BaseModel, list[BaseModel], Type]):
88105
messages: list[dict[str, str]] = _kwargs.pop("messages")
89106
contents = _vertexai_message_list_parser(messages) # type: ignore
90107

91-
tool = _create_vertexai_tool(model=model)
108+
tool = _create_vertexai_tool(models=model)
92109

93110
tool_config = ToolConfig(
94111
function_calling_config=ToolConfig.FunctionCallingConfig(
@@ -122,6 +139,7 @@ def from_vertexai(
122139
**kwargs: Any,
123140
) -> instructor.Instructor:
124141
assert mode in {
142+
instructor.Mode.VERTEXAI_PARALLEL_TOOLS,
125143
instructor.Mode.VERTEXAI_TOOLS,
126144
instructor.Mode.VERTEXAI_JSON,
127145
}, "Mode must be instructor.Mode.VERTEXAI_TOOLS"

instructor/dsl/parallel.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
import json
23
from typing import (
34
Any,
45
Optional,
@@ -45,6 +46,38 @@ def from_response(
4546
)
4647

4748

49+
class VertexAIParallelBase(ParallelBase):
50+
def from_response(
51+
self,
52+
response: Any,
53+
mode: Mode,
54+
validation_context: Optional[Any] = None,
55+
strict: Optional[bool] = None,
56+
) -> Generator[BaseModel, None, None]:
57+
assert mode == Mode.VERTEXAI_PARALLEL_TOOLS, "Mode must be VERTEXAI_PARALLEL_TOOLS"
58+
59+
if not response or not response.candidates:
60+
return
61+
62+
for candidate in response.candidates:
63+
if not candidate.content or not candidate.content.parts:
64+
continue
65+
66+
for part in candidate.content.parts:
67+
if (hasattr(part, 'function_call') and
68+
part.function_call is not None):
69+
70+
name = part.function_call.name
71+
arguments = part.function_call.args
72+
73+
if name in self.registry:
74+
# Convert dict to JSON string before validation
75+
json_str = json.dumps(arguments)
76+
yield self.registry[name].model_validate_json(
77+
json_str, context=validation_context, strict=strict
78+
)
79+
80+
4881
if sys.version_info >= (3, 10):
4982
from types import UnionType
5083

@@ -82,3 +115,7 @@ def handle_parallel_model(typehint: type[Iterable[T]]) -> list[dict[str, Any]]:
82115
def ParallelModel(typehint: type[Iterable[T]]) -> ParallelBase:
83116
the_types = get_types_array(typehint)
84117
return ParallelBase(*[model for model in the_types])
118+
119+
def VertexAIParallelModel(typehint: type[Iterable[T]]) -> VertexAIParallelBase:
120+
the_types = get_types_array(typehint)
121+
return VertexAIParallelBase(*[model for model in the_types])

instructor/mode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class Mode(enum.Enum):
1818
COHERE_TOOLS = "cohere_tools"
1919
VERTEXAI_TOOLS = "vertexai_tools"
2020
VERTEXAI_JSON = "vertexai_json"
21+
VERTEXAI_PARALLEL_TOOLS = "vertexai_parallel_tools"
2122
GEMINI_JSON = "gemini_json"
2223
GEMINI_TOOLS = "gemini_tools"
2324
COHERE_JSON_SCHEMA = "json_object"

instructor/process_response.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@
1616

1717
from instructor.mode import Mode
1818
from instructor.dsl.iterable import IterableBase, IterableModel
19-
from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model
19+
from instructor.dsl.parallel import (
20+
ParallelBase,
21+
ParallelModel,
22+
handle_parallel_model,
23+
get_types_array,
24+
VertexAIParallelBase,
25+
VertexAIParallelModel
26+
)
2027
from instructor.dsl.partial import PartialBase
2128
from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type
2229
from instructor.function_calls import OpenAISchema, openai_schema
@@ -112,7 +119,7 @@ def process_response(
112119
validation_context: dict[str, Any] | None = None,
113120
strict=None,
114121
mode: Mode = Mode.TOOLS,
115-
):
122+
) -> T_Model | list[T_Model] | VertexAIParallelBase | None:
116123
"""
117124
Process the response from the API call and convert it to the specified response model.
118125
@@ -485,6 +492,27 @@ def handle_gemini_tools(
485492
return response_model, new_kwargs
486493

487494

495+
def handle_vertexai_parallel_tools(
496+
response_model: type[Iterable[T]], new_kwargs: dict[str, Any]
497+
) -> tuple[VertexAIParallelBase, dict[str, Any]]:
498+
assert (
499+
new_kwargs.get("stream", False) is False
500+
), "stream=True is not supported when using PARALLEL_TOOLS mode"
501+
502+
from instructor.client_vertexai import vertexai_process_response
503+
from instructor.dsl.parallel import VertexAIParallelModel
504+
505+
# Extract concrete types before passing to vertexai_process_response
506+
model_types = list(get_types_array(response_model))
507+
contents, tools, tool_config = vertexai_process_response(new_kwargs, model_types)
508+
509+
new_kwargs["contents"] = contents
510+
new_kwargs["tools"] = tools
511+
new_kwargs["tool_config"] = tool_config
512+
513+
return VertexAIParallelModel(typehint=response_model), new_kwargs
514+
515+
488516
def handle_vertexai_tools(
489517
response_model: type[T], new_kwargs: dict[str, Any]
490518
) -> tuple[type[T], dict[str, Any]]:
@@ -646,7 +674,7 @@ def prepare_response_model(response_model: type[T] | None) -> type[T] | None:
646674

647675
def handle_response_model(
648676
response_model: type[T] | None, mode: Mode = Mode.TOOLS, **kwargs: Any
649-
) -> tuple[type[T] | None, dict[str, Any]]:
677+
) -> tuple[type[T] | VertexAIParallelBase | None, dict[str, Any]]:
650678
"""
651679
Handles the response model based on the specified mode and prepares the kwargs for the API call.
652680
@@ -690,6 +718,8 @@ def handle_response_model(
690718

691719
if mode in {Mode.PARALLEL_TOOLS}:
692720
return handle_parallel_tools(response_model, new_kwargs)
721+
elif mode in {Mode.VERTEXAI_PARALLEL_TOOLS}:
722+
return handle_vertexai_parallel_tools(response_model, new_kwargs)
693723

694724
response_model = prepare_response_model(response_model)
695725

0 commit comments

Comments
 (0)