Skip to content

Commit 9e74c89

Browse files
committed
Fix tool schema conversion for Gemini API
The Gemini model was failing to respect 'required' parameter constraints in tool definitions, leading to invalid tool calls. The root cause was that the proxy was sending a standard JSON Schema to the Gemini API, but the API expects a proprietary, non-standard format (uppercase types, no '$schema' or 'additionalProperties' keywords). This caused the API to return 400 errors or silently ignore the schema. This commit refactors the tool schema conversion to use strongly-typed structs that represent the Gemini-specific format. A new conversion function recursively builds the schema, ensuring only supported fields are sent and types are correctly formatted. This also corrects the JSON key for the parameters object from 'parametersJsonSchema' to 'parameters' to match the target API. Additionally, this commit includes several robustness improvements discovered during debugging: - Enable TCP KeepAlive on the upstream HTTP client to prevent idle timeouts. - Add panic-safe checks when parsing the Gemini SSE stream. - Add a keep-alive pinger to the client-facing SSE stream.
1 parent da9c6cb commit 9e74c89

6 files changed

Lines changed: 150 additions & 23 deletions

File tree

internal/gemini/types.go

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,15 @@ type SystemInstruction struct {
2424
Parts []ContentPart `json:"parts,omitempty"`
2525
}
2626

27-
// JSONSchema represents a JSON schema.
28-
type JSONSchema map[string]interface{}
27+
// GeminiParameterSchema defines the proprietary schema format for Gemini function parameters.
28+
type GeminiParameterSchema struct {
29+
Type string `json:"type,omitempty"`
30+
Description string `json:"description,omitempty"`
31+
Properties map[string]*GeminiParameterSchema `json:"properties,omitempty"`
32+
Items *GeminiParameterSchema `json:"items,omitempty"`
33+
Required []string `json:"required,omitempty"`
34+
Enum []string `json:"enum,omitempty"`
35+
}
2936

