diff --git a/cmd/docker-mcp/secret-management/secret/credstore.go b/cmd/docker-mcp/secret-management/secret/credstore.go index 5529eca8b..d35e38ea0 100644 --- a/cmd/docker-mcp/secret-management/secret/credstore.go +++ b/cmd/docker-mcp/secret-management/secret/credstore.go @@ -93,3 +93,45 @@ func DeleteDefaultSecret(ctx context.Context, id string) error { } return nil } + +// SetOAuthToken stores an OAuth token via docker pass at docker/mcp/oauth/{serverName}. +// The value should be base64-encoded JSON of the full oauth2.Token. +func SetOAuthToken(ctx context.Context, serverName string, value string) error { + c := cmd(ctx, "set", GetOAuthKey(serverName)) + c.Stdin = strings.NewReader(value) + out, err := c.CombinedOutput() + if err != nil { + return fmt.Errorf("could not store OAuth token for %s: %s\n%s", serverName, bytes.TrimSpace(out), err) + } + return nil +} + +// DeleteOAuthToken removes an OAuth token from docker pass. +func DeleteOAuthToken(ctx context.Context, serverName string) error { + out, err := cmd(ctx, "rm", GetOAuthKey(serverName)).CombinedOutput() + if err != nil { + return fmt.Errorf("could not delete OAuth token for %s: %s\n%s", serverName, bytes.TrimSpace(out), err) + } + return nil +} + +// SetDCRClient stores a DCR client config via docker pass at docker/mcp/oauth-dcr/{serverName}. +// The value should be base64-encoded JSON of the DCR client. +func SetDCRClient(ctx context.Context, serverName string, value string) error { + c := cmd(ctx, "set", GetDCRKey(serverName)) + c.Stdin = strings.NewReader(value) + out, err := c.CombinedOutput() + if err != nil { + return fmt.Errorf("could not store DCR client for %s: %s\n%s", serverName, bytes.TrimSpace(out), err) + } + return nil +} + +// DeleteDCRClient removes a DCR client config from docker pass. +func DeleteDCRClient(ctx context.Context, serverName string) error { + out, err := cmd(ctx, "rm", GetDCRKey(serverName)).CombinedOutput() + if err != nil { + return fmt.Errorf("could not delete DCR client for %s: %s\n%s", serverName, bytes.TrimSpace(out), err) + } + return nil +} diff --git a/pkg/catalog/catalog_test.go b/pkg/catalog/catalog_test.go index 42ebb1fe1..a1aec8705 100644 --- a/pkg/catalog/catalog_test.go +++ b/pkg/catalog/catalog_test.go @@ -204,6 +204,46 @@ func setupTestCatalogs(t *testing.T, homeDir string) { require.NoError(t, err) } +func TestServer_IsCommunity(t *testing.T) { + tests := []struct { + name string + server Server + expected bool + }{ + { + name: "nil metadata", + server: Server{}, + expected: false, + }, + { + name: "empty tags", + server: Server{Metadata: &Metadata{Tags: []string{}}}, + expected: false, + }, + { + name: "has community tag", + server: Server{Metadata: &Metadata{Tags: []string{"community"}}}, + expected: true, + }, + { + name: "community among other tags", + server: Server{Metadata: &Metadata{Tags: []string{"featured", "community", "ai"}}}, + expected: true, + }, + { + name: "no community tag", + server: Server{Metadata: &Metadata{Tags: []string{"featured", "official"}}}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.server.IsCommunity()) + }) + } +} + func setupOverlappingCatalogs(t *testing.T, homeDir string) { t.Helper() diff --git a/pkg/catalog/types.go b/pkg/catalog/types.go index 9ee7768c6..bc11a72da 100644 --- a/pkg/catalog/types.go +++ b/pkg/catalog/types.go @@ -1,6 +1,10 @@ package catalog -import "github.com/docker/mcp-gateway/pkg/policy" +import ( + "slices" + + "github.com/docker/mcp-gateway/pkg/policy" +) type Catalog struct { Servers map[string]Server @@ -59,6 +63,16 @@ type Metadata struct { RegistryURL string `yaml:"registryUrl,omitempty" json:"registryUrl,omitempty"` } +// IsCommunity returns true if this server was sourced from the community MCP +// registry. Community servers are tagged with "community" in Metadata.Tags by +// catalog_next/create.go when importing from the community registry. +func (s *Server) IsCommunity() bool { + if s.Metadata == nil { + return false + } + return slices.Contains(s.Metadata.Tags, "community") +} + func (s *Server) IsOAuthServer() bool { return s.OAuth != nil && len(s.OAuth.Providers) > 0 } diff --git a/pkg/gateway/mcpadd.go b/pkg/gateway/mcpadd.go index a1eacbd3b..0f1fe4de8 100644 --- a/pkg/gateway/mcpadd.go +++ b/pkg/gateway/mcpadd.go @@ -469,7 +469,7 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str // Start provider (CE mode only - Desktop mode doesn't need polling) if oauth.IsCEMode() { - g.startProvider(ctx, serverName) + g.startProvider(ctx, serverName, oauth.ModeCE) } } diff --git a/pkg/gateway/run.go b/pkg/gateway/run.go index a3903f9ee..8a98cdd57 100644 --- a/pkg/gateway/run.go +++ b/pkg/gateway/run.go @@ -355,23 +355,39 @@ func (g *Gateway) Run(ctx context.Context) error { // Start OAuth provider for each OAuth server. // Each provider runs in its own goroutine with dynamic timing based on token expiry. log.Log("- Starting OAuth provider loops...") - credHelper := oauth.NewOAuthCredentialHelper() + + // Pre-flight check: verify docker pass availability once for all + // community servers that may need it. Avoids repeated shell-outs. + hasDockerPass := desktop.CheckHasDockerPass(ctx) == nil + for _, serverName := range configuration.ServerNames() { serverConfig, _, found := configuration.Find(serverName) if !found || serverConfig == nil { continue } + isCommunity := serverConfig.Spec.IsCommunity() + mode := oauth.DetermineMode(ctx, isCommunity) + + // Community mode requires docker pass. If unavailable, fall + // back to Desktop mode so the server is not left unmanaged. + if mode == oauth.ModeCommunity && !hasDockerPass { + log.Logf("! docker pass unavailable -- falling back to Desktop OAuth for community server %s", serverName) + mode = oauth.ModeDesktop + } + + credHelper := oauth.NewOAuthCredentialHelperWithMode(mode) + if serverConfig.Spec.HasExplicitOAuthProviders() { - g.startProvider(ctx, serverName) + g.startProvider(ctx, serverName, mode) } else if serverConfig.IsRemote() { - // Community servers: start provider if they have a stored OAuth token + // Community/remote servers: start provider if they have a stored OAuth token // from dynamic discovery (DCR without explicit OAuth metadata) if exists, err := credHelper.TokenExists(ctx, serverName); err != nil { log.Logf("Warning: Failed to check OAuth token for %s: %v", serverName, err) } else if exists { - log.Logf("- Starting OAuth provider for community server: %s", serverName) - g.startProvider(ctx, serverName) + log.Logf("- Starting OAuth provider for remote server: %s (mode=%s)", serverName, mode) + g.startProvider(ctx, serverName, mode) } } } @@ -611,8 +627,11 @@ func (g *Gateway) periodicMetricExport(ctx context.Context) { // OAuth Provider Management Methods -// startProvider creates and starts an OAuth provider goroutine for a server -func (g *Gateway) startProvider(ctx context.Context, serverName string) { +// startProvider creates and starts an OAuth provider goroutine for a server. +// The mode parameter controls which credential storage backend the provider +// uses. Pass ModeAuto when the caller does not know the server type +// (backward compat); the provider will fall back to runtime IsCEMode() detection. +func (g *Gateway) startProvider(ctx context.Context, serverName string, mode oauth.Mode) { g.providersMu.Lock() defer g.providersMu.Unlock() @@ -649,7 +668,7 @@ func (g *Gateway) startProvider(ctx context.Context, serverName string) { } // Create and start provider - provider := oauth.NewProvider(serverName, reloadFn) + provider := oauth.NewProvider(serverName, mode, reloadFn) g.oauthProviders[serverName] = provider // Wrapper goroutine handles cleanup after provider exits @@ -684,10 +703,12 @@ func (g *Gateway) routeEventToProvider(event oauth.Event) { switch event.Type { case oauth.EventLoginSuccess: - // User just authorized - ensure provider exists + // User just authorized via Desktop SSE - ensure provider exists. + // SSE events are only received in Desktop mode (the notification + // monitor is skipped in CE mode), so ModeDesktop is correct. if !exists { log.Logf("- Creating provider for %s after login", event.Provider) - g.startProvider(context.Background(), event.Provider) + g.startProvider(context.Background(), event.Provider, oauth.ModeDesktop) } // Always send event to trigger reload (connects server and lists tools) diff --git a/pkg/oauth/credhelper.go b/pkg/oauth/credhelper.go index ab8f1be47..38b19a5d3 100644 --- a/pkg/oauth/credhelper.go +++ b/pkg/oauth/credhelper.go @@ -18,9 +18,12 @@ import ( "github.com/docker/mcp-gateway/pkg/oauth/dcr" ) -// CredentialHelper provides secure access to OAuth tokens via credential helpers +// CredentialHelper provides secure access to OAuth tokens via credential helpers. +// The mode field controls which storage backend is used: Secrets Engine (Desktop +// catalog), credential helper (CE), or docker pass (Desktop community). type CredentialHelper struct { credentialHelper credentials.Helper + mode Mode } // GetHelper returns the underlying credential helper @@ -28,13 +31,39 @@ func (h *CredentialHelper) GetHelper() credentials.Helper { return h.credentialHelper } -// NewOAuthCredentialHelper creates a new OAuth credential helper +// NewOAuthCredentialHelper creates a new OAuth credential helper with +// auto-detected mode (preserves existing behavior for callers that have not +// been updated to pass an explicit mode). func NewOAuthCredentialHelper() *CredentialHelper { return &CredentialHelper{ credentialHelper: newOAuthHelper(), + mode: ModeAuto, } } +// NewOAuthCredentialHelperWithMode creates a credential helper that uses the +// specified storage mode. Use this when the caller knows whether the server +// is Desktop-catalog, CE, or Desktop-community. +func NewOAuthCredentialHelperWithMode(mode Mode) *CredentialHelper { + return &CredentialHelper{ + credentialHelper: newOAuthHelper(), + mode: mode, + } +} + +// resolveMode returns the effective Mode. When mode is ModeAuto, the +// runtime IsCEMode() check determines the backend (CE or Desktop). Explicit +// modes are returned as-is. +func (h *CredentialHelper) resolveMode() Mode { + if h.mode == ModeAuto { + if IsCEMode() { + return ModeCE + } + return ModeDesktop + } + return h.mode +} + // TokenStatus represents the validity status of an OAuth token type TokenStatus struct { Valid bool @@ -42,12 +71,20 @@ type TokenStatus struct { NeedsRefresh bool } -// GetOAuthToken retrieves an OAuth token for the specified server +// GetOAuthToken retrieves an OAuth token for the specified server. +// Routes to the appropriate storage backend based on the resolved mode: +// - CE: credential helper (base64-encoded JSON) +// - Desktop: Secrets Engine (raw access token) +// - Community: docker pass via Secrets Engine (base64-encoded JSON) func (h *CredentialHelper) GetOAuthToken(ctx context.Context, serverName string) (string, error) { - if IsCEMode() { + switch h.resolveMode() { + case ModeCE: return h.getOAuthTokenCE(serverName) + case ModeCommunity: + return h.getOAuthTokenDockerPass(ctx, serverName) + default: + return h.getOAuthTokenDesktop(ctx, serverName) } - return h.getOAuthTokenDesktop(ctx, serverName) } // getOAuthTokenCE retrieves OAuth token in CE mode using credential helper. @@ -114,26 +151,35 @@ func (h *CredentialHelper) getOAuthTokenDesktop(ctx context.Context, serverName } // TokenExists checks if an OAuth token exists for the specified server. -// This is the appropriate check for Desktop mode where Secrets Engine returns -// raw tokens without validity or expiry metadata - we can only verify existence. -// For CE mode, this also works but GetTokenStatus provides more detail. +// Routes to the appropriate storage backend based on the resolved mode. func (h *CredentialHelper) TokenExists(ctx context.Context, serverName string) (bool, error) { - if IsCEMode() { - // CE mode: check credential helper - dcrMgr := dcr.NewManager(h.credentialHelper, "") - client, err := dcrMgr.GetDCRClient(serverName) - if err != nil { - return false, nil // No DCR client = no token - } - credentialKey := fmt.Sprintf("%s/%s", client.AuthorizationEndpoint, client.ProviderName) - _, tokenSecret, err := h.credentialHelper.Get(credentialKey) - if err != nil || tokenSecret == "" { - return false, nil - } - return true, nil + switch h.resolveMode() { + case ModeCE: + return h.tokenExistsCE(serverName) + case ModeCommunity: + return h.tokenExistsDockerPass(ctx, serverName) + default: + return h.tokenExistsDesktop(ctx, serverName) } +} + +// tokenExistsCE checks if a token exists in the credential helper (CE mode). +func (h *CredentialHelper) tokenExistsCE(serverName string) (bool, error) { + dcrMgr := dcr.NewManager(h.credentialHelper, "") + client, err := dcrMgr.GetDCRClient(serverName) + if err != nil { + return false, nil // No DCR client = no token + } + credentialKey := fmt.Sprintf("%s/%s", client.AuthorizationEndpoint, client.ProviderName) + _, tokenSecret, err := h.credentialHelper.Get(credentialKey) + if err != nil || tokenSecret == "" { + return false, nil + } + return true, nil +} - // Desktop mode: check Secrets Engine +// tokenExistsDesktop checks if a token exists in the Secrets Engine (Desktop mode). +func (h *CredentialHelper) tokenExistsDesktop(ctx context.Context, serverName string) (bool, error) { oauthID := secret.GetOAuthKey(serverName) env, err := secret.GetSecret(ctx, oauthID) if errors.Is(err, secret.ErrSecretNotFound) { @@ -146,14 +192,19 @@ func (h *CredentialHelper) TokenExists(ctx context.Context, serverName string) ( } // GetTokenStatus checks token validity and expiry for refresh scheduling. -// Works in both CE and Desktop modes: -// - CE mode: reads token JSON from credential helper and parses expiry -// - Desktop mode: queries Secrets Engine and reads ExpiryAt from response metadata +// Routes to the appropriate storage backend based on the resolved mode: +// - CE: reads token JSON from credential helper, parses expiry +// - Desktop: queries Secrets Engine, reads ExpiryAt from response metadata +// - Community: reads token JSON from docker pass via Secrets Engine, parses expiry func (h *CredentialHelper) GetTokenStatus(ctx context.Context, serverName string) (TokenStatus, error) { - if IsCEMode() { + switch h.resolveMode() { + case ModeCE: return h.getTokenStatusCE(serverName) + case ModeCommunity: + return h.getTokenStatusDockerPass(ctx, serverName) + default: + return h.getTokenStatusDesktop(ctx, serverName) } - return h.getTokenStatusDesktop(ctx, serverName) } // getTokenStatusDesktop retrieves token status in Desktop mode using Secrets Engine metadata. diff --git a/pkg/oauth/dockerpass.go b/pkg/oauth/dockerpass.go new file mode 100644 index 000000000..9c8491b59 --- /dev/null +++ b/pkg/oauth/dockerpass.go @@ -0,0 +1,255 @@ +package oauth + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "time" + + "golang.org/x/oauth2" + + "github.com/docker/mcp-gateway/cmd/docker-mcp/secret-management/secret" + "github.com/docker/mcp-gateway/pkg/log" + "github.com/docker/mcp-gateway/pkg/oauth/dcr" +) + +// Docker pass stores tokens and DCR clients in the OS Keychain at well-known +// key paths. Writes use the `docker pass set` CLI command (via the secret +// package). Reads go through the Secrets Engine API, which aggregates all +// providers including the docker-pass plugin. +// +// Plugin resolution: both the docker-pass plugin (pattern `**`) and the +// docker-desktop-mcp-oauth plugin (pattern `docker/mcp/oauth/**`) match the +// token key path. For community servers the token is written via `docker pass`, +// so only the docker-pass plugin has an entry; the Desktop OAuth plugin returns +// "not found" and the Secrets Engine falls through to docker-pass. For catalog +// servers written by Desktop's OAuth manager, the Desktop plugin responds first +// (more specific pattern). This asymmetry is what makes the same GetSecret call +// return the right value for both modes. +// +// Token format: base64-encoded JSON of oauth2.Token (same as CE mode). +// DCR format: base64-encoded JSON of dcr.Client (same as CE mode). + +// --- CredentialHelper methods for docker pass (community mode) --- + +// getOAuthTokenDockerPass retrieves an OAuth access token for a community +// server stored in docker pass. The Secrets Engine's docker-pass plugin +// returns the raw stored value at docker/mcp/oauth/{server}. +func (h *CredentialHelper) getOAuthTokenDockerPass(ctx context.Context, serverName string) (string, error) { + oauthID := secret.GetOAuthKey(serverName) + env, err := secret.GetSecret(ctx, oauthID) + if errors.Is(err, secret.ErrSecretNotFound) { + return "", fmt.Errorf("OAuth token not found for %s. Run 'docker mcp oauth authorize %s' to authenticate", serverName, serverName) + } + if err != nil { + return "", fmt.Errorf("failed to query Secrets Engine for %s: %w", serverName, err) + } + + storedValue := string(env.Value) + if storedValue == "" { + return "", fmt.Errorf("empty OAuth token found for %s", serverName) + } + + // Docker pass stores base64-encoded JSON of the full oauth2.Token. + tokenJSON, err := base64.StdEncoding.DecodeString(storedValue) + if err != nil { + return "", fmt.Errorf("failed to decode OAuth token for %s: %w", serverName, err) + } + + var tokenData struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + } + if err := json.Unmarshal(tokenJSON, &tokenData); err != nil { + return "", fmt.Errorf("failed to parse OAuth token JSON for %s: %w", serverName, err) + } + + if tokenData.AccessToken == "" { + return "", fmt.Errorf("empty OAuth access token found for %s", serverName) + } + + return tokenData.AccessToken, nil +} + +// tokenExistsDockerPass checks whether a token exists in docker pass for a +// community server. Reads via the Secrets Engine. +func (h *CredentialHelper) tokenExistsDockerPass(ctx context.Context, serverName string) (bool, error) { + oauthID := secret.GetOAuthKey(serverName) + env, err := secret.GetSecret(ctx, oauthID) + if errors.Is(err, secret.ErrSecretNotFound) { + return false, nil + } + if err != nil { + return false, fmt.Errorf("failed to query Secrets Engine for %s: %w", serverName, err) + } + return string(env.Value) != "", nil +} + +// getTokenStatusDockerPass retrieves token validity and expiry from docker +// pass for a community server. The stored value is base64-encoded JSON +// containing the expiry field. +func (h *CredentialHelper) getTokenStatusDockerPass(ctx context.Context, serverName string) (TokenStatus, error) { + oauthID := secret.GetOAuthKey(serverName) + env, err := secret.GetSecret(ctx, oauthID) + if errors.Is(err, secret.ErrSecretNotFound) { + return TokenStatus{Valid: false}, fmt.Errorf("OAuth token not found for %s", serverName) + } + if err != nil { + return TokenStatus{Valid: false}, fmt.Errorf("failed to query Secrets Engine for %s: %w", serverName, err) + } + + storedValue := string(env.Value) + if storedValue == "" { + return TokenStatus{Valid: false}, fmt.Errorf("empty OAuth token found for %s", serverName) + } + + tokenJSON, err := base64.StdEncoding.DecodeString(storedValue) + if err != nil { + return TokenStatus{Valid: false}, fmt.Errorf("failed to decode OAuth token for %s: %w", serverName, err) + } + + var tokenData struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + RefreshToken string `json:"refresh_token,omitempty"` + Expiry string `json:"expiry,omitempty"` + } + if err := json.Unmarshal(tokenJSON, &tokenData); err != nil { + return TokenStatus{Valid: false}, fmt.Errorf("failed to parse OAuth token JSON for %s: %w", serverName, err) + } + + if tokenData.AccessToken == "" { + return TokenStatus{Valid: false}, fmt.Errorf("empty OAuth access token found for %s", serverName) + } + + if tokenData.Expiry == "" { + // No expiry -- token is valid but trigger immediate refresh check. + return TokenStatus{ + Valid: true, + ExpiresAt: time.Time{}, + NeedsRefresh: true, + }, nil + } + + expiresAt, err := time.Parse(time.RFC3339, tokenData.Expiry) + if err != nil { + return TokenStatus{Valid: false}, fmt.Errorf("failed to parse expiry time for %s: %w", serverName, err) + } + + now := time.Now() + timeUntilExpiry := expiresAt.Sub(now) + needsRefresh := timeUntilExpiry <= 10*time.Second + + log.Logf("- Token status for %s (docker pass): valid=true, expires_at=%s, time_until_expiry=%v, needs_refresh=%v", + serverName, expiresAt.Format(time.RFC3339), timeUntilExpiry.Round(time.Second), needsRefresh) + + return TokenStatus{ + Valid: true, + ExpiresAt: expiresAt, + NeedsRefresh: needsRefresh, + }, nil +} + +// --- Exported helpers for docker pass token and DCR operations --- + +// GetTokenFromDockerPass retrieves the full oauth2.Token from docker pass via +// the Secrets Engine. Used by the refresh loop to get the current token +// (including refresh_token) before refreshing. +func GetTokenFromDockerPass(ctx context.Context, serverName string) (*oauth2.Token, error) { + oauthID := secret.GetOAuthKey(serverName) + env, err := secret.GetSecret(ctx, oauthID) + if errors.Is(err, secret.ErrSecretNotFound) { + return nil, fmt.Errorf("OAuth token not found for %s", serverName) + } + if err != nil { + return nil, fmt.Errorf("failed to query Secrets Engine for %s: %w", serverName, err) + } + + storedValue := string(env.Value) + if storedValue == "" { + return nil, fmt.Errorf("empty OAuth token found for %s", serverName) + } + + tokenJSON, err := base64.StdEncoding.DecodeString(storedValue) + if err != nil { + return nil, fmt.Errorf("failed to decode OAuth token for %s: %w", serverName, err) + } + + var token oauth2.Token + if err := json.Unmarshal(tokenJSON, &token); err != nil { + return nil, fmt.Errorf("failed to parse OAuth token for %s: %w", serverName, err) + } + + return &token, nil +} + +// SaveTokenToDockerPass stores an oauth2.Token in docker pass as base64-encoded +// JSON at docker/mcp/oauth/{serverName}. Used by the authorize flow and the +// refresh loop for community servers in Desktop mode. +func SaveTokenToDockerPass(ctx context.Context, serverName string, token *oauth2.Token) error { + tokenJSON, err := json.Marshal(token) + if err != nil { + return fmt.Errorf("marshalling token for %s: %w", serverName, err) + } + + encoded := base64.StdEncoding.EncodeToString(tokenJSON) + + if err := secret.SetOAuthToken(ctx, serverName, encoded); err != nil { + return fmt.Errorf("storing OAuth token for %s: %w", serverName, err) + } + + log.Logf("- Stored OAuth token for %s (docker pass)", serverName) + return nil +} + +// GetDCRClientFromDockerPass retrieves a DCR client from docker pass via the +// Secrets Engine. The value is base64-encoded JSON of dcr.Client stored at +// docker/mcp/oauth-dcr/{serverName}. +func GetDCRClientFromDockerPass(ctx context.Context, serverName string) (dcr.Client, error) { + dcrID := secret.GetDCRKey(serverName) + env, err := secret.GetSecret(ctx, dcrID) + if errors.Is(err, secret.ErrSecretNotFound) { + return dcr.Client{}, fmt.Errorf("DCR client not found for %s", serverName) + } + if err != nil { + return dcr.Client{}, fmt.Errorf("failed to query Secrets Engine for DCR client %s: %w", serverName, err) + } + + storedValue := string(env.Value) + if storedValue == "" { + return dcr.Client{}, fmt.Errorf("empty DCR client found for %s", serverName) + } + + jsonData, err := base64.StdEncoding.DecodeString(storedValue) + if err != nil { + return dcr.Client{}, fmt.Errorf("failed to decode DCR client for %s: %w", serverName, err) + } + + var client dcr.Client + if err := json.Unmarshal(jsonData, &client); err != nil { + return dcr.Client{}, fmt.Errorf("failed to parse DCR client for %s: %w", serverName, err) + } + + return client, nil +} + +// SaveDCRClientToDockerPass stores a DCR client in docker pass as base64-encoded +// JSON at docker/mcp/oauth-dcr/{serverName}. Used by the authorize flow for +// community servers in Desktop mode. +func SaveDCRClientToDockerPass(ctx context.Context, serverName string, client dcr.Client) error { + jsonData, err := json.Marshal(client) + if err != nil { + return fmt.Errorf("marshalling DCR client for %s: %w", serverName, err) + } + + encoded := base64.StdEncoding.EncodeToString(jsonData) + + if err := secret.SetDCRClient(ctx, serverName, encoded); err != nil { + return fmt.Errorf("storing DCR client for %s: %w", serverName, err) + } + + log.Logf("- Stored DCR client for %s (docker pass)", serverName) + return nil +} diff --git a/pkg/oauth/dockerpass_test.go b/pkg/oauth/dockerpass_test.go new file mode 100644 index 000000000..9a6ccbd30 --- /dev/null +++ b/pkg/oauth/dockerpass_test.go @@ -0,0 +1,171 @@ +package oauth + +import ( + "encoding/base64" + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + + "github.com/docker/mcp-gateway/pkg/oauth/dcr" +) + +// TestEncodeDecodeToken verifies the round-trip encoding of oauth2.Token +// used by SaveTokenToDockerPass and GetTokenFromDockerPass. This tests the +// serialization logic without requiring docker pass or the Secrets Engine. +func TestEncodeDecodeToken(t *testing.T) { + expiry := time.Now().Add(1 * time.Hour).Truncate(time.Second) + original := &oauth2.Token{ + AccessToken: "access-abc", + TokenType: "Bearer", + RefreshToken: "refresh-xyz", + Expiry: expiry, + } + + // Encode (same logic as SaveTokenToDockerPass) + tokenJSON, err := json.Marshal(original) + require.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(tokenJSON) + + // Decode (same logic as GetTokenFromDockerPass) + decoded, err := base64.StdEncoding.DecodeString(encoded) + require.NoError(t, err) + + var restored oauth2.Token + require.NoError(t, json.Unmarshal(decoded, &restored)) + + assert.Equal(t, original.AccessToken, restored.AccessToken) + assert.Equal(t, original.TokenType, restored.TokenType) + assert.Equal(t, original.RefreshToken, restored.RefreshToken) + assert.True(t, original.Expiry.Equal(restored.Expiry), + "expiry mismatch: want %v, got %v", original.Expiry, restored.Expiry) +} + +// TestEncodeDecodeDCRClient verifies the round-trip encoding of dcr.Client +// used by SaveDCRClientToDockerPass and GetDCRClientFromDockerPass. +func TestEncodeDecodeDCRClient(t *testing.T) { + original := dcr.Client{ + ServerName: "notion-remote", + ProviderName: "notion-remote", + ClientID: "client-123", + ClientName: "MCP Gateway - notion-remote", + AuthorizationEndpoint: "https://api.notion.com/v1/oauth/authorize", + TokenEndpoint: "https://api.notion.com/v1/oauth/token", + ResourceURL: "https://mcp.notion.com/sse", + RequiredScopes: []string{"read", "write"}, + RegisteredAt: time.Now().Truncate(time.Second), + } + + // Encode (same logic as SaveDCRClientToDockerPass) + jsonData, err := json.Marshal(original) + require.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(jsonData) + + // Decode (same logic as GetDCRClientFromDockerPass) + decoded, err := base64.StdEncoding.DecodeString(encoded) + require.NoError(t, err) + + var restored dcr.Client + require.NoError(t, json.Unmarshal(decoded, &restored)) + + assert.Equal(t, original.ServerName, restored.ServerName) + assert.Equal(t, original.ProviderName, restored.ProviderName) + assert.Equal(t, original.ClientID, restored.ClientID) + assert.Equal(t, original.AuthorizationEndpoint, restored.AuthorizationEndpoint) + assert.Equal(t, original.TokenEndpoint, restored.TokenEndpoint) + assert.Equal(t, original.ResourceURL, restored.ResourceURL) + assert.Equal(t, original.RequiredScopes, restored.RequiredScopes) + assert.True(t, original.RegisteredAt.Equal(restored.RegisteredAt)) +} + +// TestGetTokenStatusDockerPass_ParsesExpiry verifies that token status +// correctly parses the expiry field from base64-encoded JSON tokens. +func TestTokenStatusFromBase64JSON(t *testing.T) { + tests := []struct { + name string + token map[string]any + expectValid bool + expectRefresh bool + expectHasExpiry bool + }{ + { + name: "valid token with future expiry", + token: map[string]any{ + "access_token": "tok-123", + "token_type": "Bearer", + "refresh_token": "ref-456", + "expiry": time.Now().Add(1 * time.Hour).Format(time.RFC3339), + }, + expectValid: true, + expectRefresh: false, + expectHasExpiry: true, + }, + { + name: "token expiring within 10 seconds", + token: map[string]any{ + "access_token": "tok-123", + "token_type": "Bearer", + "refresh_token": "ref-456", + "expiry": time.Now().Add(5 * time.Second).Format(time.RFC3339), + }, + expectValid: true, + expectRefresh: true, + expectHasExpiry: true, + }, + { + name: "expired token", + token: map[string]any{ + "access_token": "tok-123", + "token_type": "Bearer", + "refresh_token": "ref-456", + "expiry": time.Now().Add(-10 * time.Minute).Format(time.RFC3339), + }, + expectValid: true, + expectRefresh: true, + expectHasExpiry: true, + }, + { + name: "token without expiry", + token: map[string]any{ + "access_token": "tok-123", + "token_type": "Bearer", + }, + expectValid: true, + expectRefresh: true, + expectHasExpiry: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokenJSON, err := json.Marshal(tt.token) + require.NoError(t, err) + + // Simulate the parsing logic from getTokenStatusDockerPass + var tokenData struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + RefreshToken string `json:"refresh_token,omitempty"` + Expiry string `json:"expiry,omitempty"` + } + require.NoError(t, json.Unmarshal(tokenJSON, &tokenData)) + + assert.NotEmpty(t, tokenData.AccessToken) + + if tt.expectHasExpiry { + assert.NotEmpty(t, tokenData.Expiry) + expiresAt, err := time.Parse(time.RFC3339, tokenData.Expiry) + require.NoError(t, err) + + timeUntilExpiry := time.Until(expiresAt) + needsRefresh := timeUntilExpiry <= 10*time.Second + assert.Equal(t, tt.expectRefresh, needsRefresh) + } else { + assert.Empty(t, tokenData.Expiry) + } + }) + } +} diff --git a/pkg/oauth/mode.go b/pkg/oauth/mode.go index f8440fbe6..63ca6fa10 100644 --- a/pkg/oauth/mode.go +++ b/pkg/oauth/mode.go @@ -7,6 +7,61 @@ import ( "github.com/docker/mcp-gateway/pkg/desktop" ) +// Mode determines which credential storage backend to use for a server. +type Mode int + +const ( + // ModeAuto auto-detects mode at runtime: IsCEMode() -> CE, else Desktop. + // Used for backward compatibility when callers have not yet been updated + // to pass an explicit mode. + ModeAuto Mode = iota + // ModeDesktop reads/writes via Secrets Engine (Desktop catalog servers). + ModeDesktop + // ModeCE reads/writes via the system credential helper (CE standalone). + ModeCE + // ModeCommunity reads/writes via docker pass (Desktop community servers). + ModeCommunity +) + +// String returns a human-readable label for the mode. +func (m Mode) String() string { + switch m { + case ModeDesktop: + return "Desktop" + case ModeCE: + return "CE" + case ModeCommunity: + return "Community" + default: + return "Auto" + } +} + +// DetermineMode returns the credential storage mode for a server. +// +// - CE mode (no Desktop): ModeCE +// - Desktop + catalog server: ModeDesktop +// - Desktop + community server + McpGatewayOAuth flag ON: ModeCommunity +// - Desktop + community server + flag OFF/error: ModeDesktop (fallback) +func DetermineMode(ctx context.Context, isCommunity bool) Mode { + return determineMode(ctx, IsCEMode(), isCommunity, desktop.CheckFeatureFlagIsEnabled) +} + +// determineMode is the testable core. ceMode is pre-resolved so tests +// don't need to mock env/OS detection or the Desktop backend socket. +func determineMode(ctx context.Context, ceMode bool, isCommunity bool, checkFlag featureFlagChecker) Mode { + if ceMode { + return ModeCE + } + if isCommunity { + enabled, err := checkFlag(ctx, "McpGatewayOAuth") + if err == nil && enabled { + return ModeCommunity + } + } + return ModeDesktop +} + // IsCEMode returns true if running in Docker CE mode (standalone OAuth flows). // When false, uses Docker Desktop for OAuth orchestration. // @@ -32,40 +87,12 @@ type featureFlagChecker func(ctx context.Context, featureName string) (bool, err // ShouldUseGatewayOAuth returns true when the Gateway should own the OAuth // lifecycle for a server (localhost callback, PKCE, token storage via -// credential helper or docker pass). -// -// Decision logic: -// - CE mode (no Desktop): always true -// - Desktop + catalog server (isCommunity=false): false (Desktop owns OAuth) -// - Desktop + community server + McpGatewayOAuth flag ON: true -// - Desktop + community server + McpGatewayOAuth flag OFF or error: false +// credential helper or docker pass). This is a convenience wrapper around +// DetermineMode -- Gateway owns OAuth for every mode except ModeDesktop. // // IsCEMode() remains the global decision for the notification monitor -// (pkg/gateway/run.go). This function is the per-server decision that later -// tickets (MCPT-482 through MCPT-486) will wire into call sites. +// (pkg/gateway/run.go). This function is the per-server decision used by +// MCPT-483 through MCPT-486 call sites. func ShouldUseGatewayOAuth(ctx context.Context, isCommunity bool) bool { - return shouldUseGatewayOAuth(ctx, IsCEMode(), isCommunity, desktop.CheckFeatureFlagIsEnabled) -} - -// shouldUseGatewayOAuth is the testable core. ceMode is pre-resolved so tests -// don't need to mock env/OS detection or the Desktop backend socket. -func shouldUseGatewayOAuth(ctx context.Context, ceMode bool, isCommunity bool, checkFlag featureFlagChecker) bool { - if ceMode { - return true - } - - // Desktop mode: catalog servers continue to use Desktop OAuth. - if !isCommunity { - return false - } - - // Desktop mode + community server: gate on the Unleash feature flag - // exposed by the Desktop backend. If the flag is not registered yet - // (MCPT-480 not deployed) or the backend is unreachable, treat as - // disabled -- callers fall back to Desktop OAuth. - enabled, err := checkFlag(ctx, "McpGatewayOAuth") - if err != nil { - return false - } - return enabled + return DetermineMode(ctx, isCommunity) != ModeDesktop } diff --git a/pkg/oauth/mode_test.go b/pkg/oauth/mode_test.go index f99ee13d5..e350682de 100644 --- a/pkg/oauth/mode_test.go +++ b/pkg/oauth/mode_test.go @@ -9,9 +9,53 @@ import ( ) func TestShouldUseGatewayOAuth(t *testing.T) { + // ShouldUseGatewayOAuth is a wrapper: DetermineMode(...) != ModeDesktop. + // The full decision-tree coverage lives in TestDetermineMode; here we + // verify the bool mapping via determineMode (testable core). + + flagOn := func(_ context.Context, _ string) (bool, error) { + return true, nil + } + flagOff := func(_ context.Context, _ string) (bool, error) { + return false, nil + } + + tests := []struct { + name string + ceMode bool + isCommunity bool + checkFlag featureFlagChecker + expected bool + }{ + {"CE mode -> true (ModeCE)", true, false, flagOff, true}, + {"Desktop catalog -> false (ModeDesktop)", false, false, flagOn, false}, + {"Desktop community flag ON -> true (ModeCommunity)", false, true, flagOn, true}, + {"Desktop community flag OFF -> false (ModeDesktop)", false, true, flagOff, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mode := determineMode(t.Context(), tt.ceMode, tt.isCommunity, tt.checkFlag) + got := mode != ModeDesktop + assert.Equal(t, tt.expected, got) + }) + } +} + +func TestShouldUseGatewayOAuth_CEModeIntegration(t *testing.T) { + // Verify the public function wiring: when DOCKER_MCP_USE_CE=true, + // ShouldUseGatewayOAuth returns true regardless of isCommunity. + t.Setenv("DOCKER_MCP_USE_CE", "true") + + assert.True(t, ShouldUseGatewayOAuth(t.Context(), false), + "CE mode override should make ShouldUseGatewayOAuth return true for catalog servers") + assert.True(t, ShouldUseGatewayOAuth(t.Context(), true), + "CE mode override should make ShouldUseGatewayOAuth return true for community servers") +} + +func TestDetermineMode(t *testing.T) { // Test the internal function directly so we can control ceMode and the - // feature-flag checker without depending on the OS, env vars, or a - // running Docker Desktop backend. + // feature-flag checker. flagOn := func(_ context.Context, _ string) (bool, error) { return true, nil @@ -28,87 +72,133 @@ func TestShouldUseGatewayOAuth(t *testing.T) { ceMode bool isCommunity bool checkFlag featureFlagChecker - expected bool + expected Mode }{ - // --- CE mode: Gateway always owns OAuth regardless of server type --- + // CE mode: always ModeCE regardless of server type { name: "CE mode, catalog server", ceMode: true, isCommunity: false, - checkFlag: flagOff, // should not be called - expected: true, + checkFlag: flagOff, + expected: ModeCE, }, { name: "CE mode, community server", ceMode: true, isCommunity: true, - checkFlag: flagOff, // should not be called - expected: true, + checkFlag: flagOff, + expected: ModeCE, }, - // --- Desktop + catalog server: Desktop always owns OAuth --- + // Desktop + catalog server: always ModeDesktop { name: "Desktop, catalog server", ceMode: false, isCommunity: false, - checkFlag: flagOn, // should not be called - expected: false, + checkFlag: flagOn, + expected: ModeDesktop, }, - // --- Desktop + community server: gated on feature flag --- + // Desktop + community server: depends on feature flag { name: "Desktop, community server, flag ON", ceMode: false, isCommunity: true, checkFlag: flagOn, - expected: true, + expected: ModeCommunity, }, { name: "Desktop, community server, flag OFF", ceMode: false, isCommunity: true, checkFlag: flagOff, - expected: false, + expected: ModeDesktop, }, { name: "Desktop, community server, flag error", ceMode: false, isCommunity: true, checkFlag: flagErr, - expected: false, + expected: ModeDesktop, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := shouldUseGatewayOAuth(t.Context(), tt.ceMode, tt.isCommunity, tt.checkFlag) + got := determineMode(t.Context(), tt.ceMode, tt.isCommunity, tt.checkFlag) assert.Equal(t, tt.expected, got) }) } } -func TestShouldUseGatewayOAuth_CEModeIntegration(t *testing.T) { +func TestDetermineMode_CEModeIntegration(t *testing.T) { // Verify the public function wiring: when DOCKER_MCP_USE_CE=true, - // IsCEMode() returns true and the public ShouldUseGatewayOAuth must - // return true regardless of isCommunity. + // DetermineMode returns ModeCE regardless of isCommunity. t.Setenv("DOCKER_MCP_USE_CE", "true") - assert.True(t, ShouldUseGatewayOAuth(t.Context(), false), - "CE mode override should make ShouldUseGatewayOAuth return true for catalog servers") - assert.True(t, ShouldUseGatewayOAuth(t.Context(), true), - "CE mode override should make ShouldUseGatewayOAuth return true for community servers") + assert.Equal(t, ModeCE, DetermineMode(t.Context(), false), + "CE mode override should return ModeCE for catalog servers") + assert.Equal(t, ModeCE, DetermineMode(t.Context(), true), + "CE mode override should return ModeCE for community servers") } -func TestShouldUseGatewayOAuth_FeatureFlagName(t *testing.T) { - // Verify the unexported function passes the correct feature flag name - // to the checker. - var capturedName string - spy := func(_ context.Context, name string) (bool, error) { - capturedName = name - return false, nil +func TestMode_ResolveMode(t *testing.T) { + tests := []struct { + name string + mode Mode + ceMode bool // controlled via env var + expected Mode + }{ + { + name: "explicit Desktop stays Desktop", + mode: ModeDesktop, + expected: ModeDesktop, + }, + { + name: "explicit CE stays CE", + mode: ModeCE, + expected: ModeCE, + }, + { + name: "explicit Community stays Community", + mode: ModeCommunity, + expected: ModeCommunity, + }, + { + name: "Auto in CE mode resolves to CE", + mode: ModeAuto, + ceMode: true, + expected: ModeCE, + }, } - shouldUseGatewayOAuth(t.Context(), false, true, spy) - assert.Equal(t, "McpGatewayOAuth", capturedName, - "should query the McpGatewayOAuth feature flag") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.ceMode { + t.Setenv("DOCKER_MCP_USE_CE", "true") + } + h := NewOAuthCredentialHelperWithMode(tt.mode) + got := h.resolveMode() + assert.Equal(t, tt.expected, got) + }) + } +} + +func TestMode_String(t *testing.T) { + tests := []struct { + mode Mode + expected string + }{ + {ModeAuto, "Auto"}, + {ModeDesktop, "Desktop"}, + {ModeCE, "CE"}, + {ModeCommunity, "Community"}, + {Mode(99), "Auto"}, // unknown falls through to default + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.mode.String()) + }) + } } diff --git a/pkg/oauth/provider.go b/pkg/oauth/provider.go index 1256edf51..d5ef494c7 100644 --- a/pkg/oauth/provider.go +++ b/pkg/oauth/provider.go @@ -64,11 +64,18 @@ func (p *DCRProvider) GeneratePKCE() string { // Provider manages OAuth token lifecycle for a single MCP server. // This is used for background token refresh loops in the gateway. -// CE mode: refreshes tokens directly via oauth2 library, then reloads. -// Desktop mode: triggers refresh via GetOAuthApp Desktop API, then SSE events -// interrupt the timer, trigger reload, and reset retry counters. +// +// The mode field (cached at construction) determines behavior: +// - ModeAuto: runtime IsCEMode() detection (backward compat) +// - ModeDesktop: triggers refresh via GetOAuthApp Desktop API; SSE events +// interrupt the timer, trigger reload, and reset retry counters. +// - ModeCE: refreshes tokens directly via oauth2 library using the +// credential helper, then reloads. +// - ModeCommunity: refreshes tokens directly via oauth2 library using +// docker pass for storage, then reloads. type Provider struct { name string + mode Mode lastRefreshExpiry time.Time refreshRetryCount int stopOnce sync.Once @@ -80,13 +87,16 @@ type Provider struct { const maxRefreshRetries = 7 // Max attempts to refresh when expiry hasn't changed -// NewProvider creates a new OAuth provider for token refresh -func NewProvider(name string, reloadFn func(context.Context, string) error) *Provider { +// NewProvider creates a new OAuth provider for token refresh. +// The mode parameter controls which credential storage backend is used. +// Pass ModeAuto to preserve the existing IsCEMode() runtime behavior. +func NewProvider(name string, mode Mode, reloadFn func(context.Context, string) error) *Provider { return &Provider{ name: name, + mode: mode, stopChan: make(chan struct{}), eventChan: make(chan Event), - credHelper: NewOAuthCredentialHelper(), + credHelper: NewOAuthCredentialHelperWithMode(mode), reloadFn: reloadFn, } } @@ -151,7 +161,8 @@ func (p *Provider) Run(ctx context.Context) { // Trigger refresh if needed if shouldTriggerRefresh { - if IsCEMode() { + switch p.resolveRefreshMode() { + case ModeCE: // CE mode: Refresh token directly, then reload server connection go func() { if err := p.refreshTokenCE(); err != nil { @@ -162,7 +173,18 @@ func (p *Provider) Run(ctx context.Context) { log.Logf("! Failed to reload %s after token refresh: %v", p.name, err) } }() - } else { + case ModeCommunity: + // Community mode: Refresh token via oauth2, store in docker pass + go func() { + if err := p.refreshTokenCommunity(ctx); err != nil { + log.Logf("! Token refresh failed for %s: %v", p.name, err) + return + } + if err := p.reloadFn(ctx, p.name); err != nil { + log.Logf("! Failed to reload %s after token refresh: %v", p.name, err) + } + }() + default: // Desktop mode: Trigger refresh via Desktop API go func() { authClient := desktop.NewAuthClient() @@ -218,6 +240,57 @@ func (p *Provider) SendEvent(event Event) { p.eventChan <- event } +// resolveRefreshMode returns the effective mode for refresh branching. +// When mode is ModeAuto, falls back to the runtime IsCEMode() check. +func (p *Provider) resolveRefreshMode() Mode { + if p.mode == ModeAuto { + if IsCEMode() { + return ModeCE + } + return ModeDesktop + } + return p.mode +} + +// refreshTokenCommunity refreshes an OAuth token for a community server. +// Reads the DCR client and current token from docker pass (via Secrets Engine), +// refreshes using the oauth2 library, and writes the new token back to docker pass. +func (p *Provider) refreshTokenCommunity(ctx context.Context) error { + // Get DCR client from docker pass + dcrClient, err := GetDCRClientFromDockerPass(ctx, p.name) + if err != nil { + return fmt.Errorf("failed to get DCR client from docker pass: %w", err) + } + + // Get current token from docker pass + token, err := GetTokenFromDockerPass(ctx, p.name) + if err != nil { + return fmt.Errorf("failed to retrieve token from docker pass: %w", err) + } + + // Refresh token using oauth2 library. + // The redirect URI value does not matter for refresh token grants -- the + // Go oauth2 library does not include redirect_uri in the token refresh + // request. We pass DefaultRedirectURI only because NewDCRProvider requires + // a value; the actual localhost redirect used during authorization is not + // persisted in the DCR client struct. + provider := NewDCRProvider(dcrClient, DefaultRedirectURI) + config := provider.Config() + + refreshedToken, err := config.TokenSource(ctx, token).Token() + if err != nil { + return fmt.Errorf("token refresh failed: %w", err) + } + + // Save refreshed token to docker pass + if err := SaveTokenToDockerPass(ctx, p.name, refreshedToken); err != nil { + return fmt.Errorf("failed to save refreshed token to docker pass: %w", err) + } + + log.Logf("- Successfully refreshed token for %s (docker pass)", p.name) + return nil +} + // refreshTokenCE refreshes an OAuth token in CE mode // Uses the same oauth2 library refresh mechanism as Desktop func (p *Provider) refreshTokenCE() error {