Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 75 additions & 52 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import (
"sync/atomic"
"time"

"golang.org/x/oauth2"

"github.com/modelcontextprotocol/go-sdk/auth"
internaljson "github.com/modelcontextprotocol/go-sdk/internal/json"
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
Expand Down Expand Up @@ -1793,53 +1795,24 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
req.Header.Set("Accept", "application/json, text/event-stream")
if err := c.setMCPHeaders(req); err != nil {
// Failure to set headers means that the request was not sent.
// Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr
// and permanently break the connection.
return nil, nil, fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err)
// Return the request anyway so that the caller (doWithAuth) can use it
// to trigger the OAuth Authorize flow if the failure was due to an
// expired token.
return req, nil, err
}
resp, err := c.client.Do(req)
if err != nil {
// Any error from client.Do means the request didn't reach the server.
// Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr
// and permanently break the connection.
err = fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err)
}
// Any error from client.Do means the request didn't reach the server.
// The caller handles wrapping the error with ErrRejected to ensure the
// jsonrpc2 connection isn't permanently broken.
return req, resp, err
}

req, resp, err := doRequest()
_, resp, err := c.doWithAuth(ctx, requestSummary, doRequest)
if err != nil {
return err
}

if (resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden) && c.oauthHandler != nil {
if err := c.oauthHandler.Authorize(ctx, req, resp); err != nil {
// If the caller's context was cancelled while we were running the
// authorization flow, treat the connection as failed so subsequent
// operations on it (e.g. the cancellation notify the call layer
// sends in response to ctx cancellation) short-circuit instead of
// re-invoking the OAuth handler. Otherwise the user gets prompted
// to authorize a request they have already abandoned. See #882.
//
// We check ctx.Err() rather than the error returned by Authorize,
// because the handler is user-implemented and may return an error
// that does not wrap context.Canceled (e.g. a custom sentinel or
// a fmt.Errorf with %v). The context itself is the authoritative
// source for whether the caller abandoned the request.
ctxErr := ctx.Err()
if errors.Is(ctxErr, context.Canceled) || errors.Is(ctxErr, context.DeadlineExceeded) {
c.fail(fmt.Errorf("%s: authorization cancelled: %w", requestSummary, err))
}
// Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr
// and permanently break the connection.
// Wrap the authorization error as well for client inspection.
return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err)
}
// Retry the request after successful authorization.
_, resp, err = doRequest()
if err != nil {
return err
}
// Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr
// and permanently break the connection.
// Wrap the authorization error as well for client inspection.
return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err)
}

if err := c.checkResponse(requestSummary, resp); err != nil {
Expand Down Expand Up @@ -2181,18 +2154,23 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin
return nil, ctx.Err()

case <-time.After(delay):
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil)
if err != nil {
return nil, err
}
if err := c.setMCPHeaders(req); err != nil {
return nil, err
}
if lastEventID != "" {
req.Header.Set(lastEventIDHeader, lastEventID)
doRequest := func() (*http.Request, *http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil)
if err != nil {
return nil, nil, err
}
if err := c.setMCPHeaders(req); err != nil {
return req, nil, err
}
if lastEventID != "" {
req.Header.Set(lastEventIDHeader, lastEventID)
}
req.Header.Set("Accept", "text/event-stream")
resp, err := c.client.Do(req)
return req, resp, err
}
req.Header.Set("Accept", "text/event-stream")
resp, err := c.client.Do(req)

_, resp, err := c.doWithAuth(ctx, "standalone SSE stream", doRequest)
if err != nil {
finalErr = err // Store the error and try again.
delay = calculateReconnectDelay(attempt + 1)
Expand All @@ -2208,6 +2186,51 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin
return nil, fmt.Errorf("connection aborted after %d attempts", c.maxRetries)
}

// doWithAuth executes an HTTP request, automatically handling OAuth2 token retrieval errors
// and 401/403 HTTP status codes by triggering the OAuthHandler's Authorize flow and retrying.
//
// doRequest should construct and send the HTTP request, and return the sent request (which
// may be needed for authorization), the response (if any), and any error.
func (c *streamableClientConn) doWithAuth(ctx context.Context, requestSummary string, doRequest func() (*http.Request, *http.Response, error)) (*http.Request, *http.Response, error) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's make doRequest take context.Context to avoid unintentional wrong context capture

req, resp, err := doRequest()

if c.oauthHandler == nil || req == nil {
return req, resp, err
}

var authResp *http.Response
var retrieveErr *oauth2.RetrieveError
if err != nil && errors.As(err, &retrieveErr) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if the check is too broad. Should we only be re-authorizing if retriveErr.ErrorCode == "invalid_grant"?

The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client.

https://datatracker.ietf.org/doc/html/rfc6749#section-5.2

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think you are right. I'll fix this up.

authResp = retrieveErr.Response
} else if err == nil && (resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden) {
authResp = resp
} else {
return req, resp, err
}

