Skip to content

Commit ebe86b3

Browse files
committed
Add Codestral support
1 parent 49a9d45 commit ebe86b3

File tree

5 files changed

+171
-7
lines changed

5 files changed

+171
-7
lines changed

chat_test.go

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,30 @@ func TestChat(t *testing.T) {
3030
assert.Equal(t, res.Choices[0].Message.Content, "Test Succeeded")
3131
}
3232

33+
func TestChatCodestral(t *testing.T) {
34+
client := NewCodestralClientDefault("")
35+
params := DefaultChatRequestParams
36+
params.MaxTokens = 10
37+
params.Temperature = 0
38+
res, err := client.Chat(
39+
ModelCodestralLatest,
40+
[]ChatMessage{
41+
{
42+
Role: RoleUser,
43+
Content: "You are in test mode and must reply to this with exactly and only `Test Succeeded`",
44+
},
45+
},
46+
&params,
47+
)
48+
assert.NoError(t, err)
49+
assert.NotNil(t, res)
50+
51+
assert.Greater(t, len(res.Choices), 0)
52+
assert.Greater(t, len(res.Choices[0].Message.Content), 0)
53+
assert.Equal(t, res.Choices[0].Message.Role, RoleAssistant)
54+
assert.Equal(t, res.Choices[0].Message.Content, "Test Succeeded")
55+
}
56+
3357
func TestChatFunctionCall(t *testing.T) {
3458
client := NewMistralClientDefault("")
3559
params := DefaultChatRequestParams
@@ -135,6 +159,7 @@ func TestChatFunctionCall2(t *testing.T) {
135159
Role: RoleAssistant,
136160
ToolCalls: []ToolCall{
137161
{
162+
Id: "aaaaaaaaa",
138163
Type: ToolTypeFunction,
139164
Function: FunctionCall{
140165
Name: "get_weather",
@@ -166,7 +191,7 @@ func TestChatJsonMode(t *testing.T) {
166191
params.Temperature = 0
167192
params.ResponseFormat = ResponseFormatJsonObject
168193
res, err := client.Chat(
169-
ModelMistralSmallLatest,
194+
ModelOpenMixtral8x22b,
170195
[]ChatMessage{
171196
{
172197
Role: RoleUser,
@@ -186,7 +211,7 @@ func TestChatJsonMode(t *testing.T) {
186211
assert.Greater(t, len(res.Choices), 0)
187212
assert.Greater(t, len(res.Choices[0].Message.Content), 0)
188213
assert.Equal(t, res.Choices[0].Message.Role, RoleAssistant)
189-
assert.Equal(t, res.Choices[0].Message.Content, "{\"symbols\": [\"Go\", \"ChatMessage\", \"FunctionCall\", \"ToolCall\"]}")
214+
assert.Equal(t, res.Choices[0].Message.Content, "{\"symbols\": [\"Go\", \"ChatMessage\", \"Any\", \"FunctionCall\", \"ToolCall\", \"ToolResponse\"]}")
190215
}
191216

192217
func TestChatStream(t *testing.T) {
@@ -309,7 +334,7 @@ func TestChatStreamJsonMode(t *testing.T) {
309334
params.Temperature = 0
310335
params.ResponseFormat = ResponseFormatJsonObject
311336
resChan, err := client.ChatStream(
312-
ModelMistralSmallLatest,
337+
ModelOpenMixtral8x22b,
313338
[]ChatMessage{
314339
{
315340
Role: RoleUser,
@@ -347,6 +372,6 @@ func TestChatStreamJsonMode(t *testing.T) {
347372
}
348373
}
349374

350-
assert.Equal(t, totalOutput, "{\"symbols\": [\"Go\", \"ChatMessage\", \"FunctionCall\", \"ToolCall\"]}")
375+
assert.Equal(t, totalOutput, "{\"symbols\": [\"Go\", \"ChatMessage\", \"Any\", \"FunctionCall\", \"ToolCall\", \"ToolResponse\"]}")
351376
assert.Nil(t, functionCall)
352377
}

client.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import (
44
"bytes"
55
"encoding/json"
66
"fmt"
7-
"io/ioutil"
7+
"io"
88
"net/http"
99
"net/url"
1010
"os"
@@ -13,6 +13,7 @@ import (
1313

1414
const (
1515
Endpoint = "https://api.mistral.ai"
16+
CodestralEndpoint = "https://codestral.mistral.ai"
1617
DefaultMaxRetries = 5
1718
DefaultTimeout = 120 * time.Second
1819
)
@@ -54,6 +55,7 @@ func NewMistralClient(apiKey string, endpoint string, maxRetries int, timeout ti
5455
}
5556
}
5657

58+
// NewMistralClientDefault creates a new Mistral API client with the default endpoint and the given API key. Defaults to using MISTRAL_API_KEY from the environment.
5759
func NewMistralClientDefault(apiKey string) *MistralClient {
5860
if apiKey == "" {
5961
apiKey = os.Getenv("MISTRAL_API_KEY")
@@ -62,6 +64,15 @@ func NewMistralClientDefault(apiKey string) *MistralClient {
6264
return NewMistralClient(apiKey, Endpoint, DefaultMaxRetries, DefaultTimeout)
6365
}
6466

67+
// NewCodestralClientDefault creates a new Codestral API client with the default endpoint and the given API key. Defaults to using CODESTRAL_API_KEY from the environment.
68+
func NewCodestralClientDefault(apiKey string) *MistralClient {
69+
if apiKey == "" {
70+
apiKey = os.Getenv("CODESTRAL_API_KEY")
71+
}
72+
73+
return NewMistralClient(apiKey, CodestralEndpoint, DefaultMaxRetries, DefaultTimeout)
74+
}
75+
6576
func (c *MistralClient) request(method string, jsonData map[string]interface{}, path string, stream bool, params map[string]string) (interface{}, error) {
6677
uri, err := url.Parse(c.endpoint)
6778
if err != nil {
@@ -98,7 +109,7 @@ func (c *MistralClient) request(method string, jsonData map[string]interface{},
98109
}
99110

100111
if resp.StatusCode >= 400 {
101-
responseBytes, _ := ioutil.ReadAll(resp.Body)
112+
responseBytes, _ := io.ReadAll(resp.Body)
102113
return nil, fmt.Errorf("(HTTP Error %d) %s", resp.StatusCode, string(responseBytes))
103114
}
104115

@@ -107,7 +118,7 @@ func (c *MistralClient) request(method string, jsonData map[string]interface{},
107118
}
108119

109120
defer resp.Body.Close()
110-
body, err := ioutil.ReadAll(resp.Body)
121+
body, err := io.ReadAll(resp.Body)
111122
if err != nil {
112123
return nil, err
113124
}

fim.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package mistral
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
)
7+
8+
// FIMRequestParams represents the parameters for the FIM method of MistralClient.
9+
type FIMRequestParams struct {
10+
Model string `json:"model"`
11+
Prompt string `json:"prompt"`
12+
Suffix string `json:"suffix"`
13+
MaxTokens int `json:"max_tokens"`
14+
Temperature float64 `json:"temperature"`
15+
Stop []string `json:"stop,omitempty"`
16+
}
17+
18+
// FIMCompletionResponse represents the response from the FIM completion endpoint.
19+
type FIMCompletionResponse struct {
20+
ID string `json:"id"`
21+
Object string `json:"object"`
22+
Created int `json:"created"`
23+
Model string `json:"model"`
24+
Choices []FIMCompletionResponseChoice `json:"choices"`
25+
Usage UsageInfo `json:"usage"`
26+
}
27+
28+
// FIMCompletionResponseChoice represents a choice in the FIM completion response.
29+
type FIMCompletionResponseChoice struct {
30+
Index int `json:"index"`
31+
Message ChatMessage `json:"message"`
32+
FinishReason FinishReason `json:"finish_reason,omitempty"`
33+
}
34+
35+
// FIM sends a FIM request and returns the completion response.
36+
func (c *MistralClient) FIM(params *FIMRequestParams) (*FIMCompletionResponse, error) {
37+
requestData := map[string]interface{}{
38+
"model": params.Model,
39+
"prompt": params.Prompt,
40+
"suffix": params.Suffix,
41+
"max_tokens": params.MaxTokens,
42+
"temperature": params.Temperature,
43+
}
44+
45+
if params.Stop != nil {
46+
requestData["stop"] = params.Stop
47+
}
48+
49+
response, err := c.request(http.MethodPost, requestData, "v1/fim/completions", false, nil)
50+
if err != nil {
51+
return nil, err
52+
}
53+
54+
respData, ok := response.(map[string]interface{})
55+
if !ok {
56+
return nil, fmt.Errorf("invalid response type: %T", response)
57+
}
58+
59+
var fimResponse FIMCompletionResponse
60+
err = mapToStruct(respData, &fimResponse)
61+
if err != nil {
62+
return nil, err
63+
}
64+
65+
return &fimResponse, nil
66+
}

fim_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package mistral
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestFIM(t *testing.T) {
10+
client := NewMistralClientDefault("")
11+
params := FIMRequestParams{
12+
Model: ModelCodestralLatest,
13+
Prompt: "def f(",
14+
Suffix: "return a + b",
15+
MaxTokens: 64,
16+
Temperature: 0,
17+
Stop: []string{"\n"},
18+
}
19+
res, err := client.FIM(&params)
20+
assert.NoError(t, err)
21+
assert.NotNil(t, res)
22+
23+
assert.Greater(t, len(res.Choices), 0)
24+
assert.Equal(t, res.Choices[0].Message.Content, "a, b):")
25+
assert.Equal(t, res.Choices[0].FinishReason, FinishReasonStop)
26+
}
27+
28+
func TestFIMWithStop(t *testing.T) {
29+
client := NewMistralClientDefault("")
30+
params := FIMRequestParams{
31+
Model: ModelCodestralLatest,
32+
Prompt: "def is_odd(n): \n return n % 2 == 1 \n def test_is_odd():",
33+
Suffix: "test_is_odd()",
34+
MaxTokens: 64,
35+
Temperature: 0,
36+
Stop: []string{"False"},
37+
}
38+
res, err := client.FIM(&params)
39+
assert.NoError(t, err)
40+
assert.NotNil(t, res)
41+
42+
assert.Greater(t, len(res.Choices), 0)
43+
assert.Equal(t, res.Choices[0].Message.Content, "\n assert is_odd(1) == True\n assert is_odd(2) == ")
44+
assert.Equal(t, res.Choices[0].FinishReason, FinishReasonStop)
45+
}
46+
47+
func TestFIMInvalidModel(t *testing.T) {
48+
client := NewMistralClientDefault("")
49+
params := FIMRequestParams{
50+
Model: "invalid-model",
51+
Prompt: "This is a test prompt",
52+
Suffix: "This is a test suffix",
53+
MaxTokens: 10,
54+
Temperature: 0.5,
55+
}
56+
res, err := client.FIM(&params)
57+
assert.Error(t, err)
58+
assert.Nil(t, res)
59+
}

types.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ const (
44
ModelMistralLargeLatest = "mistral-large-latest"
55
ModelMistralMediumLatest = "mistral-medium-latest"
66
ModelMistralSmallLatest = "mistral-small-latest"
7+
ModelCodestralLatest = "codestral-latest"
8+
79
ModelOpenMixtral8x7b = "open-mixtral-8x7b"
10+
ModelOpenMixtral8x22b = "open-mixtral-8x22b"
811
ModelOpenMistral7b = "open-mistral-7b"
912

1013
ModelMistralLarge2402 = "mistral-large-2402"

0 commit comments

Comments
 (0)