Skip to content

Commit c0283e5

Browse files
committed
fix: OAuth token security and bug fixes
- Fix XSS vulnerability in OAuth callback error page by escaping HTML output - Fix infinite recursion in RoundTrip on persistent 401 responses - Fix OpenID fallback response never being decoded due to variable shadowing - Replace http.Get with context-aware requests for proper cancellation - Add 30s backoff on failed token refresh to avoid hammering token endpoint Assisted-By: docker-agent
1 parent 7c6204b commit c0283e5

File tree

3 files changed

+57
-11
lines changed

3 files changed

+57
-11
lines changed

pkg/tools/mcp/oauth.go

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"net/http"
1212
"regexp"
1313
"strings"
14+
"sync"
1415
"time"
1516

1617
mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
@@ -78,22 +79,28 @@ func (o *oauth) getAuthorizationServerMetadata(ctx context.Context, authServerUR
7879
if resp.StatusCode == http.StatusNotFound {
7980
// Try OpenID Connect discovery as fallback
8081
openIDURL := strings.Replace(metadataURL, "/.well-known/oauth-authorization-server", "/.well-known/openid-configuration", 1)
81-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, openIDURL, http.NoBody)
82+
oidcReq, err := http.NewRequestWithContext(ctx, http.MethodGet, openIDURL, http.NoBody)
8283
if err != nil {
8384
return nil, err
8485
}
85-
req.Header.Set("Accept", "application/json")
86+
oidcReq.Header.Set("Accept", "application/json")
8687

87-
resp, err := o.metadataClient.Do(req)
88+
oidcResp, err := o.metadataClient.Do(oidcReq)
8889
if err != nil {
8990
return nil, err
9091
}
91-
defer resp.Body.Close()
92+
defer oidcResp.Body.Close()
9293

93-
if resp.StatusCode != http.StatusOK {
94+
if oidcResp.StatusCode != http.StatusOK {
9495
// Return default metadata if all discovery fails
9596
return createDefaultMetadata(authServerURL), nil
9697
}
98+
99+
var metadata AuthorizationServerMetadata
100+
if err := json.NewDecoder(oidcResp.Body).Decode(&metadata); err != nil {
101+
return nil, fmt.Errorf("failed to decode metadata from %s: %w", openIDURL, err)
102+
}
103+
return validateAndFillDefaults(&metadata, authServerURL), nil
97104
} else if resp.StatusCode != http.StatusOK {
98105
return nil, fmt.Errorf("unexpected status %d from %s", resp.StatusCode, metadataURL)
99106
}
@@ -165,9 +172,19 @@ type oauthTransport struct {
165172
tokenStore OAuthTokenStore
166173
baseURL string
167174
managed bool
175+
176+
// mu protects refreshFailedAt from concurrent access.
177+
mu sync.Mutex
178+
// refreshFailedAt tracks the last time a silent token refresh failed,
179+
// so we avoid retrying on every request.
180+
refreshFailedAt time.Time
168181
}
169182

170183
func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
184+
return t.roundTrip(req, false)
185+
}
186+
187+
func (t *oauthTransport) roundTrip(req *http.Request, isRetry bool) (*http.Response, error) {
171188
var bodyBytes []byte
172189
if req.Body != nil && req.Body != http.NoBody {
173190
var err error
@@ -190,7 +207,7 @@ func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
190207
return nil, err
191208
}
192209

193-
if resp.StatusCode == http.StatusUnauthorized {
210+
if resp.StatusCode == http.StatusUnauthorized && !isRetry {
194211
wwwAuth := resp.Header.Get("WWW-Authenticate")
195212
if wwwAuth != "" {
196213
resp.Body.Close()
@@ -204,7 +221,7 @@ func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
204221
req.Body = io.NopCloser(strings.NewReader(string(bodyBytes)))
205222
}
206223

207-
return t.RoundTrip(req)
224+
return t.roundTrip(req, true)
208225
}
209226
}
210227

@@ -228,6 +245,15 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken {
228245
return nil
229246
}
230247

248+
// Avoid hammering the token endpoint if a recent refresh already failed.
249+
const refreshBackoff = 30 * time.Second
250+
t.mu.Lock()
251+
failedAt := t.refreshFailedAt
252+
t.mu.Unlock()
253+
if !failedAt.IsZero() && time.Since(failedAt) < refreshBackoff {
254+
return nil
255+
}
256+
231257
slog.Debug("Attempting silent token refresh", "url", t.baseURL)
232258

233259
o := &oauth{metadataClient: &http.Client{Timeout: 5 * time.Second}}
@@ -240,9 +266,16 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken {
240266
newToken, err := RefreshAccessToken(ctx, metadata.TokenEndpoint, token.RefreshToken, token.ClientID, token.ClientSecret)
241267
if err != nil {
242268
slog.Debug("Token refresh failed, will require interactive auth", "error", err)
269+
t.mu.Lock()
270+
t.refreshFailedAt = time.Now()
271+
t.mu.Unlock()
243272
return nil
244273
}
245274

275+
t.mu.Lock()
276+
t.refreshFailedAt = time.Time{} // reset on success
277+
t.mu.Unlock()
278+
246279
if err := t.tokenStore.StoreToken(t.baseURL, newToken); err != nil {
247280
slog.Warn("Failed to store refreshed token", "error", err)
248281
}
@@ -265,7 +298,11 @@ func (t *oauthTransport) handleManagedOAuthFlow(ctx context.Context, authServer,
265298

266299
resourceURL := cmp.Or(resourceMetadataFromWWWAuth(wwwAuth), authServer+"/.well-known/oauth-protected-resource")
267300

268-
resp, err := http.Get(resourceURL)
301+
resourceReq, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, http.NoBody)
302+
if err != nil {
303+
return err
304+
}
305+
resp, err := http.DefaultClient.Do(resourceReq)
269306
if err != nil {
270307
return err
271308
}
@@ -411,7 +448,11 @@ func (t *oauthTransport) handleUnmanagedOAuthFlow(ctx context.Context, authServe
411448
// Extract resource URL from WWW-Authenticate header
412449
resourceURL := cmp.Or(resourceMetadataFromWWWAuth(wwwAuth), authServer+"/.well-known/oauth-protected-resource")
413450

414-
resp, err := http.Get(resourceURL)
451+
resourceReq, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, http.NoBody)
452+
if err != nil {
453+
return err
454+
}
455+
resp, err := http.DefaultClient.Do(resourceReq)
415456
if err != nil {
416457
return err
417458
}

pkg/tools/mcp/oauth_login.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ func PerformOAuthLogin(ctx context.Context, serverURL string) error {
3131

3232
// Discover protected resource metadata.
3333
resourceURL := baseURL + "/.well-known/oauth-protected-resource"
34-
resp, err := http.Get(resourceURL) //nolint:gosec // URL is user-provided
34+
resourceReq, err := http.NewRequestWithContext(ctx, http.MethodGet, resourceURL, http.NoBody)
35+
if err != nil {
36+
return fmt.Errorf("failed to create resource metadata request: %w", err)
37+
}
38+
resp, err := http.DefaultClient.Do(resourceReq)
3539
if err != nil {
3640
return fmt.Errorf("failed to fetch protected resource metadata: %w", err)
3741
}

pkg/tools/mcp/oauth_server.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"html"
78
"log/slog"
89
"net"
910
"net/http"
@@ -101,7 +102,7 @@ func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request)
101102
<p>%s</p>
102103
<p>You can close this window.</p>
103104
</body>
104-
</html>`, errMsg)
105+
</html>`, html.EscapeString(errMsg))
105106
return
106107
}
107108

0 commit comments

Comments
 (0)