diff --git a/pkg/tools/mcp/oauth.go b/pkg/tools/mcp/oauth.go index a4f0a60d1..5d661f71f 100644 --- a/pkg/tools/mcp/oauth.go +++ b/pkg/tools/mcp/oauth.go @@ -11,6 +11,7 @@ import ( "net/http" "regexp" "strings" + "sync" "time" mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" @@ -78,22 +79,28 @@ func (o *oauth) getAuthorizationServerMetadata(ctx context.Context, authServerUR if resp.StatusCode == http.StatusNotFound { // Try OpenID Connect discovery as fallback openIDURL := strings.Replace(metadataURL, "/.well-known/oauth-authorization-server", "/.well-known/openid-configuration", 1) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, openIDURL, http.NoBody) + oidcReq, err := http.NewRequestWithContext(ctx, http.MethodGet, openIDURL, http.NoBody) if err != nil { return nil, err } - req.Header.Set("Accept", "application/json") + oidcReq.Header.Set("Accept", "application/json") - resp, err := o.metadataClient.Do(req) + oidcResp, err := o.metadataClient.Do(oidcReq) if err != nil { return nil, err } - defer resp.Body.Close() + defer oidcResp.Body.Close() - if resp.StatusCode != http.StatusOK { + if oidcResp.StatusCode != http.StatusOK { // Return default metadata if all discovery fails return createDefaultMetadata(authServerURL), nil } + + var metadata AuthorizationServerMetadata + if err := json.NewDecoder(oidcResp.Body).Decode(&metadata); err != nil { + return nil, fmt.Errorf("failed to decode metadata from %s: %w", openIDURL, err) + } + return validateAndFillDefaults(&metadata, authServerURL), nil } else if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("unexpected status %d from %s", resp.StatusCode, metadataURL) } @@ -165,9 +172,19 @@ type oauthTransport struct { tokenStore OAuthTokenStore baseURL string managed bool + + // mu protects refreshFailedAt from concurrent access. + mu sync.Mutex + // refreshFailedAt tracks the last time a silent token refresh failed, + // so we avoid retrying on every request. + refreshFailedAt time.Time } func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.roundTrip(req, false) +} + +func (t *oauthTransport) roundTrip(req *http.Request, isRetry bool) (*http.Response, error) { var bodyBytes []byte if req.Body != nil && req.Body != http.NoBody { var err error @@ -190,7 +207,7 @@ func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) { return nil, err } - if resp.StatusCode == http.StatusUnauthorized { + if resp.StatusCode == http.StatusUnauthorized && !isRetry { wwwAuth := resp.Header.Get("WWW-Authenticate") if wwwAuth != "" { resp.Body.Close() @@ -204,7 +221,7 @@ func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) { req.Body = io.NopCloser(strings.NewReader(string(bodyBytes))) } - return t.RoundTrip(req) + return t.roundTrip(req, true) } } @@ -228,6 +245,15 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken { return nil } + // Avoid hammering the token endpoint if a recent refresh already failed. + const refreshBackoff = 30 * time.Second + t.mu.Lock() + failedAt := t.refreshFailedAt + t.mu.Unlock() + if !failedAt.IsZero() && time.Since(failedAt) < refreshBackoff { + return nil + } + slog.Debug("Attempting silent token refresh", "url", t.baseURL) o := &oauth{metadataClient: &http.Client{Timeout: 5 * time.Second}} @@ -240,9 +266,16 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken { newToken, err := RefreshAccessToken(ctx, metadata.TokenEndpoint, token.RefreshToken, token.ClientID, token.ClientSecret) if err != nil { slog.Debug("Token refresh failed, will require interactive auth", "error", err) + t.mu.Lock() + t.refreshFailedAt = time.Now() + t.mu.Unlock() return nil } + t.mu.Lock() + t.refreshFailedAt = time.Time{} // reset on success + t.mu.Unlock() + if err := t.tokenStore.StoreToken(t.baseURL, newToken); err != nil { slog.Warn("Failed to store refreshed token", "error", err) } @@ -265,7 +298,11 @@ func (t *oauthTransport) handleManagedOAuthFlow(ctx context.Context, authServer, resourceURL := cmp.Or(resourceMetadataFromWWWAuth(wwwAuth), authServer+"/.well-known/oauth-protected-resource") - resp, err := http.Get(resourceURL) + resourceReq, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, http.NoBody) + if err != nil { + return err + } + resp, err := http.DefaultClient.Do(resourceReq) if err != nil { return err } @@ -411,7 +448,11 @@ func (t *oauthTransport) handleUnmanagedOAuthFlow(ctx context.Context, authServe // Extract resource URL from WWW-Authenticate header resourceURL := cmp.Or(resourceMetadataFromWWWAuth(wwwAuth), authServer+"/.well-known/oauth-protected-resource") - resp, err := http.Get(resourceURL) + resourceReq, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, http.NoBody) + if err != nil { + return err + } + resp, err := http.DefaultClient.Do(resourceReq) if err != nil { return err } diff --git a/pkg/tools/mcp/oauth_login.go b/pkg/tools/mcp/oauth_login.go index e3df38fdb..b1beb1f73 100644 --- a/pkg/tools/mcp/oauth_login.go +++ b/pkg/tools/mcp/oauth_login.go @@ -31,7 +31,11 @@ func PerformOAuthLogin(ctx context.Context, serverURL string) error { // Discover protected resource metadata. resourceURL := baseURL + "/.well-known/oauth-protected-resource" - resp, err := http.Get(resourceURL) //nolint:gosec // URL is user-provided + resourceReq, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, http.NoBody) + if err != nil { + return fmt.Errorf("failed to create resource metadata request: %w", err) + } + resp, err := http.DefaultClient.Do(resourceReq) if err != nil { return fmt.Errorf("failed to fetch protected resource metadata: %w", err) } diff --git a/pkg/tools/mcp/oauth_server.go b/pkg/tools/mcp/oauth_server.go index b66f30b6a..b7c676a9f 100644 --- a/pkg/tools/mcp/oauth_server.go +++ b/pkg/tools/mcp/oauth_server.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "html" "log/slog" "net" "net/http" @@ -101,7 +102,7 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request)

%s

You can close this window.

-`, errMsg) +`, html.EscapeString(errMsg)) return }