diff --git a/backend/config.py b/backend/config.py index 05592b03d..d96167daa 100644 --- a/backend/config.py +++ b/backend/config.py @@ -5,6 +5,11 @@ ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None) +# AWS +AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", None) +AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY", None) +AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME", "us-west-2") + # Debugging-related SHOULD_MOCK_AI_RESPONSE = bool(os.environ.get("MOCK", False)) diff --git a/backend/evals/core.py b/backend/evals/core.py index 5e0536289..5213a2eb1 100644 --- a/backend/evals/core.py +++ b/backend/evals/core.py @@ -1,7 +1,11 @@ import os + +from config import AWS_REGION_NAME +from config import AWS_ACCESS_KEY +from config import AWS_SECRET_ACCESS_KEY from config import ANTHROPIC_API_KEY -from llm import Llm, stream_claude_response, stream_openai_response +from llm import Llm, stream_claude_response, stream_openai_response, stream_claude_response_aws_bedrock from prompts import assemble_prompt from prompts.types import Stack @@ -10,20 +14,32 @@ async def generate_code_core(image_url: str, stack: Stack, model: Llm) -> str: prompt_messages = assemble_prompt(image_url, stack) openai_api_key = os.environ.get("OPENAI_API_KEY") anthropic_api_key = ANTHROPIC_API_KEY + aws_access_key = AWS_ACCESS_KEY + aws_secret_access_key = AWS_SECRET_ACCESS_KEY + aws_region_name = AWS_REGION_NAME, openai_base_url = None async def process_chunk(content: str): pass if model == Llm.CLAUDE_3_SONNET: - if not anthropic_api_key: - raise Exception("Anthropic API key not found") - - completion = await stream_claude_response( - prompt_messages, - api_key=anthropic_api_key, - callback=lambda x: process_chunk(x), - ) + if not anthropic_api_key and not aws_access_key and not aws_secret_access_key: + raise Exception("Anthropic API key or AWS Access Key not found") + + if anthropic_api_key: + completion = await stream_claude_response( + prompt_messages, + api_key=anthropic_api_key, + callback=lambda x: process_chunk(x), + ) + else: + completion = await stream_claude_response_aws_bedrock( + prompt_messages, + access_key=aws_access_key, + secret_access_key=aws_secret_access_key, + aws_region_name=aws_region_name, + callback=lambda x: process_chunk(x), + ) else: if not openai_api_key: raise Exception("OpenAI API key not found") diff --git a/backend/llm.py b/backend/llm.py index e32051c8c..53bf5f9d9 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -5,7 +5,10 @@ from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk from config import IS_DEBUG_ENABLED from debug.DebugFileWriter import DebugFileWriter - +import json +import boto3 +from typing import List +from botocore.exceptions import ClientError from utils import pprint_prompt @@ -15,6 +18,7 @@ class Llm(Enum): GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09" GPT_4O_2024_05_13 = "gpt-4o-2024-05-13" CLAUDE_3_SONNET = "claude-3-sonnet-20240229" + CLAUDE_3_SONNET_BEDROCK = "anthropic.claude-3-sonnet-20240229-v1:0" CLAUDE_3_OPUS = "claude-3-opus-20240229" CLAUDE_3_HAIKU = "claude-3-haiku-20240307" @@ -128,6 +132,128 @@ async def stream_claude_response( return response.content[0].text +def initialize_bedrock_client(access_key: str, secret_access_key: str, aws_region_name: str): + try: + # Initialize the Bedrock Runtime client + bedrock_runtime = boto3.client( + service_name='bedrock-runtime', + aws_access_key_id=access_key, + aws_secret_access_key=secret_access_key, + region_name=aws_region_name, + ) + return bedrock_runtime + except ClientError as err: + message = err.response["Error"]["Message"] + print(f"A client error occurred: {message}") + except Exception as err: + print("An error occurred!") + raise err + +async def stream_bedrock_response( + bedrock_runtime, + messages: List[dict], + system_prompt: str, + model_id: str, + max_tokens: int, + content_type: str, + accept: str, + temperature: float, + callback: Callable[[str], Awaitable[None]], +) -> str: + try: + # Prepare the request body + body = json.dumps({ + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": max_tokens, + "messages": messages, + "system": system_prompt, + "temperature": temperature + }) + + # Invoke the Bedrock Runtime API with response stream + response = bedrock_runtime.invoke_model_with_response_stream( + body=body, + modelId=model_id, + accept=accept, + contentType=content_type, + ) + stream = response.get("body") + + # Stream the response + final_message = "" + if stream: + for event in stream: + chunk = event.get("chunk") + if chunk: + data = chunk.get("bytes").decode() + chunk_obj = json.loads(data) + if chunk_obj["type"] == "content_block_delta": + text = chunk_obj["delta"]["text"] + await callback(text) + final_message += text + + return final_message + + except ClientError as err: + message = err.response["Error"]["Message"] + print(f"A client error occurred: {message}") + except Exception as err: + print("An error occurred!") + raise err + +async def stream_claude_response_aws_bedrock( + messages: List[dict], + access_key: str, + secret_access_key: str, + aws_region_name: str, + callback: Callable[[str], Awaitable[None]], +) -> str: + bedrock_runtime = initialize_bedrock_client(access_key, secret_access_key, aws_region_name) + + # Set model parameters + model_id = Llm.CLAUDE_3_SONNET_BEDROCK.value + max_tokens = 4096 + content_type = 'application/json' + accept = '*/*' + temperature = 0.0 + + # Translate OpenAI messages to Claude messages + system_prompt = cast(str, messages[0].get("content")) + claude_messages = [dict(message) for message in messages[1:]] + for message in claude_messages: + if not isinstance(message["content"], list): + continue + + for content in message["content"]: # type: ignore + if content["type"] == "image_url": + content["type"] = "image" + + # Extract base64 data and media type from data URL + # Example base64 data URL: data:image/png;base64,iVBOR... + image_data_url = cast(str, content["image_url"]["url"]) + media_type = image_data_url.split(";")[0].split(":")[1] + base64_data = image_data_url.split(",")[1] + + # Remove OpenAI parameter + del content["image_url"] + + content["source"] = { + "type": "base64", + "media_type": media_type, + "data": base64_data, + } + + return await stream_bedrock_response( + bedrock_runtime, + claude_messages, + system_prompt, + model_id, + max_tokens, + content_type, + accept, + temperature, + callback, + ) async def stream_claude_response_native( system_prompt: str, @@ -216,3 +342,62 @@ async def stream_claude_response_native( raise Exception("No HTML response found in AI response") else: return response.content[0].text + +async def stream_claude_response_native_aws_bedrock( + system_prompt: str, + messages: list[Any], + access_key: str, + secret_access_key: str, + aws_region_name: str, + callback: Callable[[str], Awaitable[None]], + include_thinking: bool = False, + model: Llm = Llm.CLAUDE_3_SONNET_BEDROCK, +) -> str: + bedrock_runtime = initialize_bedrock_client(access_key, secret_access_key, aws_region_name) + + # Set model parameters + model_id = Llm.CLAUDE_3_SONNET_BEDROCK.value + max_tokens = 4096 + content_type = 'application/json' + accept = '*/*' + temperature = 0.0 + + # Multi-pass flow + current_pass_num = 1 + max_passes = 2 + + prefix = "" + response = None + + while current_pass_num <= max_passes: + current_pass_num += 1 + + # Set up message depending on whether we have a prefix + messages_to_send = ( + messages + [{"role": "assistant", "content": prefix}] + if include_thinking + else messages + ) + + response_text = await stream_bedrock_response( + bedrock_runtime, + messages_to_send, + system_prompt, + model_id, + max_tokens, + content_type, + accept, + temperature, + callback, + ) + + # Set up messages array for next pass + messages += [ + {"role": "assistant", "content": str(prefix) + response_text}, + { + "role": "user", + "content": "You've done a good job with a first draft. Improve this further based on the original instructions so that the app is fully functional and looks like the original video of the app we're trying to replicate.", + }, + ] + + return response_text \ No newline at end of file diff --git a/backend/routes/generate_code.py b/backend/routes/generate_code.py index 379042efe..ff924a1b0 100644 --- a/backend/routes/generate_code.py +++ b/backend/routes/generate_code.py @@ -2,14 +2,14 @@ import traceback from fastapi import APIRouter, WebSocket import openai -from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE +from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE, AWS_ACCESS_KEY, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME from custom_types import InputMode from llm import ( Llm, convert_frontend_str_to_llm, stream_claude_response, stream_claude_response_native, - stream_openai_response, + stream_openai_response, stream_claude_response_aws_bedrock, stream_claude_response_native_aws_bedrock, ) from openai.types.chat import ChatCompletionMessageParam from mock_llm import mock_completion @@ -25,7 +25,6 @@ from video.utils import extract_tag_content, assemble_claude_prompt_video from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore - router = APIRouter() @@ -55,7 +54,7 @@ async def stream_code(websocket: WebSocket): print("Incoming websocket connection...") async def throw_error( - message: str, + message: str, ): await websocket.send_json({"type": "error", "value": message}) await websocket.close(APP_ERROR_WEB_SOCKET_CODE) @@ -230,33 +229,53 @@ async def process_chunk(content: str): else: try: if validated_input_mode == "video": - if not anthropic_api_key: - await throw_error( - "Video only works with Anthropic models. No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog" - ) - raise Exception("No Anthropic key") - - completion = await stream_claude_response_native( - system_prompt=VIDEO_PROMPT, - messages=prompt_messages, # type: ignore - api_key=anthropic_api_key, - callback=lambda x: process_chunk(x), - model=Llm.CLAUDE_3_OPUS, - include_thinking=True, - ) - exact_llm_version = Llm.CLAUDE_3_OPUS - elif code_generation_model == Llm.CLAUDE_3_SONNET: - if not anthropic_api_key: - await throw_error( - "No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog" - ) - raise Exception("No Anthropic key") - - completion = await stream_claude_response( - prompt_messages, # type: ignore - api_key=anthropic_api_key, - callback=lambda x: process_chunk(x), - ) + if not anthropic_api_key and not AWS_ACCESS_KEY and not AWS_SECRET_ACCESS_KEY: + await throw_error( + "Video only works with Anthropic models. Neither Anthropic API key or AWS Access Key found. Please add the environment variable ANTHROPIC_API_KEY or AWS_ACCESS_KEY/AWS_SECRET_ACCESS_KEY to backend/.env or in the settings dialog" + ) + raise Exception("No Anthropic key") + + if anthropic_api_key: + completion = await stream_claude_response_native( + system_prompt=VIDEO_PROMPT, + messages=prompt_messages, # type: ignore + api_key=anthropic_api_key, + callback=lambda x: process_chunk(x), + model=Llm.CLAUDE_3_OPUS, + include_thinking=True, + ) + else: + completion = await stream_claude_response_native_aws_bedrock( + system_prompt=VIDEO_PROMPT, + messages=prompt_messages, # type: ignore + access_key=AWS_ACCESS_KEY, + secret_access_key=AWS_SECRET_ACCESS_KEY, + aws_region_name=AWS_REGION_NAME, + callback=lambda x: process_chunk(x), + model=Llm.CLAUDE_3_OPUS, + include_thinking=True, + ) + exact_llm_version = Llm.CLAUDE_3_OPUS + elif code_generation_model == Llm.CLAUDE_3_SONNET: + if not anthropic_api_key and not AWS_ACCESS_KEY and not AWS_SECRET_ACCESS_KEY: + await throw_error( + "No Anthropic API key or AWS Access Key found. Please add the environment variable ANTHROPIC_API_KEY or AWS_ACCESS_KEY/AWS_SECRET_ACCESS_KEY to backend/.env or in the settings dialog" + ) + raise Exception("No Anthropic key") + if anthropic_api_key: + completion = await stream_claude_response( + prompt_messages, # type: ignore + api_key=anthropic_api_key, + callback=lambda x: process_chunk(x), + ) + else: + completion = await stream_claude_response_aws_bedrock( + prompt_messages, # type: ignore + access_key=AWS_ACCESS_KEY, + secret_access_key=AWS_SECRET_ACCESS_KEY, + aws_region_name=AWS_REGION_NAME, + callback=lambda x: process_chunk(x), + ) exact_llm_version = code_generation_model else: completion = await stream_openai_response(