Skip to content

Commit 9e9c786

Browse files
authored
Merge pull request #21 from kaarthik108/move-to-agents
Move to Langgraph Agents - Its time
2 parents b703295 + d62fbb2 commit 9e9c786

13 files changed

+539
-299
lines changed

.github/workflows/lint.yml

-29
This file was deleted.

Makefile

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
.PHONY: all format lint lint_diff format_diff lint_package lint_tests spell_check spell_fix help lint-fix
2+
3+
# Define a variable for Python and notebook files.
4+
PYTHON_FILES=src/
5+
MYPY_CACHE=.mypy_cache
6+
7+
######################
8+
# LINTING AND FORMATTING
9+
######################
10+
11+
lint format: PYTHON_FILES=.
12+
lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$')
13+
lint_package: PYTHON_FILES=src
14+
lint_tests: PYTHON_FILES=tests
15+
lint_tests: MYPY_CACHE=.mypy_cache_test
16+
17+
lint lint_diff lint_package lint_tests:
18+
python -m ruff check .
19+
[ "$(PYTHON_FILES)" = "" ] || python -m ruff format $(PYTHON_FILES) --diff
20+
[ "$(PYTHON_FILES)" = "" ] || python -m ruff check --select I,F401 --fix $(PYTHON_FILES)
21+
[ "$(PYTHON_FILES)" = "" ] || python -m mypy --strict $(PYTHON_FILES)
22+
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && python -m mypy --strict $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
23+
24+
format format_diff:
25+
ruff format $(PYTHON_FILES)
26+
ruff check --fix $(PYTHON_FILES)
27+
28+
spell_check:
29+
codespell --toml pyproject.toml
30+
31+
spell_fix:
32+
codespell --toml pyproject.toml -w
33+
34+
######################
35+
# RUN ALL
36+
######################
37+
38+
all: format lint spell_check
39+
40+
######################
41+
# HELP
42+
######################
43+
44+
help:
45+
@echo '----'
46+
@echo 'format - run code formatters'
47+
@echo 'lint - run linters'
48+
@echo 'spell_check - run spell check'
49+
@echo 'all - run all tasks'
50+
@echo 'lint-fix - run lint and fix issues'
51+
52+
######################
53+
# LINT-FIX TARGET
54+
######################
55+
56+
lint-fix: format lint
57+
@echo "Linting and fixing completed successfully."

README.md

+10-11
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515

1616
## Supported LLM's
1717

18-
- GPT-3.5-turbo-0125
19-
- CodeLlama-70B
20-
- Mistral Medium
18+
- GPT-4o
19+
- Gemini Flash 1.5 8B
20+
- Claude 3 Haiku
21+
- Llama 3.2 3B
22+
- Llama 3.1 405B
2123

2224
#
2325

@@ -27,11 +29,12 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6-
2729

2830
## 🌟 Features
2931

30-
- **Conversational AI**: Harnesses ChatGPT to translate natural language into precise SQL queries.
32+
- **Conversational AI**: Use ChatGPT and other models to translate natural language into precise SQL queries.
3133
- **Conversational Memory**: Retains context for interactive, dynamic responses.
3234
- **Snowflake Integration**: Offers seamless, real-time data insights straight from your Snowflake database.
3335
- **Self-healing SQL**: Proactively suggests solutions for SQL errors, streamlining data access.
3436
- **Interactive User Interface**: Transforms data querying into an engaging conversation, complete with a chat reset option.
37+
- **Agent-based Architecture**: Utilizes an agent to manage interactions and tool usage.
3538

3639
## 🛠️ Installation
3740

@@ -42,7 +45,9 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6-
4245
cd snowchat
4346
pip install -r requirements.txt
4447

45-
3. Set up your `OPENAI_API_KEY`, `ACCOUNT`, `USER_NAME`, `PASSWORD`, `ROLE`, `DATABASE`, `SCHEMA`, `WAREHOUSE`, `SUPABASE_URL` , `SUPABASE_SERVICE_KEY` and `REPLICATE_API_TOKEN` in project directory `secrets.toml`.
48+
3. Set up your `OPENAI_API_KEY`, `ACCOUNT`, `USER_NAME`, `PASSWORD`, `ROLE`, `DATABASE`, `SCHEMA`, `WAREHOUSE`, `SUPABASE_URL` , `SUPABASE_SERVICE_KEY`, `SUPABASE_STORAGE_URL`,`CLOUDFLARE_ACCOUNT_ID`, `CLOUDFLARE_NAMESPACE_ID`,
49+
`CLOUDFLARE_API_TOKEN` in project directory `secrets.toml`.
50+
Cloudflare is used here for caching Snowflake responses in KV.
4651

4752
4. Make you're schemas and store them in docs folder that matches you're database.
4853

@@ -53,12 +58,6 @@ https://github.com/kaarthik108/snowChat/assets/53030784/24105e23-69d3-4676-b6d6-
5358
7. Run the Streamlit app to start chatting:
5459
streamlit run main.py
5560

56-
## 🚀 Additional Enhancements
57-
58-
1. **Platform Integration**: Connect snowChat with popular communication platforms like Slack or Discord for seamless interaction.
59-
2. **Voice Integration**: Implement voice recognition and text-to-speech functionality to make the chatbot more interactive and user-friendly.
60-
3. **Advanced Analytics**: Integrate with popular data visualization libraries like Plotly or Matplotlib to generate interactive visualizations based on the user's queries (AutoGPT).
61-
6261
## Star History
6362