3037
// FunctionCall represents a tool call emitted by the model.
3138
type FunctionCall struct {
@@ -41,9 +48,9 @@ type FunctionResponse struct {
4148

4249
// FunctionDeclaration defines a function that can be called by the model.
4350
type FunctionDeclaration struct {
44-
Name string `json:"name,omitempty"`
45-
Description string `json:"description,omitempty"`
46-
ParametersJsonSchema JSONSchema `json:"parametersJsonSchema,omitempty"`
51+
Name string `json:"name,omitempty"`
52+
Description string `json:"description,omitempty"`
53+
Parameters *GeminiParameterSchema `json:"parameters,omitempty"`
4754
}
4855

4956
// UnmarshalJSON: accept parametersJsonSchema (camelCase) and parameters (snake_case).
@@ -54,7 +61,7 @@ func (f *FunctionDeclaration) UnmarshalJSON(b []byte) error {
5461
var a alias
5562
if err := json.Unmarshal(b, &a); err == nil {
5663
*f = FunctionDeclaration(a)
57-
if f.ParametersJsonSchema != nil {
64+
if f.Parameters != nil {
5865
return nil
5966
}
6067
}
@@ -70,17 +77,17 @@ func (f *FunctionDeclaration) UnmarshalJSON(b []byte) error {
7077
}
7178
// snake_case key used by generativelanguage public API
7279
if v, ok := raw["parameters"]; ok {
73-
var m map[string]interface{}
74-
if err := json.Unmarshal(v, &m); err == nil {
75-
f.ParametersJsonSchema = m
80+
var schema GeminiParameterSchema
81+
if err := json.Unmarshal(v, &schema); err == nil {
82+
f.Parameters = &schema
7683
return nil
7784
}
7885
}
7986
// or explicit camelCase if present but empty earlier
80-
if v, ok := raw["parametersJsonSchema"]; ok && f.ParametersJsonSchema == nil {
81-
var m map[string]interface{}
82-
if err := json.Unmarshal(v, &m); err == nil {
83-
f.ParametersJsonSchema = m
87+
if v, ok := raw["parametersJsonSchema"]; ok && f.Parameters == nil {
88+
var schema GeminiParameterSchema
89+
if err := json.Unmarshal(v, &schema); err == nil {
90+
f.Parameters = &schema
8491
}
8592
}
8693
return nil

internal/http/http_client_default.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package http
44

55
import (
6+
"net"
67
"net/http"
78
"time"
89
)
@@ -11,6 +12,10 @@ import (
1112
func NewHTTPClient() HTTPClient {
1213
return &http.Client{
1314
Transport: &http.Transport{
15+
DialContext: (&net.Dialer{
16+
Timeout: 30 * time.Second,
17+
KeepAlive: 30 * time.Second,
18+
}).DialContext,
1419
MaxIdleConns: 100,
1520
MaxIdleConnsPerHost: 10,
1621
IdleConnTimeout: 90 * time.Second,

internal/logger/logger.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ func colorize(s interface{}, c int) string {
4444
func newLogger() *zerolog.Logger {
4545
env := os.Getenv("ENV")
4646

47+
// Set log level based on LOG_LEVEL env var, default to info
48+
logLevel := zerolog.InfoLevel
49+
if levelStr := os.Getenv("LOG_LEVEL"); levelStr != "" {
50+
if parsedLevel, err := zerolog.ParseLevel(strings.ToLower(levelStr)); err == nil {
51+
logLevel = parsedLevel
52+
} else {
53+
fmt.Fprintf(os.Stderr, "Invalid LOG_LEVEL \"%s\"; defaulting to 'info'\n", levelStr)
54+
}
55+
}
56+
57+
zerolog.SetGlobalLevel(logLevel)
58+
4759
if env == "development" || env == "dev" || env == "" {
4860
return newDevelopment()
4961
}

internal/openai/stream_transformer.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"time"
77

8+
"github.com/dvcrn/gemini-code-assist-proxy/internal/logger"
89
"github.com/google/uuid"
910
)
1011

@@ -128,6 +129,8 @@ func CreateOpenAIStreamTransformer(model string) func(<-chan StreamChunk) <-chan
128129

129130
// Process each chunk
130131
for chunk := range input {
132+
logger.Get().Debug().Interface("chunk", chunk).Msg("Processing Gemini stream chunk")
133+
131134
delta := OpenAIDelta{}
132135
shouldSend := false
133136

@@ -222,7 +225,9 @@ func CreateOpenAIStreamTransformer(model string) func(<-chan StreamChunk) <-chan
222225
}
223226

224227
if jsonBytes, err := json.Marshal(openAIChunk); err == nil {
225-
output <- fmt.Sprintf("data: %s\n\n", string(jsonBytes))
228+
sse := fmt.Sprintf("data: %s\n\n", string(jsonBytes))
229+
logger.Get().Debug().Str("sse", sse).Msg("Sending OpenAI SSE chunk")
230+
output <- sse
226231
}
227232
}
228233
}

internal/server/chat_completions_handler.go

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package server
22

33
import (
4+
"context"
45
"encoding/json"
56
"io"
67
"net/http"
@@ -147,6 +148,31 @@ func (s *Server) chatCompletionRequestStream(w http.ResponseWriter, r *http.Requ
147148
logger.Get().Info().
148149
Str("model", gemReq.Model).
149150
Msg("Starting upstream StreamGenerateContent")
151+
152+
// Pinger to keep connection alive
153+
pingerCtx, cancelPinger := context.WithCancel(r.Context())
154+
defer cancelPinger()
155+
156+
go func() {
157+
ticker := time.NewTicker(10 * time.Second)
158+
defer ticker.Stop()
159+
for {
160+
select {
161+
case <-ticker.C:
162+
logger.Get().Debug().Msg("Sending SSE ping to keep connection alive")
163+
if _, err := io.WriteString(w, ": ping\n\n"); err != nil {
164+
logger.Get().Warn().Err(err).Msg("Failed to write SSE ping")
165+
return
166+
}
167+
if flusher, ok := w.(http.Flusher); ok {
168+
flusher.Flush()
169+
}
170+
case <-pingerCtx.Done():
171+
return
172+
}
173+
}
174+
}()
175+
150176
if err := s.geminiClient.StreamGenerateContent(r.Context(), gemReq, upstream); err != nil {
151177
logger.Get().Error().Err(err).Msg("StreamGenerateContent call failed")
152178
http.Error(w, "Upstream streaming error", http.StatusInternalServerError)
@@ -161,6 +187,9 @@ func (s *Server) chatCompletionRequestStream(w http.ResponseWriter, r *http.Requ
161187
firstUpstream := true
162188
firstThoughtSeen := false
163189
for line := range upstream {
190+
if firstUpstream {
191+
cancelPinger() // Stop pinger on first data
192+
}
164193
// Process only data lines
165194
if !strings.HasPrefix(line, "data: ") {
166195
continue
@@ -207,7 +236,11 @@ func (s *Server) chatCompletionRequestStream(w http.ResponseWriter, r *http.Requ
207236
// Extract candidate content parts
208237
if cands, ok := obj["candidates"].([]interface{}); ok {
209238
for _, c := range cands {
210-
cand, _ := c.(map[string]interface{})
239+
cand, ok := c.(map[string]interface{})
240+
if !ok {
241+
logger.Get().Warn().Interface("candidate", c).Msg("Skipping invalid candidate in Gemini stream")
242+
continue
243+
}
211244

212245
// Optional grounding metadata passthrough
213246
if gm, ok := cand["groundingMetadata"]; ok && gm != nil {
@@ -229,7 +262,11 @@ func (s *Server) chatCompletionRequestStream(w http.ResponseWriter, r *http.Requ
229262

230263
// Process parts
231264
for _, p := range parts {
232-
part, _ := p.(map[string]interface{})
265+
part, ok := p.(map[string]interface{})
266+
if !ok {
267+
logger.Get().Warn().Interface("part", p).Msg("Skipping invalid part in Gemini stream")
268+
continue
269+
}
233270

234271
// Thought tokens (reasoning) — map to OpenAI reasoning stream
235272
if isThought, ok := part["thought"].(bool); ok && isThought {

internal/transform/openai_to_gemini.go

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,16 +248,29 @@ func convertToolsToGeminiTools(tools []openai.Tool) []gemini.Tool {
248248
continue
249249
}
250250

251-
var schema gemini.JSONSchema
251+
var geminiSchema *gemini.GeminiParameterSchema
252252
if m, ok := t.Function.Parameters.(map[string]interface{}); ok {
253-
schema = m
253+
geminiSchema = convertToGeminiSchema(m)
254254
}
255255

256-
fns = append(fns, gemini.FunctionDeclaration{
257-
Name: t.Function.Name,
258-
Description: t.Function.Description,
259-
ParametersJsonSchema: schema,
260-
})
256+
convertedFn := gemini.FunctionDeclaration{
257+
Name: t.Function.Name,
258+
Description: t.Function.Description,
259+
Parameters: geminiSchema,
260+
}
261+
262+
// For specific tools, log the before and after transformation for debugging
263+
if t.Function.Name == "Read" || t.Function.Name == "Edit" || t.Function.Name == "MultiEdit" {
264+
originalJSON, _ := json.Marshal(t.Function)
265+
convertedJSON, _ := json.Marshal(convertedFn)
266+
logger.Get().Info().
267+
Str("tool_name", t.Function.Name).
268+
RawJSON("original_schema", originalJSON).
269+
RawJSON("converted_schema", convertedJSON).
270+
Msg("Dumping tool schema conversion from OpenAI to Gemini")
271+
}
272+
273+
fns = append(fns, convertedFn)
261274
}
262275

263276
if len(fns) == 0 {
@@ -268,3 +281,51 @@ func convertToolsToGeminiTools(tools []openai.Tool) []gemini.Tool {
268281
{FunctionDeclarations: fns},
269282
}
270283
}
284+
285+
// convertToGeminiSchema recursively converts a generic map representing a JSON schema
286+
// into the strongly-typed GeminiParameterSchema struct, only mapping supported fields.
287+
func convertToGeminiSchema(input map[string]interface{}) *gemini.GeminiParameterSchema {
288+
if input == nil {
289+
return nil
290+
}
291+
292+
output := &gemini.GeminiParameterSchema{}
293+
294+
if t, ok := input["type"].(string); ok {
295+
output.Type = strings.ToUpper(t)
296+
}
297+
if d, ok := input["description"].(string); ok {
298+
output.Description = d
299+
}
300+
301+
if r, ok := input["required"].([]interface{}); ok {
302+
for _, v := range r {
303+
if s, ok := v.(string); ok {
304+
output.Required = append(output.Required, s)
305+
}
306+
}
307+
}
308+
309+
if e, ok := input["enum"].([]interface{}); ok {
310+
for _, v := range e {
311+
if s, ok := v.(string); ok {
312+
output.Enum = append(output.Enum, s)
313+
}
314+
}
315+
}
316+
317+
if p, ok := input["properties"].(map[string]interface{}); ok {
318+
output.Properties = make(map[string]*gemini.GeminiParameterSchema)
319+
for k, v := range p {
320+
if vMap, ok := v.(map[string]interface{}); ok {
321+
output.Properties[k] = convertToGeminiSchema(vMap)
322+
}
323+
}
324+
}
325+
326+
if i, ok := input["items"].(map[string]interface{}); ok {
327+
output.Items = convertToGeminiSchema(i)
328+
}
329+
330+
return output
331+
}

0 commit comments

Comments
 (0)