Skip to content

Commit 7858491

Browse files
committed
pkg/aflow: fix Temperature handling
If LLMAgent.Temperature is assigned an untyped float const (0.5) it will be typed as float64 rather than float32. So recast them. Cap Temperature at model's supported MaxTemperature.
1 parent bbddaf1 commit 7858491

File tree

4 files changed

+35
-16
lines changed

4 files changed

+35
-16
lines changed

pkg/aflow/execute.go

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ func (flow *Flow) Execute(ctx context.Context, model, workdir string, inputs map
3838
state: maps.Clone(inputs),
3939
onEvent: onEvent,
4040
}
41+
4142
defer c.close()
4243
if s := ctx.Value(stubContextKey); s != nil {
4344
c.stubContext = *s.(*stubContext)
@@ -143,10 +144,17 @@ var (
143144
createClientOnce sync.Once
144145
createClientErr error
145146
client *genai.Client
146-
modelList = make(map[string]bool)
147+
modelList = make(map[string]*modelInfo)
147148
stubContextKey = contextKeyType(1)
148149
)
149150

151+
type modelInfo struct {
152+
Thinking bool
153+
MaxTemperature float32
154+
InputTokenLimit int
155+
OutputTokenLimit int
156+
}
157+
150158
func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateContentConfig,
151159
req []*genai.Content) (*genai.GenerateContentResponse, error) {
152160
const modelPrefix = "models/"
@@ -165,19 +173,30 @@ func (ctx *Context) generateContentGemini(model string, cfg *genai.GenerateConte
165173
createClientErr = err
166174
return
167175
}
168-
modelList[strings.TrimPrefix(m.Name, modelPrefix)] = m.Thinking
176+
if !slices.Contains(m.SupportedActions, "generateContent") ||
177+
strings.Contains(m.Name, "-image") ||
178+
strings.Contains(m.Name, "-audio") {
179+
continue
180+
}
181+
modelList[strings.TrimPrefix(m.Name, modelPrefix)] = &modelInfo{
182+
Thinking: m.Thinking,
183+
MaxTemperature: m.MaxTemperature,
184+
InputTokenLimit: int(m.InputTokenLimit),
185+
OutputTokenLimit: int(m.OutputTokenLimit),
186+
}
169187
}
170188
})
171189
if createClientErr != nil {
172190
return nil, createClientErr
173191
}
174-
thinking, ok := modelList[model]
175-
if !ok {
192+
info := modelList[model]
193+
if info == nil {
176194
models := slices.Collect(maps.Keys(modelList))
177195
slices.Sort(models)
178196
return nil, fmt.Errorf("model %q does not exist (models: %v)", model, models)
179197
}
180-
if thinking {
198+
*cfg.Temperature = min(*cfg.Temperature, info.MaxTemperature)
199+
if info.Thinking {
181200
// Don't alter the original object (that may affect request caching).
182201
cfgCopy := *cfg
183202
cfg = &cfgCopy

pkg/aflow/flow_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ func TestToolMisbehavior(t *testing.T) {
321321
&LLMAgent{
322322
Name: "smarty",
323323
Model: "model",
324-
Temperature: 1,
324+
Temperature: 0.5,
325325
Reply: "Reply",
326326

327327
Outputs: LLMOutputs[struct {

pkg/aflow/llm_agent.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ type LLMAgent struct {
3535
// Value that controls the degree of randomness in token selection.
3636
// Lower temperatures are good for prompts that require a less open-ended or creative response,
3737
// while higher temperatures can lead to more diverse or creative results.
38-
// Must be assigned a float32 value in the range [0, 2].
38+
// Must be assigned a number in the range [0, 2].
3939
Temperature any
4040
// If set, the agent will generate that many candidates and the outputs will be arrays
4141
// instead of scalars.
@@ -245,7 +245,7 @@ func (a *LLMAgent) config(ctx *Context) (*genai.GenerateContentConfig, string, m
245245
}
246246
return &genai.GenerateContentConfig{
247247
ResponseModalities: []string{"TEXT"},
248-
Temperature: genai.Ptr(a.Temperature.(float32)),
248+
Temperature: genai.Ptr(float32(a.Temperature.(float64))),
249249
SystemInstruction: genai.NewContentFromText(instruction, genai.RoleUser),
250250
Tools: tools,
251251
}, instruction, toolMap
@@ -399,10 +399,10 @@ func (a *LLMAgent) verify(ctx *verifyContext) {
399399
ctx.requireNotEmpty(a.Name, "Model", a.Model)
400400
ctx.requireNotEmpty(a.Name, "Reply", a.Reply)
401401
if temp, ok := a.Temperature.(int); ok {
402-
a.Temperature = float32(temp)
402+
a.Temperature = float64(temp)
403403
}
404-
if temp, ok := a.Temperature.(float32); !ok || temp < 0 || temp > 2 {
405-
ctx.errorf(a.Name, "Temperature must have a float32 value in the range [0, 2]")
404+
if temp, ok := a.Temperature.(float64); !ok || temp < 0 || temp > 2 {
405+
ctx.errorf(a.Name, "Temperature must be a number in the range [0, 2]")
406406
}
407407
if a.Candidates < 0 || a.Candidates > 100 {
408408
ctx.errorf(a.Name, "Candidates must be in the range [0, 100]")

pkg/aflow/testdata/TestToolMisbehavior.llm.json

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
],
1111
"role": "user"
1212
},
13-
"temperature": 1,
13+
"temperature": 0.5,
1414
"tools": [
1515
{
1616
"functionDeclarations": [
@@ -132,7 +132,7 @@
132132
],
133133
"role": "user"
134134
},
135-
"temperature": 1,
135+
"temperature": 0.5,
136136
"tools": [
137137
{
138138
"functionDeclarations": [
@@ -367,7 +367,7 @@
367367
],
368368
"role": "user"
369369
},
370-
"temperature": 1,
370+
"temperature": 0.5,
371371
"tools": [
372372
{
373373
"functionDeclarations": [
@@ -618,7 +618,7 @@
618618
],
619619
"role": "user"
620620
},
621-
"temperature": 1,
621+
"temperature": 0.5,
622622
"tools": [
623623
{
624624
"functionDeclarations": [
@@ -915,7 +915,7 @@
915915
],
916916
"role": "user"
917917
},
918-
"temperature": 1,
918+
"temperature": 0.5,
919919
"tools": [
920920
{
921921
"functionDeclarations": [

0 commit comments

Comments
 (0)