forked from docker/mcp-gateway
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprovider.go
More file actions
254 lines (221 loc) · 7.54 KB
/
provider.go
File metadata and controls
254 lines (221 loc) · 7.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
package oauth
import (
"context"
"fmt"
"sync"
"time"
"golang.org/x/oauth2"
"github.com/docker/mcp-gateway/pkg/desktop"
"github.com/docker/mcp-gateway/pkg/log"
"github.com/docker/mcp-gateway/pkg/oauth/dcr"
)
// DCRProvider represents a dynamically registered OAuth provider
// Implements public client + PKCE for security
type DCRProvider struct {
name string
config *oauth2.Config
resourceURL string // For RFC 8707 token audience binding
}
// NewDCRProvider creates a new DCR provider from a registered DCR client
func NewDCRProvider(dcrClient dcr.Client, redirectURL string) *DCRProvider {
config := &oauth2.Config{
ClientID: dcrClient.ClientID,
ClientSecret: "", // Public client - no secret
RedirectURL: redirectURL,
Endpoint: oauth2.Endpoint{
AuthURL: dcrClient.AuthorizationEndpoint,
TokenURL: dcrClient.TokenEndpoint,
},
Scopes: dcrClient.RequiredScopes,
}
return &DCRProvider{
name: dcrClient.ServerName,
config: config,
resourceURL: dcrClient.ResourceURL,
}
}
// Name returns the provider name
func (p *DCRProvider) Name() string {
return p.name
}
// Config returns the OAuth2 configuration
func (p *DCRProvider) Config() *oauth2.Config {
return p.config
}
// ResourceURL returns the resource URL for RFC 8707 token audience binding
func (p *DCRProvider) ResourceURL() string {
return p.resourceURL
}
// GeneratePKCE generates a new PKCE code verifier
// The challenge is automatically computed by oauth2 library when using S256ChallengeOption
func (p *DCRProvider) GeneratePKCE() string {
return oauth2.GenerateVerifier()
}
// Provider manages OAuth token lifecycle for a single MCP server.
// This is used for background token refresh loops in the gateway.
// CE mode: refreshes tokens directly via oauth2 library, then reloads.
// Desktop mode: triggers refresh via GetOAuthApp Desktop API, then SSE events
// interrupt the timer, trigger reload, and reset retry counters.
type Provider struct {
name string
lastRefreshExpiry time.Time
refreshRetryCount int
stopOnce sync.Once
stopChan chan struct{}
eventChan chan Event
credHelper *CredentialHelper
reloadFn func(ctx context.Context, serverName string) error
}
const maxRefreshRetries = 7 // Max attempts to refresh when expiry hasn't changed
// NewProvider creates a new OAuth provider for token refresh
func NewProvider(name string, reloadFn func(context.Context, string) error) *Provider {
return &Provider{
name: name,
stopChan: make(chan struct{}),
eventChan: make(chan Event),
credHelper: NewOAuthCredentialHelper(),
reloadFn: reloadFn,
}
}
// Run starts the provider's background loop.
// Loop dynamically adjusts timing based on token expiry.
func (p *Provider) Run(ctx context.Context) {
log.Logf("- Started OAuth provider loop for %s", p.name)
defer log.Logf("- Stopped OAuth provider loop for %s", p.name)
for {
// Check current token status
status, err := p.credHelper.GetTokenStatus(ctx, p.name)
if err != nil {
log.Logf("! Unable to get token status for %s: %v", p.name, err)
log.Logf("! Run 'docker mcp oauth authorize %s' if not yet authorized", p.name)
return
}
// Calculate wait duration and whether to trigger refresh
var waitDuration time.Duration
var shouldTriggerRefresh bool
if status.NeedsRefresh {
// Token needs refresh - check if expiry unchanged from last attempt
expiryUnchanged := !p.lastRefreshExpiry.IsZero() && status.ExpiresAt.Equal(p.lastRefreshExpiry)
if expiryUnchanged {
p.refreshRetryCount++
} else {
if p.refreshRetryCount > 0 {
log.Logf("- Token expiry updated for %s, resetting refresh count", p.name)
}
p.refreshRetryCount = 1
}
if p.refreshRetryCount > maxRefreshRetries {
log.Logf("! Token expiry unchanged after %d refresh attempts for %s", maxRefreshRetries, p.name)
return
}
// Exponential backoff: 30s, 1min, 2min, 4min, 8min...
waitDuration = time.Duration(30*(1<<(p.refreshRetryCount-1))) * time.Second
log.Logf("- Triggering token refresh for %s, attempt %d/%d, waiting %v",
p.name, p.refreshRetryCount, maxRefreshRetries, waitDuration)
p.lastRefreshExpiry = status.ExpiresAt
shouldTriggerRefresh = true
} else {
if status.ExpiresAt.IsZero() {
// No expiry information available — can't schedule proactive refresh.
// Fall back to SSE events (Desktop mode) for refresh notification.
log.Logf("- No token expiry info for %s, stopping provider loop (SSE events will handle refresh)", p.name)
return
}
timeUntilExpiry := time.Until(status.ExpiresAt)
waitDuration = max(0, timeUntilExpiry-10*time.Second)
log.Logf("- Token valid for %s, next check in %v", p.name, waitDuration.Round(time.Second))
shouldTriggerRefresh = false
}
// Trigger refresh if needed
if shouldTriggerRefresh {
if IsCEMode() {
// CE mode: Refresh token directly
go func() {
if err := p.refreshTokenCE(); err != nil {
log.Logf("! Token refresh failed for %s: %v", p.name, err)
}
}()
} else {
// Desktop mode: Trigger refresh via Desktop API
go func() {
authClient := desktop.NewAuthClient()
app, err := authClient.GetOAuthApp(context.Background(), p.name)
if err != nil {
log.Logf("! GetOAuthApp failed for %s: %v", p.name, err)
return
}
if !app.Authorized {
log.Logf("! GetOAuthApp returned Authorized=false for %s", p.name)
return
}
}()
}
}
// Wait pattern - interruptible by login events
if waitDuration > 0 {
timer := time.NewTimer(waitDuration)
select {
case <-timer.C:
// Wait complete
case event := <-p.eventChan:
timer.Stop()
log.Logf("- Provider %s received event: %s", p.name, event.Type)
if err := p.reloadFn(ctx, p.name); err != nil {
log.Logf("- Failed to reload %s after %s: %v", p.name, event.Type, err)
}
if event.Type == EventLoginSuccess {
p.refreshRetryCount = 0
p.lastRefreshExpiry = time.Time{}
}
case <-p.stopChan:
timer.Stop()
return
case <-ctx.Done():
timer.Stop()
return
}
}
}
}
// Stop signals the provider to shutdown gracefully
func (p *Provider) Stop() {
p.stopOnce.Do(func() {
close(p.stopChan)
})
}
// SendEvent sends an SSE event to this provider's event channel
func (p *Provider) SendEvent(event Event) {
p.eventChan <- event
}
// refreshTokenCE refreshes an OAuth token in CE mode
// Uses the same oauth2 library refresh mechanism as Desktop
func (p *Provider) refreshTokenCE() error {
// Create read-write credential helper for save operations
rwHelper := NewReadWriteCredentialHelper()
// Get DCR client from credential helper
dcrMgr := dcr.NewManager(rwHelper, "")
dcrClient, err := dcrMgr.GetDCRClient(p.name)
if err != nil {
return fmt.Errorf("failed to get DCR client: %w", err)
}
// Get current token and create token store
tokenStore := NewTokenStore(rwHelper)
token, err := tokenStore.Retrieve(dcrClient)
if err != nil {
return fmt.Errorf("failed to retrieve token: %w", err)
}
// Refresh token using oauth2 library
provider := NewDCRProvider(dcrClient, DefaultRedirectURI)
config := provider.Config()
// TokenSource automatically refreshes using refresh_token
refreshedToken, err := config.TokenSource(context.Background(), token).Token()
if err != nil {
return fmt.Errorf("token refresh failed: %w", err)
}
// Save refreshed token
if err := tokenStore.Save(dcrClient, refreshedToken); err != nil {
return fmt.Errorf("failed to save refreshed token: %w", err)
}
log.Logf("- Successfully refreshed token for %s", p.name)
return nil
}