Skip to content

Commit d07d699

Browse files
committed
Add Structured Output functionality
1 parent af1e5eb commit d07d699

39 files changed

+1127
-170
lines changed

griptape/common/prompt_stack/prompt_stack.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Optional
44

55
from attrs import define, field
66

@@ -24,13 +24,16 @@
2424
from griptape.mixins.serializable_mixin import SerializableMixin
2525

2626
if TYPE_CHECKING:
27+
from schema import Schema
28+
2729
from griptape.tools import BaseTool
2830

2931

3032
@define
3133
class PromptStack(SerializableMixin):
3234
messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True})
3335
tools: list[BaseTool] = field(factory=list, kw_only=True)
36+
output_schema: Optional[Schema] = field(default=None, kw_only=True)
3437

3538
@property
3639
def system_messages(self) -> list[Message]:

griptape/drivers/prompt/amazon_bedrock_prompt_driver.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING, Any, Literal
55

6-
from attrs import Factory, define, field
6+
from attrs import Attribute, Factory, define, field
77
from schema import Schema
88

99
from griptape.artifacts import (
@@ -55,9 +55,20 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
5555
kw_only=True,
5656
)
5757
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
58+
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
59+
native_structured_output_mode: Literal["native", "tool"] = field(
60+
default="tool", kw_only=True, metadata={"serializable": True}
61+
)
5862
tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True})
5963
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
6064

65+
@native_structured_output_mode.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
66+
def validate_native_structured_output_mode(self, attribute: Attribute, value: str) -> str:
67+
if value == "native":
68+
raise ValueError("AmazonBedrockPromptDriver does not support `native` structured output mode.")
69+
70+
return value
71+
6172
@lazy_property()
6273
def client(self) -> Any:
6374
return self.session.client("bedrock-runtime")
@@ -103,10 +114,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
103114

104115
def _base_params(self, prompt_stack: PromptStack) -> dict:
105116
system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages]
106-
107117
messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()])
108118

109-
return {
119+
params = {
110120
"modelId": self.model,
111121
"messages": messages,
112122
"system": system_messages,
@@ -115,14 +125,27 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
115125
**({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}),
116126
},
117127
"additionalModelRequestFields": self.additional_model_request_fields,
118-
**(
119-
{"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}}
120-
if prompt_stack.tools and self.use_native_tools
121-
else {}
122-
),
123128
**self.extra_params,
124129
}
125130

131+
if prompt_stack.tools and self.use_native_tools:
132+
params["toolConfig"] = {
133+
"tools": [],
134+
"toolChoice": self.tool_choice,
135+
}
136+
137+
if (
138+
prompt_stack.output_schema is not None
139+
and self.use_native_structured_output
140+
and self.native_structured_output_mode == "tool"
141+
):
142+
self._add_structured_output_tool(prompt_stack)
143+
params["toolConfig"]["toolChoice"] = {"any": {}}
144+
145+
params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools)
146+
147+
return params
148+
126149
def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]:
127150
return [
128151
{

griptape/drivers/prompt/anthropic_prompt_driver.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import TYPE_CHECKING, Optional
4+
from typing import TYPE_CHECKING, Literal, Optional
55

6-
from attrs import Factory, define, field
6+
from attrs import Attribute, Factory, define, field
77
from schema import Schema
88

99
from griptape.artifacts import (
@@ -68,13 +68,24 @@ class AnthropicPromptDriver(BasePromptDriver):
6868
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
6969
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
7070
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
71+
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
72+
native_structured_output_mode: Literal["native", "tool"] = field(
73+
default="tool", kw_only=True, metadata={"serializable": True}
74+
)
7175
max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})
7276
_client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
7377

7478
@lazy_property()
7579
def client(self) -> Client:
7680
return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)
7781

82+
@native_structured_output_mode.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
83+
def validate_native_structured_output_mode(self, attribute: Attribute, value: str) -> str:
84+
if value == "native":
85+
raise ValueError("AnthropicPromptDriver does not support `native` structured output mode.")
86+
87+
return value
88+
7889
@observable
7990
def try_run(self, prompt_stack: PromptStack) -> Message:
8091
params = self._base_params(prompt_stack)
@@ -110,23 +121,33 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
110121
system_messages = prompt_stack.system_messages
111122
system_message = system_messages[0].to_text() if system_messages else None
112123

