Skip to content

Commit 6a6788e

Browse files
authored
Add optional --temperature parameter for spice chat CLI command (spiceai#5429)
1 parent ee84b63 commit 6a6788e

1 file changed

Lines changed: 49 additions & 6 deletions

File tree

bin/spice/cmd/chat.go

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ const (
4040
modelKeyFlag = "model"
4141
httpEndpointKeyFlag = "http-endpoint"
4242
userAgentKeyFlag = "user-agent"
43+
temperatureFlag = "temperature"
4344
)
4445

4546
type Message struct {
@@ -52,6 +53,38 @@ type ChatRequestBody struct {
5253
Model string `json:"model"`
5354
Stream bool `json:"stream"`
5455
StreamOptions *StreamOptions `json:"stream_options"`
56+
ChatRequestOptions
57+
}
58+
59+
// ChatRequestOptions contains all optional fields for chat requests
60+
type ChatRequestOptions struct {
61+
Temperature *float32 `json:"temperature,omitempty"`
62+
}
63+
64+
func NewChatRequestBody(messages []Message, model string, stream bool, streamOptions *StreamOptions) *ChatRequestBody {
65+
return &ChatRequestBody{
66+
Messages: messages,
67+
Model: model,
68+
Stream: stream,
69+
StreamOptions: streamOptions,
70+
}
71+
}
72+
73+
func ApplyChatOptions(body *ChatRequestBody, cmd *cobra.Command) (*ChatRequestBody, error) {
74+
if cmd.Flags().Changed("temperature") {
75+
temperature, err := cmd.Flags().GetFloat32("temperature")
76+
if err != nil {
77+
slog.Error("could not get temperature flag", "error", err)
78+
os.Exit(1)
79+
}
80+
if temperature < 0 {
81+
slog.Error("temperature must be greater than or equal to 0")
82+
os.Exit(1)
83+
}
84+
body.Temperature = &temperature
85+
}
86+
87+
return body, nil
5588
}
5689

5790
type StreamOptions struct {
@@ -120,6 +153,16 @@ spice chat --model <model> --cloud
120153
rtcontext.SetApiKey(apiKey)
121154
}
122155

156+
temperature, err := cmd.Flags().GetFloat32("temperature")
157+
if err != nil {
158+
slog.Error("could not get temperature flag", "error", err)
159+
os.Exit(1)
160+
}
161+
if temperature < 0 {
162+
slog.Error("temperature must be greater than or equal to 0")
163+
os.Exit(1)
164+
}
165+
123166
userAgent, _ := cmd.Flags().GetString(userAgentKeyFlag)
124167
if userAgent != "" {
125168
rtcontext.SetUserAgent(userAgent)
@@ -215,12 +258,11 @@ spice chat --model <model> --cloud
215258
util.ShowSpinner(done)
216259
}()
217260

218-
body := &ChatRequestBody{
219-
Messages: messages,
220-
Model: model,
221-
Stream: true,
222-
StreamOptions: &StreamOptions{IncludeUsage: true},
223-
}
261+
body := NewChatRequestBody(messages, model, true, &StreamOptions{
262+
IncludeUsage: true,
263+
})
264+
body, _ = ApplyChatOptions(body, cmd)
265+
224266
var timeAtCompletion time.Time
225267
var timeAtFirstToken time.Time
226268
startTime := time.Now()
@@ -382,6 +424,7 @@ func init() {
382424
chatCmd.Flags().String(modelKeyFlag, "", "Model to chat with")
383425
chatCmd.Flags().String(httpEndpointKeyFlag, "", "HTTP endpoint for chat (default: http://localhost:8090)")
384426
chatCmd.Flags().String(userAgentKeyFlag, "", "User agent to use in all requests")
427+
chatCmd.Flags().Float32(temperatureFlag, 1, "Model temperature for chat request")
385428

386429
RootCmd.AddCommand(chatCmd)
387430
}

0 commit comments

Comments
 (0)