Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions cmd/docker-mcp/secret-management/secret/credstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,45 @@ func DeleteDefaultSecret(ctx context.Context, id string) error {
}
return nil
}

// SetOAuthToken stores an OAuth token via docker pass at docker/mcp/oauth/{serverName}.
// The value should be base64-encoded JSON of the full oauth2.Token.
func SetOAuthToken(ctx context.Context, serverName string, value string) error {
c := cmd(ctx, "set", GetOAuthKey(serverName))
c.Stdin = strings.NewReader(value)
out, err := c.CombinedOutput()
if err != nil {
return fmt.Errorf("could not store OAuth token for %s: %s\n%s", serverName, bytes.TrimSpace(out), err)
}
return nil
}

// DeleteOAuthToken removes an OAuth token from docker pass.
func DeleteOAuthToken(ctx context.Context, serverName string) error {
out, err := cmd(ctx, "rm", GetOAuthKey(serverName)).CombinedOutput()
if err != nil {
return fmt.Errorf("could not delete OAuth token for %s: %s\n%s", serverName, bytes.TrimSpace(out), err)
}
return nil
}

// SetDCRClient stores a DCR client config via docker pass at docker/mcp/oauth-dcr/{serverName}.
// The value should be base64-encoded JSON of the DCR client.
func SetDCRClient(ctx context.Context, serverName string, value string) error {
c := cmd(ctx, "set", GetDCRKey(serverName))
c.Stdin = strings.NewReader(value)
out, err := c.CombinedOutput()
if err != nil {
return fmt.Errorf("could not store DCR client for %s: %s\n%s", serverName, bytes.TrimSpace(out), err)
}
return nil
}

// DeleteDCRClient removes a DCR client config from docker pass.
func DeleteDCRClient(ctx context.Context, serverName string) error {
out, err := cmd(ctx, "rm", GetDCRKey(serverName)).CombinedOutput()
if err != nil {
return fmt.Errorf("could not delete DCR client for %s: %s\n%s", serverName, bytes.TrimSpace(out), err)
}
return nil
}
40 changes: 40 additions & 0 deletions pkg/catalog/catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,46 @@ func setupTestCatalogs(t *testing.T, homeDir string) {
require.NoError(t, err)
}

func TestServer_IsCommunity(t *testing.T) {
tests := []struct {
name string
server Server
expected bool
}{
{
name: "nil metadata",
server: Server{},
expected: false,
},
{
name: "empty tags",
server: Server{Metadata: &Metadata{Tags: []string{}}},
expected: false,
},
{
name: "has community tag",
server: Server{Metadata: &Metadata{Tags: []string{"community"}}},
expected: true,
},
{
name: "community among other tags",
server: Server{Metadata: &Metadata{Tags: []string{"featured", "community", "ai"}}},
expected: true,
},
{
name: "no community tag",
server: Server{Metadata: &Metadata{Tags: []string{"featured", "official"}}},
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, tt.server.IsCommunity())
})
}
}

