Skip to content

Commit b85574f

Browse files
mcp: Enable legacy initialize fallback on any error (#1014)
To guarantee a better compatibility with legacy server, enable the fallback to the legacy initialize rpc in case of any error returned by the server, in case of `server/discover` request.
1 parent 6333aa7 commit b85574f

5 files changed

Lines changed: 84 additions & 148 deletions

File tree

internal/jsonrpc2/conn.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ func (c *Connection) write(ctx context.Context, msg Message) error {
735735

736736
// For cancelled or rejected requests, we don't set the writeErr (which would
737737
// break the connection). They can just be returned to the caller.
738-
if err != nil && ctx.Err() == nil && !errors.Is(err, ErrRejected) && !errors.Is(err, ErrUnsupportedProtocolVersion) && !errors.Is(err, ErrMethodNotFound) {
738+
if err != nil && ctx.Err() == nil && !errors.Is(err, ErrRejected) {
739739
// The call to Write failed, and since ctx.Err() is nil we can't attribute
740740
// the failure (even indirectly) to Context cancellation. The writer appears
741741
// to be broken, and future writes are likely to also fail.

mcp/client.go

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -303,13 +303,10 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp
303303
return cs, nil
304304
}
305305

306-
var werr *jsonrpc.Error
307-
if !errors.As(err, &werr) {
308-
return nil, err
309-
}
310306
// Try to negotiate a mutually supported version if the server
311307
// reports an UnsupportedProtocolVersionError with a supported version.
312-
if werr.Code == CodeUnsupportedProtocolVersion && werr.Data != nil {
308+
var werr *jsonrpc.Error
309+
if errors.As(err, &werr) && werr.Code == CodeUnsupportedProtocolVersion && werr.Data != nil {
313310
var data UnsupportedProtocolVersionData
314311
if err := json.Unmarshal(werr.Data, &data); err == nil {
315312
if negotiatedVersion := negotiateMutuallySupportedVersion(data.Supported); negotiatedVersion != "" && negotiatedVersion >= protocolVersion20260630 {
@@ -318,13 +315,11 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp
318315
}
319316
}
320317
}
321-
// MethodNotFound and UnsupportedProtocolVersion trigger a fallback to legacy initialize.
322-
if werr.Code == jsonrpc.CodeMethodNotFound || werr.Code == CodeUnsupportedProtocolVersion {
323-
break
324-
}
325-
return nil, err
318+
// Per the spec, fall back to the legacy initialize handshake on any
319+
// non-modern error from server/discover.
320+
break
326321
}
327-
// Fallback to the legacy initialize handshake with the legacy protocol version.
322+
// Use the latest legacy protocol version for the fallback initialize.
328323
protocolVersion = protocolVersion20251125
329324
}
330325

mcp/client_test.go

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,6 @@ func TestClientConnectDiscover(t *testing.T) {
657657
// Returning (nil, nil) means "let the default stub handle it" (which
658658
// returns ErrMethodNotFound).
659659
discoverHandler func() (Result, error)
660-
wantConnectErr bool
661660
// wantInitialize is true if the legacy initialize handshake should
662661
// have run (i.e. discover signaled "not supported").
663662
wantInitialize bool
@@ -710,16 +709,6 @@ func TestClientConnectDiscover(t *testing.T) {
710709
wantInitialize: true,
711710
wantVersion: latestProtocolVersion,
712711
},
713-
{
714-
name: "unexpected error propagates and aborts Connect",
715-
discoverHandler: func() (Result, error) {
716-
return nil, &jsonrpc.Error{
717-
Code: jsonrpc.CodeInternalError,
718-
Message: "boom",
719-
}
720-
},
721-
wantConnectErr: true,
722-
},
723712
}
724713

725714
for _, tc := range tests {
@@ -754,19 +743,6 @@ func TestClientConnectDiscover(t *testing.T) {
754743

755744
c := NewClient(testImpl, nil)
756745
cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630})
757-
if tc.wantConnectErr {
758-
if err == nil {
759-
_ = cs.Close()
760-
t.Fatal("Connect succeeded, want error")
761-
}
762-
if !gotDiscover.Load() {
763-
t.Error("server did not receive server/discover")
764-
}
765-
if gotInitialize.Load() {
766-
t.Error("server received initialize but discover should have aborted Connect")
767-
}
768-
return
769-
}
770746
if err != nil {
771747
t.Fatalf("Connect: %v", err)
772748
}
@@ -1061,51 +1037,6 @@ func TestInMemory_E2E_DiscoverFallback_UnsupportedProtocolVersion(t *testing.T)
10611037
}
10621038
}
10631039