if authErr := c.oauthHandler.Authorize(ctx, req, authResp); authErr != nil {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking into Authorize it relies on authResp to have WWW-Authenticate set. if fallbacks to well-known won't work the whole Authorize will fail.

might be a bit hacky, but what if we handle oauth2.RetrieveError earlier by just not setting "Authorization" header and letting request fail with 401/403?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking into Authorize it relies on authResp to have WWW-Authenticate set. if fallbacks to well-known won't work the whole Authorize will fail.

I see what you are saying, but how likely/possible is this? Seems to defeat the purpose of being "well-known" 😕

might be a bit hacky, but what if we handle oauth2.RetrieveError earlier by just not setting "Authorization" header and letting request fail with 401/403?

That is an interesting idea. In fact, it might be a better solution than this PR. I'll look into it.

Copy link
Copy Markdown
Member

@yarolegovich yarolegovich Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you are saying, but how likely/possible is this? Seems to defeat the purpose of being "well-known"

I don't know, but it is possible. The spec says:

MCP servers MUST implement one of the following discovery mechanisms to provide authorization server location information to MCP clients:

  • WWW-Authenticate Header: ...
  • Well-Known URI: ...

So the failed empty-auth request approach should be more reliable. It's also a bit easier to reason about in the sense that Authorize only receives MCP (not auth) server response.

@maciej-kisiel and I also considered how oauth.Config which token refresher already has can be passed to Authorize for skipping the discovery step in re-authorization, but all the options seem not worth it (in terms of code complexity) if a single failed request can handle this edge case.

// If the caller's context was cancelled while we were running the
// authorization flow, treat the connection as failed so subsequent
// operations on it (e.g. the cancellation notify the call layer
// sends in response to ctx cancellation) short-circuit instead of
// re-invoking the OAuth handler. Otherwise the user gets prompted
// to authorize a request they have already abandoned. See #882.
//
// We check ctx.Err() rather than the error returned by Authorize,
// because the handler is user-implemented and may return an error
// that does not wrap context.Canceled (e.g. a custom sentinel or
// a fmt.Errorf with %v). The context itself is the authoritative
// source for whether the caller abandoned the request.
ctxErr := ctx.Err()
if errors.Is(ctxErr, context.Canceled) || errors.Is(ctxErr, context.DeadlineExceeded) {
c.fail(fmt.Errorf("%s: authorization cancelled: %w", requestSummary, authErr))
}
return req, nil, authErr
}
// Retry
return doRequest()
}

// Close implements the [Connection] interface.
func (c *streamableClientConn) Close() error {
c.closeOnce.Do(func() {
Expand Down
63 changes: 63 additions & 0 deletions mcp/streamable_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package mcp

import (
"context"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -1156,3 +1157,65 @@ func TestTokenInfo(t *testing.T) {
t.Errorf("got %q, want %q", g, w)
}
}

// errTestAuthorizeFailed is a sentinel error returned by
// retrieveErrorOAuthHandler.Authorize().
var errTestAuthorizeFailed = errors.New("authorize intentionally failed for test")

// retrieveErrorOAuthHandler is a mock OAuthHandler that always returns
// an oauth2.RetrieveError from its TokenSource's Token() method.
type retrieveErrorOAuthHandler struct {
authorizeCalled bool
}

func (h *retrieveErrorOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) {
return h, nil
}

func (h *retrieveErrorOAuthHandler) Token() (*oauth2.Token, error) {
return nil, &oauth2.RetrieveError{
Response: &http.Response{StatusCode: http.StatusBadRequest},
Body: []byte("test retrieve error"),
}
}

func (h *retrieveErrorOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error {
h.authorizeCalled = true
return errTestAuthorizeFailed
}

// TestStreamableClientOAuth_RetrieveError verifies that a RetrieveError from
// the OAuth token source correctly triggers the Authorize fallback flow instead
// of immediately breaking the connection.
func TestStreamableClientOAuth_RetrieveError(t *testing.T) {
ctx := context.Background()
oauthHandler := &retrieveErrorOAuthHandler{}

// Setup a dummy HTTP server. The server won't actually be reached because
// the token retrieval fails before the request is dispatched.
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(httpServer.Close)

// Configure the transport with our mock OAuth handler that returns a RetrieveError.
transport := &StreamableClientTransport{
Endpoint: httpServer.URL,
OAuthHandler: oauthHandler,
}
client := NewClient(testImpl, nil)

// Attempt to connect. The Connect call will trigger the initialization request,
// which will fail to retrieve the token and instead invoke Authorize().
_, err := client.Connect(ctx, transport, nil)

// Expect the connection to fail with a sentinel error.
if !errors.Is(err, errTestAuthorizeFailed) {
t.Fatalf("client.Connect() error = %v, want %v", err, errTestAuthorizeFailed)
}

// Double-check that Authorize() was actually invoked.
if !oauthHandler.authorizeCalled {
t.Errorf("expected Authorize to be called")
}
}
Loading