Skip to content

Commit af8fe89

Browse files
committed
Add SiliconFlow as an LLM provider and test it out with various test cases
1 parent ec1a60b commit af8fe89

File tree

5 files changed

+233
-0
lines changed

5 files changed

+233
-0
lines changed

mem0/llms/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def validate_config(cls, v, values):
1515
"ollama",
1616
"anthropic",
1717
"groq",
18+
"siliconflow",
1819
"together",
1920
"aws_bedrock",
2021
"litellm",

mem0/llms/siliconflow.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import json
2+
import os
3+
import requests
4+
from typing import Dict, List, Optional, Any
5+
6+
from mem0.configs.llms.base import BaseLlmConfig
7+
from mem0.llms.base import LLMBase
8+
from mem0.memory.utils import extract_json
9+
10+
11+
class SiliconFlowLLM(LLMBase):
12+
"""
13+
SiliconFlow chat completion provider.
14+
Docs:
15+
https://docs.siliconflow.com/en/api-reference/chat-completions/chat-completions
16+
"""
17+
18+
def __init__(self, config: Optional[BaseLlmConfig] = None):
19+
super().__init__(config)
20+
21+
if not self.config.model:
22+
self.config.model = "Qwen/Qwen2.5-7B-Instruct"
23+
24+
self.api_key = self.config.api_key or os.getenv("SILICONFLOW_API_KEY")
25+
if not self.api_key:
26+
raise ValueError("SiliconFlow API key not found. Set SILICONFLOW_API_KEY or pass via config.api_key.")
27+
28+
# Allow override of base URL via config or environment (docs show .com domain)
29+
self.base_url = (
30+
getattr(self.config, "base_url", None)
31+
or os.getenv("SILICONFLOW_BASE_URL")
32+
or "https://api.siliconflow.com/v1"
33+
)
34+
35+
# Pre-build headers
36+
self.headers = {
37+
"Authorization": f"Bearer {self.api_key}",
38+
"Content-Type": "application/json",
39+
}
40+
41+
def _endpoint(self) -> str:
42+
return f"{self.base_url}/chat/completions"
43+
44+
def _parse_response(self, data: Dict[str, Any], tools: Optional[List[Dict]]) -> Any:
45+
"""
46+
Matches structure similar to OpenAI-like responses.
47+
"""
48+
try:
49+
choice = data["choices"][0]
50+
message = choice.get("message", {})
51+
except (KeyError, IndexError):
52+
raise ValueError(f"Unexpected SiliconFlow response format: {data}")
53+
54+
if tools:
55+
processed = {"content": message.get("content"), "tool_calls": []}
56+
# If SiliconFlow returns tool_calls similar to OpenAI:
57+
for tc in message.get("tool_calls", []) or []:
58+
try:
59+
name = tc["function"]["name"]
60+
raw_args = tc["function"].get("arguments", "{}")
61+
# Ensure JSON object parsing
62+
args = json.loads(extract_json(raw_args))
63+
processed["tool_calls"].append({"name": name, "arguments": args})
64+
except Exception:
65+
# Fallback raw
66+
processed["tool_calls"].append(
67+
{
68+
"name": tc.get("function", {}).get("name"),
69+
"arguments": tc.get("function", {}).get("arguments"),
70+
}
71+
)
72+
return processed
73+
else:
74+
return message.get("content")
75+
76+
def generate_response(
77+
self,
78+
messages: List[Dict[str, str]],
79+
response_format=None,
80+
tools: Optional[List[Dict]] = None,
81+
tool_choice: str = "auto",
82+
):
83+
"""
84+
Create chat completion via SiliconFlow.
85+
Adjust request body keys if docs differ.
86+
"""
87+
payload: Dict[str, Any] = {
88+
"model": self.config.model,
89+
"messages": messages,
90+
"temperature": self.config.temperature,
91+
"top_p": self.config.top_p,
92+
"max_tokens": self.config.max_tokens,
93+
}
94+
95+
# Response format (if SiliconFlow supports 'response_format': {"type": "json_object"})
96+
if response_format:
97+
payload["response_format"] = response_format
98+
99+
# Tool / function calling (verify exact schema in docs; may differ)
100+
if tools:
101+
payload["tools"] = tools
102+
# Some APIs expect {"type":"function","function":{...}} structures
103+
# tool_choice might be "auto" / {"type":"function","function":{"name":"..."}}
104+
payload["tool_choice"] = tool_choice
105+
106+
resp = requests.post(self._endpoint(), headers=self.headers, json=payload, timeout=60)
107+
if resp.status_code >= 400:
108+
extra_hint = ""
109+
if resp.status_code == 401:
110+
extra_hint = (
111+
" (401 Unauthorized: Verify SILICONFLOW_API_KEY is correct and matches the domain "
112+
f"{self.base_url.split('/v1')[0]}; you can also set SILICONFLOW_BASE_URL if needed)"
113+
)
114+
raise RuntimeError(f"SiliconFlow error {resp.status_code}: {resp.text}{extra_hint}")
115+
116+
data = resp.json()
117+
return self._parse_response(data, tools)

