diff --git a/backend/Dockerfile b/backend/Dockerfile index c5205175..55fd99e2 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -14,6 +14,8 @@ COPY poetry.lock pyproject.toml /app/ # Disable the creation of virtual environments RUN poetry config virtualenvs.create false +RUN poetry add boto3=^1.34.76 + # Install dependencies RUN poetry install diff --git a/backend/llm.py b/backend/llm.py index e32051c8..9e1ec448 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -1,6 +1,6 @@ from enum import Enum from typing import Any, Awaitable, Callable, List, cast -from anthropic import AsyncAnthropic +from anthropic import AsyncAnthropic, AsyncAnthropicBedrock from openai import AsyncOpenAI from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk from config import IS_DEBUG_ENABLED @@ -17,6 +17,7 @@ class Llm(Enum): CLAUDE_3_SONNET = "claude-3-sonnet-20240229" CLAUDE_3_OPUS = "claude-3-opus-20240229" CLAUDE_3_HAIKU = "claude-3-haiku-20240307" + AWS_CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0" # Will throw errors if you send a garbage string @@ -25,6 +26,8 @@ def convert_frontend_str_to_llm(frontend_str: str) -> Llm: return Llm.GPT_4_VISION elif frontend_str == "claude_3_sonnet": return Llm.CLAUDE_3_SONNET + elif frontend_str == "aws_claude_3_sonnet": + return Llm.AWS_CLAUDE_3_SONNET else: return Llm(frontend_str) @@ -129,6 +132,72 @@ async def stream_claude_response( return response.content[0].text +# TODO: Have a seperate function that translates OpenAI messages to Claude messages +async def stream_aws_claude_response( + messages: List[ChatCompletionMessageParam], + aws_access_key: str, + aws_secret_key: str, + aws_region: str, + callback: Callable[[str], Awaitable[None]], +) -> str: + + client = AsyncAnthropicBedrock( + aws_access_key=aws_access_key, + aws_secret_key=aws_secret_key, + aws_region=aws_region, + ) + + # Base parameters + model = Llm.AWS_CLAUDE_3_SONNET + max_tokens = 4096 + temperature = 0.0 + + # Translate OpenAI messages to Claude messages + system_prompt = cast(str, messages[0].get("content")) + claude_messages = [dict(message) for message in messages[1:]] + for message in claude_messages: + if not isinstance(message["content"], list): + continue + + for content in message["content"]: # type: ignore + if content["type"] == "image_url": + content["type"] = "image" + + # Extract base64 data and media type from data URL + # Example base64 data URL: ... + image_data_url = cast(str, content["image_url"]["url"]) + media_type = image_data_url.split(";")[0].split(":")[1] + base64_data = image_data_url.split(",")[1] + + # Remove OpenAI parameter + del content["image_url"] + + content["source"] = { + "type": "base64", + "media_type": media_type, + "data": base64_data, + } + + # Stream Claude response + async with client.messages.stream( + model=model.value, + max_tokens=max_tokens, + temperature=temperature, + system=system_prompt, + messages=claude_messages, # type: ignore + ) as stream: + async for text in stream.text_stream: + await callback(text) + + # Return final message + response = await stream.get_final_message() + + # Close the Anthropic client + await client.close() + + return response.content[0].text + + async def stream_claude_response_native( system_prompt: str, messages: list[Any], diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 379042ef..1693d4ac 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -7,6 +7,7 @@ from llm import ( Llm, convert_frontend_str_to_llm, + stream_aws_claude_response, stream_claude_response, stream_claude_response_native, stream_openai_response, @@ -258,6 +259,21 @@ async def process_chunk(content: str): callback=lambda x: process_chunk(x), ) exact_llm_version = code_generation_model + elif code_generation_model == Llm.AWS_CLAUDE_3_SONNET: + if not os.environ.get("AWS_AK", None): + await throw_error( + "No AWS Bedrock Anthropic API Access Key found. Please add the environment variable AWS_AK to backend/.env" + ) + raise Exception("No AWS Bedrock Anthropic Access key") + + completion = await stream_aws_claude_response( + prompt_messages, # type: ignore + aws_access_key=os.environ.get("AWS_AK"), + aws_secret_key=os.environ.get("AWS_SK"), + aws_region=os.environ.get("AWS_REGION"), + callback=lambda x: process_chunk(x), + ) + exact_llm_version = code_generation_model else: completion = await stream_openai_response( prompt_messages, # type: ignore diff --git a/frontend/src/lib/models.ts b/frontend/src/lib/models.ts index d6fd09db..afe18f4c 100644 --- a/frontend/src/lib/models.ts +++ b/frontend/src/lib/models.ts @@ -5,6 +5,7 @@ export enum CodeGenerationModel { GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09", GPT_4_VISION = "gpt_4_vision", CLAUDE_3_SONNET = "claude_3_sonnet", + AWS_CLAUDE_3_SONNET = "aws_claude_3_sonnet", } // Will generate a static error if a model in the enum above is not in the descriptions @@ -15,4 +16,5 @@ export const CODE_GENERATION_MODEL_DESCRIPTIONS: { "gpt-4-turbo-2024-04-09": { name: "GPT-4 Turbo (Apr 2024)", inBeta: false }, gpt_4_vision: { name: "GPT-4 Vision (Nov 2023)", inBeta: false }, claude_3_sonnet: { name: "Claude 3 Sonnet", inBeta: false }, + aws_claude_3_sonnet: { name: "Claude 3 Sonnet (AWS Bedrock)", inBeta: false }, };