|
| 1 | +from langchain.chains import ( |
| 2 | + create_history_aware_retriever, |
| 3 | +) |
| 4 | +from langchain.chains.combine_documents import create_stuff_documents_chain |
| 5 | +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| 6 | +from langchain_google_vertexai import ChatVertexAI |
| 7 | +from langchain_ai_agent.retriever.vector_store import DocumentEmbedder |
| 8 | +from langchain_core.runnables import RunnableLambda, RunnableMap |
| 9 | + |
| 10 | +# LangGraph memory imports |
| 11 | +from langgraph.checkpoint.memory import MemorySaver |
| 12 | +from langchain.prompts import PromptTemplate |
| 13 | +from langgraph.graph import START, END, StateGraph |
| 14 | +from langchain_core.messages import AIMessage, HumanMessage |
| 15 | +from typing import TypedDict, Annotated, Sequence |
| 16 | +import operator |
| 17 | +import logging |
| 18 | + |
| 19 | +# Configure logging |
| 20 | +logger = logging.getLogger(__name__) |
| 21 | +logging.basicConfig(level=logging.INFO) |
| 22 | + |
| 23 | +class AgentState(TypedDict): |
| 24 | + messages: Annotated[Sequence[HumanMessage | AIMessage], operator.add] |
| 25 | + question: str |
| 26 | + graph_output: str |
| 27 | + |
| 28 | + |
| 29 | +def get_chat_agent_with_memory(persist_dir: str): |
| 30 | + """ |
| 31 | + Creates and returns an agent that maintains persistent conversation memory. |
| 32 | + The agent is invoked using .ainvoke({"question": ...}, config={"configurable": {"thread_id": ...}}) |
| 33 | + """ |
| 34 | + embedder = DocumentEmbedder(persist_dir=persist_dir) |
| 35 | + retriever = embedder.get_retriever(k=10) |
| 36 | + |
| 37 | + llm = ChatVertexAI( |
| 38 | + model_name="gemini-2.0-flash-lite", |
| 39 | + temperature=0.3, |
| 40 | + max_output_tokens=1024, |
| 41 | + ) |
| 42 | + |
| 43 | + contextualize_q_system_prompt = ( |
| 44 | + "Given a chat history and the latest user question " |
| 45 | + "which might reference context in the chat history, " |
| 46 | + "formulate a standalone question that can be understood " |
| 47 | + "without the chat history. Do NOT answer the question; just " |
| 48 | + "reformulate it if needed and otherwise return it as is." |
| 49 | + ) |
| 50 | + contextualize_q_prompt = ChatPromptTemplate.from_messages( |
| 51 | + [ |
| 52 | + ("system", contextualize_q_system_prompt), |
| 53 | + MessagesPlaceholder(variable_name="chat_history"), |
| 54 | + ("human", "{input}"), |
| 55 | + ] |
| 56 | + ) |
| 57 | + |
| 58 | + try: |
| 59 | + history_aware_retriever = create_history_aware_retriever( |
| 60 | + llm, retriever, contextualize_q_prompt |
| 61 | + ) |
| 62 | + except Exception as e: |
| 63 | + logger.info(f"[Retriever] {e}") |
| 64 | + |
| 65 | + qa_system_prompt = ( |
| 66 | + "You are an assistant for question-answering tasks. Use the following " |
| 67 | + "pieces of retrieved context to answer the question. If you don't know " |
| 68 | + "the answer, just say that you don't know. Use three sentences maximum and " |
| 69 | + "keep the answer concise." |
| 70 | + ) |
| 71 | + |
| 72 | + stuff_prompt = PromptTemplate( |
| 73 | + template="""You are an assistant for question-answering tasks. |
| 74 | + Use the following pieces of context to answer the question. If you don't know the answer, just say that you don't know. |
| 75 | + Context: |
| 76 | + {context} |
| 77 | + Question: |
| 78 | + {question} |
| 79 | + Answer:""", |
| 80 | + input_variables=["context", "question"] |
| 81 | + ) |
| 82 | + |
| 83 | + try: |
| 84 | + combine_docs_chain = create_stuff_documents_chain(llm, stuff_prompt) |
| 85 | + except Exception as e: |
| 86 | + logger.info(f"[Chain] {e}") |
| 87 | + |
| 88 | + retrieval_chain = RunnableMap({ |
| 89 | + "context": lambda x: retriever.invoke(x["input"]), |
| 90 | + "question": lambda x: x["input"] |
| 91 | + }) | combine_docs_chain |
| 92 | + |
| 93 | + |
| 94 | + graph_builder = StateGraph(AgentState) |
| 95 | + |
| 96 | + def call_model(state: AgentState) -> dict: |
| 97 | + logger.info(f"[call_model] Full state: {state}") |
| 98 | + question = state.get("question", "") |
| 99 | + logger.info(f"[call_model] Extracted question: {question}") |
| 100 | + |
| 101 | + if not question: |
| 102 | + return { |
| 103 | + "messages": [AIMessage(content="[call_model] Empty or missing question.")], |
| 104 | + "graph_output": "[call_model] Empty or missing question." |
| 105 | + } |
| 106 | + |
| 107 | + chat_history = state.get("messages", []) |
| 108 | + |
| 109 | + chain_input = { |
| 110 | + "input": question, |
| 111 | + "chat_history": chat_history |
| 112 | + } |
| 113 | + |
| 114 | + chain_output = retrieval_chain.invoke(chain_input) |
| 115 | + answer_text = chain_output |
| 116 | + |
| 117 | + logger.info(f"[Agent] Response: {answer_text}") |
| 118 | + |
| 119 | + return { |
| 120 | + "messages": [AIMessage(content=answer_text or "[call_model] No answer generated.")], |
| 121 | + "graph_output": answer_text or "[call_model] No answer generated." |
| 122 | + } |
| 123 | + |
| 124 | + graph_builder.add_node("model", call_model) |
| 125 | + graph_builder.set_entry_point("model") |
| 126 | + graph_builder.add_edge("model", END) |
| 127 | + |
| 128 | + memory = MemorySaver() |
| 129 | + app = graph_builder.compile(checkpointer=memory) |
| 130 | + return app |
0 commit comments