func setupOverlappingCatalogs(t *testing.T, homeDir string) {
t.Helper()

Expand Down
16 changes: 15 additions & 1 deletion pkg/catalog/types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package catalog

import "github.com/docker/mcp-gateway/pkg/policy"
import (
"slices"

"github.com/docker/mcp-gateway/pkg/policy"
)

type Catalog struct {
Servers map[string]Server
Expand Down Expand Up @@ -59,6 +63,16 @@ type Metadata struct {
RegistryURL string `yaml:"registryUrl,omitempty" json:"registryUrl,omitempty"`
}

// IsCommunity returns true if this server was sourced from the community MCP
// registry. Community servers are tagged with "community" in Metadata.Tags by
// catalog_next/create.go when importing from the community registry.
func (s *Server) IsCommunity() bool {
if s.Metadata == nil {
return false
}
return slices.Contains(s.Metadata.Tags, "community")
}

func (s *Server) IsOAuthServer() bool {
return s.OAuth != nil && len(s.OAuth.Providers) > 0
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/gateway/mcpadd.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str

// Start provider (CE mode only - Desktop mode doesn't need polling)
if oauth.IsCEMode() {
g.startProvider(ctx, serverName)
g.startProvider(ctx, serverName, oauth.ModeCE)
}
}

Expand Down
41 changes: 31 additions & 10 deletions pkg/gateway/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,23 +355,39 @@ func (g *Gateway) Run(ctx context.Context) error {
// Start OAuth provider for each OAuth server.
// Each provider runs in its own goroutine with dynamic timing based on token expiry.
log.Log("- Starting OAuth provider loops...")
credHelper := oauth.NewOAuthCredentialHelper()

// Pre-flight check: verify docker pass availability once for all
// community servers that may need it. Avoids repeated shell-outs.
hasDockerPass := desktop.CheckHasDockerPass(ctx) == nil

for _, serverName := range configuration.ServerNames() {
serverConfig, _, found := configuration.Find(serverName)
if !found || serverConfig == nil {
continue
}

isCommunity := serverConfig.Spec.IsCommunity()
mode := oauth.DetermineMode(ctx, isCommunity)

// Community mode requires docker pass. If unavailable, fall
// back to Desktop mode so the server is not left unmanaged.
if mode == oauth.ModeCommunity && !hasDockerPass {
log.Logf("! docker pass unavailable -- falling back to Desktop OAuth for community server %s", serverName)
mode = oauth.ModeDesktop
}

credHelper := oauth.NewOAuthCredentialHelperWithMode(mode)

if serverConfig.Spec.HasExplicitOAuthProviders() {
g.startProvider(ctx, serverName)
g.startProvider(ctx, serverName, mode)
} else if serverConfig.IsRemote() {
// Community servers: start provider if they have a stored OAuth token
// Community/remote servers: start provider if they have a stored OAuth token
// from dynamic discovery (DCR without explicit OAuth metadata)
if exists, err := credHelper.TokenExists(ctx, serverName); err != nil {
log.Logf("Warning: Failed to check OAuth token for %s: %v", serverName, err)
} else if exists {
log.Logf("- Starting OAuth provider for community server: %s", serverName)
g.startProvider(ctx, serverName)
log.Logf("- Starting OAuth provider for remote server: %s (mode=%s)", serverName, mode)
g.startProvider(ctx, serverName, mode)
}
}
}
Expand Down Expand Up @@ -611,8 +627,11 @@ func (g *Gateway) periodicMetricExport(ctx context.Context) {

// OAuth Provider Management Methods

// startProvider creates and starts an OAuth provider goroutine for a server
func (g *Gateway) startProvider(ctx context.Context, serverName string) {
// startProvider creates and starts an OAuth provider goroutine for a server.
// The mode parameter controls which credential storage backend the provider
// uses. Pass ModeAuto when the caller does not know the server type
// (backward compat); the provider will fall back to runtime IsCEMode() detection.
func (g *Gateway) startProvider(ctx context.Context, serverName string, mode oauth.Mode) {
g.providersMu.Lock()
defer g.providersMu.Unlock()

Expand Down Expand Up @@ -649,7 +668,7 @@ func (g *Gateway) startProvider(ctx context.Context, serverName string) {
}

// Create and start provider
provider := oauth.NewProvider(serverName, reloadFn)
provider := oauth.NewProvider(serverName, mode, reloadFn)
g.oauthProviders[serverName] = provider

// Wrapper goroutine handles cleanup after provider exits
Expand Down Expand Up @@ -684,10 +703,12 @@ func (g *Gateway) routeEventToProvider(event oauth.Event) {

switch event.Type {
case oauth.EventLoginSuccess:
// User just authorized - ensure provider exists
// User just authorized via Desktop SSE - ensure provider exists.
// SSE events are only received in Desktop mode (the notification
// monitor is skipped in CE mode), so ModeDesktop is correct.
if !exists {
log.Logf("- Creating provider for %s after login", event.Provider)
g.startProvider(context.Background(), event.Provider)
g.startProvider(context.Background(), event.Provider, oauth.ModeDesktop)
}

// Always send event to trigger reload (connects server and lists tools)
Expand Down
105 changes: 78 additions & 27 deletions pkg/oauth/credhelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,73 @@ import (
"github.com/docker/mcp-gateway/pkg/oauth/dcr"
)

// CredentialHelper provides secure access to OAuth tokens via credential helpers
// CredentialHelper provides secure access to OAuth tokens via credential helpers.
// The mode field controls which storage backend is used: Secrets Engine (Desktop
// catalog), credential helper (CE), or docker pass (Desktop community).
type CredentialHelper struct {
credentialHelper credentials.Helper
mode Mode
}

// GetHelper returns the underlying credential helper
func (h *CredentialHelper) GetHelper() credentials.Helper {
return h.credentialHelper
}

// NewOAuthCredentialHelper creates a new OAuth credential helper
// NewOAuthCredentialHelper creates a new OAuth credential helper with
// auto-detected mode (preserves existing behavior for callers that have not
// been updated to pass an explicit mode).
func NewOAuthCredentialHelper() *CredentialHelper {
return &CredentialHelper{
credentialHelper: newOAuthHelper(),
mode: ModeAuto,
}
}

// NewOAuthCredentialHelperWithMode creates a credential helper that uses the
// specified storage mode. Use this when the caller knows whether the server
// is Desktop-catalog, CE, or Desktop-community.
func NewOAuthCredentialHelperWithMode(mode Mode) *CredentialHelper {
return &CredentialHelper{
credentialHelper: newOAuthHelper(),
mode: mode,
}
}

// resolveMode returns the effective Mode. When mode is ModeAuto, the
// runtime IsCEMode() check determines the backend (CE or Desktop). Explicit
// modes are returned as-is.
func (h *CredentialHelper) resolveMode() Mode {
if h.mode == ModeAuto {
if IsCEMode() {
return ModeCE
}
return ModeDesktop
}
return h.mode
}

// TokenStatus represents the validity status of an OAuth token
type TokenStatus struct {
Valid bool
ExpiresAt time.Time
NeedsRefresh bool
}

// GetOAuthToken retrieves an OAuth token for the specified server
// GetOAuthToken retrieves an OAuth token for the specified server.
// Routes to the appropriate storage backend based on the resolved mode:
// - CE: credential helper (base64-encoded JSON)
// - Desktop: Secrets Engine (raw access token)
// - Community: docker pass via Secrets Engine (base64-encoded JSON)
func (h *CredentialHelper) GetOAuthToken(ctx context.Context, serverName string) (string, error) {
if IsCEMode() {
switch h.resolveMode() {
case ModeCE:
return h.getOAuthTokenCE(serverName)
case ModeCommunity:
return h.getOAuthTokenDockerPass(ctx, serverName)
default:
return h.getOAuthTokenDesktop(ctx, serverName)
}
return h.getOAuthTokenDesktop(ctx, serverName)
}

// getOAuthTokenCE retrieves OAuth token in CE mode using credential helper.
Expand Down Expand Up @@ -114,26 +151,35 @@ func (h *CredentialHelper) getOAuthTokenDesktop(ctx context.Context, serverName
}

// TokenExists checks if an OAuth token exists for the specified server.
// This is the appropriate check for Desktop mode where Secrets Engine returns
// raw tokens without validity or expiry metadata - we can only verify existence.
// For CE mode, this also works but GetTokenStatus provides more detail.
// Routes to the appropriate storage backend based on the resolved mode.
func (h *CredentialHelper) TokenExists(ctx context.Context, serverName string) (bool, error) {
if IsCEMode() {
// CE mode: check credential helper
dcrMgr := dcr.NewManager(h.credentialHelper, "")
client, err := dcrMgr.GetDCRClient(serverName)
if err != nil {
return false, nil // No DCR client = no token
}
credentialKey := fmt.Sprintf("%s/%s", client.AuthorizationEndpoint, client.ProviderName)
_, tokenSecret, err := h.credentialHelper.Get(credentialKey)
if err != nil || tokenSecret == "" {
return false, nil
}
return true, nil
switch h.resolveMode() {
case ModeCE:
return h.tokenExistsCE(serverName)
case ModeCommunity:
return h.tokenExistsDockerPass(ctx, serverName)
default:
return h.tokenExistsDesktop(ctx, serverName)
}
}

// tokenExistsCE checks if a token exists in the credential helper (CE mode).
func (h *CredentialHelper) tokenExistsCE(serverName string) (bool, error) {
dcrMgr := dcr.NewManager(h.credentialHelper, "")
client, err := dcrMgr.GetDCRClient(serverName)
if err != nil {
return false, nil // No DCR client = no token
}
credentialKey := fmt.Sprintf("%s/%s", client.AuthorizationEndpoint, client.ProviderName)
_, tokenSecret, err := h.credentialHelper.Get(credentialKey)
if err != nil || tokenSecret == "" {
return false, nil
}
return true, nil
}

// Desktop mode: check Secrets Engine
// tokenExistsDesktop checks if a token exists in the Secrets Engine (Desktop mode).
func (h *CredentialHelper) tokenExistsDesktop(ctx context.Context, serverName string) (bool, error) {
oauthID := secret.GetOAuthKey(serverName)
env, err := secret.GetSecret(ctx, oauthID)
if errors.Is(err, secret.ErrSecretNotFound) {
Expand All @@ -146,14 +192,19 @@ func (h *CredentialHelper) TokenExists(ctx context.Context, serverName string) (
}

// GetTokenStatus checks token validity and expiry for refresh scheduling.
// Works in both CE and Desktop modes:
// - CE mode: reads token JSON from credential helper and parses expiry
// - Desktop mode: queries Secrets Engine and reads ExpiryAt from response metadata
// Routes to the appropriate storage backend based on the resolved mode:
// - CE: reads token JSON from credential helper, parses expiry
// - Desktop: queries Secrets Engine, reads ExpiryAt from response metadata
// - Community: reads token JSON from docker pass via Secrets Engine, parses expiry
func (h *CredentialHelper) GetTokenStatus(ctx context.Context, serverName string) (TokenStatus, error) {
if IsCEMode() {
switch h.resolveMode() {
case ModeCE:
return h.getTokenStatusCE(serverName)
case ModeCommunity:
return h.getTokenStatusDockerPass(ctx, serverName)
default:
return h.getTokenStatusDesktop(ctx, serverName)
}
return h.getTokenStatusDesktop(ctx, serverName)
}

// getTokenStatusDesktop retrieves token status in Desktop mode using Secrets Engine metadata.
Expand Down
Loading
Loading