1064-
// TestInMemory_E2E_DiscoverPropagatesOtherErrors verifies that an unrelated
1065-
// error from the discover handler aborts Connect and does NOT silently
1066-
// fall back.
1067-
func TestInMemory_E2E_DiscoverPropagatesOtherErrors(t *testing.T) {
1068-
ctx := context.Background()
1069-
1070-
orig := supportedProtocolVersions
1071-
supportedProtocolVersions = append([]string{protocolVersion20260630}, slices.Clone(orig)...)
1072-
t.Cleanup(func() { supportedProtocolVersions = orig })
1073-
1074-
var sawInitialize atomic.Bool
1075-
server := NewServer(&Implementation{Name: "broken-server", Version: "v1"}, nil)
1076-
server.AddReceivingMiddleware(func(next MethodHandler) MethodHandler {
1077-
return func(ctx context.Context, method string, req Request) (Result, error) {
1078-
switch method {
1079-
case methodDiscover:
1080-
return nil, &jsonrpc.Error{
1081-
Code: jsonrpc.CodeInternalError,
1082-
Message: "boom",
1083-
}
1084-
case methodInitialize:
1085-
sawInitialize.Store(true)
1086-
}
1087-
return next(ctx, method, req)
1088-
}
1089-
})
1090-
1091-
ct, st := NewInMemoryTransports()
1092-
ss, err := server.Connect(ctx, st, nil)
1093-
if err != nil {
1094-
t.Fatalf("server.Connect: %v", err)
1095-
}
1096-
defer ss.Close()
1097-
1098-
client := NewClient(&Implementation{Name: "new-client", Version: "v1"}, nil)
1099-
cs, err := client.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630})
1100-
if err == nil {
1101-
_ = cs.Close()
1102-
t.Fatal("Connect succeeded; want propagated discover error")
1103-
}
1104-
if sawInitialize.Load() {
1105-
t.Error("server received initialize; Connect should have aborted on the discover error")
1106-
}
1107-
}
1108-
11091040
// TestClientConnectDiscover_UnsupportedVersionNegotiation verifies the
11101041
// per SEP-2575 Version Negotiation Flow: when the client probes server/discover
11111042
// with a protocol version the server doesn't implement, the server's

mcp/streamable.go

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2103,12 +2103,14 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
21032103
}
21042104

21052105
var requestSummary string
2106+
var requestMethod string
21062107
var forCall *jsonrpc.Request
21072108
switch msg := msg.(type) {
21082109
case *jsonrpc.Request:
21092110
requestSummary = fmt.Sprintf("sending %q", msg.Method)
21102111
if msg.IsCall() {
21112112
forCall = msg
2113+
requestMethod = msg.Method
21122114
}
21132115
case *jsonrpc.Response:
21142116
requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID)
@@ -2184,10 +2186,14 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
21842186
}
21852187

