-
Notifications
You must be signed in to change notification settings - Fork 417
mcp: handle oauth2.RetrieveError in client authorization retry logic #909
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,8 @@ import ( | |
| "sync/atomic" | ||
| "time" | ||
|
|
||
| "golang.org/x/oauth2" | ||
|
|
||
| "github.com/modelcontextprotocol/go-sdk/auth" | ||
| internaljson "github.com/modelcontextprotocol/go-sdk/internal/json" | ||
| "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" | ||
|
|
@@ -1793,53 +1795,24 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e | |
| req.Header.Set("Accept", "application/json, text/event-stream") | ||
| if err := c.setMCPHeaders(req); err != nil { | ||
| // Failure to set headers means that the request was not sent. | ||
| // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr | ||
| // and permanently break the connection. | ||
| return nil, nil, fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) | ||
| // Return the request anyway so that the caller (doWithAuth) can use it | ||
| // to trigger the OAuth Authorize flow if the failure was due to an | ||
| // expired token. | ||
| return req, nil, err | ||
| } | ||
| resp, err := c.client.Do(req) | ||
| if err != nil { | ||
| // Any error from client.Do means the request didn't reach the server. | ||
| // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr | ||
| // and permanently break the connection. | ||
| err = fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) | ||
| } | ||
| // Any error from client.Do means the request didn't reach the server. | ||
| // The caller handles wrapping the error with ErrRejected to ensure the | ||
| // jsonrpc2 connection isn't permanently broken. | ||
| return req, resp, err | ||
| } | ||
|
|
||
| req, resp, err := doRequest() | ||
| _, resp, err := c.doWithAuth(ctx, requestSummary, doRequest) | ||
| if err != nil { | ||
| return err | ||
| } | ||
|
|
||
| if (resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden) && c.oauthHandler != nil { | ||
| if err := c.oauthHandler.Authorize(ctx, req, resp); err != nil { | ||
| // If the caller's context was cancelled while we were running the | ||
| // authorization flow, treat the connection as failed so subsequent | ||
| // operations on it (e.g. the cancellation notify the call layer | ||
| // sends in response to ctx cancellation) short-circuit instead of | ||
| // re-invoking the OAuth handler. Otherwise the user gets prompted | ||
| // to authorize a request they have already abandoned. See #882. | ||
| // | ||
| // We check ctx.Err() rather than the error returned by Authorize, | ||
| // because the handler is user-implemented and may return an error | ||
| // that does not wrap context.Canceled (e.g. a custom sentinel or | ||
| // a fmt.Errorf with %v). The context itself is the authoritative | ||
| // source for whether the caller abandoned the request. | ||
| ctxErr := ctx.Err() | ||
| if errors.Is(ctxErr, context.Canceled) || errors.Is(ctxErr, context.DeadlineExceeded) { | ||
| c.fail(fmt.Errorf("%s: authorization cancelled: %w", requestSummary, err)) | ||
| } | ||
| // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr | ||
| // and permanently break the connection. | ||
| // Wrap the authorization error as well for client inspection. | ||
| return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) | ||
| } | ||
| // Retry the request after successful authorization. | ||
| _, resp, err = doRequest() | ||
| if err != nil { | ||
| return err | ||
| } | ||
| // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr | ||
| // and permanently break the connection. | ||
| // Wrap the authorization error as well for client inspection. | ||
| return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) | ||
| } | ||
|
|
||
| if err := c.checkResponse(requestSummary, resp); err != nil { | ||
|
|
@@ -2181,18 +2154,23 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin | |
| return nil, ctx.Err() | ||
|
|
||
| case <-time.After(delay): | ||
| req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| if err := c.setMCPHeaders(req); err != nil { | ||
| return nil, err | ||
| } | ||
| if lastEventID != "" { | ||
| req.Header.Set(lastEventIDHeader, lastEventID) | ||
| doRequest := func() (*http.Request, *http.Response, error) { | ||
| req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.url, nil) | ||
| if err != nil { | ||
| return nil, nil, err | ||
| } | ||
| if err := c.setMCPHeaders(req); err != nil { | ||
| return req, nil, err | ||
| } | ||
| if lastEventID != "" { | ||
| req.Header.Set(lastEventIDHeader, lastEventID) | ||
| } | ||
| req.Header.Set("Accept", "text/event-stream") | ||
| resp, err := c.client.Do(req) | ||
| return req, resp, err | ||
| } | ||
| req.Header.Set("Accept", "text/event-stream") | ||
| resp, err := c.client.Do(req) | ||
|
|
||
| _, resp, err := c.doWithAuth(ctx, "standalone SSE stream", doRequest) | ||
| if err != nil { | ||
| finalErr = err // Store the error and try again. | ||
| delay = calculateReconnectDelay(attempt + 1) | ||
|
|
@@ -2208,6 +2186,51 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin | |
| return nil, fmt.Errorf("connection aborted after %d attempts", c.maxRetries) | ||
| } | ||
|
|
||
| // doWithAuth executes an HTTP request, automatically handling OAuth2 token retrieval errors | ||
| // and 401/403 HTTP status codes by triggering the OAuthHandler's Authorize flow and retrying. | ||
| // | ||
| // doRequest should construct and send the HTTP request, and return the sent request (which | ||
| // may be needed for authorization), the response (if any), and any error. | ||
| func (c *streamableClientConn) doWithAuth(ctx context.Context, requestSummary string, doRequest func() (*http.Request, *http.Response, error)) (*http.Request, *http.Response, error) { | ||
| req, resp, err := doRequest() | ||
|
|
||
| if c.oauthHandler == nil || req == nil { | ||
| return req, resp, err | ||
| } | ||
|
|
||
| var authResp *http.Response | ||
| var retrieveErr *oauth2.RetrieveError | ||
| if err != nil && errors.As(err, &retrieveErr) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering if the check is too broad. Should we only be re-authorizing if
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I think you are right. I'll fix this up. |
||
| authResp = retrieveErr.Response | ||
| } else if err == nil && (resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden) { | ||
| authResp = resp | ||
| } else { | ||
| return req, resp, err | ||
| } | ||
|
|
||
| if authErr := c.oauthHandler.Authorize(ctx, req, authResp); authErr != nil { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looking into Authorize it relies on authResp to have might be a bit hacky, but what if we handle oauth2.RetrieveError earlier by just not setting "Authorization" header and letting request fail with 401/403?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I see what you are saying, but how likely/possible is this? Seems to defeat the purpose of being "well-known" 😕
That is an interesting idea. In fact, it might be a better solution than this PR. I'll look into it.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't know, but it is possible. The spec says:
So the failed empty-auth request approach should be more reliable. It's also a bit easier to reason about in the sense that @maciej-kisiel and I also considered how |
||
| // If the caller's context was cancelled while we were running the | ||
| // authorization flow, treat the connection as failed so subsequent | ||
| // operations on it (e.g. the cancellation notify the call layer | ||
| // sends in response to ctx cancellation) short-circuit instead of | ||
| // re-invoking the OAuth handler. Otherwise the user gets prompted | ||
| // to authorize a request they have already abandoned. See #882. | ||
| // | ||
| // We check ctx.Err() rather than the error returned by Authorize, | ||
| // because the handler is user-implemented and may return an error | ||
| // that does not wrap context.Canceled (e.g. a custom sentinel or | ||
| // a fmt.Errorf with %v). The context itself is the authoritative | ||
| // source for whether the caller abandoned the request. | ||
| ctxErr := ctx.Err() | ||
| if errors.Is(ctxErr, context.Canceled) || errors.Is(ctxErr, context.DeadlineExceeded) { | ||
| c.fail(fmt.Errorf("%s: authorization cancelled: %w", requestSummary, authErr)) | ||
| } | ||
| return req, nil, authErr | ||
| } | ||
| // Retry | ||
| return doRequest() | ||
| } | ||
|
|
||
| // Close implements the [Connection] interface. | ||
| func (c *streamableClientConn) Close() error { | ||
| c.closeOnce.Do(func() { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's make doRequest take
context.Contextto avoid unintentional wrong context capture