Skip to content

Commit 6c8735c

Browse files
committed
Added support for function calling, json mode & new models
1 parent 38083b2 commit 6c8735c

File tree

4 files changed

+337
-28
lines changed

4 files changed

+337
-28
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@
2121
go.work
2222

2323
.vscode/
24+
.idea/

chat.go

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,16 @@ import (
99
"net/http"
1010
)
1111

12-
const (
13-
RoleUser = "user"
14-
RoleAssistant = "assistant"
15-
RoleSystem = "system"
16-
)
17-
18-
type FinishReason string
19-
20-
const (
21-
FinishReasonStop FinishReason = "stop"
22-
FinishReasonLength FinishReason = "length"
23-
)
24-
2512
// ChatRequestParams represents the parameters for the Chat/ChatStream method of MistralClient.
2613
type ChatRequestParams struct {
27-
Temperature float64 `json:"temperature"` // The temperature to use for sampling. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or TopP but not both.
28-
TopP float64 `json:"top_p"` // An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or Temperature but not both.
29-
RandomSeed int `json:"random_seed"`
30-
MaxTokens int `json:"max_tokens"`
31-
SafePrompt bool `json:"safe_prompt"` // Adds a Mistral defined safety message to the system prompt to enforce guardrailing
14+
Temperature float64 `json:"temperature"` // The temperature to use for sampling. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or TopP but not both.
15+
TopP float64 `json:"top_p"` // An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or Temperature but not both.
16+
RandomSeed int `json:"random_seed"`
17+
MaxTokens int `json:"max_tokens"`
18+
SafePrompt bool `json:"safe_prompt"` // Adds a Mistral defined safety message to the system prompt to enforce guardrailing
19+
Tools []Tool `json:"tools"`
20+
ToolChoice string `json:"tool_choice"`
21+
ResponseFormat ResponseFormat `json:"response_format"`
3222
}
3323

