diff --git a/cmd/docker-mcp/oauth/revoke.go b/cmd/docker-mcp/oauth/revoke.go index e7c03f698..448b1ae9a 100644 --- a/cmd/docker-mcp/oauth/revoke.go +++ b/cmd/docker-mcp/oauth/revoke.go @@ -26,7 +26,7 @@ func Revoke(ctx context.Context, app string) error { func revokeDesktopMode(ctx context.Context, app string) error { client := desktop.NewAuthClient() - // Revoke tokens + // Revoke tokens via Docker Desktop if err := client.DeleteOAuthApp(ctx, app); err != nil { return fmt.Errorf("failed to revoke OAuth access: %w", err) } diff --git a/cmd/docker-mcp/server/enable.go b/cmd/docker-mcp/server/enable.go index 62cbd834c..08787c212 100644 --- a/cmd/docker-mcp/server/enable.go +++ b/cmd/docker-mcp/server/enable.go @@ -61,7 +61,7 @@ func update(ctx context.Context, docker docker.Client, dockerCli command.Cli, ad } // DCR flag enabled AND type="remote" AND oauth present - if mcpOAuthDcrEnabled && server.IsRemoteOAuthServer() { + if mcpOAuthDcrEnabled && server.HasExplicitOAuthProviders() { // In CE mode, skip lazy setup - DCR happens during oauth authorize if pkgoauth.IsCEMode() { fmt.Printf("OAuth server %s enabled. Run 'docker mcp oauth authorize %s' to authenticate\n", serverName, serverName) @@ -74,11 +74,20 @@ func update(ctx context.Context, docker docker.Client, dockerCli command.Cli, ad fmt.Printf("OAuth provider configured for %s - use 'docker mcp oauth authorize %s' to authenticate\n", serverName, serverName) } } - } else if !mcpOAuthDcrEnabled && server.IsRemoteOAuthServer() { + } else if !mcpOAuthDcrEnabled && server.HasExplicitOAuthProviders() { // Provide guidance when DCR is needed but disabled fmt.Printf("Server %s requires OAuth authentication but DCR is disabled.\n", serverName) fmt.Printf(" To enable automatic OAuth setup, run: docker mcp feature enable mcp-oauth-dcr\n") fmt.Printf(" Or set up OAuth manually using: docker mcp oauth authorize %s\n", serverName) + } else if mcpOAuthDcrEnabled && server.Type == "remote" && !server.IsOAuthServer() && server.Remote.URL != "" { + // Community server without oauth.providers — probe for OAuth + if pkgoauth.IsCEMode() { + fmt.Printf("Remote server %s enabled. Run 'docker mcp oauth authorize %s' if authentication is required\n", serverName, serverName) + } else { + if err := pkgoauth.RegisterProviderForDynamicDiscovery(ctx, serverName, server.Remote.URL); err != nil { + fmt.Printf("Warning: Dynamic OAuth discovery failed for %s: %v\n", serverName, err) + } + } } } else { return fmt.Errorf("server %s not found in catalog", serverName) diff --git a/pkg/catalog/types.go b/pkg/catalog/types.go index e08d7223f..9ee7768c6 100644 --- a/pkg/catalog/types.go +++ b/pkg/catalog/types.go @@ -63,7 +63,10 @@ func (s *Server) IsOAuthServer() bool { return s.OAuth != nil && len(s.OAuth.Providers) > 0 } -func (s *Server) IsRemoteOAuthServer() bool { +// HasExplicitOAuthProviders returns true if this is a remote server with +// explicit OAuth provider metadata in the catalog (e.g. oauth.providers YAML). +// Community servers that discover OAuth dynamically will return false here. +func (s *Server) HasExplicitOAuthProviders() bool { return s.Type == "remote" && s.IsOAuthServer() } diff --git a/pkg/gateway/clientpool.go b/pkg/gateway/clientpool.go index 66f97231d..b9afe3268 100644 --- a/pkg/gateway/clientpool.go +++ b/pkg/gateway/clientpool.go @@ -179,24 +179,23 @@ func (cp *clientPool) InvalidateOAuthClients(provider string) { var invalidatedKeys []clientKey for key, keptClient := range cp.keptClients { - // Check if this client uses OAuth for the specified provider - if keptClient.Config.Spec.OAuth != nil { - // Match by server name (for DCR providers, server name matches provider) - if keptClient.Config.Name == provider { - log.Log(fmt.Sprintf("ClientPool: Closing OAuth connection for server: %s", keptClient.Config.Name)) - - // Close the connection - client, err := keptClient.Getter.GetClient(context.TODO()) - if err == nil { - client.Session().Close() - log.Log(fmt.Sprintf("ClientPool: Successfully closed connection for %s", keptClient.Config.Name)) - } else { - log.Log(fmt.Sprintf("ClientPool: Warning - failed to get client for %s during invalidation: %v", keptClient.Config.Name, err)) - } - - // Mark for removal from kept clients - invalidatedKeys = append(invalidatedKeys, key) + // Check if this remote client matches the OAuth provider + // Matches both catalog servers (explicit OAuth metadata) and community servers + // (dynamic OAuth discovery via DCR without Spec.OAuth) + if keptClient.Config.Name == provider && keptClient.Config.IsRemote() { + log.Log(fmt.Sprintf("ClientPool: Closing OAuth connection for server: %s", keptClient.Config.Name)) + + // Close the connection + client, err := keptClient.Getter.GetClient(context.TODO()) + if err == nil { + client.Session().Close() + log.Log(fmt.Sprintf("ClientPool: Successfully closed connection for %s", keptClient.Config.Name)) + } else { + log.Log(fmt.Sprintf("ClientPool: Warning - failed to get client for %s during invalidation: %v", keptClient.Config.Name, err)) } + + // Mark for removal from kept clients + invalidatedKeys = append(invalidatedKeys, key) } } diff --git a/pkg/gateway/clientpool_test.go b/pkg/gateway/clientpool_test.go index ef1fbccd5..96a149752 100644 --- a/pkg/gateway/clientpool_test.go +++ b/pkg/gateway/clientpool_test.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "fmt" "os" "testing" "time" @@ -277,6 +278,197 @@ func parseConfig(t *testing.T, contentYAML string) map[string]any { return config } +func TestInvalidateOAuthClients_MatchesCommunityServer(t *testing.T) { + // Community server: remote URL set, but no Spec.OAuth metadata. + // This verifies Gap 3: InvalidateOAuthClients matches community servers + // that use dynamic OAuth discovery without explicit OAuth config. + cp := &clientPool{ + keptClients: make(map[clientKey]keptClient), + } + + getter := &clientGetter{} + getter.once.Do(func() {}) // mark as executed + getter.err = fmt.Errorf("mock: no real client") + + key := clientKey{serverName: "com-notion-mcp"} + cp.keptClients[key] = keptClient{ + Name: "com-notion-mcp", + Getter: getter, + Config: &catalog.ServerConfig{ + Name: "com-notion-mcp", + Spec: catalog.Server{ + Type: "remote", + Remote: catalog.Remote{ + URL: "https://mcp.notion.so/mcp", + Transport: "streamable-http", + }, + // No OAuth field - community server + }, + }, + } + + cp.InvalidateOAuthClients("com-notion-mcp") + + assert.Empty(t, cp.keptClients, "community server should be invalidated by name") +} + +func TestInvalidateOAuthClients_MatchesCatalogServer(t *testing.T) { + // Catalog server: remote URL set WITH Spec.OAuth metadata. + // Verifies backward compatibility: catalog servers still get invalidated. + cp := &clientPool{ + keptClients: make(map[clientKey]keptClient), + } + + getter := &clientGetter{} + getter.once.Do(func() {}) + getter.err = fmt.Errorf("mock: no real client") + + key := clientKey{serverName: "notion-remote"} + cp.keptClients[key] = keptClient{ + Name: "notion-remote", + Getter: getter, + Config: &catalog.ServerConfig{ + Name: "notion-remote", + Spec: catalog.Server{ + Type: "remote", + Remote: catalog.Remote{ + URL: "https://mcp.notion.so/mcp", + Transport: "streamable-http", + }, + OAuth: &catalog.OAuth{ + Providers: []catalog.OAuthProvider{{Provider: "notion"}}, + }, + }, + }, + } + + cp.InvalidateOAuthClients("notion-remote") + + assert.Empty(t, cp.keptClients, "catalog server should be invalidated by name") +} + +func TestInvalidateOAuthClients_SkipsNonRemoteServer(t *testing.T) { + // Docker container server: not remote, should NOT be invalidated. + cp := &clientPool{ + keptClients: make(map[clientKey]keptClient), + } + + getter := &clientGetter{} + getter.once.Do(func() {}) + getter.err = fmt.Errorf("mock: no real client") + + key := clientKey{serverName: "my-container-server"} + cp.keptClients[key] = keptClient{ + Name: "my-container-server", + Getter: getter, + Config: &catalog.ServerConfig{ + Name: "my-container-server", + Spec: catalog.Server{ + Type: "server", + Image: "mcp/my-server:latest", + // Not remote - no URL + }, + }, + } + + cp.InvalidateOAuthClients("my-container-server") + + assert.Len(t, cp.keptClients, 1, "non-remote server should NOT be invalidated") +} + +func TestInvalidateOAuthClients_SkipsMismatchedName(t *testing.T) { + // Remote server with different name: should NOT be invalidated. + cp := &clientPool{ + keptClients: make(map[clientKey]keptClient), + } + + getter := &clientGetter{} + getter.once.Do(func() {}) + getter.err = fmt.Errorf("mock: no real client") + + key := clientKey{serverName: "other-server"} + cp.keptClients[key] = keptClient{ + Name: "other-server", + Getter: getter, + Config: &catalog.ServerConfig{ + Name: "other-server", + Spec: catalog.Server{ + Type: "remote", + Remote: catalog.Remote{ + URL: "https://other.example.com/mcp", + }, + }, + }, + } + + cp.InvalidateOAuthClients("com-notion-mcp") + + assert.Len(t, cp.keptClients, 1, "server with different name should NOT be invalidated") +} + +func TestInvalidateOAuthClients_OnlyMatchingRemoved(t *testing.T) { + // Multiple clients: only the matching remote server should be removed. + cp := &clientPool{ + keptClients: make(map[clientKey]keptClient), + } + + makeGetter := func() *clientGetter { + g := &clientGetter{} + g.once.Do(func() {}) + g.err = fmt.Errorf("mock: no real client") + return g + } + + // Community OAuth server (should be invalidated) + cp.keptClients[clientKey{serverName: "com-notion-mcp"}] = keptClient{ + Name: "com-notion-mcp", + Getter: makeGetter(), + Config: &catalog.ServerConfig{ + Name: "com-notion-mcp", + Spec: catalog.Server{ + Type: "remote", + Remote: catalog.Remote{URL: "https://mcp.notion.so/mcp"}, + }, + }, + } + + // Different remote server (should NOT be invalidated) + cp.keptClients[clientKey{serverName: "github-remote"}] = keptClient{ + Name: "github-remote", + Getter: makeGetter(), + Config: &catalog.ServerConfig{ + Name: "github-remote", + Spec: catalog.Server{ + Type: "remote", + Remote: catalog.Remote{URL: "https://mcp.github.com/mcp"}, + }, + }, + } + + // Docker container server (should NOT be invalidated) + cp.keptClients[clientKey{serverName: "local-server"}] = keptClient{ + Name: "local-server", + Getter: makeGetter(), + Config: &catalog.ServerConfig{ + Name: "local-server", + Spec: catalog.Server{ + Type: "server", + Image: "mcp/local:latest", + }, + }, + } + + cp.InvalidateOAuthClients("com-notion-mcp") + + assert.Len(t, cp.keptClients, 2, "only the matching remote server should be removed") + _, hasNotion := cp.keptClients[clientKey{serverName: "com-notion-mcp"}] + assert.False(t, hasNotion, "com-notion-mcp should have been removed") + _, hasGithub := cp.keptClients[clientKey{serverName: "github-remote"}] + assert.True(t, hasGithub, "github-remote should remain") + _, hasLocal := cp.keptClients[clientKey{serverName: "local-server"}] + assert.True(t, hasLocal, "local-server should remain") +} + func TestStdioClientInitialization(t *testing.T) { // This is an integration test that requires Docker if testing.Short() { diff --git a/pkg/gateway/mcpadd.go b/pkg/gateway/mcpadd.go index e89f26490..a1eacbd3b 100644 --- a/pkg/gateway/mcpadd.go +++ b/pkg/gateway/mcpadd.go @@ -286,10 +286,12 @@ func addServerHandler(g *Gateway, clientConfig *clientConfig) mcp.ToolHandler { } } - // Handle OAuth DCR only when the client supports elicitation (e.g. not stdio-based clients) + // Handle OAuth DCR for any remote server — covers both catalog servers + // (explicit OAuth metadata) and community servers (dynamic discovery). + // getRemoteOAuthServerStatus handles the case where OAuth is not needed. if g.McpOAuthDcrEnabled && serverConfig != nil && - serverConfig.Spec.IsRemoteOAuthServer() { + serverConfig.IsRemote() { init := req.Session.InitializeParams() if init != nil && @@ -444,7 +446,25 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str if !providerExists { // Register DCR client with DD so user can authorize if err := oauth.RegisterProviderForLazySetup(ctx, serverName); err != nil { - log.Logf("Warning: Failed to register OAuth provider for %s: %v", serverName, err) + // Fallback: try dynamic discovery for community servers without oauth.providers + if serverConfig, _, found := g.configuration.Find(serverName); found && serverConfig.Spec.Remote.URL != "" { + if err := oauth.RegisterProviderForDynamicDiscovery(ctx, serverName, serverConfig.Spec.Remote.URL); err != nil { + log.Logf("Warning: Failed to register OAuth provider for %s: %v", serverName, err) + } + } else { + log.Logf("Warning: Failed to register OAuth provider for %s: %v", serverName, err) + } + } + + // Verify DCR entry was created — dynamic discovery may have found no OAuth requirement. + // Distinguish "not found" (server doesn't need OAuth) from transient API errors. + authClient := desktop.NewAuthClient() + if _, err := authClient.GetDCRClient(ctx, serverName); err != nil { + if strings.Contains(err.Error(), "HTTP 404") { + return true, "" // Server doesn't require OAuth + } + log.Logf("Warning: Failed to verify DCR entry for %s (may be transient): %v", serverName, err) + return true, "" // Fail open — avoid blocking the add flow on transient errors } // Start provider (CE mode only - Desktop mode doesn't need polling) diff --git a/pkg/gateway/run.go b/pkg/gateway/run.go index fe53c344b..6a4776dea 100644 --- a/pkg/gateway/run.go +++ b/pkg/gateway/run.go @@ -352,13 +352,25 @@ func (g *Gateway) Run(ctx context.Context) error { // Start OAuth provider for each OAuth server. // Each provider runs in its own goroutine with dynamic timing based on token expiry. log.Log("- Starting OAuth provider loops...") + credHelper := oauth.NewOAuthCredentialHelper() for _, serverName := range configuration.ServerNames() { serverConfig, _, found := configuration.Find(serverName) - if !found || serverConfig == nil || !serverConfig.Spec.IsRemoteOAuthServer() { + if !found || serverConfig == nil { continue } - g.startProvider(ctx, serverName) + if serverConfig.Spec.HasExplicitOAuthProviders() { + g.startProvider(ctx, serverName) + } else if serverConfig.IsRemote() { + // Community servers: start provider if they have a stored OAuth token + // from dynamic discovery (DCR without explicit OAuth metadata) + if exists, err := credHelper.TokenExists(ctx, serverName); err != nil { + log.Logf("Warning: Failed to check OAuth token for %s: %v", serverName, err) + } else if exists { + log.Logf("- Starting OAuth provider for community server: %s", serverName) + g.startProvider(ctx, serverName) + } + } } } @@ -697,7 +709,9 @@ func (g *Gateway) routeEventToProvider(event oauth.Event) { g.clientPool.InvalidateOAuthClients(event.Provider) case oauth.EventLogoutSuccess: - // User logged out - stop provider if exists + // Invalidate cached OAuth client connections (clear stale bearer tokens) + g.clientPool.InvalidateOAuthClients(event.Provider) + // Stop provider if exists if exists { log.Logf("- Stopping provider for %s after logout", event.Provider) g.stopProvider(event.Provider) diff --git a/pkg/mcp/remote.go b/pkg/mcp/remote.go index 9931c1c2f..2bf706187 100644 --- a/pkg/mcp/remote.go +++ b/pkg/mcp/remote.go @@ -93,6 +93,17 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeParam } else if token != "" { headers["Authorization"] = "Bearer " + token } + } else if c.config.Spec.Remote.URL != "" { + // Community servers may have OAuth tokens via dynamic discovery (DCR) + // without explicit OAuth metadata in the catalog. Try to get a stored token. + credHelper := oauth.NewOAuthCredentialHelper() + token, err := credHelper.GetOAuthToken(ctx, c.config.Name) + if err == nil && token != "" { + if verbose { + log.Logf(" - Using dynamic OAuth token for: %s", c.config.Name) + } + headers["Authorization"] = "Bearer " + token + } } var mcpTransport mcp.Transport diff --git a/pkg/mcp/remote_test.go b/pkg/mcp/remote_test.go new file mode 100644 index 000000000..53bdc2467 --- /dev/null +++ b/pkg/mcp/remote_test.go @@ -0,0 +1,144 @@ +package mcp + +import ( + "context" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// roundTripFunc is an adapter to use functions as http.RoundTripper. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestHeaderRoundTripper_AttachesAuthorizationHeader(t *testing.T) { + // Verifies that headerRoundTripper propagates Authorization headers to requests. + // This is the mechanism through which OAuth tokens (both catalog and dynamic) reach + // the remote MCP server. + var capturedReq *http.Request + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + rt := &headerRoundTripper{ + base: base, + headers: map[string]string{ + "Authorization": "Bearer test-oauth-token", + }, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com/mcp", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + require.NotNil(t, capturedReq) + assert.Equal(t, "Bearer test-oauth-token", capturedReq.Header.Get("Authorization")) +} + +func TestHeaderRoundTripper_DoesNotOverrideExistingAccept(t *testing.T) { + var capturedReq *http.Request + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + rt := &headerRoundTripper{ + base: base, + headers: map[string]string{ + "Accept": "application/json", + "Authorization": "Bearer token", + }, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com/mcp", nil) + require.NoError(t, err) + req.Header.Set("Accept", "text/event-stream") + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + require.NotNil(t, capturedReq) + assert.Equal(t, "text/event-stream", capturedReq.Header.Get("Accept"), + "Accept should not be overridden when already set") + assert.Equal(t, "Bearer token", capturedReq.Header.Get("Authorization"), + "Authorization should still be set") +} + +func TestHeaderRoundTripper_DoesNotMutateOriginalRequest(t *testing.T) { + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + rt := &headerRoundTripper{ + base: base, + headers: map[string]string{ + "Authorization": "Bearer token", + }, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com/mcp", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + assert.Empty(t, req.Header.Get("Authorization"), + "original request should not be mutated") +} + +func TestHeaderRoundTripper_MultipleCustomHeaders(t *testing.T) { + var capturedReq *http.Request + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + rt := &headerRoundTripper{ + base: base, + headers: map[string]string{ + "Authorization": "Bearer dynamic-oauth-token", + "X-Custom": "custom-value", + }, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com/mcp", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + require.NotNil(t, capturedReq) + assert.Equal(t, "Bearer dynamic-oauth-token", capturedReq.Header.Get("Authorization")) + assert.Equal(t, "custom-value", capturedReq.Header.Get("X-Custom")) +} + +func TestHeaderRoundTripper_EmptyHeaders(t *testing.T) { + var capturedReq *http.Request + base := roundTripFunc(func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + rt := &headerRoundTripper{ + base: base, + headers: map[string]string{}, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com/mcp", nil) + require.NoError(t, err) + + _, err = rt.RoundTrip(req) + require.NoError(t, err) + + require.NotNil(t, capturedReq) + assert.Empty(t, capturedReq.Header.Get("Authorization"), + "no Authorization header when headers map is empty") +} diff --git a/pkg/oauth/dcr_registration.go b/pkg/oauth/dcr_registration.go index c6039dd2f..d7070cb8b 100644 --- a/pkg/oauth/dcr_registration.go +++ b/pkg/oauth/dcr_registration.go @@ -4,10 +4,32 @@ import ( "context" "fmt" + oauthhelpers "github.com/docker/mcp-gateway-oauth-helpers" + "github.com/docker/mcp-gateway/pkg/catalog" "github.com/docker/mcp-gateway/pkg/desktop" + "github.com/docker/mcp-gateway/pkg/log" ) +// dcrRegistrationClient is the subset of desktop.Tools used for DCR registration. +// Extracted as an interface to enable testing. +type dcrRegistrationClient interface { + GetDCRClient(ctx context.Context, app string) (*desktop.DCRClient, error) + RegisterDCRClientPending(ctx context.Context, app string, req desktop.RegisterDCRRequest) error +} + +// oauthProber abstracts OAuth discovery to enable testing. +type oauthProber interface { + DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*oauthhelpers.Discovery, error) +} + +// defaultOAuthProber wraps the package-level function. +type defaultOAuthProber struct{} + +func (defaultOAuthProber) DiscoverOAuthRequirements(ctx context.Context, serverURL string) (*oauthhelpers.Discovery, error) { + return oauthhelpers.DiscoverOAuthRequirements(ctx, serverURL) +} + // RegisterProviderForLazySetup registers a DCR provider with Docker Desktop // This allows 'docker mcp oauth authorize' to work before full DCR is complete // Idempotent - safe to call multiple times for the same server @@ -32,7 +54,7 @@ func RegisterProviderForLazySetup(ctx context.Context, serverName string) error } // Verify this is a remote OAuth server (Type="remote" && OAuth providers exist) - if !server.IsRemoteOAuthServer() { + if !server.HasExplicitOAuthProviders() { return fmt.Errorf("server %s is not a remote OAuth server", serverName) } @@ -46,6 +68,38 @@ func RegisterProviderForLazySetup(ctx context.Context, serverName string) error return client.RegisterDCRClientPending(ctx, serverName, dcrRequest) } +// RegisterProviderForDynamicDiscovery probes a remote server for OAuth support +// and creates a pending DCR entry if the server requires OAuth. +// This is used for community servers that lack oauth.providers metadata in the catalog. +// Idempotent - safe to call multiple times for the same server. +func RegisterProviderForDynamicDiscovery(ctx context.Context, serverName, serverURL string) error { + return registerProviderForDynamicDiscovery(ctx, serverName, serverURL, desktop.NewAuthClient(), defaultOAuthProber{}) +} + +func registerProviderForDynamicDiscovery(ctx context.Context, serverName, serverURL string, client dcrRegistrationClient, prober oauthProber) error { + // Idempotent check - already registered? + _, err := client.GetDCRClient(ctx, serverName) + if err == nil { + return nil // Already registered + } + + // Probe the server to discover OAuth requirements. + // The discovery library uses its own 30s HTTP timeout internally. + discovery, err := prober.DiscoverOAuthRequirements(ctx, serverURL) + if err != nil { + log.Logf("Dynamic OAuth discovery failed for %s: %v", serverName, err) + return nil // Probe failed, not fatal + } + if discovery == nil || !discovery.RequiresOAuth { + return nil // Server doesn't need OAuth + } + + // Register with DD (pending DCR state) using server name as provider name + return client.RegisterDCRClientPending(ctx, serverName, desktop.RegisterDCRRequest{ + ProviderName: serverName, + }) +} + // RegisterProviderWithSnapshot registers a DCR provider using OAuth metadata from the server snapshot // This avoids querying the catalog since the snapshot already contains all necessary OAuth information // Idempotent - safe to call multiple times for the same server diff --git a/pkg/oauth/dcr_registration_test.go b/pkg/oauth/dcr_registration_test.go new file mode 100644 index 000000000..a1d8ae986 --- /dev/null +++ b/pkg/oauth/dcr_registration_test.go @@ -0,0 +1,108 @@ +package oauth + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + oauthhelpers "github.com/docker/mcp-gateway-oauth-helpers" + "github.com/docker/mcp-gateway/pkg/desktop" +) + +// mockDCRClient implements dcrRegistrationClient for testing. +type mockDCRClient struct { + clients map[string]*desktop.DCRClient + registered map[string]desktop.RegisterDCRRequest +} + +func newMockDCRClient() *mockDCRClient { + return &mockDCRClient{ + clients: make(map[string]*desktop.DCRClient), + registered: make(map[string]desktop.RegisterDCRRequest), + } +} + +func (m *mockDCRClient) GetDCRClient(_ context.Context, app string) (*desktop.DCRClient, error) { + c, ok := m.clients[app] + if !ok { + return nil, errors.New("not found") + } + return c, nil +} + +func (m *mockDCRClient) RegisterDCRClientPending(_ context.Context, app string, req desktop.RegisterDCRRequest) error { + m.registered[app] = req + return nil +} + +// mockProber implements oauthProber for testing. +type mockProber struct { + discovery *oauthhelpers.Discovery + err error +} + +func (m *mockProber) DiscoverOAuthRequirements(_ context.Context, _ string) (*oauthhelpers.Discovery, error) { + return m.discovery, m.err +} + +func TestRegisterProviderForDynamicDiscovery_SkipsAlreadyRegistered(t *testing.T) { + client := newMockDCRClient() + client.clients["my-server"] = &desktop.DCRClient{State: "unregistered"} + + prober := &mockProber{} // should not be called + + err := registerProviderForDynamicDiscovery(t.Context(), "my-server", "https://example.com/mcp", client, prober) + require.NoError(t, err) + assert.Empty(t, client.registered, "should not register when already exists") +} + +func TestRegisterProviderForDynamicDiscovery_RegistersWhenOAuthRequired(t *testing.T) { + client := newMockDCRClient() + prober := &mockProber{ + discovery: &oauthhelpers.Discovery{RequiresOAuth: true}, + } + + err := registerProviderForDynamicDiscovery(t.Context(), "ai-kubit-mcp-server", "https://mcp.kubit.ai/mcp", client, prober) + require.NoError(t, err) + + req, ok := client.registered["ai-kubit-mcp-server"] + require.True(t, ok, "should have registered DCR client") + assert.Equal(t, "ai-kubit-mcp-server", req.ProviderName, "provider name should match server name") +} + +func TestRegisterProviderForDynamicDiscovery_SkipsWhenNoOAuthRequired(t *testing.T) { + client := newMockDCRClient() + prober := &mockProber{ + discovery: &oauthhelpers.Discovery{RequiresOAuth: false}, + } + + err := registerProviderForDynamicDiscovery(t.Context(), "my-server", "https://example.com/mcp", client, prober) + require.NoError(t, err) + assert.Empty(t, client.registered, "should not register when OAuth not required") +} + +func TestRegisterProviderForDynamicDiscovery_SkipsOnNilDiscovery(t *testing.T) { + client := newMockDCRClient() + prober := &mockProber{ + discovery: nil, + err: nil, + } + + err := registerProviderForDynamicDiscovery(t.Context(), "my-server", "https://example.com/mcp", client, prober) + require.NoError(t, err) + assert.Empty(t, client.registered, "should not register when discovery is nil") +} + +func TestRegisterProviderForDynamicDiscovery_SkipsOnProbeError(t *testing.T) { + client := newMockDCRClient() + prober := &mockProber{ + err: errors.New("connection refused"), + } + + err := registerProviderForDynamicDiscovery(t.Context(), "my-server", "https://unreachable.example.com/mcp", client, prober) + require.NoError(t, err) + assert.Empty(t, client.registered, "should not register when probe fails") +} diff --git a/pkg/oauth/provider.go b/pkg/oauth/provider.go index 92e28547b..1256edf51 100644 --- a/pkg/oauth/provider.go +++ b/pkg/oauth/provider.go @@ -152,10 +152,14 @@ func (p *Provider) Run(ctx context.Context) { // Trigger refresh if needed if shouldTriggerRefresh { if IsCEMode() { - // CE mode: Refresh token directly + // CE mode: Refresh token directly, then reload server connection go func() { if err := p.refreshTokenCE(); err != nil { log.Logf("! Token refresh failed for %s: %v", p.name, err) + return + } + if err := p.reloadFn(ctx, p.name); err != nil { + log.Logf("! Failed to reload %s after token refresh: %v", p.name, err) } }() } else { diff --git a/pkg/workingset/oauth.go b/pkg/workingset/oauth.go index fa3f37dd0..d019e0e6b 100644 --- a/pkg/workingset/oauth.go +++ b/pkg/workingset/oauth.go @@ -24,15 +24,20 @@ func RegisterOAuthProvidersForServers(ctx context.Context, servers []Server) { if server.Snapshot == nil { continue } - if !server.Snapshot.Server.IsRemoteOAuthServer() { - continue - } - - serverName := server.Snapshot.Server.Name - providerName := server.Snapshot.Server.OAuth.Providers[0].Provider + if server.Snapshot.Server.HasExplicitOAuthProviders() { + serverName := server.Snapshot.Server.Name + providerName := server.Snapshot.Server.OAuth.Providers[0].Provider - if err := oauth.RegisterProviderWithSnapshot(ctx, serverName, providerName); err != nil { - log.Log(fmt.Sprintf("Warning: Failed to register OAuth provider for %s: %v", serverName, err)) + if err := oauth.RegisterProviderWithSnapshot(ctx, serverName, providerName); err != nil { + log.Log(fmt.Sprintf("Warning: Failed to register OAuth provider for %s: %v", serverName, err)) + } + } else if server.Snapshot.Server.Type == "remote" && server.Snapshot.Server.Remote.URL != "" { + // Community servers without oauth.providers: probe for OAuth dynamically + serverName := server.Snapshot.Server.Name + serverURL := server.Snapshot.Server.Remote.URL + if err := oauth.RegisterProviderForDynamicDiscovery(ctx, serverName, serverURL); err != nil { + log.Log(fmt.Sprintf("Warning: Failed dynamic OAuth discovery for %s: %v", serverName, err)) + } } } } diff --git a/pkg/workingset/server.go b/pkg/workingset/server.go index 437cad066..cd9dd2eb7 100644 --- a/pkg/workingset/server.go +++ b/pkg/workingset/server.go @@ -167,7 +167,8 @@ type dcrClient interface { } // CleanupOrphanedDCREntries removes DCR entries for servers that no longer -// exist in any profile. This prevents stale OAuth entries from accumulating. +// exist in any profile and are not authorized. This prevents stale OAuth +// entries from accumulating. func CleanupOrphanedDCREntries(ctx context.Context, dao db.DAO, serverNames []string) { if oauth.IsCEMode() { return