Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- `BaseVectorStoreDriver.query_vector` for querying vector stores with vectors.
- Structured Output support for all Prompt Drivers.
- `PromptTask.output_schema` for setting an output schema to be used with Structured Output.
- `Agent.output_schema` for setting an output schema to be used on the Agent's Prompt Task.
- `BasePromptDriver.structured_output_strategy` for changing the Structured Output strategy between `native`, `tool`, and `rule`.

## [1.1.1] - 2025-01-03

Expand Down
23 changes: 23 additions & 0 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,29 @@ You can pass images to the Driver if the model supports it:
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_images.py"
```

## Structured Output

Some LLMs provide functionality often referred to as "Structured Output".
This means instructing the LLM to output data in a particular format, usually JSON.
This can be useful for forcing the LLM to output in a parsable format that can be used by downstream systems.

!!! warning
Each Driver may have a different default setting depending on the LLM provider's capabilities.

### Prompt Task

The easiest way to get started with structured output is by using a [PromptTask](../structures/tasks.md#prompt)'s [output_schema](../../reference/griptape/tasks/prompt_task.md#griptape.tasks.PromptTask.output_schema) parameter.

You can change _how_ the output is structured by setting the Driver's [structured_output_strategy](../../reference/griptape/drivers/prompt/base_prompt_driver.md#griptape.drivers.prompt.base_prompt_driver.BasePromptDriver.structured_output_strategy) to one of:

- `native`: The Driver will use the LLM's structured output functionality provided by the API.
- `tool`: The Driver will add a special tool, [StructuredOutputTool](../../reference/griptape/tools/structured_output/tool.md), and will try to force the LLM to use the Tool.
- `rule`: The Driver will add a [JsonSchemaRule](../structures/rulesets.md#json-schema-rule) to the Task's system prompt. This strategy does not guarantee that the LLM will output JSON and should only be used as a last resort.

```python
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_structured_output.py"
```

## Prompt Drivers

Griptape offers the following Prompt Drivers for interacting with LLMs.
Expand Down
10 changes: 0 additions & 10 deletions docs/griptape-framework/drivers/src/prompt_drivers_3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os

import schema

from griptape.drivers import OpenAiChatPromptDriver
from griptape.structures import Agent

Expand All @@ -11,14 +9,6 @@
model="gpt-4o-2024-08-06",
temperature=0.1,
seed=42,
response_format={
"type": "json_schema",
"json_schema": {
"strict": True,
"name": "Output",
"schema": schema.Schema({"css_code": str, "relevant_emojies": [str]}).json_schema("Output Schema"),
},
},
),
input="You will be provided with a description of a mood, and your task is to generate the CSS color code for a color that matches it. Description: {{ args[0] }}",
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import schema
from rich.pretty import pprint

from griptape.drivers import OpenAiChatPromptDriver
from griptape.rules import Rule
from griptape.structures import Pipeline
from griptape.tasks import PromptTask

pipeline = Pipeline(
tasks=[
PromptTask(
prompt_driver=OpenAiChatPromptDriver(
model="gpt-4o",
structured_output_strategy="native", # optional
),
output_schema=schema.Schema(
{
"steps": [schema.Schema({"explanation": str, "output": str})],
"final_answer": str,
}
),
rules=[
Rule("You are a helpful math tutor. Guide the user through the solution step by step."),
],
)
]
)

output = pipeline.run("How can I solve 8x + 7 = -23").output.value


pprint(output)
3 changes: 3 additions & 0 deletions docs/griptape-framework/structures/rulesets.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define [Ru

### Json Schema

!!! tip
[Structured Output](../drivers/prompt-drivers.md#structured-output) provides a more robust solution for having the LLM generate structured output.

[JsonSchemaRule](../../reference/griptape/rules/json_schema_rule.md)s defines a structured format for the LLM's output by providing a JSON schema.
This is particularly useful when you need the LLM to return well-formed data, such as JSON objects, with specific fields and data types.

Expand Down
5 changes: 4 additions & 1 deletion griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from attrs import define, field

Expand All @@ -24,13 +24,16 @@
from griptape.mixins.serializable_mixin import SerializableMixin

if TYPE_CHECKING:
from schema import Schema

from griptape.tools import BaseTool


@define
class PromptStack(SerializableMixin):
messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True})
tools: list[BaseTool] = field(factory=list, kw_only=True)
output_schema: Optional[Schema] = field(default=None, kw_only=True)

@property
def system_messages(self) -> list[Message]:
Expand Down
34 changes: 26 additions & 8 deletions griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from typing import TYPE_CHECKING, Any

from attrs import Factory, define, field
from attrs import Attribute, Factory, define, field
from schema import Schema

from griptape.artifacts import (
Expand Down Expand Up @@ -41,6 +41,7 @@
import boto3

from griptape.common import PromptStack
from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
from griptape.tools import BaseTool

logger = logging.getLogger(Defaults.logging_config.logger_name)
Expand All @@ -55,9 +56,19 @@ class AmazonBedrockPromptDriver(BasePromptDriver):
kw_only=True,
)
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: StructuredOutputStrategy = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True})
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
if value == "native":
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

return value

@lazy_property()
def client(self) -> Any:
return self.session.client("bedrock-runtime")
Expand Down Expand Up @@ -103,10 +114,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:

def _base_params(self, prompt_stack: PromptStack) -> dict:
system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages]

messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()])

return {
params = {
"modelId": self.model,
"messages": messages,
"system": system_messages,
Expand All @@ -115,14 +125,22 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}),
},
"additionalModelRequestFields": self.additional_model_request_fields,
**(
{"toolConfig": {"tools": self.__to_bedrock_tools(prompt_stack.tools), "toolChoice": self.tool_choice}}
if prompt_stack.tools and self.use_native_tools
else {}
),
**self.extra_params,
}

if prompt_stack.tools and self.use_native_tools:
params["toolConfig"] = {
"tools": [],
"toolChoice": self.tool_choice,
}

if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
params["toolConfig"]["toolChoice"] = {"any": {}}

params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools)

return params

def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]:
return [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import boto3

from griptape.common import PromptStack
from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy

logger = logging.getLogger(Defaults.logging_config.logger_name)

Expand All @@ -39,8 +40,18 @@ class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver):
),
kw_only=True,
)
structured_output_strategy: StructuredOutputStrategy = field(
default="rule", kw_only=True, metadata={"serializable": True}
)
_client: Any = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
if value != "rule":
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

return value

@lazy_property()
def client(self) -> Any:
return self.session.client("sagemaker-runtime")
Expand Down
30 changes: 23 additions & 7 deletions griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from typing import TYPE_CHECKING, Optional

from attrs import Factory, define, field
from attrs import Attribute, Factory, define, field
from schema import Schema

from griptape.artifacts import (
Expand Down Expand Up @@ -42,6 +42,7 @@
from anthropic import Client
from anthropic.types import ContentBlock, ContentBlockDeltaEvent, ContentBlockStartEvent

from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
from griptape.tools.base_tool import BaseTool


Expand All @@ -68,13 +69,23 @@ class AnthropicPromptDriver(BasePromptDriver):
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
structured_output_strategy: StructuredOutputStrategy = field(
default="tool", kw_only=True, metadata={"serializable": True}
)
max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})
_client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Client:
return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)

@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
if value == "native":
raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

return value

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
params = self._base_params(prompt_stack)
Expand Down Expand Up @@ -110,23 +121,28 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
system_messages = prompt_stack.system_messages
system_message = system_messages[0].to_text() if system_messages else None

return {
params = {
"model": self.model,
"temperature": self.temperature,
"stop_sequences": self.tokenizer.stop_sequences,
"top_p": self.top_p,
"top_k": self.top_k,
"max_tokens": self.max_tokens,
"messages": messages,
**(
{"tools": self.__to_anthropic_tools(prompt_stack.tools), "tool_choice": self.tool_choice}
if prompt_stack.tools and self.use_native_tools
else {}
),
**({"system": system_message} if system_message else {}),
**self.extra_params,
}

if prompt_stack.tools and self.use_native_tools:
params["tool_choice"] = self.tool_choice

if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
params["tool_choice"] = {"type": "any"}

params["tools"] = self.__to_anthropic_tools(prompt_stack.tools)

return params

def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]:
return [
{"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)}
Expand Down
39 changes: 37 additions & 2 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional

from attrs import Factory, define, field

from griptape.artifacts.base_artifact import BaseArtifact
from griptape.artifacts import BaseArtifact, TextArtifact
from griptape.common import (
ActionCallDeltaMessageContent,
ActionCallMessageContent,
Expand All @@ -26,12 +26,15 @@
)
from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin
from griptape.mixins.serializable_mixin import SerializableMixin
from griptape.rules.json_schema_rule import JsonSchemaRule

if TYPE_CHECKING:
from collections.abc import Iterator

from griptape.tokenizers import BaseTokenizer

StructuredOutputStrategy = Literal["native", "tool", "rule"]


@define(kw_only=True)
class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
Expand All @@ -56,9 +59,13 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
tokenizer: BaseTokenizer
stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
structured_output_strategy: StructuredOutputStrategy = field(
default="rule", kw_only=True, metadata={"serializable": True}
)
extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})

def before_run(self, prompt_stack: PromptStack) -> None:
self._init_structured_output(prompt_stack)
EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack))

def after_run(self, result: Message) -> None:
Expand Down Expand Up @@ -122,6 +129,34 @@ def try_run(self, prompt_stack: PromptStack) -> Message: ...
@abstractmethod
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ...

def _init_structured_output(self, prompt_stack: PromptStack) -> None:
from griptape.tools import StructuredOutputTool

if (output_schema := prompt_stack.output_schema) is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional, but I think avoiding nesting is cleaner:

output_schema = prompt_stack.output_schema

if output_schema is None:
    return
    
...

if self.structured_output_strategy == "tool":
structured_output_tool = StructuredOutputTool(output_schema=output_schema)
if structured_output_tool not in prompt_stack.tools:
prompt_stack.tools.append(structured_output_tool)
elif self.structured_output_strategy == "rule":
output_artifact = TextArtifact(JsonSchemaRule(output_schema.json_schema("Output Schema")).to_text())
system_messages = prompt_stack.system_messages
if system_messages:
last_system_message = prompt_stack.system_messages[-1]
last_system_message.content.extend(
[
TextMessageContent(TextArtifact("\n\n")),
TextMessageContent(output_artifact),
]
)
else:
prompt_stack.messages.insert(
0,
Message(
content=[TextMessageContent(output_artifact)],
role=Message.SYSTEM_ROLE,
),
)

def __process_run(self, prompt_stack: PromptStack) -> Message:
return self.try_run(prompt_stack)

Expand Down
Loading
Loading