Skip to content

Added support for AWS Bedrock Claude3 #322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ COPY poetry.lock pyproject.toml /app/
# Disable the creation of virtual environments
RUN poetry config virtualenvs.create false

RUN poetry add boto3=^1.34.76

# Install dependencies
RUN poetry install

Expand Down
71 changes: 70 additions & 1 deletion backend/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from typing import Any, Awaitable, Callable, List, cast
from anthropic import AsyncAnthropic
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
from config import IS_DEBUG_ENABLED
Expand All @@ -17,6 +17,7 @@ class Llm(Enum):
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
CLAUDE_3_OPUS = "claude-3-opus-20240229"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
AWS_CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"


# Will throw errors if you send a garbage string
Expand All @@ -25,6 +26,8 @@ def convert_frontend_str_to_llm(frontend_str: str) -> Llm:
return Llm.GPT_4_VISION
elif frontend_str == "claude_3_sonnet":
return Llm.CLAUDE_3_SONNET
elif frontend_str == "aws_claude_3_sonnet":
return Llm.AWS_CLAUDE_3_SONNET
else:
return Llm(frontend_str)

Expand Down Expand Up @@ -129,6 +132,72 @@ async def stream_claude_response(
return response.content[0].text


# TODO: Have a seperate function that translates OpenAI messages to Claude messages
async def stream_aws_claude_response(
messages: List[ChatCompletionMessageParam],
aws_access_key: str,
aws_secret_key: str,
aws_region: str,
callback: Callable[[str], Awaitable[None]],
) -> str:

client = AsyncAnthropicBedrock(
aws_access_key=aws_access_key,
aws_secret_key=aws_secret_key,
aws_region=aws_region,
)

# Base parameters
model = Llm.AWS_CLAUDE_3_SONNET
max_tokens = 4096
temperature = 0.0

# Translate OpenAI messages to Claude messages
system_prompt = cast(str, messages[0].get("content"))
claude_messages = [dict(message) for message in messages[1:]]
for message in claude_messages:
if not isinstance(message["content"], list):
continue

for content in message["content"]: # type: ignore
if content["type"] == "image_url":
content["type"] = "image"

# Extract base64 data and media type from data URL
# Example base64 data URL: ...
image_data_url = cast(str, content["image_url"]["url"])
media_type = image_data_url.split(";")[0].split(":")[1]
base64_data = image_data_url.split(",")[1]

# Remove OpenAI parameter
del content["image_url"]

content["source"] = {
"type": "base64",
"media_type": media_type,
"data": base64_data,
}

# Stream Claude response
async with client.messages.stream(
model=model.value,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt,
messages=claude_messages, # type: ignore
) as stream:
async for text in stream.text_stream:
await callback(text)

# Return final message
response = await stream.get_final_message()

# Close the Anthropic client
await client.close()

return response.content[0].text


async def stream_claude_response_native(
system_prompt: str,
messages: list[Any],
Expand Down
16 changes: 16 additions & 0 deletions backend/routes/generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from llm import (
Llm,
convert_frontend_str_to_llm,
stream_aws_claude_response,
stream_claude_response,
stream_claude_response_native,
stream_openai_response,
Expand Down Expand Up @@ -258,6 +259,21 @@ async def process_chunk(content: str):
callback=lambda x: process_chunk(x),
)
exact_llm_version = code_generation_model
elif code_generation_model == Llm.AWS_CLAUDE_3_SONNET:
if not os.environ.get("AWS_AK", None):
await throw_error(
"No AWS Bedrock Anthropic API Access Key found. Please add the environment variable AWS_AK to backend/.env"
)
raise Exception("No AWS Bedrock Anthropic Access key")

completion = await stream_aws_claude_response(
prompt_messages, # type: ignore
aws_access_key=os.environ.get("AWS_AK"),
aws_secret_key=os.environ.get("AWS_SK"),
aws_region=os.environ.get("AWS_REGION"),
callback=lambda x: process_chunk(x),
)
exact_llm_version = code_generation_model
else:
completion = await stream_openai_response(
prompt_messages, # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/lib/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export enum CodeGenerationModel {
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09",
GPT_4_VISION = "gpt_4_vision",
CLAUDE_3_SONNET = "claude_3_sonnet",
AWS_CLAUDE_3_SONNET = "aws_claude_3_sonnet",
}

// Will generate a static error if a model in the enum above is not in the descriptions
Expand All @@ -15,4 +16,5 @@ export const CODE_GENERATION_MODEL_DESCRIPTIONS: {
"gpt-4-turbo-2024-04-09": { name: "GPT-4 Turbo (Apr 2024)", inBeta: false },
gpt_4_vision: { name: "GPT-4 Vision (Nov 2023)", inBeta: false },
claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: false },
aws_claude_3_sonnet: { name: "Claude 3 Sonnet (AWS Bedrock)", inBeta: false },
};