Skip to content

Commit

Permalink
feat(autofix): Add prompt caching for Claude (#1884)
Browse files Browse the repository at this point in the history
Iteratively cache the last message in the conversation when using
`AnthropicProvider`. Also cache the system prompt. From evals, seems to
save an average of 15-20 seconds per run. It should also save us a lot
of cost.
  • Loading branch information
roaga authored Feb 6, 2025
1 parent b40edb9 commit 3986daf
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 10 deletions.
2 changes: 1 addition & 1 deletion requirements-constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ chromadb==0.4.14
google-cloud-storage==2.*
google-cloud-aiplatform==1.*
google-cloud-secret-manager==2.*
anthropic[vertex]==0.34.2
anthropic[vertex]==0.45.*
langfuse @ git+https://github.com/jennmueng/langfuse-python.git@9d9350de1e4e84fa548fe84f82c1b826be17956e
watchdog
stumpy==1.13.0
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ amqp==5.3.1
# via kombu
annotated-types==0.7.0
# via pydantic
anthropic==0.34.2
anthropic==0.45.2
# via -r requirements-constraints.txt
anyio==4.8.0
# via
Expand Down Expand Up @@ -729,7 +729,6 @@ threadpoolctl==3.2.0
# scikit-learn
tokenizers==0.15.2
# via
# anthropic
# chromadb
# transformers
torch==2.2.0
Expand Down
20 changes: 14 additions & 6 deletions src/seer/automation/agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def generate_text(
max_tokens: int | None = None,
timeout: float | None = None,
):
message_dicts, tool_dicts, system_prompt = self._prep_message_and_tools(
message_dicts, tool_dicts, system_prompt_block = self._prep_message_and_tools(
messages=messages,
prompt=prompt,
system_prompt=system_prompt,
Expand All @@ -447,7 +447,7 @@ def generate_text(
anthropic_client = self.get_client()

completion = anthropic_client.messages.create(
system=system_prompt or NOT_GIVEN,
system=system_prompt_block or NOT_GIVEN,
model=self.model_name,
tools=cast(Iterable[ToolParam], tool_dicts) if tool_dicts else NOT_GIVEN,
messages=cast(Iterable[MessageParam], message_dicts),
Expand Down Expand Up @@ -560,16 +560,24 @@ def _prep_message_and_tools(
prompt: str | None = None,
system_prompt: str | None = None,
tools: list[FunctionTool] | None = None,
) -> tuple[list[MessageParam], list[ToolParam] | None, str | None]:
) -> tuple[list[MessageParam], list[ToolParam] | None, list[TextBlockParam] | None]:
message_dicts = [cls.to_message_param(message) for message in messages] if messages else []
if prompt:
message_dicts.append(cls.to_message_param(Message(role="user", content=prompt)))
if message_dicts:
message_dicts[-1]["content"][0]["cache_control"] = {"type": "ephemeral"} # type: ignore[index]

tool_dicts = (
[cls.to_tool_dict(tool) for tool in tools] if tools and len(tools) > 0 else None
)

return message_dicts, tool_dicts, system_prompt
system_prompt_block = (
[TextBlockParam(type="text", text=system_prompt, cache_control={"type": "ephemeral"})]
if system_prompt
else None
)

return message_dicts, tool_dicts, system_prompt_block

@observe(as_type="generation", name="Anthropic Stream")
def generate_text_stream(
Expand All @@ -583,7 +591,7 @@ def generate_text_stream(
max_tokens: int | None = None,
timeout: float | None = None,
) -> Iterator[str | ToolCall | Usage]:
message_dicts, tool_dicts, system_prompt = self._prep_message_and_tools(
message_dicts, tool_dicts, system_prompt_block = self._prep_message_and_tools(
messages=messages,
prompt=prompt,
system_prompt=system_prompt,
Expand All @@ -593,7 +601,7 @@ def generate_text_stream(
anthropic_client = self.get_client()

stream = anthropic_client.messages.create(
system=system_prompt or NOT_GIVEN,
system=system_prompt_block or NOT_GIVEN,
model=self.model_name,
tools=cast(Iterable[ToolParam], tool_dicts) if tool_dicts else NOT_GIVEN,
messages=cast(Iterable[MessageParam], message_dicts),
Expand Down
2 changes: 1 addition & 1 deletion tests/automation/agent/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def test_anthropic_prep_message_and_tools():
assert "description" in tool_dicts[0]
assert "input_schema" in tool_dicts[0]

assert returned_system_prompt == system_prompt
assert returned_system_prompt[0]["text"] == system_prompt


@pytest.mark.vcr()
Expand Down
Binary file not shown.

0 comments on commit 3986daf

Please sign in to comment.