113-
return {
124+
params = {
114125
"model": self.model,
115126
"temperature": self.temperature,
116127
"stop_sequences": self.tokenizer.stop_sequences,
117128
"top_p": self.top_p,
118129
"top_k": self.top_k,
119130
"max_tokens": self.max_tokens,
120131
"messages": messages,
121-
**(
122-
{"tools": self.__to_anthropic_tools(prompt_stack.tools), "tool_choice": self.tool_choice}
123-
if prompt_stack.tools and self.use_native_tools
124-
else {}
125-
),
126132
**({"system": system_message} if system_message else {}),
127133
**self.extra_params,
128134
}
129135

136+
if prompt_stack.tools and self.use_native_tools:
137+
params["tool_choice"] = self.tool_choice
138+
139+
if (
140+
prompt_stack.output_schema is not None
141+
and self.use_native_structured_output
142+
and self.native_structured_output_mode == "tool"
143+
):
144+
self._add_structured_output_tool(prompt_stack)
145+
params["tool_choice"] = {"type": "any"}
146+
147+
params["tools"] = self.__to_anthropic_tools(prompt_stack.tools)
148+
149+
return params
150+
130151
def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]:
131152
return [
132153
{"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)}

griptape/drivers/prompt/base_prompt_driver.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from typing import TYPE_CHECKING, Optional
4+
from typing import TYPE_CHECKING, Literal, Optional
55

66
from attrs import Factory, define, field
77

@@ -56,6 +56,10 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
5656
tokenizer: BaseTokenizer
5757
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
5858
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
59+
use_native_structured_output: bool = field(default=False, kw_only=True, metadata={"serializable": True})
60+
native_structured_output_mode: Literal["native", "tool"] = field(
61+
default="native", kw_only=True, metadata={"serializable": True}
62+
)
5963
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
6064

6165
def before_run(self, prompt_stack: PromptStack) -> None:
@@ -122,6 +126,16 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ...
122126
@abstractmethod
123127
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ...
124128

129+
def _add_structured_output_tool(self, prompt_stack: PromptStack) -> None:
130+
from griptape.tools.structured_output.tool import StructuredOutputTool
131+
132+
if prompt_stack.output_schema is None:
133+
raise ValueError("PromptStack must have an output schema to use structured output.")
134+
135+
structured_output_tool = StructuredOutputTool(output_schema=prompt_stack.output_schema)
136+
if structured_output_tool not in prompt_stack.tools:
137+
prompt_stack.tools.append(structured_output_tool)
138+
125139
def __process_run(self, prompt_stack: PromptStack) -> Message:
126140
return self.try_run(prompt_stack)
127141

griptape/drivers/prompt/cohere_prompt_driver.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class CoherePromptDriver(BasePromptDriver):
5353
model: str = field(metadata={"serializable": True})
5454
force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True})
5555
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
56+
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
5657
_client: ClientV2 = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
5758
tokenizer: BaseTokenizer = field(
5859
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
@@ -101,21 +102,31 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
101102

102103
messages = self.__to_cohere_messages(prompt_stack.messages)
103104

104-
return {
105+
params = {
105106
"model": self.model,
106107
"messages": messages,
107108
"temperature": self.temperature,
108109
"stop_sequences": self.tokenizer.stop_sequences,
109110
"max_tokens": self.max_tokens,
110111
**({"tool_results": tool_results} if tool_results else {}),
111-
**(
112-
{"tools": self.__to_cohere_tools(prompt_stack.tools)}
113-
if prompt_stack.tools and self.use_native_tools
114-
else {}
115-
),
116112
**self.extra_params,
117113
}
118114

115+
if prompt_stack.output_schema is not None and self.use_native_structured_output:
116+
if self.native_structured_output_mode == "native":
117+
params["response_format"] = {
118+
"type": "json_object",
119+
"schema": prompt_stack.output_schema.json_schema("Output"),
120+
}
121+
elif self.native_structured_output_mode == "tool":
122+
# TODO: Implement tool choice once supported
123+
self._add_structured_output_tool(prompt_stack)
124+
125+
if prompt_stack.tools and self.use_native_tools:
126+
params["tools"] = self.__to_cohere_tools(prompt_stack.tools)
127+
128+
return params
129+
119130
def __to_cohere_messages(self, messages: list[Message]) -> list[dict]:
120131
cohere_messages = []
121132

griptape/drivers/prompt/google_prompt_driver.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import json
44
import logging
5-
from typing import TYPE_CHECKING, Optional
5+
from typing import TYPE_CHECKING, Literal, Optional
66

7-
from attrs import Factory, define, field
7+
from attrs import Attribute, Factory, define, field
88
from schema import Schema
99

1010
from griptape.artifacts import ActionArtifact, TextArtifact
@@ -63,9 +63,20 @@ class GooglePromptDriver(BasePromptDriver):
6363
top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
6464
top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
6565
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
66+
use_native_structured_output: bool = field(default=True, kw_only=True, metadata={"serializable": True})
67+
native_structured_output_mode: Literal["native", "tool"] = field(
68+
default="tool", kw_only=True, metadata={"serializable": True}
69+
)
6670
tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True})
6771
_client: GenerativeModel = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
6872

73+
@native_structured_output_mode.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
74+
def validate_native_structured_output_mode(self, attribute: Attribute, value: str) -> str:
75+
if value == "native":
76+
raise ValueError("GooglePromptDriver does not support `native` structured output mode.")
77+
78+
return value
79+
6980
@lazy_property()
7081
def client(self) -> GenerativeModel:
7182
genai = import_optional_dependency("google.generativeai")
@@ -135,7 +146,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
135146
parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages],
136147
)
137148

138-
return {
149+
params = {
139150
"generation_config": types.GenerationConfig(
140151
**{
141152
# For some reason, providing stop sequences when streaming breaks native functions
@@ -148,16 +159,23 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
148159
**self.extra_params,
149160
},
150161
),
151-
**(
152-
{
153-
"tools": self.__to_google_tools(prompt_stack.tools),
154-
"tool_config": {"function_calling_config": {"mode": self.tool_choice}},
155-
}
156-
if prompt_stack.tools and self.use_native_tools
157-
else {}
158-
),
159162
}
160163

164+
if prompt_stack.tools and self.use_native_tools:
165+
params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}}
166+
167+
if (
168+
prompt_stack.output_schema is not None
169+
and self.use_native_structured_output
170+
and self.native_structured_output_mode == "tool"
171+
):
172+
params["tool_config"]["function_calling_config"]["mode"] = "auto"
173+
self._add_structured_output_tool(prompt_stack)
174+
175+
params["tools"] = self.__to_google_tools(prompt_stack.tools)
176+
177+
return params
178+
161179
def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType:
162180
types = import_optional_dependency("google.generativeai.types")
163181

0 commit comments

Comments
 (0)