mem0/utils/factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class LlmFactory:
3636
"ollama": ("mem0.llms.ollama.OllamaLLM", OllamaConfig),
3737
"openai": ("mem0.llms.openai.OpenAILLM", OpenAIConfig),
3838
"groq": ("mem0.llms.groq.GroqLLM", BaseLlmConfig),
39+
"siliconflow": ("mem0.llms.siliconflow.SiliconFlowLLM", BaseLlmConfig),
3940
"together": ("mem0.llms.together.TogetherLLM", BaseLlmConfig),
4041
"aws_bedrock": ("mem0.llms.aws_bedrock.AWSBedrockLLM", BaseLlmConfig),
4142
"litellm": ("mem0.llms.litellm.LiteLLM", BaseLlmConfig),

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies = [
2121
"pytz>=2024.1",
2222
"sqlalchemy>=2.0.31",
2323
"protobuf>=5.29.0,<6.0.0",
24+
"requests>=2.32.0",
2425
]
2526

2627
[project.optional-dependencies]

tests/llms/test_siliconflow.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import os
2+
from unittest.mock import Mock, patch
3+
4+
import pytest
5+
6+
from mem0.configs.llms.base import BaseLlmConfig
7+
from mem0.llms.siliconflow import SiliconFlowLLM
8+
9+
10+
@patch("mem0.llms.siliconflow.requests.post")
11+
def test_generate_response_without_tools(mock_post, monkeypatch):
12+
monkeypatch.setenv("SILICONFLOW_API_KEY", "test-key")
13+
config = BaseLlmConfig(model="Qwen/Qwen2.5-7B-Instruct", temperature=0.3, max_tokens=64, top_p=1.0)
14+
llm = SiliconFlowLLM(config)
15+
messages = [
16+
{"role": "system", "content": "You are a helpful assistant."},
17+
{"role": "user", "content": "Hello SiliconFlow"},
18+
]
19+
20+
mock_post.return_value = Mock(
21+
status_code=200,
22+
json=lambda: {"choices": [{"message": {"content": "Hello back!"}}]},
23+
)
24+
25+
response = llm.generate_response(messages)
26+
27+
mock_post.assert_called_once()
28+
called_payload = mock_post.call_args.kwargs["json"]
29+
assert called_payload["model"] == config.model
30+
assert called_payload["messages"][1]["content"] == "Hello SiliconFlow"
31+
assert response == "Hello back!"
32+
33+
34+
@patch("mem0.llms.siliconflow.requests.post")
35+
def test_generate_response_with_tools(mock_post, monkeypatch):
36+
monkeypatch.setenv("SILICONFLOW_API_KEY", "test-key")
37+
config = BaseLlmConfig(model="Qwen/Qwen2.5-7B-Instruct", temperature=0.3, max_tokens=64, top_p=1.0)
38+
llm = SiliconFlowLLM(config)
39+
messages = [
40+
{"role": "system", "content": "You are a helpful assistant."},
41+
{"role": "user", "content": "Call a tool"},
42+
]
43+
44+
tools = [
45+
{
46+
"type": "function",
47+
"function": {
48+
"name": "echo",
49+
"description": "Echo input",
50+
"parameters": {
51+
"type": "object",
52+
"properties": {"text": {"type": "string"}},
53+
"required": ["text"],
54+
},
55+
},
56+
}
57+
]
58+
59+
mock_post.return_value = Mock(
60+
status_code=200,
61+
json=lambda: {
62+
"choices": [
63+
{
64+
"message": {
65+
"content": "Tool called.",
66+
"tool_calls": [{"function": {"name": "echo", "arguments": '{"text":"hi"}'}}],
67+
}
68+
}
69+
]
70+
},
71+
)
72+
73+
response = llm.generate_response(messages, tools=tools)
74+
75+
mock_post.assert_called_once()
76+
called_payload = mock_post.call_args.kwargs["json"]
77+
assert called_payload["tools"] == tools
78+
assert response["content"] == "Tool called."
79+
assert len(response["tool_calls"]) == 1
80+
assert response["tool_calls"][0]["name"] == "echo"
81+
assert response["tool_calls"][0]["arguments"]["text"] == "hi"
82+
83+
84+
@patch("mem0.llms.siliconflow.requests.post")
85+
def test_generate_response_error(mock_post, monkeypatch):
86+
monkeypatch.setenv("SILICONFLOW_API_KEY", "test-key")
87+
config = BaseLlmConfig(model="Qwen/Qwen2.5-7B-Instruct", temperature=0.3, max_tokens=64, top_p=1.0)
88+
llm = SiliconFlowLLM(config)
89+
90+
mock_post.return_value = Mock(status_code=500, text="Internal Error")
91+
92+
import pytest
93+
94+
with pytest.raises(RuntimeError):
95+
llm.generate_response([{"role": "user", "content": "Hi"}])
96+
97+
98+
# ------------------------- LIVE INTEGRATION (optional) ------------------------- #
99+
@pytest.mark.skipif(not os.getenv("SILICONFLOW_API_KEY"), reason="No SiliconFlow API key set")
100+
def test_siliconflow_live_basic():
101+
"""Live call to SiliconFlow API (non-mocked). Skipped if no key.
102+
Keeps tokens low to control cost.
103+
Set SILICONFLOW_MODEL to override model name.
104+
"""
105+
model = os.getenv("SILICONFLOW_MODEL", "Qwen/QwQ-32B")
106+
cfg = BaseLlmConfig(model=model, temperature=0.2, max_tokens=64, top_p=0.9)
107+
llm = SiliconFlowLLM(cfg)
108+
109+
prompt = "In one concise sentence, say hello from SiliconFlow integration test."
110+
resp = llm.generate_response([{"role": "user", "content": prompt}])
111+
112+
assert isinstance(resp, str)
113+
assert resp.strip() and resp.strip() != prompt

0 commit comments

Comments
 (0)