6463
[![Star History Chart](https://api.star-history.com/svg?repos=kaarthik108/snowChat&type=Date)]

agent.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import os
2+
import streamlit as st
3+
from dataclasses import dataclass
4+
from typing import Annotated, Sequence, Optional
5+
6+
from langchain.callbacks.base import BaseCallbackHandler
7+
from langchain_anthropic import ChatAnthropic
8+
from langchain_core.messages import SystemMessage
9+
from langchain_openai import ChatOpenAI
10+
from langgraph.checkpoint.memory import MemorySaver
11+
from langgraph.graph import START, StateGraph
12+
from langgraph.prebuilt import ToolNode, tools_condition
13+
from langgraph.graph.message import add_messages
14+
from langchain_core.messages import BaseMessage
15+
16+
from tools import retriever_tool
17+
from tools import search, sql_executor_tool
18+
from PIL import Image
19+
from io import BytesIO
20+
21+
@dataclass
22+
class MessagesState:
23+
messages: Annotated[Sequence[BaseMessage], add_messages]
24+
25+
26+
memory = MemorySaver()
27+
28+
29+
@dataclass
30+
class ModelConfig:
31+
model_name: str
32+
api_key: str
33+
base_url: Optional[str] = None
34+
35+
36+
model_configurations = {
37+
"gpt-4o": ModelConfig(
38+
model_name="gpt-4o", api_key=st.secrets["OPENAI_API_KEY"]
39+
),
40+
"Gemini Flash 1.5 8B": ModelConfig(
41+
model_name="google/gemini-flash-1.5-8b",
42+
api_key=st.secrets["OPENROUTER_API_KEY"],
43+
base_url="https://openrouter.ai/api/v1",
44+
),
45+
"claude3-haiku": ModelConfig(
46+
model_name="claude-3-haiku-20240307", api_key=st.secrets["ANTHROPIC_API_KEY"]
47+
),
48+
"llama-3.2-3b": ModelConfig(
49+
model_name="accounts/fireworks/models/llama-v3p2-3b-instruct",
50+
api_key=st.secrets["FIREWORKS_API_KEY"],
51+
base_url="https://api.fireworks.ai/inference/v1",
52+
),
53+
"llama-3.1-405b": ModelConfig(
54+
model_name="accounts/fireworks/models/llama-v3p1-405b-instruct",
55+
api_key=st.secrets["FIREWORKS_API_KEY"],
56+
base_url="https://api.fireworks.ai/inference/v1",
57+
),
58+
}
59+
sys_msg = SystemMessage(
60+
content="""You're an AI assistant specializing in data analysis with Snowflake SQL. When providing responses, strive to exhibit friendliness and adopt a conversational tone, similar to how a friend or tutor would communicate. Do not ask the user for schema or database details. You have access to the following tools:
61+
- Database_Schema: This tool allows you to search for database schema details when needed to generate the SQL code.
62+
- Internet_Search: This tool allows you to search the internet for snowflake sql related information when needed to generate the SQL code.
63+
- Snowflake_SQL_Executor: This tool allows you to execute snowflake sql queries when needed to generate the SQL code. You only have read access to the database, do not modify the database in any way.
64+
"""
65+
)
66+
tools = [retriever_tool, search, sql_executor_tool]
67+
68+
def create_agent(callback_handler: BaseCallbackHandler, model_name: str) -> StateGraph:
69+
config = model_configurations.get(model_name)
70+
if not config:
71+
raise ValueError(f"Unsupported model name: {model_name}")
72+
73+
if not config.api_key:
74+
raise ValueError(f"API key for model '{model_name}' is not set. Please check your environment variables or secrets configuration.")
75+
76+
llm = (
77+
ChatOpenAI(
78+
model=config.model_name,
79+
api_key=config.api_key,
80+
callbacks=[callback_handler],
81+
streaming=True,
82+
base_url=config.base_url,
83+
temperature=0.01,
84+
)
85+
if config.model_name != "claude-3-haiku-20240307"
86+
else ChatAnthropic(
87+
model=config.model_name,
88+
api_key=config.api_key,
89+
callbacks=[callback_handler],
90+
streaming=True,
91+
)
92+
)
93+
94+
llm_with_tools = llm.bind_tools(tools)
95+
96+
def llm_agent(state: MessagesState):
97+
return {"messages": [llm_with_tools.invoke([sys_msg] + state.messages)]}
98+
99+
builder = StateGraph(MessagesState)
100+
builder.add_node("llm_agent", llm_agent)
101+
builder.add_node("tools", ToolNode(tools))
102+
builder.add_edge(START, "llm_agent")
103+
builder.add_conditional_edges("llm_agent", tools_condition)
104+
builder.add_edge("tools", "llm_agent")
105+
106+
react_graph = builder.compile(checkpointer=memory)
107+
108+
# png_data = react_graph.get_graph(xray=True).draw_mermaid_png()
109+
# with open("graph.png", "wb") as f:
110+
# f.write(png_data)
111+
112+
# image = Image.open(BytesIO(png_data))
113+
# st.image(image, caption="React Graph")
114+
115+
return react_graph

0 commit comments

Comments
 (0)