diff --git a/mcp/shared.go b/mcp/shared.go index cc28de85..b5b09db2 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -36,6 +36,7 @@ const ( // It is the version that the client sends in the initialization request, and // the default version used by the server. latestProtocolVersion = protocolVersion20251125 + protocolVersion20260630 = "2026-06-30" protocolVersion20251125 = "2025-11-25" protocolVersion20250618 = "2025-06-18" protocolVersion20250326 = "2025-03-26" @@ -343,6 +344,9 @@ func clientSessionMethod[P Params, R Result](f func(*ClientSession, context.Cont // MCP-specific error codes. const ( + // CodeHeaderMismatch indicates that HTTP headers do not match the corresponding values + // in the request body, or that required headers are missing or malformed. + CodeHeaderMismatch = -32001 // CodeResourceNotFound indicates that a requested resource could not be found. CodeResourceNotFound = -32002 // CodeURLElicitationRequired indicates that the server requires URL elicitation diff --git a/mcp/streamable.go b/mcp/streamable.go index a176789e..4470954e 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -39,12 +39,6 @@ import ( "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) -const ( - protocolVersionHeader = "Mcp-Protocol-Version" - sessionIDHeader = "Mcp-Session-Id" - lastEventIDHeader = "Last-Event-ID" -) - // A StreamableHTTPHandler is an http.Handler that serves streamable MCP // sessions, as defined by the [MCP spec]. // @@ -1191,6 +1185,24 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } } + // Validate MCP standard headers (Mcp-Method, Mcp-Name) + if !isBatch && len(incoming) == 1 { + if err := validateMcpHeaders(req.Header, incoming[0]); err != nil { + resp := &jsonrpc.Response{ + Error: jsonrpc2.NewError(CodeHeaderMismatch, err.Error()), + } + if jreq, ok := incoming[0].(*jsonrpc.Request); ok { + resp.ID = jreq.ID + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + if data, err := jsonrpc2.EncodeMessage(resp); err == nil { + w.Write(data) + } + return + } + } + // The prime and close events were added in protocol version 2025-11-25 (SEP-1699). // Use the version from InitializeParams if this is an initialize request, // otherwise use the protocol version header. @@ -1797,6 +1809,9 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e // and permanently break the connection. return nil, nil, fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) } + // Keep this after the setMCPHeaders call to ensure that the + // protocol version header is set. + setStandardHeaders(req.Header, msg) resp, err := c.client.Do(req) if err != nil { // Any error from client.Do means the request didn't reach the server. @@ -1932,6 +1947,7 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) error { if c.sessionID != "" { req.Header.Set(sessionIDHeader, c.sessionID) } + return nil } diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 9aac7b54..96c5e75f 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -70,7 +70,7 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques key := streamableRequestKey{ httpMethod: req.Method, sessionID: req.Header.Get(sessionIDHeader), - lastEventID: req.Header.Get("Last-Event-ID"), // TODO: extract this to a constant, like sessionIDHeader + lastEventID: req.Header.Get(lastEventIDHeader), } var jsonrpcReq *jsonrpc.Request if req.Method == http.MethodPost { diff --git a/mcp/streamable_headers.go b/mcp/streamable_headers.go new file mode 100644 index 00000000..ac9a4121 --- /dev/null +++ b/mcp/streamable_headers.go @@ -0,0 +1,98 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + internaljson "github.com/modelcontextprotocol/go-sdk/internal/json" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +const ( + protocolVersionHeader = "Mcp-Protocol-Version" + sessionIDHeader = "Mcp-Session-Id" + lastEventIDHeader = "Last-Event-ID" + methodHeader = "Mcp-Method" + nameHeader = "Mcp-Name" + minVersionForStandardHeaders = protocolVersion20260630 +) + +func extractName(method string, params json.RawMessage) (string, bool) { + switch method { + case "tools/call": + var p CallToolParams + if err := internaljson.Unmarshal(params, &p); err == nil { + return p.Name, true + } + case "prompts/get": + var p GetPromptParams + if err := internaljson.Unmarshal(params, &p); err == nil { + return p.Name, true + } + case "resources/read": + var p ReadResourceParams + if err := internaljson.Unmarshal(params, &p); err == nil { + return p.URI, true + } + } + + return "", false +} + +// setStandardHeaders populates standard MCP headers. +// It requires the protocol version header to be set. +func setStandardHeaders(header http.Header, msg jsonrpc.Message) { + if msg == nil { + return + } + if header.Get(protocolVersionHeader) == "" || header.Get(protocolVersionHeader) < minVersionForStandardHeaders { + return + } + + switch msg := msg.(type) { + case *jsonrpc.Request: + header.Set(methodHeader, msg.Method) + if name, ok := extractName(msg.Method, msg.Params); ok { + header.Set(nameHeader, name) + } + } +} + +func validateMcpHeaders(header http.Header, msg jsonrpc.Message) error { + protocolVersion := header.Get(protocolVersionHeader) + if protocolVersion == "" || protocolVersion < minVersionForStandardHeaders { + return nil + } + + switch msg := msg.(type) { + case *jsonrpc.Request: + methodInHeader := header.Get(methodHeader) + if methodInHeader == "" { + return errors.New("missing required Mcp-Method header") + } + if methodInHeader != msg.Method { + return fmt.Errorf("header mismatch: Mcp-Method header value '%s' does not match body value '%s'", methodInHeader, msg.Method) + } + + if msg.Method == "tools/call" || msg.Method == "resources/read" || msg.Method == "prompts/get" { + nameInHeader := header.Get(nameHeader) + if nameInHeader == "" { + return fmt.Errorf("missing required Mcp-Name header for method %q", msg.Method) + } + nameInBody, ok := extractName(msg.Method, msg.Params) + if !ok { + return fmt.Errorf("failed to extract name from parameters for method %q", msg.Method) + } + if nameInHeader != nameInBody { + return fmt.Errorf("header mismatch: Mcp-Name header value '%s' does not match body value '%s'", nameInHeader, nameInBody) + } + } + } + return nil +} diff --git a/mcp/streamable_headers_test.go b/mcp/streamable_headers_test.go new file mode 100644 index 00000000..0abf1b28 --- /dev/null +++ b/mcp/streamable_headers_test.go @@ -0,0 +1,415 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "encoding/json" + "net/http" + "strings" + "testing" + + "github.com/modelcontextprotocol/go-sdk/jsonrpc" +) + +func TestExtractName(t *testing.T) { + tests := []struct { + name string + method string + params json.RawMessage + wantName string + wantOK bool + }{ + { + name: "tools/call", + method: "tools/call", + params: mustMarshal(&CallToolParams{Name: "my-tool"}), + wantName: "my-tool", + wantOK: true, + }, + { + name: "prompts/get", + method: "prompts/get", + params: mustMarshal(&GetPromptParams{Name: "code_review"}), + wantName: "code_review", + wantOK: true, + }, + { + name: "resources/read", + method: "resources/read", + params: mustMarshal(&ReadResourceParams{URI: "file:///info.txt"}), + wantName: "file:///info.txt", + wantOK: true, + }, + { + name: "tools/call with empty name", + method: "tools/call", + params: mustMarshal(&CallToolParams{Name: ""}), + wantName: "", + wantOK: true, + }, + { + name: "tool name with hyphen", + method: "tools/call", + params: mustMarshal(&CallToolParams{Name: "my-tool-v2"}), + wantName: "my-tool-v2", + wantOK: true, + }, + { + name: "tool name with underscore", + method: "tools/call", + params: mustMarshal(&CallToolParams{Name: "my_tool_name"}), + wantName: "my_tool_name", + wantOK: true, + }, + { + name: "resource URI with special chars", + method: "resources/read", + params: mustMarshal(&ReadResourceParams{URI: "file:///path/to/file%20name.txt"}), + wantName: "file:///path/to/file%20name.txt", + wantOK: true, + }, + { + name: "resource URI with query string", + method: "resources/read", + params: mustMarshal(&ReadResourceParams{URI: "https://example.com/resource?id=123"}), + wantName: "https://example.com/resource?id=123", + wantOK: true, + }, + { + name: "unrelated method", + method: "initialize", + params: mustMarshal(&InitializeParams{ProtocolVersion: "2025-06-18"}), + wantName: "", + wantOK: false, + }, + { + name: "notification method", + method: "notifications/initialized", + params: nil, + wantName: "", + wantOK: false, + }, + { + name: "invalid JSON params", + method: "tools/call", + params: []byte("not json"), + wantName: "", + wantOK: false, + }, + { + name: "nil params", + method: "tools/call", + params: nil, + wantName: "", + wantOK: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotName, gotOK := extractName(tt.method, tt.params) + if gotName != tt.wantName || gotOK != tt.wantOK { + t.Errorf("extractName(%q, ...) = (%q, %v), want (%q, %v)", + tt.method, gotName, gotOK, tt.wantName, tt.wantOK) + } + }) + } +} + +func TestSetStandardHeaders(t *testing.T) { + tests := []struct { + name string + protocolVersion string + msg jsonrpc.Message + wantMethodHeader string + wantNameHeader string + }{ + { + name: "tools/call with future version", + protocolVersion: minVersionForStandardHeaders, + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantMethodHeader: "tools/call", + wantNameHeader: "my-tool", + }, + { + name: "prompts/get with future version", + protocolVersion: minVersionForStandardHeaders, + msg: &jsonrpc.Request{Method: "prompts/get", Params: mustMarshal(&GetPromptParams{Name: "code_review"})}, + wantMethodHeader: "prompts/get", + wantNameHeader: "code_review", + }, + { + name: "resources/read with future version", + protocolVersion: minVersionForStandardHeaders, + msg: &jsonrpc.Request{Method: "resources/read", Params: mustMarshal(&ReadResourceParams{URI: "file:///info.txt"})}, + wantMethodHeader: "resources/read", + wantNameHeader: "file:///info.txt", + }, + { + name: "initialize sets method only", + protocolVersion: minVersionForStandardHeaders, + msg: &jsonrpc.Request{Method: "initialize", Params: mustMarshal(&InitializeParams{ProtocolVersion: minVersionForStandardHeaders})}, + wantMethodHeader: "initialize", + wantNameHeader: "", + }, + { + name: "notification sets method only", + protocolVersion: minVersionForStandardHeaders, + msg: &jsonrpc.Request{Method: "notifications/initialized"}, + wantMethodHeader: "notifications/initialized", + wantNameHeader: "", + }, + { + name: "old version skips all headers", + protocolVersion: protocolVersion20251125, + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantMethodHeader: "", + wantNameHeader: "", + }, + { + name: "empty version skips all headers", + protocolVersion: "", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantMethodHeader: "", + wantNameHeader: "", + }, + { + name: "nil message is a no-op", + protocolVersion: minVersionForStandardHeaders, + msg: nil, + wantMethodHeader: "", + wantNameHeader: "", + }, + { + name: "response message is ignored", + protocolVersion: minVersionForStandardHeaders, + msg: &jsonrpc.Response{}, + wantMethodHeader: "", + wantNameHeader: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + header := http.Header{} + if tt.protocolVersion != "" { + header.Set(protocolVersionHeader, tt.protocolVersion) + } + + setStandardHeaders(header, tt.msg) + + if got := header.Get(methodHeader); got != tt.wantMethodHeader { + t.Errorf("MethodHeader = %q, want %q", got, tt.wantMethodHeader) + } + if got := header.Get(nameHeader); got != tt.wantNameHeader { + t.Errorf("NameHeader = %q, want %q", got, tt.wantNameHeader) + } + }) + } +} + +func TestValidateMcpHeaders(t *testing.T) { + + tests := []struct { + name string + version string + methodHeader string + nameHeader string + msg jsonrpc.Message + wantErr bool + wantErrContain string + }{ + { + name: "old version skips validation", + version: protocolVersion20251125, + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantErr: false, + }, + { + name: "empty version skips validation", + version: "", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantErr: false, + }, + { + name: "missing Mcp-Method header", + version: minVersionForStandardHeaders, + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantErr: true, + wantErrContain: "missing required Mcp-Method header", + }, + { + name: "missing Mcp-Name for tools/call", + version: minVersionForStandardHeaders, + methodHeader: "tools/call", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantErr: true, + wantErrContain: "missing required Mcp-Name header", + }, + { + name: "missing Mcp-Name for resources/read", + version: minVersionForStandardHeaders, + methodHeader: "resources/read", + msg: &jsonrpc.Request{Method: "resources/read", Params: mustMarshal(&ReadResourceParams{URI: "file:///info.txt"})}, + wantErr: true, + wantErrContain: "missing required Mcp-Name header", + }, + { + name: "missing Mcp-Name for prompts/get", + version: minVersionForStandardHeaders, + methodHeader: "prompts/get", + msg: &jsonrpc.Request{Method: "prompts/get", Params: mustMarshal(&GetPromptParams{Name: "review"})}, + wantErr: true, + wantErrContain: "missing required Mcp-Name header", + }, + { + name: "method mismatch", + version: minVersionForStandardHeaders, + methodHeader: "tools/call", + msg: &jsonrpc.Request{Method: "prompts/get", Params: mustMarshal(&GetPromptParams{Name: "review"})}, + wantErr: true, + wantErrContain: "Mcp-Method header value 'tools/call' does not match body value 'prompts/get'", + }, + { + name: "tool name mismatch", + version: minVersionForStandardHeaders, + methodHeader: "tools/call", + nameHeader: "wrong-tool", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "right-tool"})}, + wantErr: true, + wantErrContain: "Mcp-Name header value 'wrong-tool' does not match body value 'right-tool'", + }, + { + name: "resource URI mismatch", + version: minVersionForStandardHeaders, + methodHeader: "resources/read", + nameHeader: "file:///wrong.txt", + msg: &jsonrpc.Request{Method: "resources/read", Params: mustMarshal(&ReadResourceParams{URI: "file:///right.txt"})}, + wantErr: true, + wantErrContain: "Mcp-Name header value 'file:///wrong.txt' does not match body value 'file:///right.txt'", + }, + { + name: "prompt name mismatch", + version: minVersionForStandardHeaders, + methodHeader: "prompts/get", + nameHeader: "wrong-prompt", + msg: &jsonrpc.Request{Method: "prompts/get", Params: mustMarshal(&GetPromptParams{Name: "right-prompt"})}, + wantErr: true, + wantErrContain: "Mcp-Name header value 'wrong-prompt' does not match body value 'right-prompt'", + }, + { + name: "method value is case-sensitive", + version: minVersionForStandardHeaders, + methodHeader: "TOOLS/CALL", + nameHeader: "my-tool", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantErr: true, + wantErrContain: "Mcp-Method header value 'TOOLS/CALL' does not match body value 'tools/call'", + }, + { + name: "valid tools/call", + version: minVersionForStandardHeaders, + methodHeader: "tools/call", + nameHeader: "my-tool", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool"})}, + wantErr: false, + }, + { + name: "valid resources/read", + version: minVersionForStandardHeaders, + methodHeader: "resources/read", + nameHeader: "file:///info.txt", + msg: &jsonrpc.Request{Method: "resources/read", Params: mustMarshal(&ReadResourceParams{URI: "file:///info.txt"})}, + wantErr: false, + }, + { + name: "valid prompts/get", + version: minVersionForStandardHeaders, + methodHeader: "prompts/get", + nameHeader: "code_review", + msg: &jsonrpc.Request{Method: "prompts/get", Params: mustMarshal(&GetPromptParams{Name: "code_review"})}, + wantErr: false, + }, + { + name: "valid initialize (no name needed)", + version: minVersionForStandardHeaders, + methodHeader: "initialize", + msg: &jsonrpc.Request{Method: "initialize", Params: mustMarshal(&InitializeParams{ProtocolVersion: minVersionForStandardHeaders})}, + wantErr: false, + }, + { + name: "valid notification (no name needed)", + version: minVersionForStandardHeaders, + methodHeader: "notifications/initialized", + msg: &jsonrpc.Request{Method: "notifications/initialized"}, + wantErr: false, + }, + { + name: "tool name with hyphen", + version: minVersionForStandardHeaders, + methodHeader: "tools/call", + nameHeader: "my-tool-name", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my-tool-name"})}, + wantErr: false, + }, + { + name: "tool name with underscore", + version: minVersionForStandardHeaders, + methodHeader: "tools/call", + nameHeader: "my_tool_name", + msg: &jsonrpc.Request{Method: "tools/call", Params: mustMarshal(&CallToolParams{Name: "my_tool_name"})}, + wantErr: false, + }, + { + name: "resource URI with special chars", + version: minVersionForStandardHeaders, + methodHeader: "resources/read", + nameHeader: "file:///path/to/file%20name.txt", + msg: &jsonrpc.Request{Method: "resources/read", Params: mustMarshal(&ReadResourceParams{URI: "file:///path/to/file%20name.txt"})}, + wantErr: false, + }, + { + name: "resource URI with query string", + version: minVersionForStandardHeaders, + methodHeader: "resources/read", + nameHeader: "https://example.com/resource?id=123", + msg: &jsonrpc.Request{Method: "resources/read", Params: mustMarshal(&ReadResourceParams{URI: "https://example.com/resource?id=123"})}, + wantErr: false, + }, + { + name: "response message is ignored", + version: minVersionForStandardHeaders, + msg: &jsonrpc.Response{}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + header := http.Header{} + if tt.version != "" { + header.Set(protocolVersionHeader, tt.version) + } + if tt.methodHeader != "" { + header.Set(methodHeader, tt.methodHeader) + } + if tt.nameHeader != "" { + header.Set(nameHeader, tt.nameHeader) + } + + err := validateMcpHeaders(header, tt.msg) + if tt.wantErr { + if err == nil { + t.Fatalf("validateMcpHeaders() = nil, want error containing %q", tt.wantErrContain) + } + if !strings.Contains(err.Error(), tt.wantErrContain) { + t.Errorf("validateMcpHeaders() error = %q, want substring %q", err.Error(), tt.wantErrContain) + } + } else if err != nil { + t.Errorf("validateMcpHeaders() = %v, want nil", err) + } + }) + } +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 83326218..80fb4baa 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1781,7 +1781,7 @@ func TestSessionHijackingPrevention(t *testing.T) { req.Header.Set("Accept", "application/json, text/event-stream") req.Header.Set("Authorization", "Bearer "+userID) if sessionID != "" { - req.Header.Set("Mcp-Session-Id", sessionID) + req.Header.Set(sessionIDHeader, sessionID) } resp, err := http.DefaultClient.Do(req) if err != nil { @@ -1802,7 +1802,7 @@ func TestSessionHijackingPrevention(t *testing.T) { body, _ := io.ReadAll(resp.Body) t.Fatalf("initialize failed with status %d: %s", resp.StatusCode, body) } - sessionID := resp.Header.Get("Mcp-Session-Id") + sessionID := resp.Header.Get(sessionIDHeader) if sessionID == "" { t.Fatal("no session ID in response") } @@ -1915,6 +1915,261 @@ func TestStreamableGET(t *testing.T) { } } +// TestStreamableMcpHeaderValidation tests the server-side Mcp-Method and +// Mcp-Name header validation through the full HTTP handler, as specified +// in SEP-2243. +func TestStreamableMcpHeaderValidation(t *testing.T) { + // Temporarily register the future version so the handler accepts it. + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), minVersionForStandardHeaders) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + server.AddTool( + &Tool{Name: "my-tool", InputSchema: &jsonschema.Schema{Type: "object"}}, + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{}, nil + }) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + defer handler.closeAll() + + initReq := req(1, methodInitialize, &InitializeParams{ProtocolVersion: minVersionForStandardHeaders}) + initResp := resp(1, &InitializeResult{ + Capabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + ProtocolVersion: minVersionForStandardHeaders, + ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, + }, nil) + + initialize := streamableRequest{ + method: "POST", + messages: []jsonrpc.Message{initReq}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{initResp}, + wantSessionID: true, + } + initialized := streamableRequest{ + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {notificationInitialized}, + }, + messages: []jsonrpc.Message{req(0, notificationInitialized, &InitializedParams{})}, + wantStatusCode: http.StatusAccepted, + } + + testStreamableHandler(t, handler, []streamableRequest{ + initialize, + initialized, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"my-tool"}, + }, + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "my-tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, + }, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"prompts/get"}, + nameHeader: {"my-tool"}, + }, + messages: []jsonrpc.Message{req(3, "tools/call", &CallToolParams{Name: "my-tool"})}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "Mcp-Method header value", + }, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"wrong-tool"}, + }, + messages: []jsonrpc.Message{req(4, "tools/call", &CallToolParams{Name: "my-tool"})}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "Mcp-Name header value", + }, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"TOOLS/CALL"}, + nameHeader: {"my-tool"}, + }, + messages: []jsonrpc.Message{req(5, "tools/call", &CallToolParams{Name: "my-tool"})}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "Mcp-Method header value", + }, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"my-tool"}, + }, + messages: []jsonrpc.Message{req(6, "tools/call", &CallToolParams{Name: "my-tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)}, + }, + }) +} + +// TestStreamableMcpHeaderValidationErrorFormat verifies that header +// validation errors return a JSON-RPC error with code -32001 and +// Content-Type application/json, per SEP-2243. +func TestStreamableMcpHeaderValidationErrorFormat(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), minVersionForStandardHeaders) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + server.AddTool( + &Tool{Name: "my-tool", InputSchema: &jsonschema.Schema{Type: "object"}}, + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{}, nil + }) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + defer handler.closeAll() + + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + // Use the MCP client with a custom RoundTripper to inject a bad header. + var toolCallResp *http.Response + var toolCallRespBody []byte + + customClient := &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + var originalMethodHeader string + if req.Header.Get(methodHeader) == "tools/call" { + originalMethodHeader = req.Header.Get(methodHeader) + req.Header.Set(methodHeader, "wrong-method") + } + resp, err := http.DefaultTransport.RoundTrip(req) + if err == nil && originalMethodHeader == "tools/call" { + toolCallResp = resp + toolCallRespBody, _ = io.ReadAll(resp.Body) + } + return resp, err + }), + } + + clientTransport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + HTTPClient: customClient, + } + + client := NewClient(&Implementation{Name: "testClient", Version: "v1.0.0"}, nil) + ctx := context.Background() + session, err := client.Connect(ctx, clientTransport, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders}) + if err != nil { + t.Fatal(err) + } + defer session.Close() + + _, err = session.CallTool(ctx, &CallToolParams{Name: "my-tool"}) + // We expect an error because the server should reject it. + if err == nil { + t.Error("CallTool succeeded unexpectedly") + } + + if toolCallResp == nil { + t.Fatal("no response captured") + } + + // Verify HTTP status code. + if toolCallResp.StatusCode != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", toolCallResp.StatusCode, http.StatusBadRequest) + } + + // Verify Content-Type. + if baseMediaType(toolCallResp.Header.Get("Content-Type")) != "application/json" { + t.Errorf("Content-Type = %q, want %q", baseMediaType(toolCallResp.Header.Get("Content-Type")), "application/json") + } + + // Verify JSON-RPC error body contains error code -32001. + msg, err := jsonrpc2.DecodeMessage(toolCallRespBody) + if err != nil { + t.Fatalf("failed to decode message: %v", err) + } + resp, ok := msg.(*jsonrpc2.Response) + if !ok { + t.Fatalf("expected *jsonrpc2.Response, got %T", msg) + } + var wireErr *jsonrpc2.WireError + if !errors.As(resp.Error, &wireErr) { + t.Fatalf("expected *jsonrpc2.WireError, got %T", resp.Error) + } + if wireErr.Code != -32001 { + t.Errorf("wireErr.Code = %d, want -32001", wireErr.Code) + } + if !strings.Contains(wireErr.Message, "Mcp-Method header value") { + t.Errorf("wireErr.Message = %q, want it to contain %q", wireErr.Message, "Mcp-Method header value") + } +} + +// TestStreamableMcpHeaderVersionGating verifies that header validation +// is skipped for protocol versions older than minVersionForStandardHeaders. +func TestStreamableMcpHeaderVersionGating(t *testing.T) { + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + server.AddTool( + &Tool{Name: "my-tool", InputSchema: &jsonschema.Schema{Type: "object"}}, + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{}, nil + }) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + defer handler.closeAll() + + initReq := req(1, methodInitialize, &InitializeParams{ProtocolVersion: protocolVersion20251125}) + initResp := resp(1, &InitializeResult{ + Capabilities: &ServerCapabilities{ + Logging: &LoggingCapabilities{}, + Tools: &ToolCapabilities{ListChanged: true}, + }, + ProtocolVersion: protocolVersion20251125, + ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"}, + }, nil) + + testStreamableHandler(t, handler, []streamableRequest{ + { + method: "POST", + messages: []jsonrpc.Message{initReq}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{initResp}, + wantSessionID: true, + }, + { + method: "POST", + headers: http.Header{protocolVersionHeader: {protocolVersion20251125}}, + messages: []jsonrpc.Message{req(0, notificationInitialized, &InitializedParams{})}, + wantStatusCode: http.StatusAccepted, + }, + // Requests with deliberately wrong MCP headers should still succeed + // because the protocol version is too old for validation. + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {protocolVersion20251125}, + methodHeader: {"wrong-method"}, + nameHeader: {"wrong-name"}, + }, + messages: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "my-tool"})}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(2, &CallToolResult{Content: []Content{}}, nil)}, + }, + }) +} + // TestStreamable405AllowHeader verifies RFC 9110 ยง15.5.6 compliance: // 405 Method Not Allowed responses MUST include an Allow header. func TestStreamable405AllowHeader(t *testing.T) { @@ -2082,7 +2337,7 @@ func TestStreamableClientContextPropagation(t *testing.T) { switch req.Method { case "POST": w.Header().Set("Content-Type", "application/json") - w.Header().Set("Mcp-Session-Id", "test-session") + w.Header().Set(sessionIDHeader, "test-session") w.WriteHeader(http.StatusOK) w.Write([]byte(`{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2025-03-26","capabilities":{},"serverInfo":{"name":"test","version":"1.0"}}}`)) case "GET":