21862188
if err := c.checkResponse(ctx, requestSummary, resp); err != nil {
2187-
// Only fail the connection for non-transient errors.
2188-
// Transient errors (wrapped with ErrRejected) should not break the connection.
2189-
// ErrMethodNotFound and ErrUnsupportedProtocolVersion should not break the connection as they trigger the initialize fallback.
2190-
if !errors.Is(err, jsonrpc2.ErrRejected) && !errors.Is(err, jsonrpc2.ErrMethodNotFound) && !errors.Is(err, jsonrpc2.ErrUnsupportedProtocolVersion) {
2189+
if requestMethod == methodDiscover {
2190+
// Wrap the discover failure with ErrRejected so the jsonrpc2 layer
2191+
// doesn't set writeErr, which would prevent the legacy initialize
2192+
// fallback from succeeding on the same connection.
2193+
err = fmt.Errorf("%w: %w", err, jsonrpc2.ErrRejected)
2194+
} else if !errors.Is(err, jsonrpc2.ErrRejected) {
2195+
// Only fail the connection for non-transient errors.
2196+
// Transient errors (wrapped with ErrRejected) should not break the connection.
21912197
c.fail(err)
21922198
}
21932199
return err
@@ -2400,10 +2406,8 @@ func (c *streamableClientConn) checkResponse(ctx context.Context, requestSummary
24002406
}
24012407
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
24022408
// By default, always try to decode the body and surface the underlying
2403-
// JSON-RPC error (or detect vPre servers that reject "server/discover"
2404-
// as an unsupported method with a plain HTTP 400) regardless of the
2405-
// negotiated protocol version. Setting MCPGODEBUG=noprotocolerrorbody=1
2406-
// restores the previous behavior.
2409+
// JSON-RPC error.
2410+
// Setting MCPGODEBUG=noprotocolerrorbody=1 restores the previous behavior.
24072411
if noprotocolerrorbody == "1" {
24082412
return fmt.Errorf("%s: %v", requestSummary, http.StatusText(resp.StatusCode))
24092413
}
@@ -2412,12 +2416,6 @@ func (c *streamableClientConn) checkResponse(ctx context.Context, requestSummary
24122416
if response, ok := msg.(*jsonrpc.Response); ok && response.Error != nil {
24132417
return fmt.Errorf("%s: %w: %v", requestSummary, response.Error, http.StatusText(resp.StatusCode))
24142418
}
2415-
if strings.Contains(string(body), fmt.Sprintf("%s: %q unsupported", jsonrpc2.ErrNotHandled, methodDiscover)) {
2416-
return fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrMethodNotFound, http.StatusText(resp.StatusCode))
2417-
}
2418-
if strings.Contains(string(body), "Unsupported protocol version") {
2419-
return fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrUnsupportedProtocolVersion, http.StatusText(resp.StatusCode))
2420-
}
24212419
return fmt.Errorf("%s: %v", requestSummary, http.StatusText(resp.StatusCode))
24222420
}
24232421
return nil

mcp/streamable_client_test.go

Lines changed: 65 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,59 +1470,6 @@ func TestStreamableClientConnect_DiscoverUnsupportedVersion(t *testing.T) {
14701470
}
14711471
}
14721472

