Skip to content

Commit 7ecec65

Browse files
authored
Cherry-pick: Proactive OAuth token refresh for Desktop mode (#407)
1 parent 768bcc5 commit 7ecec65

File tree

4 files changed

+174
-75
lines changed

4 files changed

+174
-75
lines changed

cmd/docker-mcp/secret-management/secret/secretsengine.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ import (
1616
var ErrSecretNotFound = errors.New("secret not found")
1717

1818
type Envelope struct {
19-
ID string `json:"id"`
20-
Value []byte `json:"value"`
21-
Provider string `json:"provider"`
19+
ID string `json:"id"`
20+
Value []byte `json:"value"`
21+
Provider string `json:"provider"`
22+
Metadata map[string]string `json:"metadata,omitempty"`
2223
}
2324

2425
func socketPath() string {

pkg/gateway/run.go

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -343,23 +343,22 @@ func (g *Gateway) Run(ctx context.Context) error {
343343
log.Log("- Starting OAuth notification monitor")
344344
monitor := oauth.NewNotificationMonitor()
345345
monitor.OnOAuthEvent = func(event oauth.Event) {
346-
g.handleOAuthEvent(event)
346+
// Route event to specific provider
347+
g.routeEventToProvider(event)
347348
}
348349
monitor.Start(ctx)
349350
}
350351

351-
// Start OAuth provider for each OAuth server (CE mode only)
352-
// In Desktop mode, tokens are auto-refreshed by Secrets Engine - no polling needed
353-
if oauth.IsCEMode() {
354-
log.Log("- Starting OAuth provider loops (CE mode)...")
355-
for _, serverName := range configuration.ServerNames() {
356-
serverConfig, _, found := configuration.Find(serverName)
357-
if !found || serverConfig == nil || !serverConfig.Spec.IsRemoteOAuthServer() {
358-
continue
359-
}
360-
361-
g.startProvider(ctx, serverName)
352+
// Start OAuth provider for each OAuth server.
353+
// Each provider runs in its own goroutine with dynamic timing based on token expiry.
354+
log.Log("- Starting OAuth provider loops...")
355+
for _, serverName := range configuration.ServerNames() {
356+
serverConfig, _, found := configuration.Find(serverName)
357+
if !found || serverConfig == nil || !serverConfig.Spec.IsRemoteOAuthServer() {
358+
continue
362359
}
360+
361+
g.startProvider(ctx, serverName)
363362
}
364363
}
365364

@@ -662,49 +661,53 @@ func (g *Gateway) stopProvider(serverName string) {
662661
}
663662
}
664663

665-
// handleOAuthEvent handles SSE events from Docker Desktop's notification monitor.
666-
// Only called in Desktop mode (the notification monitor doesn't start in CE mode).
667-
// In CE mode, providers handle token refresh via polling - no events needed.
668-
func (g *Gateway) handleOAuthEvent(event oauth.Event) {
664+
// routeEventToProvider routes SSE events to the appropriate provider
665+
func (g *Gateway) routeEventToProvider(event oauth.Event) {
666+
g.providersMu.RLock()
667+
_, exists := g.oauthProviders[event.Provider]
668+
g.providersMu.RUnlock()
669+
669670
switch event.Type {
670671
case oauth.EventLoginSuccess:
671-
log.Logf("- OAuth login success for %s", event.Provider)
672-
g.reloadOAuthServer(event.Provider)
672+
// User just authorized - ensure provider exists
673+
if !exists {
674+
log.Logf("- Creating provider for %s after login", event.Provider)
675+
g.startProvider(context.Background(), event.Provider)
676+
}
677+
678+
// Always send event to trigger reload (connects server and lists tools)
679+
// Wait briefly if we just created the provider
680+
if !exists {
681+
time.Sleep(100 * time.Millisecond)
682+
}
683+
684+
g.providersMu.RLock()
685+
provider, exists := g.oauthProviders[event.Provider]
686+
g.providersMu.RUnlock()
687+
688+
if exists {
689+
provider.SendEvent(event)
690+
}
673691

674692
case oauth.EventTokenRefresh:
675-
// Secrets Engine refreshed the token - invalidate cached connections
676-
// Next request will create new connection with fresh token
677-
log.Logf("- OAuth token refreshed for %s", event.Provider)
693+
// Token refreshed - invalidate cached connections directly.
694+
// Don't route to Provider (reloadFn would trigger Secrets Engine Filter → SSE loop).
695+
// The Provider's timer handles refresh scheduling independently.
696+
log.Logf("- OAuth token refreshed for %s, invalidating connections", event.Provider)
678697
g.clientPool.InvalidateOAuthClients(event.Provider)
679698

680699
case oauth.EventLogoutSuccess:
681-
log.Logf("- OAuth logout for %s", event.Provider)
682-
g.clientPool.InvalidateOAuthClients(event.Provider)
700+
// User logged out - stop provider if exists
701+
if exists {
702+
log.Logf("- Stopping provider for %s after logout", event.Provider)
703+
g.stopProvider(event.Provider)
704+
}
683705

684706
default:
685-
// Other events (login-start, code-received, error) - no action needed
707+
// Other events (login-start, code-received, error) - ignore
686708
}
687709
}
688710

689-
// reloadOAuthServer invalidates cached clients and reloads server capabilities.
690-
// Used by Desktop mode after login success.
691-
func (g *Gateway) reloadOAuthServer(serverName string) {
692-
g.clientPool.InvalidateOAuthClients(serverName)
693-
694-
oldCaps, err := g.reloadServerCapabilities(context.Background(), serverName, nil)
695-
if err != nil {
696-
log.Logf("! Failed to reload OAuth server %s: %v", serverName, err)
697-
return
698-
}
699-
700-
g.capabilitiesMu.Lock()
701-
newCaps := g.allCapabilities(serverName)
702-
_ = g.updateServerCapabilities(serverName, oldCaps, newCaps, nil)
703-
g.capabilitiesMu.Unlock()
704-
705-
log.Logf("> OAuth server %s reloaded", serverName)
706-
}
707-
708711
// GetToolRegistrations returns a copy of all registered tools
709712
// This is useful for introspection and serialization
710713
func (g *Gateway) GetToolRegistrations() map[string]ToolRegistration {

pkg/oauth/credhelper.go

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,68 @@ func (h *CredentialHelper) TokenExists(ctx context.Context, serverName string) (
146146
}
147147

148148
// GetTokenStatus checks token validity and expiry for refresh scheduling.
149-
// CE mode only - requires JSON token format with expiry metadata.
150-
// In Desktop mode, use TokenExists() instead since Secrets Engine returns
151-
// raw tokens without expiry information.
152-
func (h *CredentialHelper) GetTokenStatus(_ context.Context, serverName string) (TokenStatus, error) {
153-
if !IsCEMode() {
154-
return TokenStatus{}, fmt.Errorf("GetTokenStatus is only available in CE mode; use TokenExists() for Desktop mode")
149+
// Works in both CE and Desktop modes:
150+
// - CE mode: reads token JSON from credential helper and parses expiry
151+
// - Desktop mode: queries Secrets Engine and reads ExpiryAt from response metadata
152+
func (h *CredentialHelper) GetTokenStatus(ctx context.Context, serverName string) (TokenStatus, error) {
153+
if IsCEMode() {
154+
return h.getTokenStatusCE(serverName)
155+
}
156+
return h.getTokenStatusDesktop(ctx, serverName)
157+
}
158+
159+
// getTokenStatusDesktop retrieves token status in Desktop mode using Secrets Engine metadata.
160+
// The Secrets Engine response includes ExpiryAt and ExpiresIn in the metadata map,
161+
// added by Docker Desktop's OAuthCredential.Metadata() method.
162+
func (h *CredentialHelper) getTokenStatusDesktop(ctx context.Context, serverName string) (TokenStatus, error) {
163+
oauthID := secret.GetOAuthKey(serverName)
164+
env, err := secret.GetSecret(ctx, oauthID)
165+
if errors.Is(err, secret.ErrSecretNotFound) {
166+
return TokenStatus{Valid: false}, fmt.Errorf("OAuth token not found for %s", serverName)
167+
}
168+
if err != nil {
169+
return TokenStatus{Valid: false}, fmt.Errorf("failed to query Secrets Engine for %s: %w", serverName, err)
170+
}
171+
172+
if string(env.Value) == "" {
173+
return TokenStatus{Valid: false}, fmt.Errorf("empty OAuth token found for %s", serverName)
174+
}
175+
176+
// Read ExpiryAt from Secrets Engine response metadata
177+
expiryAtStr, hasExpiry := env.Metadata["ExpiryAt"]
178+
if !hasExpiry || expiryAtStr == "" {
179+
// No expiry metadata available - token exists but we can't determine when it expires.
180+
// Return valid with no scheduled refresh (rely on SSE events as fallback).
181+
log.Logf("- Token status for %s: valid=true, no expiry metadata available", serverName)
182+
return TokenStatus{
183+
Valid: true,
184+
ExpiresAt: time.Time{},
185+
NeedsRefresh: false,
186+
}, nil
155187
}
156188

157-
// CE mode: Use credential helper directly
189+
expiresAt, err := time.Parse(time.RFC3339, expiryAtStr)
190+
if err != nil {
191+
return TokenStatus{Valid: false}, fmt.Errorf("failed to parse ExpiryAt metadata for %s: %w", serverName, err)
192+
}
193+
194+
now := time.Now()
195+
timeUntilExpiry := expiresAt.Sub(now)
196+
needsRefresh := timeUntilExpiry <= 10*time.Second
197+
198+
log.Logf("- Token status for %s: valid=true, expires_at=%s, time_until_expiry=%v, needs_refresh=%v",
199+
serverName, expiresAt.Format(time.RFC3339), timeUntilExpiry.Round(time.Second), needsRefresh)
200+
201+
return TokenStatus{
202+
Valid: true,
203+
ExpiresAt: expiresAt,
204+
NeedsRefresh: needsRefresh,
205+
}, nil
206+
}
207+
208+
// getTokenStatusCE retrieves token status in CE mode using the credential helper.
209+
// Reads the base64-encoded JSON token and parses the expiry field.
210+
func (h *CredentialHelper) getTokenStatusCE(serverName string) (TokenStatus, error) {
158211
dcrMgr := dcr.NewManager(h.credentialHelper, "")
159212
client, err := dcrMgr.GetDCRClient(serverName)
160213
if err != nil {

pkg/oauth/provider.go

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88

99
"golang.org/x/oauth2"
1010

11+
"github.com/docker/mcp-gateway/pkg/desktop"
1112
"github.com/docker/mcp-gateway/pkg/log"
1213
"github.com/docker/mcp-gateway/pkg/oauth/dcr"
1314
)
@@ -61,33 +62,37 @@ func (p *DCRProvider) GeneratePKCE() string {
6162
return oauth2.GenerateVerifier()
6263
}
6364

64-
// Provider manages OAuth token lifecycle for a single MCP server (CE mode only).
65-
// Polls token expiry, triggers refresh when needed, and reloads the server connection.
66-
// In Desktop mode, Secrets Engine handles token refresh and SSE events trigger reloads.
65+
// Provider manages OAuth token lifecycle for a single MCP server.
66+
// This is used for background token refresh loops in the gateway.
67+
// CE mode: refreshes tokens directly via oauth2 library, then reloads.
68+
// Desktop mode: triggers refresh via GetOAuthApp Desktop API, then SSE events
69+
// interrupt the timer, trigger reload, and reset retry counters.
6770
type Provider struct {
6871
name string
6972
lastRefreshExpiry time.Time
7073
refreshRetryCount int
7174
stopOnce sync.Once
7275
stopChan chan struct{}
76+
eventChan chan Event
7377
credHelper *CredentialHelper
7478
reloadFn func(ctx context.Context, serverName string) error
7579
}
7680

7781
const maxRefreshRetries = 7 // Max attempts to refresh when expiry hasn't changed
7882

79-
// NewProvider creates a new OAuth provider for token refresh polling
83+
// NewProvider creates a new OAuth provider for token refresh
8084
func NewProvider(name string, reloadFn func(context.Context, string) error) *Provider {
8185
return &Provider{
8286
name: name,
8387
stopChan: make(chan struct{}),
88+
eventChan: make(chan Event),
8489
credHelper: NewOAuthCredentialHelper(),
8590
reloadFn: reloadFn,
8691
}
8792
}
8893

89-
// Run starts the provider's background polling loop.
90-
// Checks token expiry, triggers refresh when needed, and reloads server connections.
94+
// Run starts the provider's background loop.
95+
// Loop dynamically adjusts timing based on token expiry.
9196
func (p *Provider) Run(ctx context.Context) {
9297
log.Logf("- Started OAuth provider loop for %s", p.name)
9398
defer log.Logf("- Stopped OAuth provider loop for %s", p.name)
@@ -101,8 +106,9 @@ func (p *Provider) Run(ctx context.Context) {
101106
return
102107
}
103108

104-
// Calculate wait duration based on token status
109+
// Calculate wait duration and whether to trigger refresh
105110
var waitDuration time.Duration
111+
var shouldTriggerRefresh bool
106112

107113
if status.NeedsRefresh {
108114
// Token needs refresh - check if expiry unchanged from last attempt
@@ -128,32 +134,63 @@ func (p *Provider) Run(ctx context.Context) {
128134
p.name, p.refreshRetryCount, maxRefreshRetries, waitDuration)
129135

130136
p.lastRefreshExpiry = status.ExpiresAt
131-
132-
// Refresh token and reload server connection
133-
go func() {
134-
if err := p.refreshTokenCE(); err != nil {
135-
log.Logf("! Token refresh failed for %s: %v", p.name, err)
136-
return
137-
}
138-
// Reload server to pick up the new token
139-
if err := p.reloadFn(ctx, p.name); err != nil {
140-
log.Logf("! Failed to reload %s after token refresh: %v", p.name, err)
141-
}
142-
}()
137+
shouldTriggerRefresh = true
143138

144139
} else {
145-
// Token still valid
140+
if status.ExpiresAt.IsZero() {
141+
// No expiry information available — can't schedule proactive refresh.
142+
// Fall back to SSE events (Desktop mode) for refresh notification.
143+
log.Logf("- No token expiry info for %s, stopping provider loop (SSE events will handle refresh)", p.name)
144+
return
145+
}
146146
timeUntilExpiry := time.Until(status.ExpiresAt)
147147
waitDuration = max(0, timeUntilExpiry-10*time.Second)
148148
log.Logf("- Token valid for %s, next check in %v", p.name, waitDuration.Round(time.Second))
149+
shouldTriggerRefresh = false
150+
}
151+
152+
// Trigger refresh if needed
153+
if shouldTriggerRefresh {
154+
if IsCEMode() {
155+
// CE mode: Refresh token directly
156+
go func() {
157+
if err := p.refreshTokenCE(); err != nil {
158+
log.Logf("! Token refresh failed for %s: %v", p.name, err)
159+
}
160+
}()
161+
} else {
162+
// Desktop mode: Trigger refresh via Desktop API
163+
go func() {
164+
authClient := desktop.NewAuthClient()
165+
app, err := authClient.GetOAuthApp(context.Background(), p.name)
166+
if err != nil {
167+
log.Logf("! GetOAuthApp failed for %s: %v", p.name, err)
168+
return
169+
}
170+
if !app.Authorized {
171+
log.Logf("! GetOAuthApp returned Authorized=false for %s", p.name)
172+
return
173+
}
174+
}()
175+
}
149176
}
150177

151-
// Wait until next check, interruptible by stop signal
178+
// Wait pattern - interruptible by login events
152179
if waitDuration > 0 {
153180
timer := time.NewTimer(waitDuration)
154181
select {
155182
case <-timer.C:
156-
// Wait complete, continue to next iteration
183+
// Wait complete
184+
case event := <-p.eventChan:
185+
timer.Stop()
186+
log.Logf("- Provider %s received event: %s", p.name, event.Type)
187+
if err := p.reloadFn(ctx, p.name); err != nil {
188+
log.Logf("- Failed to reload %s after %s: %v", p.name, event.Type, err)
189+
}
190+
if event.Type == EventLoginSuccess {
191+
p.refreshRetryCount = 0
192+
p.lastRefreshExpiry = time.Time{}
193+
}
157194
case <-p.stopChan:
158195
timer.Stop()
159196
return
@@ -172,6 +209,11 @@ func (p *Provider) Stop() {
172209
})
173210
}
174211

212+
// SendEvent sends an SSE event to this provider's event channel
213+
func (p *Provider) SendEvent(event Event) {
214+
p.eventChan <- event
215+
}
216+
175217
// refreshTokenCE refreshes an OAuth token in CE mode
176218
// Uses the same oauth2 library refresh mechanism as Desktop
177219
func (p *Provider) refreshTokenCE() error {

0 commit comments

Comments
 (0)