Skip to content

Commit 36d5d2c

Browse files
authored
Add Structured Output (#1443)
1 parent 89e816b commit 36d5d2c

File tree

51 files changed

+1034
-202
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1034
-202
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111

1212
- `BaseVectorStoreDriver.query_vector` for querying vector stores with vectors.
13+
- Structured Output support for all Prompt Drivers.
14+
- `PromptTask.output_schema` for setting an output schema to be used with Structured Output.
15+
- `Agent.output_schema` for setting an output schema to be used on the Agent's Prompt Task.
16+
- `BasePromptDriver.structured_output_strategy` for changing the Structured Output strategy between `native`, `tool`, and `rule`.
1317

1418
## [1.1.1] - 2025-01-03
1519

docs/griptape-framework/drivers/prompt-drivers.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,29 @@ You can pass images to the Driver if the model supports it:
2525
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_images.py"
2626
```
2727

28+
## Structured Output
29+
30+
Some LLMs provide functionality often referred to as "Structured Output".
31+
This means instructing the LLM to output data in a particular format, usually JSON.
32+
This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems.
33+
34+
!!! warning
35+
Each Driver may have a different default setting depending on the LLM provider's capabilities.
36+
37+
### Prompt Task
38+
39+
The easiest way to get started with structured output is by using a [PromptTask](../structures/tasks.md#prompt)'s [output_schema](../../reference/griptape/tasks/prompt_task.md#griptape.tasks.PromptTask.output_schema) parameter.
40+
41+
You can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers/prompt/base_prompt_driver.md#griptape.drivers.prompt.base_prompt_driver.BasePromptDriver.structured_output_strategy) to one of:
42+
43+
- `native`: The Driver will use the LLM's structured output functionality provided by the API.
44+
- `tool`: The Driver will add a special tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md), and will try to force the LLM to use the Tool.
45+
- `rule`: The Driver will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. This strategy does not guarantee that the LLM will output JSON and should only be used as a last resort.
46+
47+
```python
48+
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py"
49+
```
50+
2851
## Prompt Drivers
2952

3053
Griptape offers the following Prompt Drivers for interacting with LLMs.

docs/griptape-framework/drivers/src/prompt_drivers_3.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import os
22

3-
import schema
4-
53
from griptape.drivers import OpenAiChatPromptDriver
64
from griptape.structures import Agent
75

@@ -11,14 +9,6 @@
119
model="gpt-4o-2024-08-06",
1210
temperature=0.1,
1311
seed=42,
14-
response_format={
15-
"type": "json_schema",
16-
"json_schema": {
17-
"strict": True,
18-
"name": "Output",
19-
"schema": schema.Schema({"css_code": str, "relevant_emojies": [str]}).json_schema("Output Schema"),
20-
},
21-
},
2212
),
2313
input="You will be provided with a description of a mood, and your task is to generate the CSS color code for a color that matches it. Description: {{ args[0] }}",
2414
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import schema
2+
from rich.pretty import pprint
3+
4+
from griptape.drivers import OpenAiChatPromptDriver
5+
from griptape.rules import Rule
6+
from griptape.structures import Pipeline
7+
from griptape.tasks import PromptTask
8+
9+
pipeline = Pipeline(
10+
tasks=[
11+
PromptTask(
12+
prompt_driver=OpenAiChatPromptDriver(
13+
model="gpt-4o",
14+
structured_output_strategy="native", # optional
15+
),
16+
output_schema=schema.Schema(
17+
{
18+
"steps": [schema.Schema({"explanation": str, "output": str})],
19+
"final_answer": str,
20+
}
21+
),
22+
rules=[
23+
Rule("You are a helpful math tutor. Guide the user through the solution step by step."),
24+
],
25+
)
26+
]
27+
)
28+
29+
output = pipeline.run("How can I solve 8x + 7 = -23").output.value
30+
31+
32+
pprint(output)

docs/griptape-framework/structures/rulesets.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define [Ru
2626

2727
### Json Schema
2828

29+
!!! tip
30+
[Structured Output](../drivers/prompt-drivers.md#structured-output) provides a more robust solution for having the LLM generate structured output.
31+
2932
[JsonSchemaRule](../../reference/griptape/rules/json_schema_rule.md)s defines a structured format for the LLM's output by providing a JSON schema.
3033
This is particularly useful when you need the LLM to return well-formed data, such as JSON objects, with specific fields and data types.
3134

griptape/common/prompt_stack/prompt_stack.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Optional
44

55
from attrs import define, field
66

@@ -24,13 +24,16 @@
2424
from griptape.mixins.serializable_mixin import SerializableMixin
2525

2626
if TYPE_CHECKING:
27+
from schema import Schema
28+
2729
from griptape.tools import BaseTool
2830

2931

3032
@define
3133
class PromptStack(SerializableMixin):
3234
messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True})
3335
tools: list[BaseTool] = field(factory=list, kw_only=True)
36+
output_schema: Optional[Schema] = field(default=None, kw_only=True)
3437

3538
@property
3639
def system_messages(self) -> list[Message]:

griptape/drivers/prompt/amazon_bedrock_prompt_driver.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
from typing import TYPE_CHECKING, Any
55

6-
from attrs import Factory, define, field
6+
from attrs import Attribute, Factory, define, field
77
from schema import Schema
88

99
from griptape.artifacts import (
@@ -41,6 +41,7 @@
4141
import boto3
4242

4343
from griptape.common import PromptStack
44+
from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
4445
from griptape.tools import BaseTool
4546

4647
logger = logging.getLogger(Defaults.logging_config.logger_name)
@@ -55,9 +56,19 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
5556
kw_only=True,
5657
)
5758
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
59+
structured_output_strategy: StructuredOutputStrategy = field(
60+
default="tool", kw_only=True, metadata={"serializable": True}
61+
)
5862
tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True})
5963
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
6064

65+
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
66+
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
67+
if value == "native":
68+
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")
69+
70+
return value
71+
6172
@lazy_property()
6273
def client(self) -> Any:
6374
return self.session.client("bedrock-runtime")
@@ -103,10 +114,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
103114

104115
def _base_params(self, prompt_stack: PromptStack) -> dict:
105116
system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages]
106-
107117
messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()])
108118

109-
return {
119+
params = {
110120
"modelId": self.model,
111121
"messages": messages,
112122
"system": system_messages,
@@ -115,14 +125,22 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
115125
**({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}),
116126
},
117127
"additionalModelRequestFields": self.additional_model_request_fields,
118-
**(
119-
{"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}}
120-
if prompt_stack.tools and self.use_native_tools
121-
else {}
122-
),
123128
**self.extra_params,
124129
}
125130

131+
if prompt_stack.tools and self.use_native_tools:
132+
params["toolConfig"] = {
133+
"tools": [],
134+
"toolChoice": self.tool_choice,
135+
}
136+
137+
if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
138+
params["toolConfig"]["toolChoice"] = {"any": {}}
139+
140+
params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools)
141+
142+
return params
143+
126144
def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]:
127145
return [
128146
{

griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import boto3
2121

2222
from griptape.common import PromptStack
23+
from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
2324

2425
logger = logging.getLogger(Defaults.logging_config.logger_name)
2526

@@ -39,8 +40,18 @@ class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver):
3940
),
4041
kw_only=True,
4142
)
43+
structured_output_strategy: StructuredOutputStrategy = field(
44+
default="rule", kw_only=True, metadata={"serializable": True}
45+
)
4246
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
4347

48+
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
49+
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
50+
if value != "rule":
51+
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")
52+
53+
return value
54+
4455
@lazy_property()
4556
def client(self) -> Any:
4657
return self.session.client("sagemaker-runtime")

griptape/drivers/prompt/anthropic_prompt_driver.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
from typing import TYPE_CHECKING, Optional
55

6-
from attrs import Factory, define, field
6+
from attrs import Attribute, Factory, define, field
77
from schema import Schema
88

99
from griptape.artifacts import (
@@ -42,6 +42,7 @@
4242
from anthropic import Client
4343
from anthropic.types import ContentBlock, ContentBlockDeltaEvent, ContentBlockStartEvent
4444

45+
from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
4546
from griptape.tools.base_tool import BaseTool
4647

4748

@@ -68,13 +69,23 @@ class AnthropicPromptDriver(BasePromptDriver):
6869
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
6970
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
7071
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
72+
structured_output_strategy: StructuredOutputStrategy = field(
73+
default="tool", kw_only=True, metadata={"serializable": True}
74+
)
7175
max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})
7276
_client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
7377

7478
@lazy_property()
7579
def client(self) -> Client:
7680
return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)
7781

82+
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
83+
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
84+
if value == "native":
85+
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")
86+
87+
return value
88+
7889
@observable
7990
def try_run(self, prompt_stack: PromptStack) -> Message:
8091
params = self._base_params(prompt_stack)
@@ -110,23 +121,28 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
110121
system_messages = prompt_stack.system_messages
111122
system_message = system_messages[0].to_text() if system_messages else None
112123

113-
return {
124+
params = {
114125
"model": self.model,
115126
"temperature": self.temperature,
116127
"stop_sequences": self.tokenizer.stop_sequences,
117128
"top_p": self.top_p,
118129
"top_k": self.top_k,
119130
"max_tokens": self.max_tokens,
120131
"messages": messages,
121-
**(
122-
{"tools": self.__to_anthropic_tools(prompt_stack.tools), "tool_choice": self.tool_choice}
123-
if prompt_stack.tools and self.use_native_tools
124-
else {}
125-
),
126132
**({"system": system_message} if system_message else {}),
127133
**self.extra_params,
128134
}
129135

136+
if prompt_stack.tools and self.use_native_tools:
137+
params["tool_choice"] = self.tool_choice
138+
139+
if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
140+
params["tool_choice"] = {"type": "any"}
141+
142+
params["tools"] = self.__to_anthropic_tools(prompt_stack.tools)
143+
144+
return params
145+
130146
def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]:
131147
return [
132148
{"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)}

griptape/drivers/prompt/base_prompt_driver.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from typing import TYPE_CHECKING, Optional
4+
from typing import TYPE_CHECKING, Literal, Optional
55

66
from attrs import Factory, define, field
77

8-
from griptape.artifacts.base_artifact import BaseArtifact
8+
from griptape.artifacts import BaseArtifact, TextArtifact
99
from griptape.common import (
1010
ActionCallDeltaMessageContent,
1111
ActionCallMessageContent,
@@ -26,12 +26,15 @@
2626
)
2727
from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin
2828
from griptape.mixins.serializable_mixin import SerializableMixin
29+
from griptape.rules.json_schema_rule import JsonSchemaRule
2930

3031
if TYPE_CHECKING:
3132
from collections.abc import Iterator
3233

3334
from griptape.tokenizers import BaseTokenizer
3435

36+
StructuredOutputStrategy = Literal["native", "tool", "rule"]
37+
3538

3639
@define(kw_only=True)
3740
class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
@@ -56,9 +59,13 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
5659
tokenizer: BaseTokenizer
5760
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
5861
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
62+
structured_output_strategy: StructuredOutputStrategy = field(
63+
default="rule", kw_only=True, metadata={"serializable": True}
64+
)
5965
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
6066

6167
def before_run(self, prompt_stack: PromptStack) -> None:
68+
self._init_structured_output(prompt_stack)
6269
EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack))
6370

6471
def after_run(self, result: Message) -> None:
@@ -122,6 +129,34 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ...
122129
@abstractmethod
123130
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ...
124131

132+
def _init_structured_output(self, prompt_stack: PromptStack) -> None:
133+
from griptape.tools import StructuredOutputTool
134+
135+
if (output_schema := prompt_stack.output_schema) is not None:
136+
if self.structured_output_strategy == "tool":
137+
structured_output_tool = StructuredOutputTool(output_schema=output_schema)
138+
if structured_output_tool not in prompt_stack.tools:
139+
prompt_stack.tools.append(structured_output_tool)
140+
elif self.structured_output_strategy == "rule":
141+
output_artifact = TextArtifact(JsonSchemaRule(output_schema.json_schema("Output Schema")).to_text())
142+
system_messages = prompt_stack.system_messages
143+
if system_messages:
144+
last_system_message = prompt_stack.system_messages[-1]
145+
last_system_message.content.extend(
146+
[
147+
TextMessageContent(TextArtifact("\n\n")),
148+
TextMessageContent(output_artifact),
149+
]
150+
)
151+
else:
152+
prompt_stack.messages.insert(
153+
0,
154+
Message(
155+
content=[TextMessageContent(output_artifact)],
156+
role=Message.SYSTEM_ROLE,
157+
),
158+
)
159+
125160
def __process_run(self, prompt_stack: PromptStack) -> Message:
126161
return self.try_run(prompt_stack)
127162

0 commit comments

Comments
 (0)