Skip to content

Commit 30cebb1

Browse files
committed
feat: Tool calling
1 parent 1abc110 commit 30cebb1

File tree

6 files changed

+173
-5
lines changed

6 files changed

+173
-5
lines changed

Changelog.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Changelog
22

3+
## 11.4.0
4+
- Tool calling for chat tasks where the model supports it
5+
36
## 11.3.0
47

58
- Drop support for python3.9

aleph_alpha_client/chat.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import asdict, dataclass
33
from enum import Enum
44
from io import BytesIO
5-
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
5+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, Literal
66

77
from pydantic import BaseModel
88
from aleph_alpha_client.structured_output import ResponseFormat
@@ -41,6 +41,40 @@ def to_json(self) -> Mapping[str, Any]:
4141
return result
4242

4343

44+
@dataclass(frozen=True)
45+
class FunctionCall:
46+
name: str
47+
arguments: str
48+
49+
50+
@dataclass(frozen=True)
51+
class ToolCall:
52+
id: str
53+
type: str
54+
function: FunctionCall
55+
56+
@staticmethod
57+
def from_json(json: Dict[str, Any]) -> "ToolCall":
58+
function = json["function"]
59+
return ToolCall(
60+
id=json["id"],
61+
type=json["type"],
62+
function=FunctionCall(
63+
name=function["name"], arguments=function["arguments"]
64+
),
65+
)
66+
67+
def to_json(self) -> Mapping[str, Any]:
68+
return {
69+
"id": self.id,
70+
"type": self.type,
71+
"function": {
72+
"name": self.function.name,
73+
"arguments": self.function.arguments,
74+
},
75+
}
76+
77+
4478
# We introduce a more specific message type because chat responses can only
4579
# contain text at the moment. This enables static type checking to proof that
4680
# `content` is always a string.
@@ -59,12 +93,17 @@ class TextMessage:
5993

6094
role: Role
6195
content: str
96+
tool_calls: Optional[List[ToolCall]] = None
6297

6398
@staticmethod
6499
def from_json(json: Dict[str, Any]) -> "TextMessage":
100+
tool_calls = json.get("tool_calls")
65101
return TextMessage(
66102
role=Role(json["role"]),
67103
content=json["content"],
104+
tool_calls=None
105+
if tool_calls is None
106+
else [ToolCall.from_json(tool_call) for tool_call in tool_calls],
68107
)
69108

70109
# In multi-turn conversations the returned TextMessage is part of the chat
@@ -76,6 +115,8 @@ def to_json(self) -> Mapping[str, Any]:
76115
"role": self.role.value,
77116
"content": _message_content_to_json(self.content),
78117
}
118+
if self.tool_calls is not None:
119+
result["tool_calls"] = [t.to_json() for t in self.tool_calls]
79120
return result
80121

81122

@@ -122,6 +163,12 @@ class StreamOptions:
122163
include_usage: bool
123164

124165

166+
@dataclass(frozen=True)
167+
class ToolFunction:
168+
type: Literal["function"]
169+
function: Any
170+
171+
125172
@dataclass(frozen=True)
126173
class ChatRequest:
127174
"""
@@ -141,6 +188,12 @@ class ChatRequest:
141188
steering_concepts: Optional[List[str]] = None
142189
response_format: Optional[ResponseFormat] = None
143190

191+
tools: Optional[List[Any]] = None
192+
tool_choice: Optional[Union[Literal["auto", "required", "none"], ToolFunction]] = (
193+
None
194+
)
195+
parallel_tool_calls: Optional[bool] = None
196+
144197
def to_json(self) -> Mapping[str, Any]:
145198
payload = {k: v for k, v in asdict(self).items() if v is not None}
146199
payload["messages"] = [message.to_json() for message in self.messages]
@@ -164,7 +217,7 @@ class FinishReason(str, Enum):
164217
"""
165218
The reason the model stopped generating tokens.
166219
167-
This will be stop if the model hit a natural stop point or a provided stop
220+
This will be `stop` if the model hit a natural stop point or a provided stop
168221
sequence or length if the maximum number of tokens specified in the request
169222
was reached. If the API is unable to understand the stop reason emitted by
170223
one of the workers, content_filter is returned.
@@ -173,6 +226,7 @@ class FinishReason(str, Enum):
173226
Stop = "stop"
174227
Length = "length"
175228
ContentFilter = "content_filter"
229+
ToolCalls = "tool_calls"
176230

177231

