diff --git a/mcp/streamable.go b/mcp/streamable.go index a176789e..95d9e237 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -30,6 +30,8 @@ import ( "sync/atomic" "time" + "golang.org/x/oauth2" + "github.com/modelcontextprotocol/go-sdk/auth" internaljson "github.com/modelcontextprotocol/go-sdk/internal/json" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" @@ -1793,53 +1795,24 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e req.Header.Set("Accept", "application/json, text/event-stream") if err := c.setMCPHeaders(req); err != nil { // Failure to set headers means that the request was not sent. - // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr - // and permanently break the connection. - return nil, nil, fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) + // Return the request anyway so that the caller (doWithAuth) can use it + // to trigger the OAuth Authorize flow if the failure was due to an + // expired token. + return req, nil, err } resp, err := c.client.Do(req) - if err != nil { - // Any error from client.Do means the request didn't reach the server. - // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr - // and permanently break the connection. - err = fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) - } + // Any error from client.Do means the request didn't reach the server. + // The caller handles wrapping the error with ErrRejected to ensure the + // jsonrpc2 connection isn't permanently broken. return req, resp, err } - req, resp, err := doRequest() + _, resp, err := c.doWithAuth(ctx, requestSummary, doRequest) if err != nil { - return err - } - - if (resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden) && c.oauthHandler != nil { - if err := c.oauthHandler.Authorize(ctx, req, resp); err != nil { - // If the caller's context was cancelled while we were running the - // authorization flow, treat the connection as failed so subsequent - // operations on it (e.g. the cancellation notify the call layer - // sends in response to ctx cancellation) short-circuit instead of - // re-invoking the OAuth handler. Otherwise the user gets prompted - // to authorize a request they have already abandoned. See #882. - // - // We check ctx.Err() rather than the error returned by Authorize, - // because the handler is user-implemented and may return an error - // that does not wrap context.Canceled (e.g. a custom sentinel or - // a fmt.Errorf with %v). The context itself is the authoritative - // source for whether the caller abandoned the request. - ctxErr := ctx.Err() - if errors.Is(ctxErr, context.Canceled) || errors.Is(ctxErr, context.DeadlineExceeded) { - c.fail(fmt.Errorf("%s: authorization cancelled: %w", requestSummary, err)) - } - // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr - // and permanently break the connection. - // Wrap the authorization error as well for client inspection. - return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) - } - // Retry the request after successful authorization. - _, resp, err = doRequest() - if err != nil { - return err - } + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + // Wrap the authorization error as well for client inspection. + return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) } if err := c.checkResponse(requestSummary, resp); err != nil { @@ -2181,18 +2154,23 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin return nil, ctx.Err() case <-time.After(delay): - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil) - if err != nil { - return nil, err - } - if err := c.setMCPHeaders(req); err != nil { - return nil, err - } - if lastEventID != "" { - req.Header.Set(lastEventIDHeader, lastEventID) + doRequest := func() (*http.Request, *http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil) + if err != nil { + return nil, nil, err + } + if err := c.setMCPHeaders(req); err != nil { + return req, nil, err + } + if lastEventID != "" { + req.Header.Set(lastEventIDHeader, lastEventID) + } + req.Header.Set("Accept", "text/event-stream") + resp, err := c.client.Do(req) + return req, resp, err } - req.Header.Set("Accept", "text/event-stream") - resp, err := c.client.Do(req) + + _, resp, err := c.doWithAuth(ctx, "standalone SSE stream", doRequest) if err != nil { finalErr = err // Store the error and try again. delay = calculateReconnectDelay(attempt + 1) @@ -2208,6 +2186,51 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin return nil, fmt.Errorf("connection aborted after %d attempts", c.maxRetries) } +// doWithAuth executes an HTTP request, automatically handling OAuth2 token retrieval errors +// and 401/403 HTTP status codes by triggering the OAuthHandler's Authorize flow and retrying. +// +// doRequest should construct and send the HTTP request, and return the sent request (which +// may be needed for authorization), the response (if any), and any error. +func (c *streamableClientConn) doWithAuth(ctx context.Context, requestSummary string, doRequest func() (*http.Request, *http.Response, error)) (*http.Request, *http.Response, error) { + req, resp, err := doRequest() + + if c.oauthHandler == nil || req == nil { + return req, resp, err + } + + var authResp *http.Response + var retrieveErr *oauth2.RetrieveError + if err != nil && errors.As(err, &retrieveErr) { + authResp = retrieveErr.Response + } else if err == nil && (resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden) { + authResp = resp + } else { + return req, resp, err + } + + if authErr := c.oauthHandler.Authorize(ctx, req, authResp); authErr != nil { + // If the caller's context was cancelled while we were running the + // authorization flow, treat the connection as failed so subsequent + // operations on it (e.g. the cancellation notify the call layer + // sends in response to ctx cancellation) short-circuit instead of + // re-invoking the OAuth handler. Otherwise the user gets prompted + // to authorize a request they have already abandoned. See #882. + // + // We check ctx.Err() rather than the error returned by Authorize, + // because the handler is user-implemented and may return an error + // that does not wrap context.Canceled (e.g. a custom sentinel or + // a fmt.Errorf with %v). The context itself is the authoritative + // source for whether the caller abandoned the request. + ctxErr := ctx.Err() + if errors.Is(ctxErr, context.Canceled) || errors.Is(ctxErr, context.DeadlineExceeded) { + c.fail(fmt.Errorf("%s: authorization cancelled: %w", requestSummary, authErr)) + } + return req, nil, authErr + } + // Retry + return doRequest() +} + // Close implements the [Connection] interface. func (c *streamableClientConn) Close() error { c.closeOnce.Do(func() { diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 9aac7b54..4ecfccbd 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -6,6 +6,7 @@ package mcp import ( "context" + "errors" "fmt" "io" "net/http" @@ -1156,3 +1157,65 @@ func TestTokenInfo(t *testing.T) { t.Errorf("got %q, want %q", g, w) } } + +// errTestAuthorizeFailed is a sentinel error returned by +// retrieveErrorOAuthHandler.Authorize(). +var errTestAuthorizeFailed = errors.New("authorize intentionally failed for test") + +// retrieveErrorOAuthHandler is a mock OAuthHandler that always returns +// an oauth2.RetrieveError from its TokenSource's Token() method. +type retrieveErrorOAuthHandler struct { + authorizeCalled bool +} + +func (h *retrieveErrorOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return h, nil +} + +func (h *retrieveErrorOAuthHandler) Token() (*oauth2.Token, error) { + return nil, &oauth2.RetrieveError{ + Response: &http.Response{StatusCode: http.StatusBadRequest}, + Body: []byte("test retrieve error"), + } +} + +func (h *retrieveErrorOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + h.authorizeCalled = true + return errTestAuthorizeFailed +} + +// TestStreamableClientOAuth_RetrieveError verifies that a RetrieveError from +// the OAuth token source correctly triggers the Authorize fallback flow instead +// of immediately breaking the connection. +func TestStreamableClientOAuth_RetrieveError(t *testing.T) { + ctx := context.Background() + oauthHandler := &retrieveErrorOAuthHandler{} + + // Setup a dummy HTTP server. The server won't actually be reached because + // the token retrieval fails before the request is dispatched. + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(httpServer.Close) + + // Configure the transport with our mock OAuth handler that returns a RetrieveError. + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + OAuthHandler: oauthHandler, + } + client := NewClient(testImpl, nil) + + // Attempt to connect. The Connect call will trigger the initialization request, + // which will fail to retrieve the token and instead invoke Authorize(). + _, err := client.Connect(ctx, transport, nil) + + // Expect the connection to fail with a sentinel error. + if !errors.Is(err, errTestAuthorizeFailed) { + t.Fatalf("client.Connect() error = %v, want %v", err, errTestAuthorizeFailed) + } + + // Double-check that Authorize() was actually invoked. + if !oauthHandler.authorizeCalled { + t.Errorf("expected Authorize to be called") + } +}