88
99 "golang.org/x/oauth2"
1010
11+ "github.com/docker/mcp-gateway/pkg/desktop"
1112 "github.com/docker/mcp-gateway/pkg/log"
1213 "github.com/docker/mcp-gateway/pkg/oauth/dcr"
1314)
@@ -61,33 +62,37 @@ func (p *DCRProvider) GeneratePKCE() string {
6162 return oauth2 .GenerateVerifier ()
6263}
6364
64- // Provider manages OAuth token lifecycle for a single MCP server (CE mode only).
65- // Polls token expiry, triggers refresh when needed, and reloads the server connection.
66- // In Desktop mode, Secrets Engine handles token refresh and SSE events trigger reloads.
65+ // Provider manages OAuth token lifecycle for a single MCP server.
66+ // This is used for background token refresh loops in the gateway.
67+ // CE mode: refreshes tokens directly via oauth2 library, then reloads.
68+ // Desktop mode: triggers refresh via GetOAuthApp Desktop API, then SSE events
69+ // interrupt the timer, trigger reload, and reset retry counters.
6770type Provider struct {
6871 name string
6972 lastRefreshExpiry time.Time
7073 refreshRetryCount int
7174 stopOnce sync.Once
7275 stopChan chan struct {}
76+ eventChan chan Event
7377 credHelper * CredentialHelper
7478 reloadFn func (ctx context.Context , serverName string ) error
7579}
7680
7781const maxRefreshRetries = 7 // Max attempts to refresh when expiry hasn't changed
7882
79- // NewProvider creates a new OAuth provider for token refresh polling
83+ // NewProvider creates a new OAuth provider for token refresh
8084func NewProvider (name string , reloadFn func (context.Context , string ) error ) * Provider {
8185 return & Provider {
8286 name : name ,
8387 stopChan : make (chan struct {}),
88+ eventChan : make (chan Event ),
8489 credHelper : NewOAuthCredentialHelper (),
8590 reloadFn : reloadFn ,
8691 }
8792}
8893
89- // Run starts the provider's background polling loop.
90- // Checks token expiry, triggers refresh when needed, and reloads server connections .
94+ // Run starts the provider's background loop.
95+ // Loop dynamically adjusts timing based on token expiry .
9196func (p * Provider ) Run (ctx context.Context ) {
9297 log .Logf ("- Started OAuth provider loop for %s" , p .name )
9398 defer log .Logf ("- Stopped OAuth provider loop for %s" , p .name )
@@ -101,8 +106,9 @@ func (p *Provider) Run(ctx context.Context) {
101106 return
102107 }
103108
104- // Calculate wait duration based on token status
109+ // Calculate wait duration and whether to trigger refresh
105110 var waitDuration time.Duration
111+ var shouldTriggerRefresh bool
106112
107113 if status .NeedsRefresh {
108114 // Token needs refresh - check if expiry unchanged from last attempt
@@ -128,32 +134,63 @@ func (p *Provider) Run(ctx context.Context) {
128134 p .name , p .refreshRetryCount , maxRefreshRetries , waitDuration )
129135
130136 p .lastRefreshExpiry = status .ExpiresAt
131-
132- // Refresh token and reload server connection
133- go func () {
134- if err := p .refreshTokenCE (); err != nil {
135- log .Logf ("! Token refresh failed for %s: %v" , p .name , err )
136- return
137- }
138- // Reload server to pick up the new token
139- if err := p .reloadFn (ctx , p .name ); err != nil {
140- log .Logf ("! Failed to reload %s after token refresh: %v" , p .name , err )
141- }
142- }()
137+ shouldTriggerRefresh = true
143138
144139 } else {
145- // Token still valid
140+ if status .ExpiresAt .IsZero () {
141+ // No expiry information available — can't schedule proactive refresh.
142+ // Fall back to SSE events (Desktop mode) for refresh notification.
143+ log .Logf ("- No token expiry info for %s, stopping provider loop (SSE events will handle refresh)" , p .name )
144+ return
145+ }
146146 timeUntilExpiry := time .Until (status .ExpiresAt )
147147 waitDuration = max (0 , timeUntilExpiry - 10 * time .Second )
148148 log .Logf ("- Token valid for %s, next check in %v" , p .name , waitDuration .Round (time .Second ))
149+ shouldTriggerRefresh = false
150+ }
151+
152+ // Trigger refresh if needed
153+ if shouldTriggerRefresh {
154+ if IsCEMode () {
155+ // CE mode: Refresh token directly
156+ go func () {
157+ if err := p .refreshTokenCE (); err != nil {
158+ log .Logf ("! Token refresh failed for %s: %v" , p .name , err )
159+ }
160+ }()
161+ } else {
162+ // Desktop mode: Trigger refresh via Desktop API
163+ go func () {
164+ authClient := desktop .NewAuthClient ()
165+ app , err := authClient .GetOAuthApp (context .Background (), p .name )
166+ if err != nil {
167+ log .Logf ("! GetOAuthApp failed for %s: %v" , p .name , err )
168+ return
169+ }
170+ if ! app .Authorized {
171+ log .Logf ("! GetOAuthApp returned Authorized=false for %s" , p .name )
172+ return
173+ }
174+ }()
175+ }
149176 }
150177
151- // Wait until next check, interruptible by stop signal
178+ // Wait pattern - interruptible by login events
152179 if waitDuration > 0 {
153180 timer := time .NewTimer (waitDuration )
154181 select {
155182 case <- timer .C :
156- // Wait complete, continue to next iteration
183+ // Wait complete
184+ case event := <- p .eventChan :
185+ timer .Stop ()
186+ log .Logf ("- Provider %s received event: %s" , p .name , event .Type )
187+ if err := p .reloadFn (ctx , p .name ); err != nil {
188+ log .Logf ("- Failed to reload %s after %s: %v" , p .name , event .Type , err )
189+ }
190+ if event .Type == EventLoginSuccess {
191+ p .refreshRetryCount = 0
192+ p .lastRefreshExpiry = time.Time {}
193+ }
157194 case <- p .stopChan :
158195 timer .Stop ()
159196 return
@@ -172,6 +209,11 @@ func (p *Provider) Stop() {
172209 })
173210}
174211
212+ // SendEvent sends an SSE event to this provider's event channel
213+ func (p * Provider ) SendEvent (event Event ) {
214+ p .eventChan <- event
215+ }
216+
175217// refreshTokenCE refreshes an OAuth token in CE mode
176218// Uses the same oauth2 library refresh mechanism as Desktop
177219func (p * Provider ) refreshTokenCE () error {
0 commit comments