Skip to content

Commit aa76e06

Browse files
Move profile loading into middleware
It's better if this happens before the client sends the initialized notification
1 parent e7e09f1 commit aa76e06

File tree

2 files changed

+85
-91
lines changed

2 files changed

+85
-91
lines changed

pkg/gateway/activateprofile.go

Lines changed: 51 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"github.com/docker/mcp-gateway/pkg/db"
1414
"github.com/docker/mcp-gateway/pkg/log"
1515
"github.com/docker/mcp-gateway/pkg/oci"
16-
"github.com/docker/mcp-gateway/pkg/workingset"
1716
)
1817

1918
// ActivateProfileResult contains the result of profile activation
@@ -25,46 +24,36 @@ type ActivateProfileResult struct {
2524

2625
// ActivateProfile activates a profile by name, loading its servers into the gateway
2726
func (g *Gateway) ActivateProfile(ctx context.Context, profileName string) error {
28-
// Load profile from database
27+
// Create database connection
2928
dao, err := db.New()
3029
if err != nil {
3130
return fmt.Errorf("failed to create database client: %w", err)
3231
}
3332
defer dao.Close()
3433

35-
dbWorkingSet, err := dao.GetWorkingSet(ctx, profileName)
36-
if err != nil {
37-
return fmt.Errorf("profile '%s' not found", profileName)
38-
}
34+
// Create a temporary WorkingSetConfiguration to load the profile
35+
wsConfig := NewWorkingSetConfiguration(
36+
Config{WorkingSet: profileName},
37+
oci.NewService(),
38+
g.docker,
39+
)
3940

40-
// Convert and resolve snapshots
41-
ws := workingset.NewFromDb(dbWorkingSet)
42-
43-
// Resolve server snapshots (OCI metadata)
44-
ociService := oci.NewService()
45-
if err := ws.EnsureSnapshotsResolved(ctx, ociService); err != nil {
46-
return fmt.Errorf("failed to resolve server snapshots: %w", err)
41+
// Load the full profile configuration using the existing readOnce method
42+
profileConfig, err := wsConfig.readOnce(ctx, dao)
43+
if err != nil {
44+
return fmt.Errorf("failed to load profile '%s': %w", profileName, err)
4745
}
4846

49-
// Filter servers: only process image and remote servers that are not already active
50-
var serversToActivate []workingset.Server
47+
// Filter servers: only activate servers that are not already active
48+
var serversToActivate []string
5149
var skippedServers []string
5250

53-
for _, server := range ws.Servers {
54-
// Skip registry servers (not supported for direct activation)
55-
if server.Type != workingset.ServerTypeImage && server.Type != workingset.ServerTypeRemote {
56-
continue
57-
}
58-
59-
serverName := server.Snapshot.Server.Name
60-
61-
// Skip servers that are already active
51+
for _, serverName := range profileConfig.serverNames {
6252
if slices.Contains(g.configuration.serverNames, serverName) {
6353
skippedServers = append(skippedServers, serverName)
64-
continue
54+
} else {
55+
serversToActivate = append(serversToActivate, serverName)
6556
}
66-
67-
serversToActivate = append(serversToActivate, server)
6857
}
6958

7059
// If no servers to activate, return early
@@ -80,52 +69,21 @@ func (g *Gateway) ActivateProfile(ctx context.Context, profileName string) error
8069
// Validate ALL servers before activating any (all-or-nothing)
8170
var validationErrors []serverValidation
8271

83-
for _, server := range serversToActivate {
84-
serverName := server.Snapshot.Server.Name
85-
serverConfig := server.Snapshot.Server
72+
for _, serverName := range serversToActivate {
73+
serverConfig := profileConfig.servers[serverName]
8674
validation := serverValidation{serverName: serverName}
8775

88-
// Temporarily add server to configuration to fetch updated secrets
89-
originalServerNames := slices.Clone(g.configuration.serverNames)
90-
g.configuration.serverNames = append(g.configuration.serverNames, serverName)
91-
92-
// Add server to servers map for secret resolution
93-
g.configuration.servers[serverName] = serverConfig
94-
95-
// Fetch updated secrets for validation
96-
if g.configurator != nil {
97-
updatedSecrets, err := g.configurator.readDockerDesktopSecrets(ctx, g.configuration.servers, g.configuration.serverNames)
98-
if err != nil {
99-
log.Log(fmt.Errorf("failed to read DockerDesktop secrets: %w", err))
100-
} else {
101-
g.configuration.secrets = updatedSecrets
102-
}
103-
}
104-
10576
// Check if all required secrets are set
10677
for _, secret := range serverConfig.Secrets {
107-
secretName := secret.Name
108-
// Handle namespaced secrets from profile
109-
if server.Secrets != "" {
110-
secretName = server.Secrets + "_" + secret.Name
111-
}
112-
113-
if value, exists := g.configuration.secrets[secretName]; !exists || value == "" {
78+
if value, exists := profileConfig.secrets[secret.Name]; !exists || value == "" {
11479
validation.missingSecrets = append(validation.missingSecrets, secret.Name)
11580
}
11681
}
11782

11883
// Check if all required config values are set and validate against schema
11984
if len(serverConfig.Config) > 0 {
120-
canonicalServerName := oci.CanonicalizeServerName(serverName)
121-
122-
// Get config from profile or existing configuration
123-
var serverConfigMap map[string]any
124-
if server.Config != nil {
125-
serverConfigMap = server.Config
126-
} else if g.configuration.config != nil {
127-
serverConfigMap = g.configuration.config[canonicalServerName]
128-
}
85+
// Get config from profile
86+
serverConfigMap := profileConfig.config[serverName]
12987

13088
for _, configItem := range serverConfig.Config {
13189
// Config items should be schema objects with a "name" property
@@ -188,10 +146,6 @@ func (g *Gateway) ActivateProfile(ctx context.Context, profileName string) error
188146
}
189147
}
190148

191-
// Restore original server names and servers map (rollback temporary changes)
192-
g.configuration.serverNames = originalServerNames
193-
delete(g.configuration.servers, serverName)
194-
195149
// Collect validation errors
196150
if len(validation.missingSecrets) > 0 || len(validation.missingConfig) > 0 || validation.imagePullError != nil {
197151
validationErrors = append(validationErrors, validation)
@@ -222,44 +176,56 @@ func (g *Gateway) ActivateProfile(ctx context.Context, profileName string) error
222176
return fmt.Errorf("%s", strings.Join(errorMessages, "\n"))
223177
}
224178

225-
// All validations passed - activate all servers
179+
// All validations passed - merge configuration into current gateway
226180
var activatedServers []string
227181

228-
for _, server := range serversToActivate {
229-
serverName := server.Snapshot.Server.Name
230-
serverConfig := server.Snapshot.Server
182+
// Merge secrets once (they're already namespaced in profileConfig)
183+
for secretName, secretValue := range profileConfig.secrets {
184+
g.configuration.secrets[secretName] = secretValue
185+
}
231186

232-
// Add server to configuration
187+
for _, serverName := range serversToActivate {
188+
// Add server name to the list
233189
g.configuration.serverNames = append(g.configuration.serverNames, serverName)
234-
g.configuration.servers[serverName] = serverConfig
235190

236-
// Add server config from profile
237-
if server.Config != nil {
191+
// Add server definition
192+
g.configuration.servers[serverName] = profileConfig.servers[serverName]
193+
194+
// Merge server config
195+
if profileConfig.config[serverName] != nil {
238196
if g.configuration.config == nil {
239197
g.configuration.config = make(map[string]map[string]any)
240198
}
241-
canonicalServerName := oci.CanonicalizeServerName(serverName)
242-
g.configuration.config[canonicalServerName] = server.Config
199+
g.configuration.config[serverName] = profileConfig.config[serverName]
243200
}
244201

245-
// Refresh secrets for the updated server list
246-
if g.configurator != nil {
247-
updatedSecrets, err := g.configurator.readDockerDesktopSecrets(ctx, g.configuration.servers, g.configuration.serverNames)
248-
if err == nil {
249-
g.configuration.secrets = updatedSecrets
250-
} else {
251-
log.Log("Warning: Failed to update secrets:", err)
202+
// Merge tools configuration
203+
if tools, exists := profileConfig.tools.ServerTools[serverName]; exists {
204+
if g.configuration.tools.ServerTools == nil {
205+
g.configuration.tools.ServerTools = make(map[string][]string)
252206
}
207+
g.configuration.tools.ServerTools[serverName] = tools
253208
}
254209

255210
// Reload server capabilities
256-
_, err := g.reloadServerCapabilities(ctx, serverName, nil)
211+
oldCaps, err := g.reloadServerCapabilities(ctx, serverName, nil)
257212
if err != nil {
258213
log.Log(fmt.Sprintf("Warning: Failed to reload capabilities for server '%s': %v", serverName, err))
259214
// Continue with other servers even if this one fails
260215
continue
261216
}
262217

218+
// Update g.mcpServer with the new capabilities
219+
g.capabilitiesMu.Lock()
220+
newCaps := g.allCapabilities(serverName)
221+
if err := g.updateServerCapabilities(serverName, oldCaps, newCaps, nil); err != nil {
222+
g.capabilitiesMu.Unlock()
223+
log.Log(fmt.Sprintf("Warning: Failed to update server capabilities for '%s': %v", serverName, err))
224+
// Continue with other servers even if this one fails
225+
continue
226+
}
227+
g.capabilitiesMu.Unlock()
228+
263229
activatedServers = append(activatedServers, serverName)
264230
}
265231

pkg/gateway/run.go

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,6 @@ func (g *Gateway) Run(ctx context.Context) error {
269269
log.Log(fmt.Sprintf("- Initialize request:\n %s", string(initJSON)))
270270
}
271271
}
272-
273-
// Load profiles from profiles.json if client is claude-code
274-
if g.UseProfiles {
275-
// LoadProfilesForClient handles Claude Code detection and profile loading
276-
_ = project.LoadProfilesForClient(ctx, clientInfo, g)
277-
}
278272
},
279273
HasPrompts: true,
280274
HasResources: true,
@@ -283,6 +277,12 @@ func (g *Gateway) Run(ctx context.Context) error {
283277

284278
// Add interceptor middleware to the server (includes telemetry)
285279
middlewares := interceptors.Callbacks(g.LogCalls, g.BlockSecrets, g.OAuthInterceptorEnabled, parsedInterceptors)
280+
281+
// Add profile loading middleware for initialize method
282+
if g.UseProfiles {
283+
middlewares = append(middlewares, g.profileLoadingMiddleware())
284+
}
285+
286286
if len(middlewares) > 0 {
287287
g.mcpServer.AddReceivingMiddleware(middlewares...)
288288
}
@@ -724,3 +724,31 @@ func (g *Gateway) ReloadConfiguration(ctx context.Context, configuration Configu
724724
func (g *Gateway) PullAndVerify(ctx context.Context, configuration Configuration) error {
725725
return g.pullAndVerify(ctx, configuration)
726726
}
727+
728+
// profileLoadingMiddleware creates middleware that loads profiles when the initialize method is received
729+
func (g *Gateway) profileLoadingMiddleware() mcp.Middleware {
730+
return func(next mcp.MethodHandler) mcp.MethodHandler {
731+
return func(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) {
732+
// Only handle initialize method
733+
if method != "initialize" {
734+
return next(ctx, method, req)
735+
}
736+
737+
// Call the next handler first
738+
result, err := next(ctx, method, req)
739+
740+
// After the initialize method has been handled, load profiles
741+
session := req.GetSession()
742+
if serverSession, ok := session.(*mcp.ServerSession); ok {
743+
initParams := serverSession.InitializeParams()
744+
if initParams != nil && initParams.ClientInfo != nil {
745+
// Load profiles from profiles.json if client is claude-code
746+
_ = project.LoadProfilesForClient(ctx, initParams.ClientInfo, g)
747+
}
748+
}
749+
750+
return result, err
751+
}
752+
}
753+
}
754+

0 commit comments

Comments
 (0)