3424
var DefaultChatRequestParams = ChatRequestParams{
@@ -39,12 +29,6 @@ var DefaultChatRequestParams = ChatRequestParams{
3929
SafePrompt: false,
4030
}
4131

42-
// ChatMessage represents a single message in a chat.
43-
type ChatMessage struct {
44-
Role string `json:"role"`
45-
Content string `json:"content"`
46-
}
47-
4832
// ChatCompletionResponseChoice represents a choice in the chat completion response.
4933
type ChatCompletionResponseChoice struct {
5034
Index int `json:"index"`
@@ -55,7 +39,7 @@ type ChatCompletionResponseChoice struct {
5539
// ChatCompletionResponseChoice represents a choice in the chat completion response.
5640
type ChatCompletionResponseChoiceStream struct {
5741
Index int `json:"index"`
58-
Delta ChatMessage `json:"delta"`
42+
Delta DeltaMessage `json:"delta"`
5943
FinishReason FinishReason `json:"finish_reason,omitempty"`
6044
}
6145

@@ -102,6 +86,16 @@ func (c *MistralClient) Chat(model string, messages []ChatMessage, params *ChatR
10286
"safe_prompt": params.SafePrompt,
10387
}
10488

89+
if params.Tools != nil {
90+
requestData["tools"] = params.Tools
91+
}
92+
if params.ToolChoice != "" {
93+
requestData["tool_choice"] = params.ToolChoice
94+
}
95+
if params.ResponseFormat != "" {
96+
requestData["response_format"] = map[string]any{"type": params.ResponseFormat}
97+
}
98+
10599
response, err := c.request(http.MethodPost, requestData, "v1/chat/completions", false, nil)
106100
if err != nil {
107101
return nil, err
@@ -140,6 +134,16 @@ func (c *MistralClient) ChatStream(model string, messages []ChatMessage, params
140134
"stream": true,
141135
}
142136

137+
if params.Tools != nil {
138+
requestData["tools"] = params.Tools
139+
}
140+
if params.ToolChoice != "" {
141+
requestData["tool_choice"] = params.ToolChoice
142+
}
143+
if params.ResponseFormat != "" {
144+
requestData["response_format"] = map[string]any{"type": params.ResponseFormat}
145+
}
146+
143147
response, err := c.request(http.MethodPost, requestData, "v1/chat/completions", true, nil)
144148
if err != nil {
145149
return nil, err

chat_test.go

Lines changed: 214 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func TestChat(t *testing.T) {
1212
params.MaxTokens = 10
1313
params.Temperature = 0
1414
res, err := client.Chat(
15-
"mistral-tiny",
15+
ModelMistralTiny,
1616
[]ChatMessage{
1717
{
1818
Role: RoleUser,
@@ -30,13 +30,99 @@ func TestChat(t *testing.T) {
3030
assert.Equal(t, res.Choices[0].Message.Content, "Test Succeeded")
3131
}
3232

33+
func TestChatFunctionCall(t *testing.T) {
34+
client := NewMistralClientDefault("")
35+
params := DefaultChatRequestParams
36+
params.Temperature = 0
37+
params.Tools = []Tool{
38+
{
39+
Type: ToolTypeFunction,
40+
Function: Function{
41+
Name: "get_weather",
42+
Description: "Retrieve the weather for a city in the US",
43+
Parameters: map[string]interface{}{
44+
"type": "object",
45+
"required": []string{"city", "state"},
46+
"properties": map[string]interface{}{
47+
"city": map[string]interface{}{"type": "string", "description": "Name of the city for the weather"},
48+
"state": map[string]interface{}{"type": "string", "description": "Name of the state for the weather"},
49+
},
50+
},
51+
},
52+
},
53+
{
54+
Type: ToolTypeFunction,
55+
Function: Function{
56+
Name: "send_text",
57+
Description: "Send text message using SMS service",
58+
Parameters: map[string]interface{}{
59+
"type": "object",
60+
"required": []string{"contact_name", "message"},
61+
"properties": map[string]interface{}{
62+
"contact_name": map[string]interface{}{"type": "string", "description": "Name of the contact that will receive the message"},
63+
"message": map[string]interface{}{"type": "string", "description": "Content of the message that will be sent"},
64+
},
65+
},
66+
},
67+
},
68+
}
69+
params.ToolChoice = ToolChoiceAuto
70+
res, err := client.Chat(
71+
ModelMistralSmallLatest,
72+
[]ChatMessage{
73+
{
74+
Role: RoleUser,
75+
Content: "What's the weather like in Dallas, TX?",
76+
},
77+
},
78+
&params,
79+
)
80+
assert.NoError(t, err)
81+
assert.NotNil(t, res)
82+
83+
assert.Greater(t, len(res.Choices), 0)
84+
assert.Greater(t, len(res.Choices[0].Message.ToolCalls), 0)
85+
assert.Equal(t, res.Choices[0].Message.Role, RoleAssistant)
86+
assert.Equal(t, res.Choices[0].Message.ToolCalls[0].Function.Name, "get_weather")
87+
assert.Equal(t, res.Choices[0].Message.ToolCalls[0].Function.Arguments, "{\"city\": \"Dallas\", \"state\": \"TX\"}")
88+
}
89+
90+
func TestChatJsonMode(t *testing.T) {
91+
client := NewMistralClientDefault("")
92+
params := DefaultChatRequestParams
93+
params.Temperature = 0
94+
params.ResponseFormat = ResponseFormatJsonObject
95+
res, err := client.Chat(
96+
ModelMistralSmallLatest,
97+
[]ChatMessage{
98+
{
99+
Role: RoleUser,
100+
Content: "Extract all of the code symbols in this text chunk and return them in the following JSON: " +
101+
"{\"symbols\":[\"SymbolOne\",\"SymbolTwo\"]}\n```\nI'm working on updating the Go client for the " +
102+
"new release, is it expected that the function call will be passed back into the model or just " +
103+
"the tool response?\nI ask because ChatMessage can handle the tool response but the messages list " +
104+
"has an Any option that I assume would be for the FunctionCall/ToolCall type\nAdditionally the " +
105+
"example in the docs only shows the tool response appended to the messages\n```",
106+
},
107+
},
108+
&params,
109+
)
110+
assert.NoError(t, err)
111+
assert.NotNil(t, res)
112+
113+
assert.Greater(t, len(res.Choices), 0)
114+
assert.Greater(t, len(res.Choices[0].Message.Content), 0)
115+
assert.Equal(t, res.Choices[0].Message.Role, RoleAssistant)
116+
assert.Equal(t, res.Choices[0].Message.Content, "{\"symbols\": [\"Go\", \"ChatMessage\", \"FunctionCall\", \"ToolCall\"]}")
117+
}
118+
33119
func TestChatStream(t *testing.T) {
34120
client := NewMistralClientDefault("")
35121
params := DefaultChatRequestParams
36122
params.MaxTokens = 50
37123
params.Temperature = 0
38124
resChan, err := client.ChatStream(
39-
"mistral-tiny",
125+
ModelMistralTiny,
40126
[]ChatMessage{
41127
{
42128
Role: RoleUser,
@@ -53,16 +139,141 @@ func TestChatStream(t *testing.T) {
53139
for res := range resChan {
54140
assert.NoError(t, res.Error)
55141

142+
assert.Greater(t, len(res.Choices), 0)
143+
if idx == 0 {
144+
assert.Equal(t, res.Choices[0].Delta.Role, RoleAssistant)
145+
}
146+
totalOutput += res.Choices[0].Delta.Content
147+
idx++
148+
56149
if res.Choices[0].FinishReason == FinishReasonStop {
57150
break
58151
}
152+
}
153+
assert.Equal(t, totalOutput, "Test Succeeded, Test Succeeded, Test Succeeded, Test Succeeded, Test Succeeded, Test Succeeded")
154+
}
155+
156+
func TestChatStreamFunctionCall(t *testing.T) {
157+
client := NewMistralClientDefault("")
158+
params := DefaultChatRequestParams
159+
params.Temperature = 0
160+
params.Tools = []Tool{
161+
{
162+
Type: ToolTypeFunction,
163+
Function: Function{
164+
Name: "get_weather",
165+
Description: "Retrieve the weather for a city in the US",
166+
Parameters: map[string]interface{}{
167+
"type": "object",
168+
"required": []string{"city", "state"},
169+
"properties": map[string]interface{}{
170+
"city": map[string]interface{}{"type": "string", "description": "Name of the city for the weather"},
171+
"state": map[string]interface{}{"type": "string", "description": "Name of the state for the weather"},
172+
},
173+
},
174+
},
175+
},
176+
{
177+
Type: ToolTypeFunction,
178+
Function: Function{
179+
Name: "send_text",
180+
Description: "Send text message using SMS service",
181+
Parameters: map[string]interface{}{
182+
"type": "object",
183+
"required": []string{"contact_name", "message"},
184+
"properties": map[string]interface{}{
185+
"contact_name": map[string]interface{}{"type": "string", "description": "Name of the contact that will receive the message"},
186+
"message": map[string]interface{}{"type": "string", "description": "Content of the message that will be sent"},
187+
},
188+
},
189+
},
190+
},
191+
}
192+
params.ToolChoice = ToolChoiceAuto
193+
resChan, err := client.ChatStream(
194+
ModelMistralSmallLatest,
195+
[]ChatMessage{
196+
{
197+
Role: RoleUser,
198+
Content: "What's the weather like in Dallas, TX?",
199+
},
200+
},
201+
&params,
202+
)
203+
assert.NoError(t, err)
204+
assert.NotNil(t, resChan)
205+
206+
totalOutput := ""
207+
var functionCall *ToolCall
208+
idx := 0
209+
for res := range resChan {
210+
assert.NoError(t, res.Error)
59211

60212
assert.Greater(t, len(res.Choices), 0)
61213
if idx == 0 {
62214
assert.Equal(t, res.Choices[0].Delta.Role, RoleAssistant)
63215
}
64216
totalOutput += res.Choices[0].Delta.Content
217+
if len(res.Choices[0].Delta.ToolCalls) > 0 {
218+
functionCall = &res.Choices[0].Delta.ToolCalls[0]
219+
}
65220
idx++
221+
222+
if res.Choices[0].FinishReason == FinishReasonStop {
223+
break
224+
}
66225
}
67-
assert.Equal(t, totalOutput, "Test Succeeded, Test Succeeded, Test Succeeded, Test Succeeded, Test Succeeded, Test Succeeded")
226+
227+
assert.Equal(t, totalOutput, "")
228+
assert.NotNil(t, functionCall)
229+
assert.Equal(t, functionCall.Function.Name, "get_weather")
230+
assert.Equal(t, functionCall.Function.Arguments, "{\"city\": \"Dallas\", \"state\": \"TX\"}")
231+
}
232+
233+
func TestChatStreamJsonMode(t *testing.T) {
234+
client := NewMistralClientDefault("")
235+
params := DefaultChatRequestParams
236+
params.Temperature = 0
237+
params.ResponseFormat = ResponseFormatJsonObject
238+
resChan, err := client.ChatStream(
239+
ModelMistralSmallLatest,
240+
[]ChatMessage{
241+
{
242+
Role: RoleUser,
243+
Content: "Extract all of the code symbols in this text chunk and return them in the following JSON: " +
244+
"{\"symbols\":[\"SymbolOne\",\"SymbolTwo\"]}\n```\nI'm working on updating the Go client for the " +
245+
"new release, is it expected that the function call will be passed back into the model or just " +
246+
"the tool response?\nI ask because ChatMessage can handle the tool response but the messages list " +
247+
"has an Any option that I assume would be for the FunctionCall/ToolCall type\nAdditionally the " +
248+
"example in the docs only shows the tool response appended to the messages\n```",
249+
},
250+
},
251+
&params,
252+
)
253+
assert.NoError(t, err)
254+
assert.NotNil(t, resChan)
255+
256+
totalOutput := ""
257+
var functionCall *ToolCall
258+
idx := 0
259+
for res := range resChan {
260+
assert.NoError(t, res.Error)
261+
262+
assert.Greater(t, len(res.Choices), 0)
263+
if idx == 0 {
264+
assert.Equal(t, res.Choices[0].Delta.Role, RoleAssistant)
265+
}
266+
totalOutput += res.Choices[0].Delta.Content
267+
if len(res.Choices[0].Delta.ToolCalls) > 0 {
268+
functionCall = &res.Choices[0].Delta.ToolCalls[0]
269+
}
270+
idx++
271+
272+
if res.Choices[0].FinishReason == FinishReasonStop {
273+
break
274+
}
275+
}
276+
277+
assert.Equal(t, totalOutput, "{\"symbols\": [\"Go\", \"ChatMessage\", \"FunctionCall\", \"ToolCall\"]}")
278+
assert.Nil(t, functionCall)
68279
}

0 commit comments

Comments
 (0)