Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions internal/prompts/promptsets.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ func (p Promptset) ToConfig() PromptsetConfig {
return p.PromptsetConfig
}

// ContainsPrompt reports whether the promptset includes a prompt with the given name.
func (p Promptset) ContainsPrompt(name string) bool {
for _, n := range p.PromptNames {
if n == name {
return true
}
}
return false
}
Comment on lines +40 to +47
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ContainsPrompt method uses a linear search ($O(N)$), which is executed on every prompts/get request. While likely acceptable for small promptsets, this could impact performance as the number of prompts grows. Consider using a map for $O(1)$ lookups, which could be initialized once in the Initialize method.


type PromptsetManifest struct {
ServerVersion string `json:"serverVersion"`
PromptsManifest map[string]Manifest `json:"prompts"`
Expand Down
64 changes: 64 additions & 0 deletions internal/prompts/promptsets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,70 @@ func newMockPrompt(name, desc string) prompts.Prompt {
}
}

func TestPromptset_ContainsPrompt(t *testing.T) {
t.Parallel()

promptset := prompts.Promptset{
PromptsetConfig: prompts.PromptsetConfig{
Name: "test-promptset",
PromptNames: []string{"greet", "summarize"},
},
}

tests := []struct {
name string
promptName string
want bool
}{
{
name: "prompt exists in promptset",
promptName: "greet",
want: true,
},
{
name: "another prompt exists in promptset",
promptName: "summarize",
want: true,
},
{
name: "prompt not in promptset",
promptName: "admin_prompt",
want: false,
},
{
name: "empty prompt name",
promptName: "",
want: false,
},
}

for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := promptset.ContainsPrompt(tc.promptName)
if got != tc.want {
t.Errorf("ContainsPrompt(%q) = %v, want %v", tc.promptName, got, tc.want)
}
})
}
}

func TestPromptset_ContainsPrompt_EmptyPromptset(t *testing.T) {
t.Parallel()

promptset := prompts.Promptset{
PromptsetConfig: prompts.PromptsetConfig{
Name: "empty-promptset",
PromptNames: []string{},
},
}

if promptset.ContainsPrompt("anything") {
t.Error("ContainsPrompt should return false for empty promptset")
}
}

func TestPromptsetConfig_Initialize(t *testing.T) {
t.Parallel()

Expand Down
20 changes: 16 additions & 4 deletions internal/server/mcp/v20241105/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, too
case TOOLS_LIST:
return toolsListHandler(id, toolset, body)
case TOOLS_CALL:
return toolsCallHandler(ctx, id, resourceMgr, body, header)
return toolsCallHandler(ctx, id, toolset, resourceMgr, body, header)
case PROMPTS_LIST:
return promptsListHandler(ctx, id, promptset, body)
case PROMPTS_GET:
return promptsGetHandler(ctx, id, resourceMgr, body)
return promptsGetHandler(ctx, id, promptset, resourceMgr, body)
default:
err := fmt.Errorf("invalid method %s", method)
return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
Expand Down Expand Up @@ -87,7 +87,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte)
}

// toolsCallHandler generate a response for tools call.
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolset tools.Toolset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
authServices := resourceMgr.GetAuthServiceMap()

// retrieve logger from context
Expand All @@ -114,6 +114,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
attribute.String("gen_ai.operation.name", "execute_tool"),
)

// Verify tool belongs to the current toolset before resolving globally.
if !toolset.ContainsTool(toolName) {
err = fmt.Errorf("tool %q is not part of the current toolset", toolName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}

tool, ok := resourceMgr.GetTool(toolName)
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
Expand Down Expand Up @@ -339,7 +345,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro
}

// promptsGetHandler handles the "prompts/get" method.
func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) {
func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
Expand All @@ -361,6 +367,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r
span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName))
span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName))

// Verify prompt belongs to the current promptset before resolving globally.
if !promptset.ContainsPrompt(promptName) {
err := fmt.Errorf("prompt %q is not part of the current promptset", promptName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}

prompt, ok := resourceMgr.GetPrompt(promptName)
if !ok {
err := fmt.Errorf("prompt with name %q does not exist", promptName)
Expand Down
20 changes: 16 additions & 4 deletions internal/server/mcp/v20250326/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, too
case TOOLS_LIST:
return toolsListHandler(id, toolset, body)
case TOOLS_CALL:
return toolsCallHandler(ctx, id, resourceMgr, body, header)
return toolsCallHandler(ctx, id, toolset, resourceMgr, body, header)
case PROMPTS_LIST:
return promptsListHandler(ctx, id, promptset, body)
case PROMPTS_GET:
return promptsGetHandler(ctx, id, resourceMgr, body)
return promptsGetHandler(ctx, id, promptset, resourceMgr, body)
default:
err := fmt.Errorf("invalid method %s", method)
return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
Expand Down Expand Up @@ -87,7 +87,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte)
}

// toolsCallHandler generate a response for tools call.
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolset tools.Toolset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
authServices := resourceMgr.GetAuthServiceMap()

// retrieve logger from context
Expand All @@ -114,6 +114,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
attribute.String("gen_ai.operation.name", "execute_tool"),
)

