Skip to content

Commit ee02027

Browse files
authored
Tool call support for LangchainLLM (#3542)
1 parent 8ba032e commit ee02027

File tree

2 files changed

+95
-21
lines changed

2 files changed

+95
-21
lines changed

mem0/llms/langchain.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
try:
77
from langchain.chat_models.base import BaseChatModel
8+
from langchain_core.messages import AIMessage
89
except ImportError:
910
raise ImportError("langchain is not installed. Please install it using `pip install langchain`")
1011

@@ -21,6 +22,35 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
2122

2223
self.langchain_model = self.config.model
2324

25+
def _parse_response(self, response: AIMessage, tools: Optional[List[Dict]]):
26+
"""
27+
Process the response based on whether tools are used or not.
28+
29+
Args:
30+
response: AI Message.
31+
tools: The list of tools provided in the request.
32+
33+
Returns:
34+
str or dict: The processed response.
35+
"""
36+
if not tools:
37+
return response.content
38+
39+
processed_response = {
40+
"content": response.content,
41+
"tool_calls": [],
42+
}
43+
44+
for tool_call in response.tool_calls:
45+
processed_response["tool_calls"].append(
46+
{
47+
"name": tool_call["name"],
48+
"arguments": tool_call["args"],
49+
}
50+
)
51+
52+
return processed_response
53+
2454
def generate_response(
2555
self,
2656
messages: List[Dict[str, str]],
@@ -34,32 +64,31 @@ def generate_response(
3464
Args:
3565
messages (list): List of message dicts containing 'role' and 'content'.
3666
response_format (str or object, optional): Format of the response. Not used in Langchain.
37-
tools (list, optional): List of tools that the model can call. Not used in Langchain.
38-
tool_choice (str, optional): Tool choice method. Not used in Langchain.
67+
tools (list, optional): List of tools that the model can call.
68+
tool_choice (str, optional): Tool choice method.
3969
4070
Returns:
4171
str: The generated response.
4272
"""
43-
try:
44-
# Convert the messages to LangChain's tuple format
45-
langchain_messages = []
46-
for message in messages:
47-
role = message["role"]
48-
content = message["content"]
49-
50-
if role == "system":
51-
langchain_messages.append(("system", content))
52-
elif role == "user":
53-
langchain_messages.append(("human", content))
54-
elif role == "assistant":
55-
langchain_messages.append(("ai", content))
73+
# Convert the messages to LangChain's tuple format
74+
langchain_messages = []
75+
for message in messages:
76+
role = message["role"]
77+
content = message["content"]
5678

57-
if not langchain_messages:
58-
raise ValueError("No valid messages found in the messages list")
79+
if role == "system":
80+
langchain_messages.append(("system", content))
81+
elif role == "user":
82+
langchain_messages.append(("human", content))
83+
elif role == "assistant":
84+
langchain_messages.append(("ai", content))
5985

60-
ai_message = self.langchain_model.invoke(langchain_messages)
86+
if not langchain_messages:
87+
raise ValueError("No valid messages found in the messages list")
6188

62-
return ai_message.content
89+
langchain_model = self.langchain_model
90+
if tools:
91+
langchain_model = langchain_model.bind_tools(tools=tools, tool_choice=tool_choice)
6392

64-
except Exception as e:
65-
raise Exception(f"Error generating response using langchain model: {str(e)}")
93+
response: AIMessage = langchain_model.invoke(langchain_messages)
94+
return self._parse_response(response, tools)

tests/llms/test_langchain.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,51 @@ def test_generate_response(mock_langchain_model):
6868
assert response == "This is a test response"
6969

7070

71+
def test_generate_response_with_tools(mock_langchain_model):
72+
config = BaseLlmConfig(model=mock_langchain_model, temperature=0.7, max_tokens=100, api_key="test-api-key")
73+
llm = LangchainLLM(config)
74+
75+
messages = [
76+
{"role": "system", "content": "You are a helpful assistant."},
77+
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
78+
]
79+
tools = [
80+
{
81+
"type": "function",
82+
"function": {
83+
"name": "add_memory",
84+
"description": "Add a memory",
85+
"parameters": {
86+
"type": "object",
87+
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
88+
"required": ["data"],
89+
},
90+
},
91+
}
92+
]
93+
94+
mock_response = Mock()
95+
mock_response.content = "I've added the memory for you."
96+
97+
mock_tool_call = Mock()
98+
mock_tool_call.__getitem__ = Mock(
99+
side_effect={"name": "add_memory", "args": {"data": "Today is a sunny day."}}.__getitem__
100+
)
101+
102+
mock_response.tool_calls = [mock_tool_call]
103+
mock_langchain_model.invoke.return_value = mock_response
104+
mock_langchain_model.bind_tools.return_value = mock_langchain_model
105+
106+
response = llm.generate_response(messages, tools=tools)
107+
108+
mock_langchain_model.invoke.assert_called_once()
109+
110+
assert response["content"] == "I've added the memory for you."
111+
assert len(response["tool_calls"]) == 1
112+
assert response["tool_calls"][0]["name"] == "add_memory"
113+
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
114+
115+
71116
def test_invalid_model():
72117
"""Test that LangchainLLM raises an error with an invalid model."""
73118
config = BaseLlmConfig(model="not-a-valid-model-instance", temperature=0.7, max_tokens=100, api_key="test-api-key")

0 commit comments

Comments
 (0)