178232
@dataclass(frozen=True)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "aleph-alpha-client"
3-
version = "11.3.0"
3+
version = "11.4.0"
44
description = "python client to interact with Aleph Alpha api endpoints"
55
authors = [{ name = "Aleph Alpha", email = "[email protected]" }]
66
requires-python = ">=3.10,<3.14"
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
interactions:
2+
- request:
3+
body:
4+
messages:
5+
- content: You are a helpful assistant.
6+
role: system
7+
- content: What is the weather like in Paris today?
8+
role: user
9+
model: qwen3-32b-tool
10+
tools:
11+
- function:
12+
description: Get current temperature for a given location.
13+
name: get_weather
14+
parameters:
15+
additionalProperties: false
16+
properties:
17+
location:
18+
description: "City and country e.g. Bogot\xE1, Colombia"
19+
type: string
20+
required:
21+
- location
22+
type: object
23+
strict: true
24+
type: function
25+
headers: {}
26+
method: POST
27+
uri: https://inference-api.stage.product.pharia.com/chat/completions
28+
response:
29+
body:
30+
string: '{"id":"chatcmpl-11b3f640-841a-478a-93cb-0c7ac98fc3da","choices":[{"finish_reason":"tool_calls","index":0,"message":{"role":"assistant","content":"\n\n","reasoning_content":"\nOkay,
31+
the user is asking about the weather in Paris today. I need to figure out
32+
which function to use. The available tool is get_weather, which requires a
33+
location parameter. Paris is the city mentioned, and the country is France.
34+
So I should format the location as \"Paris, France\". Let me make sure there
35+
are no other parameters needed. The function only needs the location, so I''ll
36+
construct the tool call with that.\n","tool_calls":[{"id":"chatcmpl-tool-2370633f184e43d8a700b78806cb1083","type":"function","function":{"name":"get_weather","arguments":"{\"location\":
37+
\"Paris, France\"}"}}]},"logprobs":null}],"created":1755691940,"model":"qwen3-32b-tool","system_fingerprint":null,"object":"chat.completion","usage":{"prompt_tokens":188,"completion_tokens":114,"total_tokens":302}}'
38+
headers:
39+
Access-Control-Allow-Credentials:
40+
- 'true'
41+
Access-Control-Expose-Headers:
42+
- content-type
43+
Connection:
44+
- keep-alive
45+
Content-Encoding:
46+
- gzip
47+
Content-Type:
48+
- application/json
49+
Date:
50+
- Wed, 20 Aug 2025 12:12:23 GMT
51+
Strict-Transport-Security:
52+
- max-age=31536000; includeSubDomains
53+
Transfer-Encoding:
54+
- chunked
55+
Vary:
56+
- Origin, Access-Control-Request-Method, Access-Control-Request-Headers
57+
- accept-encoding
58+
status:
59+
code: 200
60+
message: OK
61+
version: 1

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ def structured_output_model_name() -> str:
116116
return "qwen3-32b-tool"
117117

118118

119+
@pytest.fixture(scope="session")
120+
def tool_calling_model_name() -> str:
121+
return "qwen3-32b-tool"
122+
123+
119124
@pytest.fixture(scope="session")
120125
def dummy_model_name() -> str:
121126
return "dummy-model"

tests/test_chat.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,53 @@ async def test_can_chat_with_async_client(
6565
assert response.message.content is not None
6666

6767

68+
TOOLS = [
69+
{
70+
"type": "function",
71+
"function": {
72+
"name": "get_weather",
73+
"description": "Get current temperature for a given location.",
74+
"parameters": {
75+
"type": "object",
76+
"properties": {
77+
"location": {
78+
"type": "string",
79+
"description": "City and country e.g. Bogotá, Colombia",
80+
}
81+
},
82+
"required": ["location"],
83+
"additionalProperties": False,
84+
},
85+
"strict": True,
86+
},
87+
}
88+
]
89+
90+
91+
@pytest.mark.vcr
92+
async def test_can_chat_with_tools(
93+
async_client: AsyncClient, tool_calling_model_name: str
94+
):
95+
system_msg = Message(role=Role.System, content="You are a helpful assistant.")
96+
user_msg = Message(
97+
role=Role.User, content="What is the weather like in Paris today?"
98+
)
99+
request = ChatRequest(
100+
messages=[system_msg, user_msg],
101+
model=tool_calling_model_name,
102+
tools=TOOLS,
103+
)
104+
105+
response = await async_client.chat(request, model=tool_calling_model_name)
106+
assert response.message.role == Role.Assistant
107+
assert response.message.content is not None
108+
assert response.message.tool_calls is not None
109+
calls = response.message.tool_calls
110+
assert len(calls) == 1
111+
assert calls[0].type == "function"
112+
assert calls[0].function.name == "get_weather"
113+
114+
68115
@pytest.mark.vcr
69116
async def test_can_chat_with_streaming_support(
70117
async_client: AsyncClient, chat_model_name: str
@@ -263,7 +310,6 @@ def test_response_format_json_schema(
263310
assert field in json_response.keys(), (
264311
f"Required field '{field}' is missing from response"
265312
)
266-
267313
# Validate field types
268314
assert isinstance(json_response["nemo"], str), "Field 'nemo' should be a string"
269315
assert isinstance(json_response["species"], str), (
@@ -273,7 +319,6 @@ def test_response_format_json_schema(
273319
assert isinstance(json_response["size_cm"], (int, float)), (
274320
"Field 'size_cm' should be a number"
275321
)
276-
277322
# Validate size constraints
278323
assert 0.1 <= json_response["size_cm"] <= 100.0, (
279324
"Field 'size_cm' should be between 0.1 and 100.0"

0 commit comments

Comments
 (0)