diff --git a/internal/prompts/promptsets.go b/internal/prompts/promptsets.go index 2d3cf6315246..a147ca9c8917 100644 --- a/internal/prompts/promptsets.go +++ b/internal/prompts/promptsets.go @@ -36,6 +36,16 @@ func (p Promptset) ToConfig() PromptsetConfig { return p.PromptsetConfig } +// ContainsPrompt reports whether the promptset includes a prompt with the given name. +func (p Promptset) ContainsPrompt(name string) bool { + for _, n := range p.PromptNames { + if n == name { + return true + } + } + return false +} + type PromptsetManifest struct { ServerVersion string `json:"serverVersion"` PromptsManifest map[string]Manifest `json:"prompts"` diff --git a/internal/prompts/promptsets_test.go b/internal/prompts/promptsets_test.go index 170120de5e1c..8ed1900cb58c 100644 --- a/internal/prompts/promptsets_test.go +++ b/internal/prompts/promptsets_test.go @@ -65,6 +65,70 @@ func newMockPrompt(name, desc string) prompts.Prompt { } } +func TestPromptset_ContainsPrompt(t *testing.T) { + t.Parallel() + + promptset := prompts.Promptset{ + PromptsetConfig: prompts.PromptsetConfig{ + Name: "test-promptset", + PromptNames: []string{"greet", "summarize"}, + }, + } + + tests := []struct { + name string + promptName string + want bool + }{ + { + name: "prompt exists in promptset", + promptName: "greet", + want: true, + }, + { + name: "another prompt exists in promptset", + promptName: "summarize", + want: true, + }, + { + name: "prompt not in promptset", + promptName: "admin_prompt", + want: false, + }, + { + name: "empty prompt name", + promptName: "", + want: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := promptset.ContainsPrompt(tc.promptName) + if got != tc.want { + t.Errorf("ContainsPrompt(%q) = %v, want %v", tc.promptName, got, tc.want) + } + }) + } +} + +func TestPromptset_ContainsPrompt_EmptyPromptset(t *testing.T) { + t.Parallel() + + promptset := prompts.Promptset{ + PromptsetConfig: prompts.PromptsetConfig{ + Name: "empty-promptset", + PromptNames: []string{}, + }, + } + + if promptset.ContainsPrompt("anything") { + t.Error("ContainsPrompt should return false for empty promptset") + } +} + func TestPromptsetConfig_Initialize(t *testing.T) { t.Parallel() diff --git a/internal/server/mcp/v20241105/method.go b/internal/server/mcp/v20241105/method.go index efcbab26f64c..f9aeb25d2b42 100644 --- a/internal/server/mcp/v20241105/method.go +++ b/internal/server/mcp/v20241105/method.go @@ -42,11 +42,11 @@ func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, too case TOOLS_LIST: return toolsListHandler(id, toolset, body) case TOOLS_CALL: - return toolsCallHandler(ctx, id, resourceMgr, body, header) + return toolsCallHandler(ctx, id, toolset, resourceMgr, body, header) case PROMPTS_LIST: return promptsListHandler(ctx, id, promptset, body) case PROMPTS_GET: - return promptsGetHandler(ctx, id, resourceMgr, body) + return promptsGetHandler(ctx, id, promptset, resourceMgr, body) default: err := fmt.Errorf("invalid method %s", method) return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err @@ -87,7 +87,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) } // toolsCallHandler generate a response for tools call. -func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { +func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolset tools.Toolset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { authServices := resourceMgr.GetAuthServiceMap() // retrieve logger from context @@ -114,6 +114,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re attribute.String("gen_ai.operation.name", "execute_tool"), ) + // Verify tool belongs to the current toolset before resolving globally. + if !toolset.ContainsTool(toolName) { + err = fmt.Errorf("tool %q is not part of the current toolset", toolName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -339,7 +345,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro } // promptsGetHandler handles the "prompts/get" method. -func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) { +func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte) (any, error) { // retrieve logger from context logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -361,6 +367,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) + // Verify prompt belongs to the current promptset before resolving globally. + if !promptset.ContainsPrompt(promptName) { + err := fmt.Errorf("prompt %q is not part of the current promptset", promptName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) diff --git a/internal/server/mcp/v20250326/method.go b/internal/server/mcp/v20250326/method.go index 1d4292f38467..76b859664dcd 100644 --- a/internal/server/mcp/v20250326/method.go +++ b/internal/server/mcp/v20250326/method.go @@ -42,11 +42,11 @@ func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, too case TOOLS_LIST: return toolsListHandler(id, toolset, body) case TOOLS_CALL: - return toolsCallHandler(ctx, id, resourceMgr, body, header) + return toolsCallHandler(ctx, id, toolset, resourceMgr, body, header) case PROMPTS_LIST: return promptsListHandler(ctx, id, promptset, body) case PROMPTS_GET: - return promptsGetHandler(ctx, id, resourceMgr, body) + return promptsGetHandler(ctx, id, promptset, resourceMgr, body) default: err := fmt.Errorf("invalid method %s", method) return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err @@ -87,7 +87,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) } // toolsCallHandler generate a response for tools call. -func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { +func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolset tools.Toolset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { authServices := resourceMgr.GetAuthServiceMap() // retrieve logger from context @@ -114,6 +114,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re attribute.String("gen_ai.operation.name", "execute_tool"), ) + // Verify tool belongs to the current toolset before resolving globally. + if !toolset.ContainsTool(toolName) { + err = fmt.Errorf("tool %q is not part of the current toolset", toolName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -338,7 +344,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro } // promptsGetHandler handles the "prompts/get" method. -func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) { +func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte) (any, error) { // retrieve logger from context logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -360,6 +366,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) + // Verify prompt belongs to the current promptset before resolving globally. + if !promptset.ContainsPrompt(promptName) { + err := fmt.Errorf("prompt %q is not part of the current promptset", promptName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) diff --git a/internal/server/mcp/v20250618/method.go b/internal/server/mcp/v20250618/method.go index 529bd90e3870..94bd47859d7f 100644 --- a/internal/server/mcp/v20250618/method.go +++ b/internal/server/mcp/v20250618/method.go @@ -42,11 +42,11 @@ func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, too case TOOLS_LIST: return toolsListHandler(id, toolset, body) case TOOLS_CALL: - return toolsCallHandler(ctx, id, resourceMgr, body, header) + return toolsCallHandler(ctx, id, toolset, resourceMgr, body, header) case PROMPTS_LIST: return promptsListHandler(ctx, id, promptset, body) case PROMPTS_GET: - return promptsGetHandler(ctx, id, resourceMgr, body) + return promptsGetHandler(ctx, id, promptset, resourceMgr, body) default: err := fmt.Errorf("invalid method %s", method) return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err @@ -80,7 +80,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) } // toolsCallHandler generate a response for tools call. -func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { +func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolset tools.Toolset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { authServices := resourceMgr.GetAuthServiceMap() // retrieve logger from context @@ -107,6 +107,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re attribute.String("gen_ai.operation.name", "execute_tool"), ) + // Verify tool belongs to the current toolset before resolving globally. + if !toolset.ContainsTool(toolName) { + err = fmt.Errorf("tool %q is not part of the current toolset", toolName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -332,7 +338,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro } // promptsGetHandler handles the "prompts/get" method. -func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) { +func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte) (any, error) { // retrieve logger from context logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -354,6 +360,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) + // Verify prompt belongs to the current promptset before resolving globally. + if !promptset.ContainsPrompt(promptName) { + err := fmt.Errorf("prompt %q is not part of the current promptset", promptName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) diff --git a/internal/server/mcp/v20251125/method.go b/internal/server/mcp/v20251125/method.go index 68fd5c3bfd57..11d64ecc3727 100644 --- a/internal/server/mcp/v20251125/method.go +++ b/internal/server/mcp/v20251125/method.go @@ -42,11 +42,11 @@ func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, too case TOOLS_LIST: return toolsListHandler(id, toolset, body) case TOOLS_CALL: - return toolsCallHandler(ctx, id, resourceMgr, body, header) + return toolsCallHandler(ctx, id, toolset, resourceMgr, body, header) case PROMPTS_LIST: return promptsListHandler(ctx, id, promptset, body) case PROMPTS_GET: - return promptsGetHandler(ctx, id, resourceMgr, body) + return promptsGetHandler(ctx, id, promptset, resourceMgr, body) default: err := fmt.Errorf("invalid method %s", method) return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err @@ -80,7 +80,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) } // toolsCallHandler generate a response for tools call. -func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { +func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolset tools.Toolset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { authServices := resourceMgr.GetAuthServiceMap() // retrieve logger from context @@ -107,6 +107,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re attribute.String("gen_ai.operation.name", "execute_tool"), ) + // Verify tool belongs to the current toolset before resolving globally. + if !toolset.ContainsTool(toolName) { + err = fmt.Errorf("tool %q is not part of the current toolset", toolName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -332,7 +338,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro } // promptsGetHandler handles the "prompts/get" method. -func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) { +func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte) (any, error) { // retrieve logger from context logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -354,6 +360,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) + // Verify prompt belongs to the current promptset before resolving globally. + if !promptset.ContainsPrompt(promptName) { + err := fmt.Errorf("prompt %q is not part of the current promptset", promptName) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) diff --git a/internal/tools/toolsets.go b/internal/tools/toolsets.go index b429ef5b19df..008ba844e5c6 100644 --- a/internal/tools/toolsets.go +++ b/internal/tools/toolsets.go @@ -35,6 +35,16 @@ func (t Toolset) ToConfig() ToolsetConfig { return t.ToolsetConfig } +// ContainsTool reports whether the toolset includes a tool with the given name. +func (t Toolset) ContainsTool(name string) bool { + for _, n := range t.ToolNames { + if n == name { + return true + } + } + return false +} + type ToolsetManifest struct { ServerVersion string `json:"serverVersion"` ToolsManifest map[string]Manifest `json:"tools"` diff --git a/internal/tools/toolsets_test.go b/internal/tools/toolsets_test.go new file mode 100644 index 000000000000..e72e49d3e878 --- /dev/null +++ b/internal/tools/toolsets_test.go @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools_test + +import ( + "testing" + + "github.com/googleapis/mcp-toolbox/internal/tools" +) + +func TestToolset_ContainsTool(t *testing.T) { + t.Parallel() + + toolset := tools.Toolset{ + ToolsetConfig: tools.ToolsetConfig{ + Name: "test-toolset", + ToolNames: []string{"echo", "list_tables"}, + }, + } + + tests := []struct { + name string + toolName string + want bool + }{ + { + name: "tool exists in toolset", + toolName: "echo", + want: true, + }, + { + name: "another tool exists in toolset", + toolName: "list_tables", + want: true, + }, + { + name: "tool not in toolset", + toolName: "admin_delete", + want: false, + }, + { + name: "empty tool name", + toolName: "", + want: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := toolset.ContainsTool(tc.toolName) + if got != tc.want { + t.Errorf("ContainsTool(%q) = %v, want %v", tc.toolName, got, tc.want) + } + }) + } +} + +func TestToolset_ContainsTool_EmptyToolset(t *testing.T) { + t.Parallel() + + toolset := tools.Toolset{ + ToolsetConfig: tools.ToolsetConfig{ + Name: "empty-toolset", + ToolNames: []string{}, + }, + } + + if toolset.ContainsTool("anything") { + t.Error("ContainsTool should return false for empty toolset") + } +}