|
| 1 | +import os |
| 2 | +import streamlit as st |
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import Annotated, Sequence, Optional |
| 5 | + |
| 6 | +from langchain.callbacks.base import BaseCallbackHandler |
| 7 | +from langchain_anthropic import ChatAnthropic |
| 8 | +from langchain_core.messages import SystemMessage |
| 9 | +from langchain_openai import ChatOpenAI |
| 10 | +from langgraph.checkpoint.memory import MemorySaver |
| 11 | +from langgraph.graph import START, StateGraph |
| 12 | +from langgraph.prebuilt import ToolNode, tools_condition |
| 13 | +from langgraph.graph.message import add_messages |
| 14 | +from langchain_core.messages import BaseMessage |
| 15 | + |
| 16 | +from tools import retriever_tool |
| 17 | +from tools import search, sql_executor_tool |
| 18 | +from PIL import Image |
| 19 | +from io import BytesIO |
| 20 | + |
| 21 | +@dataclass |
| 22 | +class MessagesState: |
| 23 | + messages: Annotated[Sequence[BaseMessage], add_messages] |
| 24 | + |
| 25 | + |
| 26 | +memory = MemorySaver() |
| 27 | + |
| 28 | + |
| 29 | +@dataclass |
| 30 | +class ModelConfig: |
| 31 | + model_name: str |
| 32 | + api_key: str |
| 33 | + base_url: Optional[str] = None |
| 34 | + |
| 35 | + |
| 36 | +model_configurations = { |
| 37 | + "gpt-4o": ModelConfig( |
| 38 | + model_name="gpt-4o", api_key=st.secrets["OPENAI_API_KEY"] |
| 39 | + ), |
| 40 | + "Gemini Flash 1.5 8B": ModelConfig( |
| 41 | + model_name="google/gemini-flash-1.5-8b", |
| 42 | + api_key=st.secrets["OPENROUTER_API_KEY"], |
| 43 | + base_url="https://openrouter.ai/api/v1", |
| 44 | + ), |
| 45 | + "claude3-haiku": ModelConfig( |
| 46 | + model_name="claude-3-haiku-20240307", api_key=st.secrets["ANTHROPIC_API_KEY"] |
| 47 | + ), |
| 48 | + "llama-3.2-3b": ModelConfig( |
| 49 | + model_name="accounts/fireworks/models/llama-v3p2-3b-instruct", |
| 50 | + api_key=st.secrets["FIREWORKS_API_KEY"], |
| 51 | + base_url="https://api.fireworks.ai/inference/v1", |
| 52 | + ), |
| 53 | + "llama-3.1-405b": ModelConfig( |
| 54 | + model_name="accounts/fireworks/models/llama-v3p1-405b-instruct", |
| 55 | + api_key=st.secrets["FIREWORKS_API_KEY"], |
| 56 | + base_url="https://api.fireworks.ai/inference/v1", |
| 57 | + ), |
| 58 | +} |
| 59 | +sys_msg = SystemMessage( |
| 60 | + content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. Do not ask the user for schema or database details. You have access to the following tools: |
| 61 | + - Database_Schema: This tool allows you to search for database schema details when needed to generate the SQL code. |
| 62 | + - Internet_Search: This tool allows you to search the internet for snowflake sql related information when needed to generate the SQL code. |
| 63 | + - Snowflake_SQL_Executor: This tool allows you to execute snowflake sql queries when needed to generate the SQL code. You only have read access to the database, do not modify the database in any way. |
| 64 | + """ |
| 65 | +) |
| 66 | +tools = [retriever_tool, search, sql_executor_tool] |
| 67 | + |
| 68 | +def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> StateGraph: |
| 69 | + config = model_configurations.get(model_name) |
| 70 | + if not config: |
| 71 | + raise ValueError(f"Unsupported model name: {model_name}") |
| 72 | + |
| 73 | + if not config.api_key: |
| 74 | + raise ValueError(f"API key for model '{model_name}' is not set. Please check your environment variables or secrets configuration.") |
| 75 | + |
| 76 | + llm = ( |
| 77 | + ChatOpenAI( |
| 78 | + model=config.model_name, |
| 79 | + api_key=config.api_key, |
| 80 | + callbacks=[callback_handler], |
| 81 | + streaming=True, |
| 82 | + base_url=config.base_url, |
| 83 | + temperature=0.01, |
| 84 | + ) |
| 85 | + if config.model_name != "claude-3-haiku-20240307" |
| 86 | + else ChatAnthropic( |
| 87 | + model=config.model_name, |
| 88 | + api_key=config.api_key, |
| 89 | + callbacks=[callback_handler], |
| 90 | + streaming=True, |
| 91 | + ) |
| 92 | + ) |
| 93 | + |
| 94 | + llm_with_tools = llm.bind_tools(tools) |
| 95 | + |
| 96 | + def llm_agent(state: MessagesState): |
| 97 | + return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]} |
| 98 | + |
| 99 | + builder = StateGraph(MessagesState) |
| 100 | + builder.add_node("llm_agent", llm_agent) |
| 101 | + builder.add_node("tools", ToolNode(tools)) |
| 102 | + builder.add_edge(START, "llm_agent") |
| 103 | + builder.add_conditional_edges("llm_agent", tools_condition) |
| 104 | + builder.add_edge("tools", "llm_agent") |
| 105 | + |
| 106 | + react_graph = builder.compile(checkpointer=memory) |
| 107 | + |
| 108 | + # png_data = react_graph.get_graph(xray=True).draw_mermaid_png() |
| 109 | + # with open("graph.png", "wb") as f: |
| 110 | + # f.write(png_data) |
| 111 | + |
| 112 | + # image = Image.open(BytesIO(png_data)) |
| 113 | + # st.image(image, caption="React Graph") |
| 114 | + |
| 115 | + return react_graph |
0 commit comments