Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ const (
// It is the version that the client sends in the initialization request, and
// the default version used by the server.
latestProtocolVersion = protocolVersion20251125
protocolVersion20260630 = "2026-06-30"
protocolVersion20251125 = "2025-11-25"
protocolVersion20250618 = "2025-06-18"
protocolVersion20250326 = "2025-03-26"
Expand Down Expand Up @@ -343,6 +344,9 @@ func clientSessionMethod[P Params, R Result](f func(*ClientSession, context.Cont

// MCP-specific error codes.
const (
// CodeHeaderMismatch indicates that HTTP headers do not match the corresponding values
// in the request body, or that required headers are missing or malformed.
CodeHeaderMismatch = -32001
// CodeResourceNotFound indicates that a requested resource could not be found.
CodeResourceNotFound = -32002
// CodeURLElicitationRequired indicates that the server requires URL elicitation
Expand Down
28 changes: 22 additions & 6 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@ import (
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
)

const (
protocolVersionHeader = "Mcp-Protocol-Version"
sessionIDHeader = "Mcp-Session-Id"
lastEventIDHeader = "Last-Event-ID"
)

// A StreamableHTTPHandler is an http.Handler that serves streamable MCP
// sessions, as defined by the [MCP spec].
//
Expand Down Expand Up @@ -1191,6 +1185,24 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
}
}

// Validate MCP standard headers (Mcp-Method, Mcp-Name)
if !isBatch && len(incoming) == 1 {
Comment thread
guglielmo-san marked this conversation as resolved.
if err := validateMcpHeaders(req.Header, incoming[0]); err != nil {
resp := &jsonrpc.Response{
Error: jsonrpc2.NewError(CodeHeaderMismatch, err.Error()),
}
if jreq, ok := incoming[0].(*jsonrpc.Request); ok {
resp.ID = jreq.ID
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
if data, err := jsonrpc2.EncodeMessage(resp); err == nil {
w.Write(data)
}
return
}
}

// The prime and close events were added in protocol version 2025-11-25 (SEP-1699).
// Use the version from InitializeParams if this is an initialize request,
// otherwise use the protocol version header.
Expand Down Expand Up @@ -1797,6 +1809,9 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
// and permanently break the connection.
return nil, nil, fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err)
}
// Keep this after the setMCPHeaders call to ensure that the
// protocol version header is set.
setStandardHeaders(req.Header, msg)
Comment thread
maciej-kisiel marked this conversation as resolved.
resp, err := c.client.Do(req)
if err != nil {
// Any error from client.Do means the request didn't reach the server.
Expand Down Expand Up @@ -1932,6 +1947,7 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) error {
if c.sessionID != "" {
req.Header.Set(sessionIDHeader, c.sessionID)
}

return nil
}

Expand Down
2 changes: 1 addition & 1 deletion mcp/streamable_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques
key := streamableRequestKey{
httpMethod: req.Method,
sessionID: req.Header.Get(sessionIDHeader),
lastEventID: req.Header.Get("Last-Event-ID"), // TODO: extract this to a constant, like sessionIDHeader
lastEventID: req.Header.Get(lastEventIDHeader),
}
var jsonrpcReq *jsonrpc.Request
if req.Method == http.MethodPost {
Expand Down
98 changes: 98 additions & 0 deletions mcp/streamable_headers.go
Comment thread
guglielmo-san marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
// Use of this source code is governed by the license
// that can be found in the LICENSE file.

package mcp

import (
"encoding/json"
"errors"
"fmt"
"net/http"

internaljson "github.com/modelcontextprotocol/go-sdk/internal/json"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
)

const (
protocolVersionHeader = "Mcp-Protocol-Version"
sessionIDHeader = "Mcp-Session-Id"
lastEventIDHeader = "Last-Event-ID"
methodHeader = "Mcp-Method"
nameHeader = "Mcp-Name"
minVersionForStandardHeaders = protocolVersion20260630
)

func extractName(method string, params json.RawMessage) (string, bool) {
switch method {
case "tools/call":
var p CallToolParams
if err := internaljson.Unmarshal(params, &p); err == nil {
return p.Name, true
}
case "prompts/get":
var p GetPromptParams
if err := internaljson.Unmarshal(params, &p); err == nil {
return p.Name, true
}
case "resources/read":
var p ReadResourceParams
if err := internaljson.Unmarshal(params, &p); err == nil {
return p.URI, true
}
}

return "", false
}

// setStandardHeaders populates standard MCP headers.
// It requires the protocol version header to be set.
func setStandardHeaders(header http.Header, msg jsonrpc.Message) {
Comment thread
guglielmo-san marked this conversation as resolved.
if msg == nil {
return
}
if header.Get(protocolVersionHeader) == "" || header.Get(protocolVersionHeader) < minVersionForStandardHeaders {
return
}

switch msg := msg.(type) {
case *jsonrpc.Request:
header.Set(methodHeader, msg.Method)
if name, ok := extractName(msg.Method, msg.Params); ok {
header.Set(nameHeader, name)
}
}
}

func validateMcpHeaders(header http.Header, msg jsonrpc.Message) error {
protocolVersion := header.Get(protocolVersionHeader)
if protocolVersion == "" || protocolVersion < minVersionForStandardHeaders {
return nil
}

switch msg := msg.(type) {
case *jsonrpc.Request:
methodInHeader := header.Get(methodHeader)
if methodInHeader == "" {
return errors.New("missing required Mcp-Method header")
}
if methodInHeader != msg.Method {
return fmt.Errorf("header mismatch: Mcp-Method header value '%s' does not match body value '%s'", methodInHeader, msg.Method)
}

if msg.Method == "tools/call" || msg.Method == "resources/read" || msg.Method == "prompts/get" {
nameInHeader := header.Get(nameHeader)
if nameInHeader == "" {
return fmt.Errorf("missing required Mcp-Name header for method %q", msg.Method)
}
nameInBody, ok := extractName(msg.Method, msg.Params)
if !ok {
return fmt.Errorf("failed to extract name from parameters for method %q", msg.Method)
}
if nameInHeader != nameInBody {
return fmt.Errorf("header mismatch: Mcp-Name header value '%s' does not match body value '%s'", nameInHeader, nameInBody)
}
}
}
return nil
}
Loading
Loading