@@ -17,15 +17,21 @@ import (
1717 mcpclient "github.com/docker/mcp-gateway/cmd/docker-mcp/internal/mcp"
1818)
1919
20+ type clientKey struct {
21+ serverName string
22+ session * mcp.ServerSession
23+ }
24+
2025type keptClient struct {
21- Name string
22- Getter * clientGetter
23- Config catalog.ServerConfig
26+ Name string
27+ Getter * clientGetter
28+ Config * catalog.ServerConfig
29+ ClientConfig * clientConfig
2430}
2531
2632type clientPool struct {
2733 Options
28- keptClients [ ]keptClient
34+ keptClients map [ clientKey ]keptClient
2935 clientLock sync.RWMutex
3036 networks []string
3137 docker docker.Client
@@ -41,20 +47,42 @@ func newClientPool(options Options, docker docker.Client) *clientPool {
4147 return & clientPool {
4248 Options : options ,
4349 docker : docker ,
44- keptClients : []keptClient {},
50+ keptClients : make (map [clientKey ]keptClient ),
51+ }
52+ }
53+
54+ func (cp * clientPool ) UpdateRoots (ss * mcp.ServerSession , roots []* mcp.Root ) {
55+ cp .clientLock .RLock ()
56+ defer cp .clientLock .RUnlock ()
57+
58+ for _ , kc := range cp .keptClients {
59+ if kc .ClientConfig != nil && (kc .ClientConfig .serverSession == ss ) {
60+ client , err := kc .Getter .GetClient (context .TODO ()) // should be cached
61+ if err == nil {
62+ client .AddRoots (roots )
63+ }
64+ }
4565 }
4666}
4767
48- func (cp * clientPool ) AcquireClient (ctx context.Context , serverConfig catalog.ServerConfig , config * clientConfig ) (mcpclient.Client , error ) {
68+ func (cp * clientPool ) longLived (serverConfig * catalog.ServerConfig , config * clientConfig ) bool {
69+ keep := config != nil && config .serverSession != nil && (serverConfig .Spec .LongLived || cp .LongLived )
70+ return keep
71+ }
72+
73+ func (cp * clientPool ) AcquireClient (ctx context.Context , serverConfig * catalog.ServerConfig , config * clientConfig ) (mcpclient.Client , error ) {
4974 var getter * clientGetter
75+ c := ctx
5076
5177 // Check if client is kept, can be returned immediately
78+ var session * mcp.ServerSession
79+ if config != nil {
80+ session = config .serverSession
81+ }
82+ key := clientKey {serverName : serverConfig .Name , session : session }
5283 cp .clientLock .RLock ()
53- for _ , kc := range cp .keptClients {
54- if kc .Name == serverConfig .Name {
55- getter = kc .Getter
56- break
57- }
84+ if kc , exists := cp .keptClients [key ]; exists {
85+ getter = kc .Getter
5886 }
5987 cp .clientLock .RUnlock ()
6088
@@ -63,30 +91,27 @@ func (cp *clientPool) AcquireClient(ctx context.Context, serverConfig catalog.Se
6391 getter = newClientGetter (serverConfig , cp , config )
6492
6593 // If the client is long running, save it for later
66- if serverConfig .Spec .LongLived || cp .LongLived {
94+ if cp .longLived (serverConfig , config ) {
95+ c = context .Background ()
6796 cp .clientLock .Lock ()
68- cp .keptClients = append (cp .keptClients , keptClient {
69- Name : serverConfig .Name ,
70- Getter : getter ,
71- Config : serverConfig ,
72- })
97+ cp .keptClients [key ] = keptClient {
98+ Name : serverConfig .Name ,
99+ Getter : getter ,
100+ Config : serverConfig ,
101+ ClientConfig : config ,
102+ }
73103 cp .clientLock .Unlock ()
74104 }
75105 }
76106
77- client , err := getter .GetClient (ctx ) // first time creates the client, can take some time
107+ client , err := getter .GetClient (c ) // first time creates the client, can take some time
78108 if err != nil {
79109 cp .clientLock .Lock ()
80110 defer cp .clientLock .Unlock ()
81111
82112 // Wasn't successful, remove it
83- if serverConfig .Spec .LongLived || cp .LongLived {
84- for i , kc := range cp .keptClients {
85- if kc .Getter == getter {
86- cp .keptClients = append (cp .keptClients [:i ], cp .keptClients [i + 1 :]... )
87- break
88- }
89- }
113+ if cp .longLived (serverConfig , config ) {
114+ delete (cp .keptClients , key )
90115 }
91116
92117 return nil , err
@@ -111,14 +136,12 @@ func (cp *clientPool) ReleaseClient(client mcpclient.Client) {
111136 client .Session ().Close ()
112137 return
113138 }
114-
115- // Otherwise, leave the client as is
116139}
117140
118141func (cp * clientPool ) Close () {
119142 cp .clientLock .Lock ()
120143 existingMap := cp .keptClients
121- cp .keptClients = [ ]keptClient {}
144+ cp .keptClients = make ( map [ clientKey ]keptClient )
122145 cp .clientLock .Unlock ()
123146
124147 // Close all clients
@@ -215,7 +238,7 @@ func (cp *clientPool) baseArgs(name string) []string {
215238 return args
216239}
217240
218- func (cp * clientPool ) argsAndEnv (serverConfig catalog.ServerConfig , readOnly * bool , targetConfig proxies.TargetConfig ) ([]string , []string ) {
241+ func (cp * clientPool ) argsAndEnv (serverConfig * catalog.ServerConfig , readOnly * bool , targetConfig proxies.TargetConfig ) ([]string , []string ) {
219242 args := cp .baseArgs (serverConfig .Name )
220243 var env []string
221244
@@ -308,13 +331,13 @@ type clientGetter struct {
308331 client mcpclient.Client
309332 err error
310333
311- serverConfig catalog.ServerConfig
334+ serverConfig * catalog.ServerConfig
312335 cp * clientPool
313336
314337 clientConfig * clientConfig
315338}
316339
317- func newClientGetter (serverConfig catalog.ServerConfig , cp * clientPool , config * clientConfig ) * clientGetter {
340+ func newClientGetter (serverConfig * catalog.ServerConfig , cp * clientPool , config * clientConfig ) * clientGetter {
318341 return & clientGetter {
319342 serverConfig : serverConfig ,
320343 cp : cp ,
@@ -388,6 +411,7 @@ func (cg *clientGetter) GetClient(ctx context.Context) (mcpclient.Client, error)
388411 // ctx, cancel := context.WithTimeout(ctx, 20*time.Second)
389412 // defer cancel()
390413
414+ // TODO add initial roots
391415 if err := client .Initialize (ctx , initParams , cg .cp .Verbose , ss , server ); err != nil {
392416 return nil , err
393417 }
0 commit comments