-
Notifications
You must be signed in to change notification settings - Fork 219
Expand file tree
/
Copy pathrunner.go
More file actions
948 lines (824 loc) · 35.7 KB
/
runner.go
File metadata and controls
948 lines (824 loc) · 35.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
// SPDX-FileCopyrightText: Copyright 2026 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0
// Package runner provides functionality for running MCP servers
package runner
import (
"bytes"
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"strings"
"time"
"golang.org/x/oauth2"
"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/auth/remote"
authsecrets "github.com/stacklok/toolhive/pkg/auth/secrets"
"github.com/stacklok/toolhive/pkg/auth/upstreamtoken"
authserverrunner "github.com/stacklok/toolhive/pkg/authserver/runner"
"github.com/stacklok/toolhive/pkg/authserver/server/keys"
"github.com/stacklok/toolhive/pkg/client"
"github.com/stacklok/toolhive/pkg/config"
ct "github.com/stacklok/toolhive/pkg/container"
rt "github.com/stacklok/toolhive/pkg/container/runtime"
"github.com/stacklok/toolhive/pkg/labels"
"github.com/stacklok/toolhive/pkg/process"
"github.com/stacklok/toolhive/pkg/runtime"
"github.com/stacklok/toolhive/pkg/secrets"
"github.com/stacklok/toolhive/pkg/telemetry"
"github.com/stacklok/toolhive/pkg/transport"
"github.com/stacklok/toolhive/pkg/transport/session"
"github.com/stacklok/toolhive/pkg/transport/types"
"github.com/stacklok/toolhive/pkg/workloads/statuses"
)
// ErrContainerExitedRestartNeeded is returned when a container exits and needs to be restarted
var ErrContainerExitedRestartNeeded = errors.New("container exited, restart needed")
// Runner is responsible for running an MCP server with the provided configuration
type Runner struct {
// Config is the configuration for the runner
Config *RunConfig
// telemetryProvider is the OpenTelemetry provider for cleanup
telemetryProvider *telemetry.Provider
// supportedMiddleware is a map of supported middleware types to their factory functions.
supportedMiddleware map[string]types.MiddlewareFactory
// middlewares is a slice of created middleware instances for cleanup
middlewares []types.Middleware
// namedMiddlewares is a slice of named middleware to apply to the transport
namedMiddlewares []types.NamedMiddleware
// authInfoHandler is the authentication info handler set by auth middleware
authInfoHandler http.Handler
// prometheusHandler is the Prometheus metrics handler set by telemetry middleware
prometheusHandler http.Handler
statusManager statuses.StatusManager
// authenticatedTokenSource is the wrapped token source for remote workloads with authentication monitoring
authenticatedTokenSource *auth.MonitoredTokenSource
// monitoringCtx is the context for background authentication monitoring
// It is cancelled during Cleanup() to stop monitoring
monitoringCtx context.Context
monitoringCancel context.CancelFunc
// embeddedAuthServer is the embedded OAuth/OIDC authorization server.
// Only initialized when Config.EmbeddedAuthServerConfig is set.
embeddedAuthServer *authserverrunner.EmbeddedAuthServer
// upstreamTokenReader provides read-only access to upstream tokens for
// identity enrichment in auth middleware. Set when the embedded auth
// server is initialized in Run().
// Nil when no embedded auth server is configured.
upstreamTokenReader upstreamtoken.TokenReader
// keyProvider provides in-process JWKS key lookups from the embedded
// auth server, eliminating self-referential HTTP calls.
// Nil when no embedded auth server is configured.
keyProvider keys.PublicKeyProvider
}
// statusManagerAdapter adapts statuses.StatusManager to auth.StatusUpdater interface
type statusManagerAdapter struct {
sm statuses.StatusManager
}
func (a *statusManagerAdapter) SetWorkloadStatus(
ctx context.Context,
workloadName string,
status rt.WorkloadStatus,
reason string,
) error {
slog.Debug("setting workload status", "workload", workloadName, "status", status, "reason", reason)
return a.sm.SetWorkloadStatus(ctx, workloadName, status, reason)
}
// NewRunner creates a new Runner with the provided configuration
func NewRunner(runConfig *RunConfig, statusManager statuses.StatusManager) *Runner {
return &Runner{
Config: runConfig,
statusManager: statusManager,
supportedMiddleware: GetSupportedMiddlewareFactories(),
}
}
// AddMiddleware adds a middleware instance and its function to the runner with a name
func (r *Runner) AddMiddleware(name string, middleware types.Middleware) {
r.middlewares = append(r.middlewares, middleware)
r.namedMiddlewares = append(r.namedMiddlewares, types.NamedMiddleware{
Name: name,
Function: middleware.Handler(),
})
}
// SetAuthInfoHandler sets the authentication info handler
func (r *Runner) SetAuthInfoHandler(handler http.Handler) {
r.authInfoHandler = handler
}
// SetPrometheusHandler sets the Prometheus metrics handler
func (r *Runner) SetPrometheusHandler(handler http.Handler) {
r.prometheusHandler = handler
}
// GetConfig returns a config interface for middleware to access runner configuration
func (r *Runner) GetConfig() types.RunnerConfig {
return r.Config
}
// GetUpstreamTokenReader returns the UpstreamTokenReader for identity
// enrichment in the auth middleware. Returns nil if no embedded auth
// server is configured.
func (r *Runner) GetUpstreamTokenReader() upstreamtoken.TokenReader {
return r.upstreamTokenReader
}
// GetKeyProvider returns the embedded auth server's public key provider
// for in-process JWKS key lookups. Returns nil if no embedded auth server
// is configured.
func (r *Runner) GetKeyProvider() keys.PublicKeyProvider {
return r.keyProvider
}
// GetName returns the name of the mcp-service from the runner config (implements types.RunnerConfig)
func (c *RunConfig) GetName() string {
return c.Name
}
// GetPort returns the port from the runner config (implements types.RunnerConfig)
func (c *RunConfig) GetPort() int {
return c.Port
}
// Run runs the MCP server with the provided configuration
//
//nolint:gocyclo // This function is complex but manageable
func (r *Runner) Run(ctx context.Context) error {
// Resolve session TTL once so both the transport proxy and Redis storage use
// the same effective value, rather than each applying their own zero-fallback
// independently.
effectiveSessionTTL := r.Config.SessionTTL
if effectiveSessionTTL <= 0 {
effectiveSessionTTL = session.DefaultSessionTTL
}
// Create transport with runtime
transportConfig := types.Config{
Type: r.Config.Transport,
ProxyPort: r.Config.Port,
TargetPort: r.Config.TargetPort,
Host: r.Config.Host,
TargetHost: r.Config.TargetHost,
Deployer: r.Config.Deployer,
Debug: r.Config.Debug,
TrustProxyHeaders: r.Config.TrustProxyHeaders,
EndpointPrefix: r.Config.EndpointPrefix,
SessionTTL: effectiveSessionTTL,
}
// Set proxy mode for stdio transport
transportConfig.ProxyMode = r.Config.ProxyMode
// Process secrets before middleware population so that resolved values
// (e.g., header forward secrets) are available to middleware factories.
hasRegularSecrets := len(r.Config.Secrets) > 0
hasRemoteAuthSecret := r.Config.RemoteAuthConfig != nil &&
(r.Config.RemoteAuthConfig.ClientSecret != "" || r.Config.RemoteAuthConfig.BearerToken != "")
hasHeaderForwardSecrets := r.Config.HeaderForward != nil && len(r.Config.HeaderForward.AddHeadersFromSecret) > 0
slog.Debug("secret processing check",
"has_regular_secrets", hasRegularSecrets,
"has_remote_auth_secret", hasRemoteAuthSecret,
"has_header_forward_secrets", hasHeaderForwardSecrets)
if hasRemoteAuthSecret {
if r.Config.RemoteAuthConfig.ClientSecret != "" {
slog.Debug("remote auth config has client secret configured")
}
if r.Config.RemoteAuthConfig.BearerToken != "" {
slog.Debug("remote auth config has bearer token configured")
}
}
if hasRegularSecrets || hasRemoteAuthSecret || hasHeaderForwardSecrets {
slog.Debug("calling WithSecrets to process secrets")
cfgprovider := config.NewDefaultProvider()
cfg := cfgprovider.GetConfig()
providerType, err := cfg.Secrets.GetProviderType()
if err != nil {
return fmt.Errorf("error determining secrets provider type: %w", err)
}
systemProvider, err := secrets.CreateProvider(providerType, secrets.WithScope(secrets.ScopeWorkloads))
if err != nil {
return fmt.Errorf("error instantiating system secret manager: %w", err)
}
userProvider, err := secrets.CreateProvider(providerType, secrets.WithUserFacing())
if err != nil {
return fmt.Errorf("error instantiating user secret manager: %w", err)
}
// Process secrets (including RemoteAuthConfig and header forward secret resolution)
if _, err = r.Config.WithSecrets(ctx, systemProvider, userProvider); err != nil {
return err
}
}
// Populate default middlewares from config fields if not already populated.
// This runs after WithSecrets so resolved values are available.
if len(r.Config.MiddlewareConfigs) == 0 {
if err := PopulateMiddlewareConfigs(r.Config); err != nil {
return fmt.Errorf("failed to populate middleware configs: %w", err)
}
} else {
// MiddlewareConfigs was pre-populated (e.g., by WithMiddlewareFromFlags).
// Header forward is appended here (consistent with PopulateMiddlewareConfigs
// which also places it at the end) after secret resolution, because
// secret-backed header values are not available at builder time.
var err error
r.Config.MiddlewareConfigs, err = addHeaderForwardMiddleware(r.Config.MiddlewareConfigs, r.Config)
if err != nil {
return fmt.Errorf("failed to add header forward middleware: %w", err)
}
}
// Initialize embedded auth server if configured.
// This must happen before middleware creation so that the upstream token
// service is available to middleware factories (e.g., upstreamswap).
if r.Config.EmbeddedAuthServerConfig != nil {
// Proxy runner supports only single-upstream configs; multi-upstream
// requires VirtualMCPServer.
if len(r.Config.EmbeddedAuthServerConfig.Upstreams) > 1 {
return fmt.Errorf(
"proxy runner does not support multiple upstream providers (found %d); "+
"use VirtualMCPServer for multi-upstream deployments",
len(r.Config.EmbeddedAuthServerConfig.Upstreams),
)
}
var err error
r.embeddedAuthServer, err = authserverrunner.NewEmbeddedAuthServer(ctx, r.Config.EmbeddedAuthServerConfig)
if err != nil {
return fmt.Errorf("failed to create embedded auth server: %w", err)
}
slog.Debug("embedded authorization server initialized")
// Create the upstream token service eagerly now that the auth server exists.
// IDPTokenStorage is guaranteed non-nil after successful construction.
// UpstreamTokenRefresher may be nil if no upstream IDP is configured;
// InProcessService handles this gracefully (returns ErrNoRefreshToken).
stor := r.embeddedAuthServer.IDPTokenStorage()
refresher := r.embeddedAuthServer.UpstreamTokenRefresher()
r.upstreamTokenReader = upstreamtoken.NewInProcessService(stor, refresher)
// Expose key provider for in-process JWKS lookups (avoids self-referential HTTP)
r.keyProvider = r.embeddedAuthServer.KeyProvider()
// Mount auth server routes at specific prefixes to avoid conflicts with MCP endpoints
// (e.g., /.well-known/oauth-protected-resource is an MCP endpoint, not auth server)
transportConfig.PrefixHandlers = r.embeddedAuthServer.Routes()
}
// Create middleware from the MiddlewareConfigs instances in the RunConfig.
for _, middlewareConfig := range r.Config.MiddlewareConfigs {
// First, get the correct factory function for the middleware type.
factory, ok := r.supportedMiddleware[middlewareConfig.Type]
if !ok {
return fmt.Errorf("unsupported middleware type: %s", middlewareConfig.Type)
}
// Create the middleware instance using the factory function.
// The factory will add the middleware to the runner and handle any special configuration.
if err := factory(&middlewareConfig, r); err != nil {
return fmt.Errorf("failed to create middleware of type %s: %w", middlewareConfig.Type, err)
}
}
// Set all named middleware and handlers on transport config
transportConfig.Middlewares = r.namedMiddlewares
transportConfig.AuthInfoHandler = r.authInfoHandler
transportConfig.PrometheusHandler = r.prometheusHandler
// Set up the transport
slog.Debug("setting up transport", "transport", r.Config.Transport)
// Prepare transport options based on workload type
var transportOpts []transport.Option
var setupResult *runtime.SetupResult
// Check policy gate before creating the server (applies to both local and remote)
if err := ActivePolicyGate().CheckCreateServer(ctx, r.Config); err != nil {
return fmt.Errorf("server creation blocked by policy: %w", err)
}
if r.Config.RemoteURL == "" {
// For local workloads, deploy the container using runtime.Setup first
var scalingConfig *rt.ScalingConfig
if r.Config.ScalingConfig != nil {
scalingConfig = &rt.ScalingConfig{
BackendReplicas: r.Config.ScalingConfig.BackendReplicas,
}
}
result, err := runtime.Setup(
ctx,
r.Config.Transport,
r.Config.Deployer,
r.Config.ContainerName,
r.Config.Image,
r.Config.CmdArgs,
r.Config.EnvVars,
r.Config.ContainerLabels,
r.Config.PermissionProfile,
r.Config.K8sPodTemplatePatch,
r.Config.IsolateNetwork,
r.Config.AllowDockerGateway,
r.Config.IgnoreConfig,
r.Config.Host,
r.Config.TargetPort,
r.Config.TargetHost,
r.Config.Publish,
scalingConfig,
r.Config.MCPServerGeneration,
)
if err != nil {
return fmt.Errorf("failed to set up workload: %w", err)
}
setupResult = result
// Configure the transport with the setup results using options
transportOpts = append(transportOpts, transport.WithContainerName(setupResult.ContainerName))
if setupResult.TargetURI != "" {
transportOpts = append(transportOpts, transport.WithTargetURI(setupResult.TargetURI))
}
}
// When Redis session storage is configured, create a Redis-backed session store
// so sessions are shared across proxy replicas instead of being pod-local.
if r.Config.ScalingConfig != nil && r.Config.ScalingConfig.SessionRedis != nil {
redisCfg := r.Config.ScalingConfig.SessionRedis
keyPrefix := redisCfg.KeyPrefix
if keyPrefix == "" {
keyPrefix = "thv:proxy:session:"
}
storage, err := session.NewRedisStorage(ctx, session.RedisConfig{
Addr: redisCfg.Address,
Password: os.Getenv(session.RedisPasswordEnvVar),
DB: int(redisCfg.DB),
KeyPrefix: keyPrefix,
}, effectiveSessionTTL)
if err != nil {
return fmt.Errorf("failed to create Redis session storage: %w", err)
}
slog.Info("using Redis session storage",
"address", redisCfg.Address,
"db", redisCfg.DB,
"key_prefix", keyPrefix,
)
transportConfig.SessionStorage = storage
}
// Create transport with options
transportHandler, err := transport.NewFactory().Create(transportConfig, transportOpts...)
if err != nil {
return fmt.Errorf("failed to create transport: %w", err)
}
// For remote MCP servers, set the remote URL on HTTP transports
if r.Config.RemoteURL != "" {
transportHandler.SetRemoteURL(r.Config.RemoteURL)
// Handle remote authentication if configured
tokenSource, err := r.handleRemoteAuthentication(ctx)
if err != nil {
return fmt.Errorf("failed to authenticate to remote server: %w", err)
}
// Wrap the token source with authentication monitoring for remote workloads
if tokenSource != nil {
// Create a child context for monitoring that can be cancelled during cleanup
r.monitoringCtx, r.monitoringCancel = context.WithCancel(ctx)
// Create adapter to bridge statuses.StatusManager to auth.StatusUpdater
adapter := &statusManagerAdapter{sm: r.statusManager}
r.authenticatedTokenSource = auth.NewMonitoredTokenSource(r.monitoringCtx, tokenSource, r.Config.BaseName, adapter)
tokenSource = r.authenticatedTokenSource
r.authenticatedTokenSource.StartBackgroundMonitoring()
}
// Set the token source on the transport
transportHandler.SetTokenSource(tokenSource)
// Set the health check failure callback for remote servers
transportHandler.SetOnHealthCheckFailed(func() {
slog.Warn("health check failed for remote server, marking as unhealthy", "server", r.Config.BaseName)
// Use Background context for status update callback - this is triggered by health check
// failure and is independent of any request context. The callback is fired asynchronously
// and needs its own lifecycle separate from the transport's parent context.
if err := r.statusManager.SetWorkloadStatus(
context.Background(),
r.Config.BaseName,
rt.WorkloadStatusUnhealthy,
"Health check failed",
); err != nil {
slog.Error("failed to update workload status", "error", err)
}
})
// Set the unauthorized response callback for bearer token authentication
errorMsg := "Bearer token authentication failed. Please restart the server with a new token"
transportHandler.SetOnUnauthorizedResponse(func() {
slog.Warn("received 401 Unauthorized response for remote server, marking as unauthenticated", "server", r.Config.BaseName)
// Use Background context for status update callback - this is triggered by 401 response
// and is independent of any request context. The callback is fired asynchronously
// and needs its own lifecycle separate from the transport's parent context.
if err := r.statusManager.SetWorkloadStatus(
context.Background(),
r.Config.BaseName,
rt.WorkloadStatusUnauthenticated,
errorMsg,
); err != nil {
slog.Error("failed to update workload status", "error", err)
}
})
}
// Configure stateless mode if requested. Stateless mode applies to any
// streamable-HTTP server (remote or local container) where the upstream
// only accepts POST and does not support SSE-based sessions.
if r.Config.Stateless {
httpT, ok := transportHandler.(*transport.HTTPTransport)
if !ok {
return fmt.Errorf("--stateless requires streamable-HTTP or SSE transport, got %T", transportHandler)
}
httpT.SetStateless(true)
}
// Start the transport (which also starts the container and monitoring)
slog.Debug("starting transport", "transport", r.Config.Transport, "container", r.Config.ContainerName)
if err := transportHandler.Start(ctx); err != nil {
return fmt.Errorf("failed to start transport: %w", err)
}
slog.Debug("mcp server started successfully", "container", r.Config.ContainerName)
// Wait for the MCP server to accept initialize requests before updating client configurations.
// This prevents timing issues where clients try to connect before the server is fully ready.
// We repeatedly call initialize until it succeeds (up to 5 minutes).
// Note: We skip this check for pure STDIO transport because STDIO servers may reject
// multiple initialize calls (see #1982).
transportType := labels.GetTransportType(r.Config.ContainerLabels)
serverURL := transport.GenerateMCPServerURL(
transportType,
string(r.Config.ProxyMode),
"localhost",
r.Config.Port,
r.Config.ContainerName,
r.Config.RemoteURL)
// Only wait for initialization on non-STDIO transports
// STDIO servers communicate directly via stdin/stdout and calling initialize multiple times
// can cause issues as the behavior is not specified by the MCP spec
if transportType != "stdio" {
// Repeatedly try calling initialize until it succeeds (up to 5 minutes)
// Some servers (like mcp-optimizer) can take significant time to start up
if err := waitForInitializeSuccess(ctx, serverURL, transportType, 5*time.Minute); err != nil {
slog.Warn("initialize not successful, but continuing", "error", err)
// Continue anyway to maintain backward compatibility, but log a warning
}
} else {
slog.Debug("skipping initialize check for STDIO transport")
}
// Update client configurations with the MCP server URL.
// Note that this function checks the configuration to determine which
// clients should be updated, if any.
clientManager, err := client.NewManager(ctx)
if err != nil {
slog.Warn("failed to create client manager", "error", err)
} else {
if err := clientManager.AddServerToClients(ctx, r.Config.ContainerName, serverURL, transportType, r.Config.Group); err != nil {
slog.Warn("failed to add server to client configurations", "error", err)
}
}
// Define a function to stop the MCP server
stopMCPServer := func(reason string) {
// Use Background context for cleanup operations. The parent context may already be
// cancelled when this cleanup function runs (e.g., on graceful shutdown or context
// cancellation). We need a fresh context with its own timeout to ensure cleanup
// operations complete successfully regardless of the parent context state.
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cleanupCancel()
slog.Debug("stopping MCP server", "reason", reason)
// Stop the transport (which also stops the container, monitoring, and handles removal)
slog.Debug("stopping transport", "transport", r.Config.Transport)
if err := transportHandler.Stop(cleanupCtx); err != nil {
slog.Warn("failed to stop transport", "error", err)
}
// Cleanup telemetry provider
if err := r.Cleanup(cleanupCtx); err != nil {
slog.Warn("failed to cleanup telemetry", "error", err)
}
// Remove the PID file if it exists. Use PID-guarded reset so that a
// dying process does not clobber the PID of a replacement process that
// started in the meantime (e.g. during thv rm + thv run).
if err := r.statusManager.ResetWorkloadPIDIfMatch(cleanupCtx, r.Config.BaseName, os.Getpid()); err != nil {
slog.Warn("failed to reset workload PID", "container", r.Config.ContainerName, "error", err)
}
slog.Debug("mcp server stopped", "container", r.Config.ContainerName)
}
if err := r.statusManager.SetWorkloadPID(ctx, r.Config.BaseName, os.Getpid()); err != nil {
slog.Warn("failed to set workload PID", "error", err)
}
if process.IsDetached() {
// We're a detached process running in foreground mode
// Write the PID to a file so the stop command can kill the process
slog.Info("running as detached process", "pid", os.Getpid())
} else {
// Notify that user that the workload has started successfully when using --foreground
slog.Info("workload started successfully, press Ctrl+C to stop")
}
// Create a done channel to signal when the server has been stopped
doneCh := make(chan struct{})
// Start a goroutine to monitor the transport's running state
go func() {
for {
// Safely check if transportHandler is nil
if transportHandler == nil {
slog.Debug("transport handler is nil, exiting monitoring routine")
close(doneCh)
return
}
// Check if the transport is still running
running, err := transportHandler.IsRunning()
if err != nil {
slog.Error("error checking transport status", "error", err)
// Don't exit immediately on error, try again after pause
time.Sleep(1 * time.Second)
continue
}
if !running {
// Transport is no longer running (container exited or was stopped)
slog.Warn("transport is no longer running, attempting automatic restart")
close(doneCh)
return
}
// Sleep for a short time before checking again
time.Sleep(1 * time.Second)
}
}()
// At this point, we can consider the workload started successfully.
// However, we should preserve unauthenticated status if it was already set
// (e.g., if bearer token authentication failed during initialization)
currentWorkload, err := r.statusManager.GetWorkload(ctx, r.Config.BaseName)
if err != nil && !errors.Is(err, rt.ErrWorkloadNotFound) {
slog.Warn("failed to get current workload status", "error", err)
}
// Only set status to running if it's not already unauthenticated
// This preserves the unauthenticated state when bearer token authentication fails
if err == nil && currentWorkload.Status == rt.WorkloadStatusUnauthenticated {
slog.Debug("preserving unauthenticated status for workload", "workload", r.Config.BaseName)
} else {
if err := r.statusManager.SetWorkloadStatus(ctx, r.Config.BaseName, rt.WorkloadStatusRunning, ""); err != nil {
// If we can't set the status to `running` - treat it as a fatal error.
return fmt.Errorf("failed to set workload status: %w", err)
}
}
// Wait for either a signal or the done channel to be closed
select {
case <-ctx.Done():
stopMCPServer("Context cancelled")
case <-doneCh:
// The transport has already been stopped (likely by the container exit)
// Remove the old PID from the state file. Use PID-guarded reset to
// avoid clobbering a replacement process's PID.
if err := r.statusManager.ResetWorkloadPIDIfMatch(ctx, r.Config.BaseName, os.Getpid()); err != nil {
slog.Warn("failed to reset workload PID", "workload", r.Config.BaseName, "error", err)
}
// Check if workload still exists (using status manager and runtime)
// If it doesn't exist, it was removed - clean up client config
// If it exists, it exited unexpectedly - signal restart needed
exists, checkErr := r.doesWorkloadExist(ctx, r.Config.BaseName)
if checkErr != nil {
slog.Warn("failed to check if workload exists", "error", checkErr)
// Assume restart needed if we can't check
} else if !exists {
// Workload doesn't exist in `thv ls` - it was removed
slog.Debug("Workload no longer exists, removing from client configurations",
"workload", r.Config.BaseName)
clientManager, clientErr := client.NewManager(ctx)
if clientErr == nil {
removeErr := clientManager.RemoveServerFromClients(
ctx,
r.Config.ContainerName,
r.Config.Group,
)
if removeErr != nil {
slog.Warn("failed to remove from client config", "error", removeErr)
} else {
slog.Debug("Successfully removed from client configurations",
"container", r.Config.ContainerName)
}
}
slog.Debug("MCP server stopped and cleaned up", "container", r.Config.ContainerName)
return nil // Exit gracefully, no restart
}
// Workload still exists - signal restart needed
slog.Debug("MCP server stopped, restart needed", "container", r.Config.ContainerName)
return ErrContainerExitedRestartNeeded
}
return nil
}
// doesWorkloadExist checks if a workload exists in the status manager and runtime.
// For remote workloads, it trusts the status manager.
// For container workloads, it verifies the container exists in the runtime.
func (r *Runner) doesWorkloadExist(ctx context.Context, workloadName string) (bool, error) {
// Check if workload exists by trying to get it from status manager
workload, err := r.statusManager.GetWorkload(ctx, workloadName)
if err != nil {
if errors.Is(err, rt.ErrWorkloadNotFound) {
return false, nil
}
return false, fmt.Errorf("failed to check if workload exists: %w", err)
}
// If remote workload, check if it should exist
if workload.Remote {
// For remote workloads, trust the status manager
return workload.Status != rt.WorkloadStatusError, nil
}
// For container workloads, verify the container actually exists in the runtime
// Create a runtime instance to check if container exists
backend, err := ct.NewFactory().Create(ctx)
if err != nil {
slog.Warn("Failed to create runtime to check container existence", "error", err)
// Fall back to status manager only
return workload.Status != rt.WorkloadStatusError, nil
}
// Check if container exists in the runtime (not just running)
// GetWorkloadInfo will return an error if the container doesn't exist
_, err = backend.GetWorkloadInfo(ctx, workloadName)
if err != nil {
// Container doesn't exist
slog.Debug("Container not found in runtime", "workload", workloadName, "error", err)
return false, nil
}
// Container exists (may be running or stopped)
return true, nil
}
// handleRemoteAuthentication handles authentication for remote MCP servers
func (r *Runner) handleRemoteAuthentication(ctx context.Context) (oauth2.TokenSource, error) {
if r.Config.RemoteAuthConfig == nil {
return nil, nil
}
// Get the secret manager for token storage
secretManager, err := authsecrets.GetSecretsManager()
if err != nil {
// Secret manager not available - log warning but continue
// OAuth will work but tokens won't be persisted across restarts
slog.Warn("Secret manager not available, OAuth tokens will not be persisted", "error", err)
}
// Create remote authentication handler
authHandler := remote.NewHandler(r.Config.RemoteAuthConfig)
// Set the secret provider for retrieving cached tokens
if secretManager != nil {
authHandler.SetSecretProvider(secretManager)
}
// Set up token persister to save tokens across restarts
if secretManager != nil {
authHandler.SetTokenPersister(func(refreshToken string, expiry time.Time) error {
// Generate a unique secret name for this workload's refresh token
secretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix(
r.Config.Name,
"OAUTH_REFRESH_TOKEN_",
secretManager,
)
if err != nil {
return fmt.Errorf("failed to generate secret name: %w", err)
}
// Store the refresh token in the secret manager
if err := authsecrets.StoreSecretInManagerWithProvider(ctx, secretName, refreshToken, secretManager); err != nil {
return fmt.Errorf("failed to store refresh token: %w", err)
}
// Store the secret reference (not the actual token) in the config
r.Config.RemoteAuthConfig.CachedRefreshTokenRef = secretName
r.Config.RemoteAuthConfig.CachedTokenExpiry = expiry
// Save the updated config to persist the reference
if err := r.Config.SaveState(ctx); err != nil {
return fmt.Errorf("failed to save config with token reference: %w", err)
}
slog.Debug("Stored OAuth refresh token in secret manager", "secret_name", secretName)
return nil
})
// Set up client credentials persister for DCR (Dynamic Client Registration)
authHandler.SetClientCredentialsPersister(func(clientID, clientSecret string) error {
// Store client ID directly (it's public information)
r.Config.RemoteAuthConfig.CachedClientID = clientID
// Only store client secret if it's non-empty (PKCE flows may not have one)
if clientSecret != "" {
clientSecretSecretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix(
r.Config.Name,
"OAUTH_CLIENT_SECRET_",
secretManager,
)
if err != nil {
return fmt.Errorf("failed to generate client secret secret name: %w", err)
}
if err := authsecrets.StoreSecretInManagerWithProvider(ctx, clientSecretSecretName, clientSecret, secretManager); err != nil {
return fmt.Errorf("failed to store client secret: %w", err)
}
r.Config.RemoteAuthConfig.CachedClientSecretRef = clientSecretSecretName
}
// Save the updated config to persist the credentials
if err := r.Config.SaveState(ctx); err != nil {
return fmt.Errorf("failed to save config with client credentials: %w", err)
}
slog.Debug("Stored DCR client credentials", "client_id", clientID)
return nil
})
}
// Perform authentication
tokenSource, err := authHandler.Authenticate(ctx, r.Config.RemoteURL)
if err != nil {
return nil, fmt.Errorf("remote authentication failed: %w", err)
}
return tokenSource, nil
}
// Cleanup performs cleanup operations for the runner, including shutting down all middleware.
func (r *Runner) Cleanup(ctx context.Context) error {
// For simplicity, return the last error we encounter during cleanup.
var lastErr error
// Clean up all middleware instances
for i, middleware := range r.middlewares {
if err := middleware.Close(); err != nil {
slog.Warn("Failed to close middleware", "index", i, "error", err)
lastErr = err
}
}
// Close embedded auth server
if r.embeddedAuthServer != nil {
if err := r.embeddedAuthServer.Close(); err != nil {
slog.Warn("Failed to close embedded auth server", "error", err)
if lastErr == nil {
lastErr = err
}
}
}
// Legacy telemetry provider cleanup (will be removed when telemetry middleware handles it)
if r.telemetryProvider != nil {
slog.Debug("Shutting down telemetry provider")
if err := r.telemetryProvider.Shutdown(ctx); err != nil {
slog.Warn("failed to shutdown telemetry provider", "error", err)
lastErr = err
}
}
// Stop background authentication monitoring for remote workloads
// Cancel the monitoring context to stop the background goroutine
if r.monitoringCancel != nil {
r.monitoringCancel()
r.monitoringCancel = nil
}
return lastErr
}
// waitForInitializeSuccess repeatedly checks if the MCP server is ready to accept requests.
// This prevents timing issues where clients try to connect before the server is fully ready.
// It makes repeated attempts with exponential backoff up to a maximum timeout.
// Note: This function should not be called for STDIO transport.
func waitForInitializeSuccess(ctx context.Context, serverURL, transportType string, maxWaitTime time.Duration) error {
// Determine the endpoint and method to use based on transport type
var endpoint string
var method string
var payload string
switch transportType {
case "streamable-http", "streamable":
// For streamable-http, send initialize request to /mcp endpoint
// Format: http://localhost:port/mcp
endpoint = serverURL
method = "POST"
payload = `{"jsonrpc":"2.0","method":"initialize","id":"toolhive-init-check",` +
`"params":{"protocolVersion":"2024-11-05","capabilities":{},` +
`"clientInfo":{"name":"toolhive","version":"1.0"}}}`
case "sse":
// For SSE, just check if the SSE endpoint is available
// We can't easily call initialize without establishing a full SSE connection,
// so we just verify the endpoint responds.
// Format: http://localhost:port/sse#container-name -> http://localhost:port/sse
endpoint = serverURL
// Remove fragment if present (everything after #)
if idx := strings.Index(endpoint, "#"); idx != -1 {
endpoint = endpoint[:idx]
}
method = "GET"
payload = ""
default:
// For other transports, no HTTP check is needed
slog.Debug("Skipping readiness check for transport type", "transport", transportType)
return nil
}
// Setup retry logic with exponential backoff
startTime := time.Now()
attempt := 0
delay := 100 * time.Millisecond
maxDelay := 2 * time.Second // Cap at 2 seconds between retries
slog.Info("Waiting for MCP server to be ready", "endpoint", endpoint, "timeout", maxWaitTime)
// Create HTTP client with a reasonable timeout for requests
httpClient := &http.Client{
Timeout: 10 * time.Second,
}
for {
attempt++
// Make the readiness check request
var req *http.Request
var err error
if payload != "" {
req, err = http.NewRequestWithContext(ctx, method, endpoint, bytes.NewBufferString(payload))
} else {
req, err = http.NewRequestWithContext(ctx, method, endpoint, nil)
}
if err != nil {
slog.Debug("Failed to create request", "attempt", attempt, "error", err)
} else {
if method == "POST" {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/event-stream")
req.Header.Set("MCP-Protocol-Version", "2024-11-05")
}
resp, err := httpClient.Do(req) // #nosec G704 -- endpoint is the local MCP server readiness URL
if err == nil {
//nolint:errcheck // Ignoring close error on response body in error path
defer resp.Body.Close()
// For GET (SSE), accept 200 OK
// For POST (streamable-http), also accept 200 OK
if resp.StatusCode == http.StatusOK {
elapsed := time.Since(startTime)
slog.Debug("MCP server is ready", "elapsed", elapsed, "attempt", attempt)
return nil
}
slog.Debug("Server returned status", //nolint:gosec // G706: status code and attempt are integers
"status_code", resp.StatusCode, "attempt", attempt)
} else {
slog.Debug("Failed to reach endpoint", "attempt", attempt, "error", err)
}
}
// Check if we've exceeded the maximum wait time
elapsed := time.Since(startTime)
if elapsed >= maxWaitTime {
return fmt.Errorf("initialize not successful after %v (%d attempts)", elapsed, attempt)
}
// Wait before retrying
select {
case <-ctx.Done():
return fmt.Errorf("context cancelled while waiting for initialize")
case <-time.After(delay):
// Continue to next attempt
}
// Update delay for next iteration with exponential backoff
delay *= 2
if delay > maxDelay {
delay = maxDelay
}
}
}