Skip to content

Commit 8023ae0

Browse files
committed
Add function call validation to catch parity errors
1 parent 8647d47 commit 8023ae0

2 files changed

Lines changed: 600 additions & 6 deletions

File tree

internal/server/transform.go

Lines changed: 135 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ func parseGeminiPath(path string) (model, action string) {
3030
if len(matches) < 3 {
3131
return "", ""
3232
}
33+
3334
return matches[1], matches[2]
3435
}
3536

@@ -111,18 +112,19 @@ func unwrapCloudCodeResponse(cloudCodeResp map[string]interface{}) map[string]in
111112
// Build the standard Gemini response by merging fields
112113
geminiResp := make(map[string]interface{})
113114

114-
// Copy all fields from the response object
115-
for k, v := range response {
116-
geminiResp[k] = v
117-
}
118-
119-
// Copy other top-level fields (except "response")
115+
// Copy top-level fields first (except "response")
120116
for k, v := range cloudCodeResp {
121117
if k != "response" {
122118
geminiResp[k] = v
123119
}
124120
}
125121

122+
// Then, copy all fields from the nested response object.
123+
// This ensures the nested response's fields (like 'candidates') take precedence.
124+
for k, v := range response {
125+
geminiResp[k] = v
126+
}
127+
126128
return geminiResp
127129
}
128130

@@ -156,6 +158,13 @@ func (s *Server) TransformRequest(r *http.Request, body []byte) (*http.Request,
156158
Dur("json_parse_duration", parseDuration).
157159
Msg("JSON parsing complete")
158160

161+
// Validate function call/response parity before sending to CloudCode
162+
// This helps catch issues early and provides better error messages
163+
if err := validateFunctionCallParity(requestData); err != nil {
164+
logger.Get().Error().Err(err).Msg("Function call/response parity validation failed")
165+
return nil, err
166+
}
167+
159168
// Extract model and action from the path
160169
model, action := parseGeminiPath(r.URL.Path)
161170

@@ -360,3 +369,123 @@ func TransformJSONResponse(body []byte) []byte {
360369

361370
return transformedJSON
362371
}
372+
373+
// countFunctionCalls counts the number of functionCall parts in a turn
374+
func countFunctionCalls(turn map[string]interface{}) int {
375+
count := 0
376+
377+
// Check if this turn has parts
378+
parts, ok := turn["parts"].([]interface{})
379+
if !ok {
380+
return 0
381+
}
382+
383+
// Count functionCall parts
384+
for _, part := range parts {
385+
partMap, ok := part.(map[string]interface{})
386+
if !ok {
387+
continue
388+
}
389+
if _, hasFunctionCall := partMap["functionCall"]; hasFunctionCall {
390+
count++
391+
}
392+
}
393+
394+
return count
395+
}
396+
397+
// countFunctionResponses counts the number of functionResponse parts in a turn
398+
func countFunctionResponses(turn map[string]interface{}) int {
399+
count := 0
400+
401+
// Check if this turn has parts
402+
parts, ok := turn["parts"].([]interface{})
403+
if !ok {
404+
return 0
405+
}
406+
407+
// Count functionResponse parts
408+
for _, part := range parts {
409+
partMap, ok := part.(map[string]interface{})
410+
if !ok {
411+
continue
412+
}
413+
if _, hasFunctionResponse := partMap["functionResponse"]; hasFunctionResponse {
414+
count++
415+
}
416+
}
417+
418+
return count
419+
}
420+
421+
// validateFunctionCallParity checks that function calls and responses match
422+
// Returns an error if there's a mismatch
423+
func validateFunctionCallParity(requestData map[string]interface{}) error {
424+
// Extract contents array
425+
contents, ok := requestData["contents"].([]interface{})
426+
if !ok || len(contents) == 0 {
427+
// No contents, nothing to validate
428+
return nil
429+
}
430+
431+
// Iterate through contents to find model turns with function calls
432+
for i := 0; i < len(contents); i++ {
433+
currentTurn, ok := contents[i].(map[string]interface{})
434+
if !ok {
435+
continue
436+
}
437+
438+
// Check if this is a model turn
439+
role, hasRole := currentTurn["role"].(string)
440+
if !hasRole || role != "model" {
441+
continue
442+
}
443+
444+
// Count function calls in this model turn
445+
functionCallCount := countFunctionCalls(currentTurn)
446+
if functionCallCount == 0 {
447+
continue // No function calls in this turn
448+
}
449+
450+
// Check if this is the last turn (no following turn for responses)
451+
if i == len(contents)-1 {
452+
return fmt.Errorf(
453+
"Function call/response parity violation: model turn at index %d (last turn) has %d function calls but no following user turn with responses. Function calls must be followed by a user turn with matching function responses.",
454+
i, functionCallCount,
455+
)
456+
}
457+
458+
// Check the next turn
459+
nextTurn, ok := contents[i+1].(map[string]interface{})
460+
if !ok {
461+
return fmt.Errorf(
462+
"Function call/response parity violation: model turn at index %d has %d function calls but the following turn at index %d is invalid.",
463+
i, functionCallCount, i+1,
464+
)
465+
}
466+
467+
// The next turn MUST be a user turn with function responses
468+
nextRole, hasNextRole := nextTurn["role"].(string)
469+
if !hasNextRole || nextRole != "user" {
470+
// This is an error - function calls must be immediately followed by user responses
471+
return fmt.Errorf(
472+
"Function call/response parity violation: model turn at index %d has %d function calls but is followed by a %s turn instead of a user turn with function responses.",
473+
i, functionCallCount, nextRole,
474+
)
475+
}
476+
477+
// Count function responses in the user turn
478+
functionResponseCount := countFunctionResponses(nextTurn)
479+
480+
// Check for parity
481+
if functionCallCount != functionResponseCount {
482+
return fmt.Errorf(
483+
"Function call/response parity violation: model turn at index %d has %d function calls, but following user turn has %d function responses. Every function call must have exactly one corresponding response.",
484+
i, functionCallCount, functionResponseCount,
485+
)
486+
}
487+
488+
}
489+
490+
return nil
491+
}

0 commit comments

Comments
 (0)