1473-
// TestStreamableClientConnect_DiscoverPropagatesOtherErrors verifies that
1474-
// Client.Connect does NOT fall back to initialize when server/discover
1475-
// returns an unrelated JSON-RPC error (here, CodeInternalError). The Connect
1476-
// call should fail with the propagated error rather than masking it.
1477-
func TestStreamableClientConnect_DiscoverPropagatesOtherErrors(t *testing.T) {
1478-
ctx := context.Background()
1479-
1480-
var sawInitialize atomic.Bool
1481-
fake := &fakeStreamableServer{
1482-
t: t,
1483-
responses: fakeResponses{
1484-
{"POST", "", methodDiscover, ""}: {
1485-
header: header{"Content-Type": "application/json"},
1486-
wantProtocolVersion: protocolVersion20260630,
1487-
responseFunc: func(r *jsonrpc.Request) (string, int) {
1488-
return jsonBody(t, &jsonrpc.Response{
1489-
ID: r.ID,
1490-
Error: &jsonrpc.Error{
1491-
Code: jsonrpc.CodeInternalError,
1492-
Message: "boom",
1493-
},
1494-
}), http.StatusOK
1495-
},
1496-
},
1497-
{"POST", "", methodInitialize, ""}: {
1498-
responseFunc: func(r *jsonrpc.Request) (string, int) {
1499-
sawInitialize.Store(true)
1500-
return jsonBody(t, &jsonrpc.Response{ID: r.ID, Result: mustMarshal(initResult)}), http.StatusOK
1501-
},
1502-
header: header{
1503-
"Content-Type": "application/json",
1504-
sessionIDHeader: "fallback",
1505-
},
1506-
optional: true,
1507-
},
1508-
},
1509-
}
1510-
1511-
httpServer := httptest.NewServer(fake)
1512-
defer httpServer.Close()
1513-
1514-
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
1515-
client := NewClient(testImpl, nil)
1516-
session, err := client.Connect(ctx, transport, &ClientSessionOptions{protocolVersion: protocolVersion20260630})
1517-
if err == nil {
1518-
_ = session.Close()
1519-
t.Fatal("Connect succeeded; want propagated error")
1520-
}
1521-
if sawInitialize.Load() {
1522-
t.Error("server received initialize; Connect should have aborted on the discover error")
1523-
}
1524-
}
1525-
15261473
// TestStreamableClientConnect_DiscoverMethodNotFoundVPre verifies that
15271474
// Client.Connect falls back to the legacy initialize handshake when a
15281475
// pre-SEP-2575 (vPre) server rejects server/discover.
@@ -1635,3 +1582,68 @@ func TestStreamableClientConnect_DiscoverUnsupportedVersionVPre(t *testing.T) {
16351582
t.Errorf("InitializeResult.ProtocolVersion = %q, want %q (initialize fallback)", got, latestProtocolVersion)
16361583
}
16371584
}
1585+
1586+
// TestStreamableClientConnect_DiscoverUnsupportedVersionNegotiation verifies that
1587+
// when Client.Connect over a streamable transport receives an
1588+
// UnsupportedProtocolVersion error containing Data.Supported, it negotiates a
1589+
// mutually supported version and retries server/discover.
1590+
func TestStreamableClientConnect_DiscoverUnsupportedVersionNegotiation(t *testing.T) {
1591+
ctx := context.Background()
1592+
1593+
oldSupported := supportedProtocolVersions
1594+
supportedProtocolVersions = append([]string{protocolVersion20260630}, supportedProtocolVersions...)
1595+
t.Cleanup(func() {
1596+
supportedProtocolVersions = oldSupported
1597+
})
1598+
1599+
const unsupportedClientVersion = "2099-12-31"
1600+
1601+
var discoverCalls atomic.Int32
1602+
1603+
fake := &fakeStreamableServer{
1604+
t: t,
1605+
responses: fakeResponses{
1606+
{"POST", "", methodDiscover, ""}: {
1607+
header: header{
1608+
"Content-Type": "application/json",
1609+
},
1610+
responseFunc: func(r *jsonrpc.Request) (string, int) {
1611+
n := discoverCalls.Add(1)
1612+
if n == 1 {
1613+
data, _ := json.Marshal(UnsupportedProtocolVersionData{
1614+
Supported: []string{protocolVersion20260630},
1615+
})
1616+
respMsg := &jsonrpc.Response{
1617+
ID: r.ID,
1618+
Error: &jsonrpc.Error{
1619+
Code: CodeUnsupportedProtocolVersion,
1620+
Message: "unsupported protocol version",
1621+
Data: data,
1622+
},
1623+
}
1624+
return jsonBody(t, respMsg), http.StatusOK
1625+
}
1626+
return jsonBody(t, &jsonrpc.Response{ID: r.ID, Result: mustMarshal(discoverResult)}), http.StatusOK
1627+
},
1628+
},
1629+
},
1630+
}
1631+
1632+
httpServer := httptest.NewServer(fake)
1633+
defer httpServer.Close()
1634+
1635+
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
1636+
client := NewClient(testImpl, nil)
1637+
session, err := client.Connect(ctx, transport, &ClientSessionOptions{protocolVersion: unsupportedClientVersion})
1638+
if err != nil {
1639+
t.Fatalf("Connect: %v", err)
1640+
}
1641+
defer session.Close()
1642+
1643+
if got, want := discoverCalls.Load(), int32(2); got != want {
1644+
t.Errorf("discover call count = %d, want %d", got, want)
1645+
}
1646+
if got := session.InitializeResult().ProtocolVersion; got != protocolVersion20260630 {
1647+
t.Errorf("InitializeResult.ProtocolVersion = %q, want %q", got, protocolVersion20260630)
1648+
}
1649+
}

0 commit comments

Comments
 (0)