Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import (
"net/http"
)

const (
DefaultChatCompletion = "v1/chat/completions"
DefaultAgentChatCompletion = "v1/agents/completions"
)

// ChatRequestParams represents the parameters for the Chat/ChatStream method of MistralClient.
type ChatRequestParams struct {
Temperature float64 `json:"temperature"` // The temperature to use for sampling. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or TopP but not both.
Expand All @@ -19,6 +24,8 @@ type ChatRequestParams struct {
Tools []Tool `json:"tools"`
ToolChoice string `json:"tool_choice"`
ResponseFormat ResponseFormat `json:"response_format"`

AgentId string `json:"agent_id"` // If `agent_id` is provided, `model` and `temperature` will be overridden.
}

var DefaultChatRequestParams = ChatRequestParams{
Expand Down Expand Up @@ -76,15 +83,22 @@ func (c *MistralClient) Chat(model string, messages []ChatMessage, params *ChatR
params = &DefaultChatRequestParams
}

chatCompletion := DefaultChatCompletion

requestData := map[string]interface{}{
"model": model,
"messages": messages,
"temperature": params.Temperature,
"max_tokens": params.MaxTokens,
"top_p": params.TopP,
"random_seed": params.RandomSeed,
"safe_prompt": params.SafePrompt,
"messages": messages,
}
if params.AgentId != "" {
chatCompletion = DefaultAgentChatCompletion
requestData["agent_id"] = params.AgentId
} else {
requestData["model"] = model
requestData["temperature"] = params.Temperature
}
requestData["max_tokens"] = params.MaxTokens
requestData["top_p"] = params.TopP
requestData["random_seed"] = params.RandomSeed
requestData["safe_prompt"] = params.SafePrompt

if params.Tools != nil {
requestData["tools"] = params.Tools
Expand All @@ -96,7 +110,7 @@ func (c *MistralClient) Chat(model string, messages []ChatMessage, params *ChatR
requestData["response_format"] = map[string]any{"type": params.ResponseFormat}
}

response, err := c.request(http.MethodPost, requestData, "v1/chat/completions", false, nil)
response, err := c.request(http.MethodPost, requestData, chatCompletion, false, nil)
if err != nil {
return nil, err
}
Expand All @@ -121,18 +135,25 @@ func (c *MistralClient) ChatStream(model string, messages []ChatMessage, params
params = &DefaultChatRequestParams
}

chatCompletion := DefaultChatCompletion

responseChannel := make(chan ChatCompletionStreamResponse)

requestData := map[string]interface{}{
"model": model,
"messages": messages,
"temperature": params.Temperature,
"max_tokens": params.MaxTokens,
"top_p": params.TopP,
"random_seed": params.RandomSeed,
"safe_prompt": params.SafePrompt,
"stream": true,
"messages": messages,
}
if params.AgentId != "" {
chatCompletion = DefaultAgentChatCompletion
requestData["agent_id"] = params.AgentId
} else {
requestData["model"] = model
requestData["temperature"] = params.Temperature
}
requestData["max_tokens"] = params.MaxTokens
requestData["top_p"] = params.TopP
requestData["random_seed"] = params.RandomSeed
requestData["safe_prompt"] = params.SafePrompt
requestData["stream"] = true

if params.Tools != nil {
requestData["tools"] = params.Tools
Expand All @@ -144,7 +165,7 @@ func (c *MistralClient) ChatStream(model string, messages []ChatMessage, params
requestData["response_format"] = map[string]any{"type": params.ResponseFormat}
}

response, err := c.request(http.MethodPost, requestData, "v1/chat/completions", true, nil)
response, err := c.request(http.MethodPost, requestData, chatCompletion, true, nil)
if err != nil {
return nil, err
}
Expand Down
23 changes: 23 additions & 0 deletions examples/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,27 @@ func main() {
}

log.Printf("Embeddings response: %+v\n", embsRes)

// Example: Using Agent Chat Completions
agentReqParam := mistral.DefaultChatRequestParams
agentReqParam.AgentId = "your-agent-id"
agentRes, err := client.Chat("mistral-tiny", []mistral.ChatMessage{{Content: "Hello, world!", Role: mistral.RoleUser}}, &agentReqParam)
if err != nil {
log.Fatalf("Error getting chat completion: %v", err)
}
log.Printf("Agent chat completion: %+v\n", agentRes)

// Example: Using Agent Chat Completions Stream
agentStreamReqParam := mistral.DefaultChatRequestParams
agentStreamReqParam.AgentId = "your-agent-id"
agentResChan, err := client.ChatStream("mistral-tiny", []mistral.ChatMessage{{Content: "Hello, world!", Role: mistral.RoleUser}}, &agentStreamReqParam)
if err != nil {
log.Fatalf("Error getting chat completion stream: %v", err)
}
for agentResChunk := range agentResChan {
if agentResChunk.Error != nil {
log.Fatalf("Error while streaming response: %v", agentResChunk.Error)
}
log.Printf("Agent chat completion stream part: %+v\n", agentResChunk)
}
}