diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go index 70a873e40c..e009c56c51 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/gemini.go @@ -443,7 +443,13 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest) if request.Tools != nil { functions := make([]function, 0, len(request.Tools)) for _, tool := range request.Tools { - functions = append(functions, tool.Function) + // 清理 function parameters 中不支持的 JSON Schema 字段 + cleanedFunc := function{ + Name: tool.Function.Name, + Description: tool.Function.Description, + Parameters: cleanFunctionParameters(tool.Function.Parameters), + } + functions = append(functions, cleanedFunc) } geminiRequest.Tools = []geminiChatTools{ { diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model.go b/plugins/wasm-go/extensions/ai-proxy/provider/model.go index 3a11a279bc..0021b604c1 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/model.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model.go @@ -130,6 +130,68 @@ type function struct { Parameters map[string]interface{} `json:"parameters,omitempty"` } +// cleanFunctionParameters 清理 function parameters 中某些 AI 服务(如 Vertex AI、Gemini)不支持的 JSON Schema 字段 +// 这些服务基于 OpenAPI 3.0 规范,不支持标准 JSON Schema 的元数据字段如 $schema, $id, $ref 等 +// 注意:某些客户端可能使用不带 $ 前缀的变体(如 ref 代替 $ref),也需要一并清理 +func cleanFunctionParameters(params map[string]interface{}) map[string]interface{} { + if params == nil { + return nil + } + + // 需要移除的 JSON Schema 元数据字段 + // 包括带 $ 前缀的标准字段和不带 $ 前缀的变体 + unsupportedKeys := []string{ + // 标准 JSON Schema 元数据字段 + "$schema", + "$id", + "$ref", + "$defs", + "definitions", + "$comment", + "$vocabulary", + "$anchor", + "$dynamicRef", + "$dynamicAnchor", + // 不带 $ 前缀的变体(某些客户端可能使用) + "ref", + } + + result := make(map[string]interface{}) + for key, value := range params { + // 检查是否是不支持的字段 + isUnsupported := false + for _, unsupportedKey := range unsupportedKeys { + if key == unsupportedKey { + isUnsupported = true + break + } + } + if isUnsupported { + continue + } + + // 递归清理嵌套的 map + switch v := value.(type) { + case map[string]interface{}: + result[key] = cleanFunctionParameters(v) + case []interface{}: + // 处理数组中的 map 元素 + cleanedArray := make([]interface{}, len(v)) + for i, item := range v { + if itemMap, ok := item.(map[string]interface{}); ok { + cleanedArray[i] = cleanFunctionParameters(itemMap) + } else { + cleanedArray[i] = item + } + } + result[key] = cleanedArray + default: + result[key] = value + } + } + return result +} + type toolChoice struct { Type string `json:"type"` Function function `json:"function"` diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/model_test.go b/plugins/wasm-go/extensions/ai-proxy/provider/model_test.go new file mode 100644 index 0000000000..ae18ee4531 --- /dev/null +++ b/plugins/wasm-go/extensions/ai-proxy/provider/model_test.go @@ -0,0 +1,265 @@ +package provider + +import ( + "reflect" + "testing" +) + +func TestCleanFunctionParameters(t *testing.T) { + tests := []struct { + name string + input map[string]interface{} + expected map[string]interface{} + }{ + { + name: "nil input", + input: nil, + expected: nil, + }, + { + name: "empty map", + input: map[string]interface{}{}, + expected: map[string]interface{}{}, + }, + { + name: "remove $schema at root level", + input: map[string]interface{}{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state", + }, + }, + }, + expected: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state", + }, + }, + }, + }, + { + name: "remove multiple unsupported fields", + input: map[string]interface{}{ + "$schema": "http://json-schema.org/draft-07/schema#", + "$id": "https://example.com/schema", + "$comment": "This is a comment", + "definitions": map[string]interface{}{}, + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + }, + }, + }, + expected: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + { + name: "nested $schema in properties", + input: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "nested": map[string]interface{}{ + "$schema": "should be removed", + "type": "object", + "properties": map[string]interface{}{ + "field": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + }, + expected: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "nested": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "field": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + }, + }, + { + name: "array with map elements", + input: map[string]interface{}{ + "type": "array", + "items": []interface{}{ + map[string]interface{}{ + "$schema": "should be removed", + "type": "string", + }, + map[string]interface{}{ + "type": "number", + }, + }, + }, + expected: map[string]interface{}{ + "type": "array", + "items": []interface{}{ + map[string]interface{}{ + "type": "string", + }, + map[string]interface{}{ + "type": "number", + }, + }, + }, + }, + { + name: "preserve valid fields", + input: map[string]interface{}{ + "type": "object", + "description": "A valid description", + "properties": map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city", + "enum": []interface{}{"NYC", "LA", "SF"}, + }, + }, + "required": []interface{}{"location"}, + }, + expected: map[string]interface{}{ + "type": "object", + "description": "A valid description", + "properties": map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city", + "enum": []interface{}{"NYC", "LA", "SF"}, + }, + }, + "required": []interface{}{"location"}, + }, + }, + { + name: "remove $defs field", + input: map[string]interface{}{ + "$defs": map[string]interface{}{ + "Address": map[string]interface{}{ + "type": "object", + }, + }, + "type": "object", + }, + expected: map[string]interface{}{ + "type": "object", + }, + }, + { + name: "remove ref field without dollar sign", + input: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "options": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "ref": "QuestionOption", + "type": "object", + "properties": map[string]interface{}{ + "label": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + }, + }, + expected: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "options": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "label": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + }, + }, + }, + { + name: "real world question tool schema", + input: map[string]interface{}{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": map[string]interface{}{ + "questions": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "options": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "ref": "QuestionOption", + "type": "object", + "properties": map[string]interface{}{ + "label": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + }, + }, + }, + }, + }, + expected: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "questions": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "options": map[string]interface{}{ + "type": "array", + "items": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "label": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := cleanFunctionParameters(tt.input) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("cleanFunctionParameters() = %v, want %v", result, tt.expected) + } + }) + } +} diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index f7cebccf72..8a312098b5 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -279,6 +279,48 @@ type TransformResponseBodyHandler interface { TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) } +// RedisConfig Redis配置结构体 +type RedisConfig struct { + // @Title zh-CN Redis服务名称 + // @Description zh-CN Redis服务的FQDN,如 redis.static、redis.my-ns.svc.cluster.local + ServiceName string `yaml:"serviceName" json:"serviceName"` + // @Title zh-CN Redis服务端口 + // @Description zh-CN Redis服务端口,static服务默认80,其他默认6379 + ServicePort int `yaml:"servicePort" json:"servicePort"` + // @Title zh-CN Redis用户名 + // @Description zh-CN Redis认证用户名(可选) + Username string `yaml:"username" json:"username"` + // @Title zh-CN Redis密码 + // @Description zh-CN Redis认证密码(可选) + Password string `yaml:"password" json:"password"` + // @Title zh-CN 连接超时时间 + // @Description zh-CN Redis连接超时时间,单位毫秒,默认1000 + Timeout int `yaml:"timeout" json:"timeout"` + // @Title zh-CN 数据库ID + // @Description zh-CN Redis数据库ID,默认0 + Database int `yaml:"database" json:"database"` +} + +// FromJson 从JSON解析Redis配置 +func (r *RedisConfig) FromJson(json gjson.Result) { + r.ServiceName = json.Get("serviceName").String() + r.ServicePort = int(json.Get("servicePort").Int()) + if r.ServicePort == 0 { + if strings.HasSuffix(r.ServiceName, ".static") { + r.ServicePort = 80 + } else { + r.ServicePort = 6379 + } + } + r.Username = json.Get("username").String() + r.Password = json.Get("password").String() + r.Timeout = int(json.Get("timeout").Int()) + if r.Timeout == 0 { + r.Timeout = 1000 + } + r.Database = int(json.Get("database").Int()) +} + type ProviderConfig struct { // @Title zh-CN ID // @Description zh-CN AI服务提供商标识 @@ -391,6 +433,17 @@ type ProviderConfig struct { // @Title zh-CN Vertex AI OpenAI兼容模式 // @Description zh-CN 启用后将使用Vertex AI的OpenAI兼容API,请求和响应均使用OpenAI格式,无需协议转换。与Express Mode(apiTokens)互斥。 vertexOpenAICompatible bool `required:"false" yaml:"vertexOpenAICompatible" json:"vertexOpenAICompatible"` + // @Title zh-CN Vertex AI Thought Signature 缓存开关 + // @Description zh-CN 启用后将使用Redis缓存Gemini 3模型的thought_signature,用于多轮function calling场景。需要配置Redis连接信息。 + vertexEnableThoughtSigCache bool `required:"false" yaml:"vertexEnableThoughtSigCache" json:"vertexEnableThoughtSigCache"` + // @Title zh-CN Vertex AI Thought Signature 缓存过期时间 + // @Description zh-CN thought_signature在Redis中的过期时间,单位为秒。默认值为3600(1小时)。 + vertexThoughtSigCacheTTL int `required:"false" yaml:"vertexThoughtSigCacheTTL" json:"vertexThoughtSigCacheTTL"` + // @Title zh-CN Redis服务配置 + // @Description zh-CN 用于thought_signature缓存的Redis服务配置 + redisConfig *RedisConfig `required:"false" yaml:"redis" json:"redis"` + // Redis客户端实例(运行时使用,不从配置中读取) + redisClient wrapper.RedisClient `yaml:"-" json:"-"` // @Title zh-CN 翻译服务需指定的目标语种 // @Description zh-CN 翻译结果的语种,目前仅适用于DeepL服务。 targetLang string `required:"false" yaml:"targetLang" json:"targetLang"` @@ -468,6 +521,21 @@ func (c *ProviderConfig) GetContextCleanupCommands() []string { return c.contextCleanupCommands } +// GetVertexEnableThoughtSigCache 返回是否启用Vertex thought_signature缓存 +func (c *ProviderConfig) GetVertexEnableThoughtSigCache() bool { + return c.vertexEnableThoughtSigCache +} + +// GetVertexThoughtSigCacheTTL 返回thought_signature缓存的过期时间(秒) +func (c *ProviderConfig) GetVertexThoughtSigCacheTTL() int { + return c.vertexThoughtSigCacheTTL +} + +// GetRedisClient 返回Redis客户端实例 +func (c *ProviderConfig) GetRedisClient() wrapper.RedisClient { + return c.redisClient +} + func (c *ProviderConfig) IsOpenAIProtocol() bool { return c.protocol == protocolOpenAI } @@ -555,6 +623,19 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.vertexTokenRefreshAhead = 60 } c.vertexOpenAICompatible = json.Get("vertexOpenAICompatible").Bool() + c.vertexEnableThoughtSigCache = json.Get("vertexEnableThoughtSigCache").Bool() + c.vertexThoughtSigCacheTTL = int(json.Get("vertexThoughtSigCacheTTL").Int()) + if c.vertexThoughtSigCacheTTL == 0 { + c.vertexThoughtSigCacheTTL = 3600 // 默认1小时 + } + + // 解析Redis配置 + redisJson := json.Get("redis") + if redisJson.Exists() { + c.redisConfig = &RedisConfig{} + c.redisConfig.FromJson(redisJson) + } + c.targetLang = json.Get("targetLang").String() if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok { @@ -680,6 +761,32 @@ func (c *ProviderConfig) Validate() error { if err := initializer.ValidateConfig(c); err != nil { return err } + + // 初始化Redis客户端(如果启用了Vertex thought_signature缓存) + if c.vertexEnableThoughtSigCache { + if c.redisConfig == nil || c.redisConfig.ServiceName == "" { + log.Warn("vertexEnableThoughtSigCache is enabled but redis config is missing, disabling thought_signature cache") + c.vertexEnableThoughtSigCache = false + } else { + c.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{ + FQDN: c.redisConfig.ServiceName, + Port: int64(c.redisConfig.ServicePort), + }) + err := c.redisClient.Init( + c.redisConfig.Username, + c.redisConfig.Password, + int64(c.redisConfig.Timeout), + wrapper.WithDataBase(c.redisConfig.Database), + ) + if err != nil { + log.Errorf("failed to init redis client for thought_signature cache: %v", err) + c.vertexEnableThoughtSigCache = false + } else { + log.Info("redis client initialized for Vertex thought_signature cache") + } + } + } + return nil } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go index 3791e06ef8..1c21277cdf 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/vertex.go @@ -16,11 +16,13 @@ import ( "time" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" + "github.com/google/uuid" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/wasm-go/pkg/log" "github.com/higress-group/wasm-go/pkg/wrapper" "github.com/tidwall/gjson" + "github.com/tidwall/resp" "github.com/tidwall/sjson" ) @@ -46,6 +48,14 @@ const ( contextOpenAICompatibleMarker = "isOpenAICompatibleRequest" contextVertexRawMarker = "isVertexRawRequest" vertexAnthropicVersion = "vertex-2023-10-16" + + // Redis key prefix for thought_signature cache + thoughtSigRedisKeyPrefix = "higress-vertex-thought-sig:" + + // Context keys for thought_signature caching + ctxThoughtSigMap = "vertexThoughtSigMap" + ctxThoughtSigPending = "vertexThoughtSigPending" + ctxThoughtSigReady = "vertexThoughtSigReady" ) // vertexRawPathRegex 匹配原生 Vertex AI REST API 路径 @@ -149,6 +159,223 @@ func (v *vertexProvider) GetProviderType() string { return providerTypeVertex } +// storeThoughtSignature 将 thought_signature 存入 Redis +func (v *vertexProvider) storeThoughtSignature(toolCallId, thoughtSignature string) { + if !v.config.GetVertexEnableThoughtSigCache() { + return + } + + redisClient := v.config.GetRedisClient() + if redisClient == nil { + log.Warnf("[ThoughtSig] Redis client not available") + return + } + + key := thoughtSigRedisKeyPrefix + toolCallId + ttl := v.config.GetVertexThoughtSigCacheTTL() + + err := redisClient.SetEx(key, thoughtSignature, ttl, func(response resp.Value) { + if err := response.Error(); err != nil { + log.Errorf("[ThoughtSig] STORE FAILED: key=%s, err=%v", key, err) + } + }) + if err != nil { + log.Errorf("[ThoughtSig] STORE CALL FAILED: key=%s, err=%v", key, err) + } +} + +// extractToolCallIdsFromMessages 从请求消息中提取所有 tool 角色消息的 tool_call_id +func extractToolCallIdsFromMessages(body []byte) []string { + var toolCallIds []string + messages := gjson.GetBytes(body, "messages") + if !messages.Exists() || !messages.IsArray() { + return toolCallIds + } + + for _, msg := range messages.Array() { + role := msg.Get("role").String() + if role == "tool" { + toolCallId := msg.Get("tool_call_id").String() + if toolCallId != "" { + toolCallIds = append(toolCallIds, toolCallId) + } + } + } + return toolCallIds +} + +// fetchThoughtSignaturesFromRedis 批量获取 thought_signature 并存入 context +// 返回 true 表示需要暂停请求等待 Redis 回调 +// 在 Redis 回调完成后,会自动进行请求体转换并恢复请求 +func (v *vertexProvider) fetchThoughtSignaturesFromRedis(ctx wrapper.HttpContext, toolCallIds []string) bool { + if !v.config.GetVertexEnableThoughtSigCache() || len(toolCallIds) == 0 { + return false + } + + redisClient := v.config.GetRedisClient() + if redisClient == nil { + log.Warnf("[ThoughtSig] Redis client not available") + return false + } + + // 初始化 map 和计数器 + thoughtSigMap := make(map[string]string) + ctx.SetContext(ctxThoughtSigMap, thoughtSigMap) + ctx.SetContext(ctxThoughtSigPending, len(toolCallIds)) + + // 发起所有 Redis 查询 + for _, toolCallId := range toolCallIds { + id := toolCallId // 避免闭包问题 + key := thoughtSigRedisKeyPrefix + id + err := redisClient.Get(key, func(response resp.Value) { + // 获取当前的 map 和 pending 计数 + sigMap, ok := ctx.GetContext(ctxThoughtSigMap).(map[string]string) + if !ok { + sigMap = make(map[string]string) + } + + if err := response.Error(); err != nil { + log.Errorf("[ThoughtSig] FETCH ERROR: key=%s, err=%v", key, err) + } else if !response.IsNull() { + sigMap[id] = response.String() + } + + ctx.SetContext(ctxThoughtSigMap, sigMap) + + // 减少 pending 计数 + pending, _ := ctx.GetContext(ctxThoughtSigPending).(int) + pending-- + ctx.SetContext(ctxThoughtSigPending, pending) + + if pending <= 0 { + // 所有查询完成,标记 ready + ctx.SetContext(ctxThoughtSigReady, true) + + // 执行请求体转换,transformRequestBodyAfterRedis 会负责恢复请求 + // (直接恢复或通过 token 获取回调恢复) + if !v.transformRequestBodyAfterRedis(ctx) { + // 转换失败,恢复请求让它继续(会失败但不阻塞) + log.Errorf("[ThoughtSig] transform failed, resuming request anyway") + proxywasm.ResumeHttpRequest() + } + } + }) + + if err != nil { + log.Errorf("[ThoughtSig] FETCH CALL FAILED: id=%s, err=%v", id, err) + // 减少 pending 计数 + pending, _ := ctx.GetContext(ctxThoughtSigPending).(int) + pending-- + ctx.SetContext(ctxThoughtSigPending, pending) + if pending <= 0 { + ctx.SetContext(ctxThoughtSigReady, true) + return false + } + } + } + + return true +} + +// transformRequestBodyAfterRedis 在 Redis 获取完成后执行请求体转换 +// 返回 true 表示请求已恢复或将由 token 回调恢复 +// 返回 false 表示出错,调用者应该恢复请求 +func (v *vertexProvider) transformRequestBodyAfterRedis(ctx wrapper.HttpContext) bool { + // 获取保存的原始请求体 + bodyData := ctx.GetContext(ctxOriginalRequestBody) + if bodyData == nil { + log.Errorf("[ThoughtSig] original request body not found in context") + return false + } + + body, ok := bodyData.([]byte) + if !ok { + log.Errorf("[ThoughtSig] invalid body type in context") + return false + } + + headers := util.GetRequestHeaders() + + // 解析请求 + request := &chatCompletionRequest{} + err := v.config.parseRequestAndMapModel(ctx, request, body) + if err != nil { + log.Errorf("[ThoughtSig] failed to parse request: %v", err) + return false + } + + // 设置请求路径 + if strings.HasPrefix(request.Model, "claude") { + ctx.SetContext(contextClaudeMarker, true) + path := v.getAhthropicRequestPath(ApiNameChatCompletion, request.Model, request.Stream) + util.OverwriteRequestPathHeader(headers, path) + + claudeRequest := v.claude.buildClaudeTextGenRequest(request) + claudeRequest.Model = "" + claudeRequest.AnthropicVersion = vertexAnthropicVersion + transformedBody, err := json.Marshal(claudeRequest) + if err != nil { + log.Errorf("[ThoughtSig] failed to marshal claude request: %v", err) + return false + } + headers.Set("Content-Length", fmt.Sprint(len(transformedBody))) + util.ReplaceRequestHeaders(headers) + _ = proxywasm.ReplaceHttpRequestBody(transformedBody) + } else { + path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream) + util.OverwriteRequestPathHeader(headers, path) + + vertexRequest := v.buildVertexChatRequest(ctx, request) + transformedBody, err := json.Marshal(vertexRequest) + if err != nil { + log.Errorf("[ThoughtSig] failed to marshal vertex request: %v", err) + return false + } + + headers.Set("Content-Length", fmt.Sprint(len(transformedBody))) + + if v.isExpressMode() { + headers.Del("Authorization") + } + util.ReplaceRequestHeaders(headers) + _ = proxywasm.ReplaceHttpRequestBody(transformedBody) + } + + // 处理 OAuth token(标准模式需要) + if v.isExpressMode() { + // Express Mode 不需要 OAuth token,直接恢复请求 + proxywasm.ResumeHttpRequest() + return true + } + + // 标准模式需要获取 OAuth token + cached, err := v.getToken() + if err != nil { + log.Errorf("[ThoughtSig] failed to get token: %v", err) + // 出错时恢复请求,让它继续(会失败但至少不会阻塞) + proxywasm.ResumeHttpRequest() + return true + } + + if cached { + // Token 已缓存,直接恢复请求 + proxywasm.ResumeHttpRequest() + return true + } + + // Token 需要获取,getAccessToken 的回调会恢复请求 + return true +} + +// getThoughtSignatureFromContext 从 context 中获取缓存的 thought_signature +func (v *vertexProvider) getThoughtSignatureFromContext(ctx wrapper.HttpContext, toolCallId string) string { + sigMap, ok := ctx.GetContext(ctxThoughtSigMap).(map[string]string) + if !ok { + return "" + } + return sigMap[toolCallId] +} + func (v *vertexProvider) GetApiName(path string) ApiName { // 优先匹配原生 Vertex AI REST API 路径,支持任意 basePath 前缀 // 格式: [任意前缀]/{api-version}/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{action} @@ -220,11 +447,34 @@ func (v *vertexProvider) getToken() (cached bool, err error) { return false, err } +// Context key for saving original request body during Redis fetch +const ctxOriginalRequestBody = "vertexOriginalRequestBody" + +// Context key to mark that request body has been transformed in Redis callback +const ctxBodyTransformed = "vertexBodyTransformed" + func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) { if !v.config.isSupportedAPI(apiName) { return types.ActionContinue, errUnsupportedApiName } + // 检查是否已完成 thought_signature 获取(两阶段处理的第二阶段) + thoughtSigReady, _ := ctx.GetContext(ctxThoughtSigReady).(bool) + + // 如果启用了 thought_signature 缓存且尚未获取,先从 Redis 获取 + if !thoughtSigReady && v.config.GetVertexEnableThoughtSigCache() && apiName == ApiNameChatCompletion { + toolCallIds := extractToolCallIdsFromMessages(body) + if len(toolCallIds) > 0 { + // 保存原始请求体,以便 Redis 回调后使用 + ctx.SetContext(ctxOriginalRequestBody, body) + // 需要获取 thought_signature,暂停请求 + if v.fetchThoughtSignaturesFromRedis(ctx, toolCallIds) { + return types.ActionPause, nil + } + // 如果 fetchThoughtSignaturesFromRedis 返回 false,继续处理(Redis 不可用) + } + } + // Vertex Raw 模式: 透传请求体,只做 OAuth 认证 // 用于直接访问 Vertex AI REST API,不做协议转换 // 注意:此检查必须在 IsOriginal() 之前,因为 Vertex Raw 模式通常与 original 协议一起使用 @@ -361,7 +611,7 @@ func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream) util.OverwriteRequestPathHeader(headers, path) - vertexRequest := v.buildVertexChatRequest(request) + vertexRequest := v.buildVertexChatRequest(ctx, request) return json.Marshal(vertexRequest) } } @@ -501,7 +751,6 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) { return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk) } - log.Infof("[vertexProvider] receive chunk body: %s", string(chunk)) if isLastChunk { return []byte(ssePrefix + "[DONE]\n\n"), nil } @@ -524,6 +773,7 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A log.Errorf("unable to unmarshal vertex response: %v", err) continue } + response := v.buildChatCompletionStreamResponse(ctx, &vertexResp) responseBody, err := json.Marshal(response) if err != nil { @@ -594,22 +844,48 @@ func (v *vertexProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, re FinishReason: util.Ptr(candidate.FinishReason), } if len(candidate.Content.Parts) > 0 { - part := candidate.Content.Parts[0] - if part.FunctionCall != nil { - args, _ := json.Marshal(part.FunctionCall.Args) + // 遍历所有 parts 查找 functionCall 和 thought_signature + var foundFunctionCall *vertexFunctionCall + var foundThoughtSig string + var firstPart *vertexPart + + for i := range candidate.Content.Parts { + part := &candidate.Content.Parts[i] + if i == 0 { + firstPart = part + } + if part.FunctionCall != nil { + foundFunctionCall = part.FunctionCall + } + if part.ThoughtSignature != "" { + foundThoughtSig = part.ThoughtSignature + } + } + + if foundFunctionCall != nil { + args, _ := json.Marshal(foundFunctionCall.Args) + toolCallId := fmt.Sprintf("call_%s", uuid.New().String()) choice.Message.ToolCalls = []toolCall{ { + Id: toolCallId, Type: "function", Function: functionCall{ - Name: part.FunctionCall.Name, + Name: foundFunctionCall.Name, Arguments: string(args), }, }, } - } else if part.Thounght != nil && len(candidate.Content.Parts) > 1 { - choice.Message.Content = reasoningStartTag + part.Text + reasoningEndTag + candidate.Content.Parts[1].Text - } else if part.Text != "" { - choice.Message.Content = part.Text + + // Store thought_signature in Redis if found + if foundThoughtSig != "" { + v.storeThoughtSignature(toolCallId, foundThoughtSig) + } + } else if firstPart != nil { + if firstPart.Thounght != nil && len(candidate.Content.Parts) > 1 { + choice.Message.Content = reasoningStartTag + firstPart.Text + reasoningEndTag + candidate.Content.Parts[1].Text + } else if firstPart.Text != "" { + choice.Message.Content = firstPart.Text + } } } else { choice.Message.Content = "" @@ -701,33 +977,61 @@ func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpConte var choice chatCompletionChoice choice.Delta = &chatMessage{} if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 { - part := vertexResp.Candidates[0].Content.Parts[0] - if part.FunctionCall != nil { - args, _ := json.Marshal(part.FunctionCall.Args) + parts := vertexResp.Candidates[0].Content.Parts + + // 遍历所有 parts 查找 functionCall 和 thought_signature + var foundFunctionCall *vertexFunctionCall + var foundThoughtSig string + var firstPart *vertexPart + + for i := range parts { + part := &parts[i] + if i == 0 { + firstPart = part + } + if part.FunctionCall != nil { + foundFunctionCall = part.FunctionCall + } + if part.ThoughtSignature != "" { + foundThoughtSig = part.ThoughtSignature + } + } + + if foundFunctionCall != nil { + args, _ := json.Marshal(foundFunctionCall.Args) + toolCallId := fmt.Sprintf("call_%s", uuid.New().String()) choice.Delta = &chatMessage{ ToolCalls: []toolCall{ { + Id: toolCallId, Type: "function", Function: functionCall{ - Name: part.FunctionCall.Name, + Name: foundFunctionCall.Name, Arguments: string(args), }, }, }, } - } else if part.Thounght != nil { - if ctx.GetContext("thinking_start") == nil { - choice.Delta = &chatMessage{Content: reasoningStartTag + part.Text} - ctx.SetContext("thinking_start", true) - } else { - choice.Delta = &chatMessage{Content: part.Text} + + // Store thought_signature in Redis if found + if foundThoughtSig != "" { + v.storeThoughtSignature(toolCallId, foundThoughtSig) } - } else if part.Text != "" { - if ctx.GetContext("thinking_start") != nil && ctx.GetContext("thinking_end") == nil { - choice.Delta = &chatMessage{Content: reasoningEndTag + part.Text} - ctx.SetContext("thinking_end", true) - } else { - choice.Delta = &chatMessage{Content: part.Text} + } else if firstPart != nil { + if firstPart.Thounght != nil { + if ctx.GetContext("thinking_start") == nil { + choice.Delta = &chatMessage{Content: reasoningStartTag + firstPart.Text} + ctx.SetContext("thinking_start", true) + } else { + choice.Delta = &chatMessage{Content: firstPart.Text} + } + } else if firstPart.Text != "" { + if ctx.GetContext("thinking_start") != nil && ctx.GetContext("thinking_end") == nil { + choice.Delta = &chatMessage{Content: reasoningEndTag + firstPart.Text} + ctx.SetContext("thinking_end", true) + } else { + choice.Delta = &chatMessage{Content: firstPart.Text} + } } } } @@ -818,7 +1122,7 @@ func (v *vertexProvider) getOpenAICompatibleRequestPath() string { return fmt.Sprintf(vertexOpenAICompatiblePathTemplate, v.config.vertexProjectId, v.config.vertexRegion) } -func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest { +func (v *vertexProvider) buildVertexChatRequest(ctx wrapper.HttpContext, request *chatCompletionRequest) *vertexChatRequest { safetySettings := make([]vertexChatSafetySetting, 0) for category, threshold := range v.config.geminiSafetySetting { safetySettings = append(safetySettings, vertexChatSafetySetting{ @@ -855,8 +1159,28 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) } if request.Tools != nil { functions := make([]function, 0, len(request.Tools)) - for _, tool := range request.Tools { - functions = append(functions, tool.Function) + for i, tool := range request.Tools { + // DEBUG: 打印清理前的 function parameters + if tool.Function.Parameters != nil { + originalParamsJson, _ := json.Marshal(tool.Function.Parameters) + log.Debugf("[vertexProvider] tool[%d] %s original parameters: %s", i, tool.Function.Name, string(originalParamsJson)) + } + + // 清理 function parameters 中不支持的 JSON Schema 字段 + cleanedParams := cleanFunctionParameters(tool.Function.Parameters) + + // DEBUG: 打印清理后的 function parameters + if cleanedParams != nil { + cleanedParamsJson, _ := json.Marshal(cleanedParams) + log.Debugf("[vertexProvider] tool[%d] %s cleaned parameters: %s", i, tool.Function.Name, string(cleanedParamsJson)) + } + + cleanedFunc := function{ + Name: tool.Function.Name, + Description: tool.Function.Description, + Parameters: cleanedParams, + } + functions = append(functions, cleanedFunc) } vertexRequest.Tools = []vertexTool{ { @@ -865,32 +1189,57 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) } } shouldAddDummyModelMessage := false - var lastFunctionName string + // Map to track tool_call_id -> function_name for tool response messages + toolCallIdToFunctionName := make(map[string]string) + for _, message := range request.Messages { content := vertexChatContent{ Role: message.Role, Parts: []vertexPart{}, } if len(message.ToolCalls) > 0 { - lastFunctionName = message.ToolCalls[0].Function.Name - args := make(map[string]interface{}) - if err := json.Unmarshal([]byte(message.ToolCalls[0].Function.Arguments), &args); err != nil { - log.Errorf("unable to unmarshal function arguments: %v", err) + // Process ALL tool calls in the message, not just the first one + for i, tc := range message.ToolCalls { + args := make(map[string]interface{}) + if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil { + log.Errorf("unable to unmarshal function arguments: %v", err) + } + // Track tool_call_id -> function_name mapping for tool response messages + toolCallIdToFunctionName[tc.Id] = tc.Function.Name + + // Get thought_signature from Redis cache if available + // According to Google docs, thought_signature should be attached to the functionCall part + // For parallel function calls, only the first one has the signature + var thoughtSig string + if i == 0 { + thoughtSig = v.getThoughtSignatureFromContext(ctx, tc.Id) + } + + content.Parts = append(content.Parts, vertexPart{ + FunctionCall: &vertexFunctionCall{ + Name: tc.Function.Name, + Args: args, + }, + ThoughtSignature: thoughtSig, + }) } - content.Parts = append(content.Parts, vertexPart{ - FunctionCall: &vertexFunctionCall{ - Name: lastFunctionName, - Args: args, - }, - }) } else { for _, part := range message.ParseContent() { switch part.Type { case contentTypeText: if message.Role == roleTool { + // Use tool_call_id to find the corresponding function name + functionName := toolCallIdToFunctionName[message.ToolCallId] + if functionName == "" { + log.Warnf("[vertexProvider] could not find function name for tool_call_id: %s", message.ToolCallId) + } + + // Note: thought_signature is attached to the functionCall part (model's response), + // NOT the functionResponse part (tool's response). + // This follows Google's documentation requirements. content.Parts = append(content.Parts, vertexPart{ FunctionResponse: &vertexFunctionResponse{ - Name: lastFunctionName, + Name: functionName, Response: vertexFunctionResponseDetail{ Output: part.Text, }, @@ -975,6 +1324,7 @@ type vertexPart struct { FunctionCall *vertexFunctionCall `json:"functionCall,omitempty"` FunctionResponse *vertexFunctionResponse `json:"functionResponse,omitempty"` Thounght *bool `json:"thought,omitempty"` + ThoughtSignature string `json:"thoughtSignature,omitempty"` } type blob struct {