// Verify tool belongs to the current toolset before resolving globally.
if !toolset.ContainsTool(toolName) {
err = fmt.Errorf("tool %q is not part of the current toolset", toolName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}

tool, ok := resourceMgr.GetTool(toolName)
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
Expand Down Expand Up @@ -338,7 +344,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro
}

// promptsGetHandler handles the "prompts/get" method.
func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) {
func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
Expand All @@ -360,6 +366,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r
span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName))
span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName))

// Verify prompt belongs to the current promptset before resolving globally.
if !promptset.ContainsPrompt(promptName) {
err := fmt.Errorf("prompt %q is not part of the current promptset", promptName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}

prompt, ok := resourceMgr.GetPrompt(promptName)
if !ok {
err := fmt.Errorf("prompt with name %q does not exist", promptName)
Expand Down
20 changes: 16 additions & 4 deletions internal/server/mcp/v20250618/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, too
case TOOLS_LIST:
return toolsListHandler(id, toolset, body)
case TOOLS_CALL:
return toolsCallHandler(ctx, id, resourceMgr, body, header)
return toolsCallHandler(ctx, id, toolset, resourceMgr, body, header)
case PROMPTS_LIST:
return promptsListHandler(ctx, id, promptset, body)
case PROMPTS_GET:
return promptsGetHandler(ctx, id, resourceMgr, body)
return promptsGetHandler(ctx, id, promptset, resourceMgr, body)
default:
err := fmt.Errorf("invalid method %s", method)
return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
Expand Down Expand Up @@ -80,7 +80,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte)
}

// toolsCallHandler generate a response for tools call.
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolset tools.Toolset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
authServices := resourceMgr.GetAuthServiceMap()

// retrieve logger from context
Expand All @@ -107,6 +107,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
attribute.String("gen_ai.operation.name", "execute_tool"),
)

// Verify tool belongs to the current toolset before resolving globally.
if !toolset.ContainsTool(toolName) {
err = fmt.Errorf("tool %q is not part of the current toolset", toolName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}

tool, ok := resourceMgr.GetTool(toolName)
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
Expand Down Expand Up @@ -332,7 +338,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro
}

// promptsGetHandler handles the "prompts/get" method.
func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) {
func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
Expand All @@ -354,6 +360,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r
span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName))
span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName))

// Verify prompt belongs to the current promptset before resolving globally.
if !promptset.ContainsPrompt(promptName) {
err := fmt.Errorf("prompt %q is not part of the current promptset", promptName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}

prompt, ok := resourceMgr.GetPrompt(promptName)
if !ok {
err := fmt.Errorf("prompt with name %q does not exist", promptName)
Expand Down
20 changes: 16 additions & 4 deletions internal/server/mcp/v20251125/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, too
case TOOLS_LIST:
return toolsListHandler(id, toolset, body)
case TOOLS_CALL:
return toolsCallHandler(ctx, id, resourceMgr, body, header)
return toolsCallHandler(ctx, id, toolset, resourceMgr, body, header)
case PROMPTS_LIST:
return promptsListHandler(ctx, id, promptset, body)
case PROMPTS_GET:
return promptsGetHandler(ctx, id, resourceMgr, body)
return promptsGetHandler(ctx, id, promptset, resourceMgr, body)
default:
err := fmt.Errorf("invalid method %s", method)
return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
Expand Down Expand Up @@ -80,7 +80,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte)
}

// toolsCallHandler generate a response for tools call.
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolset tools.Toolset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
authServices := resourceMgr.GetAuthServiceMap()

// retrieve logger from context
Expand All @@ -107,6 +107,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
attribute.String("gen_ai.operation.name", "execute_tool"),
)

// Verify tool belongs to the current toolset before resolving globally.
if !toolset.ContainsTool(toolName) {
err = fmt.Errorf("tool %q is not part of the current toolset", toolName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}

tool, ok := resourceMgr.GetTool(toolName)
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
Expand Down Expand Up @@ -332,7 +338,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro
}

// promptsGetHandler handles the "prompts/get" method.
func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) {
func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
Expand All @@ -354,6 +360,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r
span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName))
span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName))

// Verify prompt belongs to the current promptset before resolving globally.
if !promptset.ContainsPrompt(promptName) {
err := fmt.Errorf("prompt %q is not part of the current promptset", promptName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}

prompt, ok := resourceMgr.GetPrompt(promptName)
if !ok {
err := fmt.Errorf("prompt with name %q does not exist", promptName)
Expand Down
10 changes: 10 additions & 0 deletions internal/tools/toolsets.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ func (t Toolset) ToConfig() ToolsetConfig {
return t.ToolsetConfig
}

// ContainsTool reports whether the toolset includes a tool with the given name.
func (t Toolset) ContainsTool(name string) bool {
for _, n := range t.ToolNames {
if n == name {
return true
}
}
return false
}
Comment on lines +39 to +46
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ContainsTool method performs a linear search ($O(N)$) on every tools/call. For toolsets with many tools, this may become a performance bottleneck. Using a map for $O(1)$ lookups (populated during Initialize) would be more efficient and scalable.


type ToolsetManifest struct {
ServerVersion string `json:"serverVersion"`
ToolsManifest map[string]Manifest `json:"tools"`
Expand Down
Loading