diff --git a/cmd/docker-mcp/oauth/auth.go b/cmd/docker-mcp/oauth/auth.go index 6d5cb6c99..78a008fee 100644 --- a/cmd/docker-mcp/oauth/auth.go +++ b/cmd/docker-mcp/oauth/auth.go @@ -3,20 +3,51 @@ package oauth import ( "context" "fmt" + "net/url" "time" + "golang.org/x/oauth2" + + "github.com/docker/mcp-gateway/pkg/catalog" "github.com/docker/mcp-gateway/pkg/desktop" pkgoauth "github.com/docker/mcp-gateway/pkg/oauth" + "github.com/docker/mcp-gateway/pkg/oauth/dcr" ) +// Authorize performs OAuth authorization for a server, routing to the +// appropriate flow based on the per-server mode (Desktop, CE, or Community). func Authorize(ctx context.Context, app string, scopes string) error { - // Check if running in CE mode - if pkgoauth.IsCEMode() { + isCommunity, err := lookupIsCommunity(ctx, app) + if err != nil { + // Server not in catalog -- fall back to legacy global routing + // so existing servers without catalog entries still work. + if pkgoauth.IsCEMode() { + return authorizeCEMode(ctx, app, scopes) + } + return authorizeDesktopMode(ctx, app, scopes) + } + + switch pkgoauth.DetermineMode(ctx, isCommunity) { + case pkgoauth.ModeCE: return authorizeCEMode(ctx, app, scopes) + case pkgoauth.ModeCommunity: + return authorizeCommunityMode(ctx, app, scopes) + default: // ModeDesktop + return authorizeDesktopMode(ctx, app, scopes) } +} - // Desktop mode - existing implementation - return authorizeDesktopMode(ctx, app, scopes) +// lookupIsCommunity checks the catalog to determine if a server is a community server. +func lookupIsCommunity(ctx context.Context, serverName string) (bool, error) { + cat, err := catalog.GetWithOptions(ctx, true, nil) + if err != nil { + return false, err + } + server, found := cat.Servers[serverName] + if !found { + return false, fmt.Errorf("server %s not found in catalog", serverName) + } + return server.IsCommunity(), nil } // authorizeDesktopMode handles OAuth via Docker Desktop (existing behavior) @@ -115,3 +146,140 @@ func authorizeCEMode(ctx context.Context, serverName string, scopes string) erro return nil } + +// authorizeCommunityMode handles OAuth for community servers in Desktop mode. +// Uses the Gateway OAuth flow (localhost callback, PKCE) with docker pass storage. +func authorizeCommunityMode(ctx context.Context, serverName string, scopes string) error { + fmt.Printf("Starting OAuth authorization for %s (community)...\n", serverName) + + // Validate docker pass is available (required for community mode) + if err := desktop.CheckHasDockerPass(ctx); err != nil { + return fmt.Errorf("docker pass required for community server OAuth: %w", err) + } + + // Step 1: Ensure DCR client is registered in docker pass + fmt.Printf("Checking DCR registration...\n") + dcrClient, err := pkgoauth.GetDCRClientFromDockerPass(ctx, serverName) + if err != nil || dcrClient.ClientID == "" { + // No DCR client in docker pass -- perform discovery and registration + dcrClient, err = dcr.DiscoverAndRegister(ctx, serverName, scopes, pkgoauth.DefaultRedirectURI) + if err != nil { + return fmt.Errorf("DCR registration failed: %w", err) + } + if err := pkgoauth.SaveDCRClientToDockerPass(ctx, serverName, dcrClient); err != nil { + return fmt.Errorf("failed to save DCR client: %w", err) + } + } + + // Step 2: Create callback server + callbackServer, err := pkgoauth.NewCallbackServer() + if err != nil { + return fmt.Errorf("failed to create callback server: %w", err) + } + + // Start callback server in background + go func() { + if err := callbackServer.Start(); err != nil { + fmt.Printf("Callback server error: %v\n", err) + } + }() + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := callbackServer.Shutdown(shutdownCtx); err != nil { + fmt.Printf("Warning: failed to shutdown callback server: %v\n", err) + } + }() + + // Step 3: Build authorization URL with PKCE + fmt.Printf("Generating authorization URL...\n") + + provider := pkgoauth.NewDCRProvider(dcrClient, pkgoauth.DefaultRedirectURI) + verifier := provider.GeneratePKCE() + + stateManager := pkgoauth.NewStateManager() + baseState := stateManager.Generate(serverName, verifier) + + // Encode callback port in state for mcp-oauth proxy routing + callbackURL := callbackServer.URL() + parsedCallback, err := url.Parse(callbackURL) + if err != nil { + return fmt.Errorf("invalid callback URL: %w", err) + } + port := parsedCallback.Port() + if port == "" { + return fmt.Errorf("callback URL missing port") + } + state := fmt.Sprintf("mcp-gateway:%s:%s", port, baseState) + + config := provider.Config() + + scopesList := []string{} + if scopes != "" { + scopesList = []string{scopes} + } + if len(scopesList) > 0 { + config.Scopes = scopesList + } + + opts := []oauth2.AuthCodeOption{ + oauth2.AccessTypeOffline, + oauth2.S256ChallengeOption(verifier), + } + if provider.ResourceURL() != "" { + opts = append(opts, oauth2.SetAuthURLParam("resource", provider.ResourceURL())) + } + + authURL := config.AuthCodeURL(state, opts...) + + // Step 4: Display authorization URL + fmt.Printf("Please visit this URL to authorize:\n\n %s\n\n", authURL) + + // Step 5: Wait for callback + fmt.Printf("Waiting for authorization callback on http://localhost:%d/callback...\n", callbackServer.Port()) + + timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + code, callbackState, err := callbackServer.Wait(timeoutCtx) + if err != nil { + return fmt.Errorf("failed to receive callback: %w", err) + } + + // Validate the returned state to prevent CSRF attacks. + // The mcp-oauth proxy strips the "mcp-gateway:PORT:" prefix and passes + // the bare UUID to our localhost callback, so callbackState is the UUID + // that stateManager.Generate() returned. + validatedServer, validatedVerifier, err := stateManager.Validate(callbackState) + if err != nil { + return fmt.Errorf("OAuth state validation failed: %w", err) + } + if validatedServer != serverName { + return fmt.Errorf("OAuth state mismatch: expected server %q, got %q", serverName, validatedServer) + } + + // Step 6: Exchange code for token + fmt.Printf("Exchanging authorization code for access token...\n") + + exchangeOpts := []oauth2.AuthCodeOption{ + oauth2.VerifierOption(validatedVerifier), + } + if provider.ResourceURL() != "" { + exchangeOpts = append(exchangeOpts, oauth2.SetAuthURLParam("resource", provider.ResourceURL())) + } + + token, err := config.Exchange(ctx, code, exchangeOpts...) + if err != nil { + return fmt.Errorf("token exchange failed: %w", err) + } + + // Step 7: Store token in docker pass + if err := pkgoauth.SaveTokenToDockerPass(ctx, serverName, token); err != nil { + return fmt.Errorf("failed to store token: %w", err) + } + + fmt.Printf("Authorization successful! Token stored securely.\n") + fmt.Printf("You can now use: docker mcp server start %s\n", serverName) + + return nil +} diff --git a/cmd/docker-mcp/oauth/revoke.go b/cmd/docker-mcp/oauth/revoke.go index 448b1ae9a..5059c914a 100644 --- a/cmd/docker-mcp/oauth/revoke.go +++ b/cmd/docker-mcp/oauth/revoke.go @@ -4,22 +4,35 @@ import ( "context" "fmt" + "github.com/docker/mcp-gateway/cmd/docker-mcp/secret-management/secret" "github.com/docker/mcp-gateway/pkg/db" "github.com/docker/mcp-gateway/pkg/desktop" pkgoauth "github.com/docker/mcp-gateway/pkg/oauth" "github.com/docker/mcp-gateway/pkg/workingset" ) +// Revoke revokes OAuth access for a server, routing to the appropriate flow +// based on the per-server mode (Desktop, CE, or Community). func Revoke(ctx context.Context, app string) error { fmt.Printf("Revoking OAuth access for %s...\n", app) - // Check if CE mode - if pkgoauth.IsCEMode() { - return revokeCEMode(ctx, app) + isCommunity, err := lookupIsCommunity(ctx, app) + if err != nil { + // Server not in catalog -- fall back to legacy global routing. + if pkgoauth.IsCEMode() { + return revokeCEMode(ctx, app) + } + return revokeDesktopMode(ctx, app) } - // Desktop mode - existing implementation - return revokeDesktopMode(ctx, app) + switch pkgoauth.DetermineMode(ctx, isCommunity) { + case pkgoauth.ModeCE: + return revokeCEMode(ctx, app) + case pkgoauth.ModeCommunity: + return revokeCommunityMode(ctx, app) + default: // ModeDesktop + return revokeDesktopMode(ctx, app) + } } // revokeDesktopMode handles revoke via Docker Desktop (existing behavior) @@ -64,3 +77,22 @@ func revokeCEMode(ctx context.Context, app string) error { fmt.Printf("OAuth access revoked for %s\n", app) return nil } + +// revokeCommunityMode handles revoke for community servers in Desktop mode. +// Deletes the OAuth token and DCR client from docker pass. +func revokeCommunityMode(ctx context.Context, app string) error { + // Delete OAuth token from docker pass + if err := secret.DeleteOAuthToken(ctx, app); err != nil { + // Token might not exist, continue to DCR deletion + fmt.Printf("Note: %v\n", err) + } + + // Delete DCR client from docker pass (soft failure -- entry may not exist + // if authorize was never completed or was already revoked) + if err := secret.DeleteDCRClient(ctx, app); err != nil { + fmt.Printf("Note: %v\n", err) + } + + fmt.Printf("OAuth access revoked for %s\n", app) + return nil +} diff --git a/pkg/gateway/mcpadd.go b/pkg/gateway/mcpadd.go index 0f1fe4de8..823075d72 100644 --- a/pkg/gateway/mcpadd.go +++ b/pkg/gateway/mcpadd.go @@ -18,6 +18,7 @@ import ( "github.com/docker/mcp-gateway/pkg/desktop" "github.com/docker/mcp-gateway/pkg/log" "github.com/docker/mcp-gateway/pkg/oauth" + "github.com/docker/mcp-gateway/pkg/oauth/dcr" "github.com/docker/mcp-gateway/pkg/oci" "github.com/docker/mcp-gateway/pkg/policy" ) @@ -433,10 +434,21 @@ func shortenURL(ctx context.Context, longURL string) (string, error) { return response.Link, nil } -// addRemoteOAuthServer handles the OAuth setup for a remote OAuth server -// It registers the provider, starts it, and handles authorization through elicitation or direct URL -// Returns the text message for the CallToolResult +// getRemoteOAuthServerStatus handles the OAuth setup for a remote OAuth server. +// It registers the provider, starts it, and handles authorization through +// elicitation or direct URL. Routes per-server based on DetermineMode: +// - ModeDesktop: registers with Desktop API, uses PostOAuthApp +// - ModeCE: uses credential helper for DCR/tokens +// - ModeCommunity: uses docker pass for DCR/tokens +// +// Returns (authorized bool, message string). func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName string, req *mcp.CallToolRequest, shouldSendTools bool) (bool, string) { + // Determine per-server mode + serverConfig, _, found := g.configuration.Find(serverName) + isCommunity := found && serverConfig != nil && serverConfig.Spec.IsCommunity() + mode := oauth.DetermineMode(ctx, isCommunity) + useGatewayOAuth := oauth.ShouldUseGatewayOAuth(ctx, isCommunity) + // Check if provider already exists g.providersMu.RLock() _, providerExists := g.oauthProviders[serverName] @@ -444,32 +456,37 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str // Only register and start provider if it doesn't already exist if !providerExists { - // Register DCR client with DD so user can authorize - if err := oauth.RegisterProviderForLazySetup(ctx, serverName); err != nil { - // Fallback: try dynamic discovery for community servers without oauth.providers - if serverConfig, _, found := g.configuration.Find(serverName); found && serverConfig.Spec.Remote.URL != "" { - if err := oauth.RegisterProviderForDynamicDiscovery(ctx, serverName, serverConfig.Spec.Remote.URL); err != nil { + if useGatewayOAuth { + // Gateway-owned OAuth (CE or Community mode) + g.registerGatewayOAuthDCR(ctx, serverName, mode) + + // Verify DCR exists in the appropriate backend + if !g.gatewayDCRExists(ctx, serverName, mode) { + return true, "" // Server doesn't require OAuth + } + + g.startProvider(ctx, serverName, mode) + } else { + // Desktop mode: register with Desktop API (existing behavior) + if err := oauth.RegisterProviderForLazySetup(ctx, serverName); err != nil { + if found && serverConfig.Spec.Remote.URL != "" { + if err := oauth.RegisterProviderForDynamicDiscovery(ctx, serverName, serverConfig.Spec.Remote.URL); err != nil { + log.Logf("Warning: Failed to register OAuth provider for %s: %v", serverName, err) + } + } else { log.Logf("Warning: Failed to register OAuth provider for %s: %v", serverName, err) } - } else { - log.Logf("Warning: Failed to register OAuth provider for %s: %v", serverName, err) } - } - // Verify DCR entry was created — dynamic discovery may have found no OAuth requirement. - // Distinguish "not found" (server doesn't need OAuth) from transient API errors. - authClient := desktop.NewAuthClient() - if _, err := authClient.GetDCRClient(ctx, serverName); err != nil { - if strings.Contains(err.Error(), "HTTP 404") { - return true, "" // Server doesn't require OAuth + // Verify DCR entry was created via Desktop API + authClient := desktop.NewAuthClient() + if _, err := authClient.GetDCRClient(ctx, serverName); err != nil { + if strings.Contains(err.Error(), "HTTP 404") { + return true, "" // Server doesn't require OAuth + } + log.Logf("Warning: Failed to verify DCR entry for %s (may be transient): %v", serverName, err) + return true, "" // Fail open } - log.Logf("Warning: Failed to verify DCR entry for %s (may be transient): %v", serverName, err) - return true, "" // Fail open — avoid blocking the add flow on transient errors - } - - // Start provider (CE mode only - Desktop mode doesn't need polling) - if oauth.IsCEMode() { - g.startProvider(ctx, serverName, oauth.ModeCE) } } @@ -496,9 +513,15 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str log.Logf("Warning: Failed to elicit authorization response for %s: %v", serverName, err) return false, "Client rejected eliciation to authorize" } else if elicitResult.Action == "accept" && elicitResult.Content != nil { - // Check if user authorized if authorize, ok := elicitResult.Content["authorize"].(bool); ok && authorize { - // User agreed to authorize, call the OAuth authorize function + if useGatewayOAuth { + // Gateway-owned OAuth: direct the user to the CLI authorize command. + // The tool handler should not block waiting for browser auth completion. + return false, fmt.Sprintf( + "Successfully added server '%s'. To complete authorization, run:\n docker mcp oauth authorize %s\n\nAfter authorizing, reconnect your agent to the MCP gateway.", + serverName, serverName) + } + // Desktop mode: trigger OAuth via Desktop API (existing behavior) client := desktop.NewAuthClient() authResponse, err := client.PostOAuthApp(ctx, serverName, "", false) if err != nil { @@ -518,10 +541,9 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str // Check if user is already authorized by checking if token exists (only if provider exists) if providerExists { - credHelper := oauth.NewOAuthCredentialHelper() + credHelper := oauth.NewOAuthCredentialHelperWithMode(mode) exists, err := credHelper.TokenExists(ctx, serverName) if err == nil && exists { - // User is already authorized, skip the OAuth URL generation if shouldSendTools { return true, fmt.Sprintf("You will need to authorize this server with: docker mcp oauth authorize %s.\n After authorizing, reconnect your agent to the MCP gateway.", serverName) } @@ -529,25 +551,28 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str } } - // Client doesn't support elicitations, get the login link and include it in the response + // Client doesn't support elicitations -- provide authorize instructions + if useGatewayOAuth { + // Gateway-owned OAuth: direct to CLI command + return false, fmt.Sprintf( + "Successfully added server '%s'. You will need to authorize this server with: docker mcp oauth authorize %s\n After authorizing, reconnect your agent to the MCP gateway.", + serverName, serverName) + } + + // Desktop mode: get the login link via Desktop API (existing behavior) client := desktop.NewAuthClient() - // Set context flag to enable disableAutoOpen parameter ctxWithFlag := context.WithValue(ctx, contextkeys.OAuthInterceptorEnabledKey, true) - // disable auto-open authResponse, err := client.PostOAuthApp(ctxWithFlag, serverName, "", true) if err != nil { log.Logf("Warning: Failed to get OAuth URL for %s: %v", serverName, err) return false, "Unable to get OAuth URL" } else if authResponse.BrowserURL != "" { - // Try to shorten the URL using Bitly shortURL, err := shortenURL(ctx, authResponse.BrowserURL) var displayLink string if err != nil { - // If shortening fails, use the original URL log.Logf("Warning: Failed to shorten URL for %s: %v", serverName, err) displayLink = fmt.Sprintf("[Click here to authorize](%s)", authResponse.BrowserURL) } else { - // Use the shortened URL in the markdown link displayLink = fmt.Sprintf("[Click here to authorize](%s)", shortURL) } @@ -556,3 +581,57 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str return false, fmt.Sprintf("Successfully added server '%s'. You will need to authorize this server with: docker mcp oauth authorize %s", serverName, serverName) } + +// registerGatewayOAuthDCR registers a DCR client for Gateway-owned OAuth modes +// (CE or Community). For Community mode, stores in docker pass. For CE mode, +// stores in the credential helper. +func (g *Gateway) registerGatewayOAuthDCR(ctx context.Context, serverName string, mode oauth.Mode) { + switch mode { + case oauth.ModeCommunity: + // Check docker pass availability + if err := desktop.CheckHasDockerPass(ctx); err != nil { + log.Logf("Warning: docker pass unavailable for community server %s: %v", serverName, err) + return + } + + // Check if DCR client already exists in docker pass + if dcrClient, err := oauth.GetDCRClientFromDockerPass(ctx, serverName); err == nil && dcrClient.ClientID != "" { + return // Already registered + } + + // Perform discovery and registration, save to docker pass + dcrClient, err := dcr.DiscoverAndRegister(ctx, serverName, "", oauth.DefaultRedirectURI) + if err != nil { + log.Logf("Warning: DCR registration failed for community server %s: %v", serverName, err) + return + } + if err := oauth.SaveDCRClientToDockerPass(ctx, serverName, dcrClient); err != nil { + log.Logf("Warning: Failed to save DCR client for %s: %v", serverName, err) + } + + case oauth.ModeCE: + // CE mode: use credential helper via Manager + credHelper := oauth.NewReadWriteCredentialHelper() + manager := oauth.NewManager(credHelper) + if err := manager.EnsureDCRClient(ctx, serverName, ""); err != nil { + log.Logf("Warning: DCR registration failed for CE server %s: %v", serverName, err) + } + } +} + +// gatewayDCRExists checks if a DCR client exists for Gateway-owned OAuth modes. +// Returns false if the server doesn't require OAuth. +func (g *Gateway) gatewayDCRExists(ctx context.Context, serverName string, mode oauth.Mode) bool { + switch mode { + case oauth.ModeCommunity: + dcrClient, err := oauth.GetDCRClientFromDockerPass(ctx, serverName) + return err == nil && dcrClient.ClientID != "" + case oauth.ModeCE: + credHelper := oauth.NewReadWriteCredentialHelper() + dcrMgr := dcr.NewManager(credHelper, "") + client, err := dcrMgr.GetDCRClient(serverName) + return err == nil && client.ClientID != "" + default: + return false + } +} diff --git a/pkg/oauth/dcr/manager.go b/pkg/oauth/dcr/manager.go index c7db7917e..44fb69c4f 100644 --- a/pkg/oauth/dcr/manager.go +++ b/pkg/oauth/dcr/manager.go @@ -41,12 +41,30 @@ func (m *Manager) GetDCRClient(serverName string) (Client, error) { // PerformDiscoveryAndRegistration executes OAuth discovery and DCR for a server // This is called when no DCR client exists or when it needs re-registration func (m *Manager) PerformDiscoveryAndRegistration(ctx context.Context, serverName string, scopes string) error { + dcrClient, err := DiscoverAndRegister(ctx, serverName, scopes, m.redirectURI) + if err != nil { + return err + } + + if err := m.credentials.SaveClient(serverName, dcrClient); err != nil { + return fmt.Errorf("saving DCR client for %s: %w", serverName, err) + } + + log.Logf("- Completed DCR for: %s", serverName) + return nil +} + +// DiscoverAndRegister performs OAuth discovery and DCR for a server, returning +// the resulting Client without persisting it. Callers are responsible for +// storing the client in the appropriate backend (credential helper, docker pass, +// etc.). This is the storage-agnostic core of PerformDiscoveryAndRegistration. +func DiscoverAndRegister(ctx context.Context, serverName string, scopes string, redirectURI string) (Client, error) { log.Logf("- Performing OAuth discovery and DCR for: %s", serverName) // Get server URL from catalog serverURL, err := getServerURL(ctx, serverName) if err != nil { - return fmt.Errorf("getting server URL: %w", err) + return Client{}, fmt.Errorf("getting server URL: %w", err) } // Perform OAuth discovery (RFC 9728, RFC 8414) @@ -54,7 +72,7 @@ func (m *Manager) PerformDiscoveryAndRegistration(ctx context.Context, serverNam ctx = oauth.WithLogger(ctx, &logger{}) discovery, err := oauth.DiscoverOAuthRequirements(ctx, serverURL) if err != nil { - return fmt.Errorf("discovering OAuth requirements for %s: %w", serverName, err) + return Client{}, fmt.Errorf("discovering OAuth requirements for %s: %w", serverName, err) } log.Logf("- Discovery successful for: %s", serverName) @@ -66,14 +84,13 @@ func (m *Manager) PerformDiscoveryAndRegistration(ctx context.Context, serverNam } // Perform Dynamic Client Registration (RFC 7591) with our redirect URI - creds, err := oauth.PerformDCR(ctx, discovery, serverName, m.redirectURI) + creds, err := oauth.PerformDCR(ctx, discovery, serverName, redirectURI) if err != nil { - return fmt.Errorf("registering DCR client for %s: %w", serverName, err) + return Client{}, fmt.Errorf("registering DCR client for %s: %w", serverName, err) } log.Logf("- Registration successful for: %s, clientID: %s", serverName, creds.ClientID) - // Create and save DCR client - dcrClient := Client{ + return Client{ ServerName: serverName, ProviderName: serverName, // For DCR, provider name = server name ClientID: creds.ClientID, @@ -84,14 +101,7 @@ func (m *Manager) PerformDiscoveryAndRegistration(ctx context.Context, serverNam ScopesSupported: discovery.ScopesSupported, RequiredScopes: discovery.Scopes, RegisteredAt: time.Now(), - } - - if err := m.credentials.SaveClient(serverName, dcrClient); err != nil { - return fmt.Errorf("saving DCR client for %s: %w", serverName, err) - } - - log.Logf("- Completed DCR for: %s", serverName) - return nil + }, nil } // DeleteDCRClient removes a DCR client from storage