Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions pkg/tools/mcp/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net/http"
"regexp"
"strings"
"sync"
"time"

mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
Expand Down Expand Up @@ -78,22 +79,28 @@ func (o *oauth) getAuthorizationServerMetadata(ctx context.Context, authServerUR
if resp.StatusCode == http.StatusNotFound {
// Try OpenID Connect discovery as fallback
openIDURL := strings.Replace(metadataURL, "/.well-known/oauth-authorization-server", "/.well-known/openid-configuration", 1)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, openIDURL, http.NoBody)
oidcReq, err := http.NewRequestWithContext(ctx, http.MethodGet, openIDURL, http.NoBody)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
oidcReq.Header.Set("Accept", "application/json")

resp, err := o.metadataClient.Do(req)
oidcResp, err := o.metadataClient.Do(oidcReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
defer oidcResp.Body.Close()

if resp.StatusCode != http.StatusOK {
if oidcResp.StatusCode != http.StatusOK {
// Return default metadata if all discovery fails
return createDefaultMetadata(authServerURL), nil
}

var metadata AuthorizationServerMetadata
if err := json.NewDecoder(oidcResp.Body).Decode(&metadata); err != nil {
return nil, fmt.Errorf("failed to decode metadata from %s: %w", openIDURL, err)
}
return validateAndFillDefaults(&metadata, authServerURL), nil
} else if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status %d from %s", resp.StatusCode, metadataURL)
}
Expand Down Expand Up @@ -165,9 +172,19 @@ type oauthTransport struct {
tokenStore OAuthTokenStore
baseURL string
managed bool

// mu protects refreshFailedAt from concurrent access.
mu sync.Mutex
// refreshFailedAt tracks the last time a silent token refresh failed,
// so we avoid retrying on every request.
refreshFailedAt time.Time
}

func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.roundTrip(req, false)
}

func (t *oauthTransport) roundTrip(req *http.Request, isRetry bool) (*http.Response, error) {
var bodyBytes []byte
if req.Body != nil && req.Body != http.NoBody {
var err error
Expand All @@ -190,7 +207,7 @@ func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, err
}

if resp.StatusCode == http.StatusUnauthorized {
if resp.StatusCode == http.StatusUnauthorized && !isRetry {
wwwAuth := resp.Header.Get("WWW-Authenticate")
if wwwAuth != "" {
resp.Body.Close()
Expand All @@ -204,7 +221,7 @@ func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
}

return t.RoundTrip(req)
return t.roundTrip(req, true)
}
}

Expand All @@ -228,6 +245,15 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken {
return nil
}

// Avoid hammering the token endpoint if a recent refresh already failed.
const refreshBackoff = 30 * time.Second
t.mu.Lock()
failedAt := t.refreshFailedAt
t.mu.Unlock()
if !failedAt.IsZero() && time.Since(failedAt) < refreshBackoff {
return nil
}

slog.Debug("Attempting silent token refresh", "url", t.baseURL)

o := &oauth{metadataClient: &http.Client{Timeout: 5 * time.Second}}
Expand All @@ -240,9 +266,16 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken {
newToken, err := RefreshAccessToken(ctx, metadata.TokenEndpoint, token.RefreshToken, token.ClientID, token.ClientSecret)
if err != nil {
slog.Debug("Token refresh failed, will require interactive auth", "error", err)
t.mu.Lock()
t.refreshFailedAt = time.Now()
t.mu.Unlock()
return nil
}

t.mu.Lock()
t.refreshFailedAt = time.Time{} // reset on success
t.mu.Unlock()

if err := t.tokenStore.StoreToken(t.baseURL, newToken); err != nil {
slog.Warn("Failed to store refreshed token", "error", err)
}
Expand All @@ -265,7 +298,11 @@ func (t *oauthTransport) handleManagedOAuthFlow(ctx context.Context, authServer,

resourceURL := cmp.Or(resourceMetadataFromWWWAuth(wwwAuth), authServer+"/.well-known/oauth-protected-resource")

resp, err := http.Get(resourceURL)
resourceReq, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, http.NoBody)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(resourceReq)
if err != nil {
return err
}
Expand Down Expand Up @@ -411,7 +448,11 @@ func (t *oauthTransport) handleUnmanagedOAuthFlow(ctx context.Context, authServe
// Extract resource URL from WWW-Authenticate header
resourceURL := cmp.Or(resourceMetadataFromWWWAuth(wwwAuth), authServer+"/.well-known/oauth-protected-resource")

resp, err := http.Get(resourceURL)
resourceReq, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, http.NoBody)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(resourceReq)
if err != nil {
return err
}
Expand Down
6 changes: 5 additions & 1 deletion pkg/tools/mcp/oauth_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ func PerformOAuthLogin(ctx context.Context, serverURL string) error {

// Discover protected resource metadata.
resourceURL := baseURL + "/.well-known/oauth-protected-resource"
resp, err := http.Get(resourceURL) //nolint:gosec // URL is user-provided
resourceReq, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, http.NoBody)
if err != nil {
return fmt.Errorf("failed to create resource metadata request: %w", err)
}
resp, err := http.DefaultClient.Do(resourceReq)
if err != nil {
return fmt.Errorf("failed to fetch protected resource metadata: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/tools/mcp/oauth_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"html"
"log/slog"
"net"
"net/http"
Expand Down Expand Up @@ -101,7 +102,7 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request)
<p>%s</p>
<p>You can close this window.</p>
</body>
</html>`, errMsg)
</html>`, html.EscapeString(errMsg))
return
}

Expand Down
Loading