Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions cmd/docker-mcp/secret-management/secret/secretsengine.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ import (
var ErrSecretNotFound = errors.New("secret not found")

type Envelope struct {
ID string `json:"id"`
Value []byte `json:"value"`
Provider string `json:"provider"`
ID string `json:"id"`
Value []byte `json:"value"`
Provider string `json:"provider"`
Metadata map[string]string `json:"metadata,omitempty"`
}

func socketPath() string {
Expand Down
89 changes: 46 additions & 43 deletions pkg/gateway/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,23 +343,22 @@ func (g *Gateway) Run(ctx context.Context) error {
log.Log("- Starting OAuth notification monitor")
monitor := oauth.NewNotificationMonitor()
monitor.OnOAuthEvent = func(event oauth.Event) {
g.handleOAuthEvent(event)
// Route event to specific provider
g.routeEventToProvider(event)
}
monitor.Start(ctx)
}

// Start OAuth provider for each OAuth server (CE mode only)
// In Desktop mode, tokens are auto-refreshed by Secrets Engine - no polling needed
if oauth.IsCEMode() {
log.Log("- Starting OAuth provider loops (CE mode)...")
for _, serverName := range configuration.ServerNames() {
serverConfig, _, found := configuration.Find(serverName)
if !found || serverConfig == nil || !serverConfig.Spec.IsRemoteOAuthServer() {
continue
}

g.startProvider(ctx, serverName)
// Start OAuth provider for each OAuth server.
// Each provider runs in its own goroutine with dynamic timing based on token expiry.
log.Log("- Starting OAuth provider loops...")
for _, serverName := range configuration.ServerNames() {
serverConfig, _, found := configuration.Find(serverName)
if !found || serverConfig == nil || !serverConfig.Spec.IsRemoteOAuthServer() {
continue
}

g.startProvider(ctx, serverName)
}
}

Expand Down Expand Up @@ -662,49 +661,53 @@ func (g *Gateway) stopProvider(serverName string) {
}
}

// handleOAuthEvent handles SSE events from Docker Desktop's notification monitor.
// Only called in Desktop mode (the notification monitor doesn't start in CE mode).
// In CE mode, providers handle token refresh via polling - no events needed.
func (g *Gateway) handleOAuthEvent(event oauth.Event) {
// routeEventToProvider routes SSE events to the appropriate provider
func (g *Gateway) routeEventToProvider(event oauth.Event) {
g.providersMu.RLock()
_, exists := g.oauthProviders[event.Provider]
g.providersMu.RUnlock()

switch event.Type {
case oauth.EventLoginSuccess:
log.Logf("- OAuth login success for %s", event.Provider)
g.reloadOAuthServer(event.Provider)
// User just authorized - ensure provider exists
if !exists {
log.Logf("- Creating provider for %s after login", event.Provider)
g.startProvider(context.Background(), event.Provider)
}

// Always send event to trigger reload (connects server and lists tools)
// Wait briefly if we just created the provider
if !exists {
time.Sleep(100 * time.Millisecond)
}

g.providersMu.RLock()
provider, exists := g.oauthProviders[event.Provider]
g.providersMu.RUnlock()

if exists {
provider.SendEvent(event)
}

case oauth.EventTokenRefresh:
// Secrets Engine refreshed the token - invalidate cached connections
// Next request will create new connection with fresh token
log.Logf("- OAuth token refreshed for %s", event.Provider)
// Token refreshed - invalidate cached connections directly.
// Don't route to Provider (reloadFn would trigger Secrets Engine Filter → SSE loop).
// The Provider's timer handles refresh scheduling independently.
log.Logf("- OAuth token refreshed for %s, invalidating connections", event.Provider)
g.clientPool.InvalidateOAuthClients(event.Provider)

case oauth.EventLogoutSuccess:
log.Logf("- OAuth logout for %s", event.Provider)
g.clientPool.InvalidateOAuthClients(event.Provider)
// User logged out - stop provider if exists
if exists {
log.Logf("- Stopping provider for %s after logout", event.Provider)
g.stopProvider(event.Provider)
}

default:
// Other events (login-start, code-received, error) - no action needed
// Other events (login-start, code-received, error) - ignore
}
}

// reloadOAuthServer invalidates cached clients and reloads server capabilities.
// Used by Desktop mode after login success.
func (g *Gateway) reloadOAuthServer(serverName string) {
g.clientPool.InvalidateOAuthClients(serverName)

oldCaps, err := g.reloadServerCapabilities(context.Background(), serverName, nil)
if err != nil {
log.Logf("! Failed to reload OAuth server %s: %v", serverName, err)
return
}

g.capabilitiesMu.Lock()
newCaps := g.allCapabilities(serverName)
_ = g.updateServerCapabilities(serverName, oldCaps, newCaps, nil)
g.capabilitiesMu.Unlock()

log.Logf("> OAuth server %s reloaded", serverName)
}

// GetToolRegistrations returns a copy of all registered tools
// This is useful for introspection and serialization
func (g *Gateway) GetToolRegistrations() map[string]ToolRegistration {
Expand Down
67 changes: 60 additions & 7 deletions pkg/oauth/credhelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,68 @@ func (h *CredentialHelper) TokenExists(ctx context.Context, serverName string) (
}

// GetTokenStatus checks token validity and expiry for refresh scheduling.
// CE mode only - requires JSON token format with expiry metadata.
// In Desktop mode, use TokenExists() instead since Secrets Engine returns
// raw tokens without expiry information.
func (h *CredentialHelper) GetTokenStatus(_ context.Context, serverName string) (TokenStatus, error) {
if !IsCEMode() {
return TokenStatus{}, fmt.Errorf("GetTokenStatus is only available in CE mode; use TokenExists() for Desktop mode")
// Works in both CE and Desktop modes:
// - CE mode: reads token JSON from credential helper and parses expiry
// - Desktop mode: queries Secrets Engine and reads ExpiryAt from response metadata
func (h *CredentialHelper) GetTokenStatus(ctx context.Context, serverName string) (TokenStatus, error) {
if IsCEMode() {
return h.getTokenStatusCE(serverName)
}
return h.getTokenStatusDesktop(ctx, serverName)
}

// getTokenStatusDesktop retrieves token status in Desktop mode using Secrets Engine metadata.
// The Secrets Engine response includes ExpiryAt and ExpiresIn in the metadata map,
// added by Docker Desktop's OAuthCredential.Metadata() method.
func (h *CredentialHelper) getTokenStatusDesktop(ctx context.Context, serverName string) (TokenStatus, error) {
oauthID := secret.GetOAuthKey(serverName)
env, err := secret.GetSecret(ctx, oauthID)
if errors.Is(err, secret.ErrSecretNotFound) {
return TokenStatus{Valid: false}, fmt.Errorf("OAuth token not found for %s", serverName)
}
if err != nil {
return TokenStatus{Valid: false}, fmt.Errorf("failed to query Secrets Engine for %s: %w", serverName, err)
}

if string(env.Value) == "" {
return TokenStatus{Valid: false}, fmt.Errorf("empty OAuth token found for %s", serverName)
}

// Read ExpiryAt from Secrets Engine response metadata
expiryAtStr, hasExpiry := env.Metadata["ExpiryAt"]
if !hasExpiry || expiryAtStr == "" {
// No expiry metadata available - token exists but we can't determine when it expires.
// Return valid with no scheduled refresh (rely on SSE events as fallback).
log.Logf("- Token status for %s: valid=true, no expiry metadata available", serverName)
return TokenStatus{
Valid: true,
ExpiresAt: time.Time{},
NeedsRefresh: false,
}, nil
}

// CE mode: Use credential helper directly
expiresAt, err := time.Parse(time.RFC3339, expiryAtStr)
if err != nil {
return TokenStatus{Valid: false}, fmt.Errorf("failed to parse ExpiryAt metadata for %s: %w", serverName, err)
}

now := time.Now()
timeUntilExpiry := expiresAt.Sub(now)
needsRefresh := timeUntilExpiry <= 10*time.Second

log.Logf("- Token status for %s: valid=true, expires_at=%s, time_until_expiry=%v, needs_refresh=%v",
serverName, expiresAt.Format(time.RFC3339), timeUntilExpiry.Round(time.Second), needsRefresh)

return TokenStatus{
Valid: true,
ExpiresAt: expiresAt,
NeedsRefresh: needsRefresh,
}, nil
}

// getTokenStatusCE retrieves token status in CE mode using the credential helper.
// Reads the base64-encoded JSON token and parses the expiry field.
func (h *CredentialHelper) getTokenStatusCE(serverName string) (TokenStatus, error) {
dcrMgr := dcr.NewManager(h.credentialHelper, "")
client, err := dcrMgr.GetDCRClient(serverName)
if err != nil {
Expand Down
86 changes: 64 additions & 22 deletions pkg/oauth/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"golang.org/x/oauth2"

"github.com/docker/mcp-gateway/pkg/desktop"
"github.com/docker/mcp-gateway/pkg/log"
"github.com/docker/mcp-gateway/pkg/oauth/dcr"
)
Expand Down Expand Up @@ -61,33 +62,37 @@ func (p *DCRProvider) GeneratePKCE() string {
return oauth2.GenerateVerifier()
}

// Provider manages OAuth token lifecycle for a single MCP server (CE mode only).
// Polls token expiry, triggers refresh when needed, and reloads the server connection.
// In Desktop mode, Secrets Engine handles token refresh and SSE events trigger reloads.
// Provider manages OAuth token lifecycle for a single MCP server.
// This is used for background token refresh loops in the gateway.
// CE mode: refreshes tokens directly via oauth2 library, then reloads.
// Desktop mode: triggers refresh via GetOAuthApp Desktop API, then SSE events
// interrupt the timer, trigger reload, and reset retry counters.
type Provider struct {
name string
lastRefreshExpiry time.Time
refreshRetryCount int
stopOnce sync.Once
stopChan chan struct{}
eventChan chan Event
credHelper *CredentialHelper
reloadFn func(ctx context.Context, serverName string) error
}

const maxRefreshRetries = 7 // Max attempts to refresh when expiry hasn't changed

// NewProvider creates a new OAuth provider for token refresh polling
// NewProvider creates a new OAuth provider for token refresh
func NewProvider(name string, reloadFn func(context.Context, string) error) *Provider {
return &Provider{
name: name,
stopChan: make(chan struct{}),
eventChan: make(chan Event),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying to think if this chan needs any buffering just in case we get a send before a receive but I don't think that's possible here.

credHelper: NewOAuthCredentialHelper(),
reloadFn: reloadFn,
}
}

// Run starts the provider's background polling loop.
// Checks token expiry, triggers refresh when needed, and reloads server connections.
// Run starts the provider's background loop.
// Loop dynamically adjusts timing based on token expiry.
func (p *Provider) Run(ctx context.Context) {
log.Logf("- Started OAuth provider loop for %s", p.name)
defer log.Logf("- Stopped OAuth provider loop for %s", p.name)
Expand All @@ -101,8 +106,9 @@ func (p *Provider) Run(ctx context.Context) {
return
}

// Calculate wait duration based on token status
// Calculate wait duration and whether to trigger refresh
var waitDuration time.Duration
var shouldTriggerRefresh bool

if status.NeedsRefresh {
// Token needs refresh - check if expiry unchanged from last attempt
Expand All @@ -128,32 +134,63 @@ func (p *Provider) Run(ctx context.Context) {
p.name, p.refreshRetryCount, maxRefreshRetries, waitDuration)

p.lastRefreshExpiry = status.ExpiresAt

// Refresh token and reload server connection
go func() {
if err := p.refreshTokenCE(); err != nil {
log.Logf("! Token refresh failed for %s: %v", p.name, err)
return
}
// Reload server to pick up the new token
if err := p.reloadFn(ctx, p.name); err != nil {
log.Logf("! Failed to reload %s after token refresh: %v", p.name, err)
}
}()
shouldTriggerRefresh = true

} else {
// Token still valid
if status.ExpiresAt.IsZero() {
// No expiry information available — can't schedule proactive refresh.
// Fall back to SSE events (Desktop mode) for refresh notification.
log.Logf("- No token expiry info for %s, stopping provider loop (SSE events will handle refresh)", p.name)
return
}
timeUntilExpiry := time.Until(status.ExpiresAt)
waitDuration = max(0, timeUntilExpiry-10*time.Second)
log.Logf("- Token valid for %s, next check in %v", p.name, waitDuration.Round(time.Second))
shouldTriggerRefresh = false
}

// Trigger refresh if needed
if shouldTriggerRefresh {
if IsCEMode() {
// CE mode: Refresh token directly
go func() {
if err := p.refreshTokenCE(); err != nil {
log.Logf("! Token refresh failed for %s: %v", p.name, err)
}
}()
} else {
// Desktop mode: Trigger refresh via Desktop API
go func() {
authClient := desktop.NewAuthClient()
app, err := authClient.GetOAuthApp(context.Background(), p.name)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't be cancelled if the provider stops

if err != nil {
log.Logf("! GetOAuthApp failed for %s: %v", p.name, err)
return
}
if !app.Authorized {
log.Logf("! GetOAuthApp returned Authorized=false for %s", p.name)
return
}
}()
}
}

// Wait until next check, interruptible by stop signal
// Wait pattern - interruptible by login events
if waitDuration > 0 {
timer := time.NewTimer(waitDuration)
select {
case <-timer.C:
// Wait complete, continue to next iteration
// Wait complete
case event := <-p.eventChan:
timer.Stop()
log.Logf("- Provider %s received event: %s", p.name, event.Type)
if err := p.reloadFn(ctx, p.name); err != nil {
log.Logf("- Failed to reload %s after %s: %v", p.name, event.Type, err)
}
if event.Type == EventLoginSuccess {
p.refreshRetryCount = 0
p.lastRefreshExpiry = time.Time{}
}
case <-p.stopChan:
timer.Stop()
return
Expand All @@ -172,6 +209,11 @@ func (p *Provider) Stop() {
})
}

// SendEvent sends an SSE event to this provider's event channel
func (p *Provider) SendEvent(event Event) {
p.eventChan <- event
}

// refreshTokenCE refreshes an OAuth token in CE mode
// Uses the same oauth2 library refresh mechanism as Desktop
func (p *Provider) refreshTokenCE() error {
Expand Down
Loading