-
Notifications
You must be signed in to change notification settings - Fork 65
Expand file tree
/
Copy pathmain.go
More file actions
408 lines (357 loc) · 13.6 KB
/
main.go
File metadata and controls
408 lines (357 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
// main implements the CLI for the MCP broker/router.
package main
import (
"context"
"flag"
"fmt"
"log"
"log/slog"
"net"
"net/http"
"os"
"os/signal"
"sync"
"time"
mcpv1alpha1 "github.com/Kuadrant/mcp-gateway/api/v1alpha1"
"github.com/Kuadrant/mcp-gateway/internal/broker"
"github.com/Kuadrant/mcp-gateway/internal/clients"
config "github.com/Kuadrant/mcp-gateway/internal/config"
"github.com/Kuadrant/mcp-gateway/internal/idmap"
mcpRouter "github.com/Kuadrant/mcp-gateway/internal/mcp-router"
mcpotel "github.com/Kuadrant/mcp-gateway/internal/otel"
"github.com/Kuadrant/mcp-gateway/internal/session"
goenv "github.com/caitlinelfring/go-env-default"
extProcV3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
"github.com/fsnotify/fsnotify"
"github.com/mark3labs/mcp-go/server"
redis "github.com/redis/go-redis/v9"
"github.com/spf13/viper"
"google.golang.org/grpc"
"k8s.io/apimachinery/pkg/runtime"
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
gatewayv1 "sigs.k8s.io/gateway-api/apis/v1"
)
var (
version = "dev"
gitSHA = "unknown"
dirty = ""
)
var (
mcpConfig = &config.MCPServersConfig{}
mutex sync.RWMutex
logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
scheme = runtime.NewScheme()
)
func init() {
_ = clientgoscheme.AddToScheme(scheme)
_ = mcpv1alpha1.AddToScheme(scheme)
_ = gatewayv1.Install(scheme)
}
var (
mcpRouterAddrFlag string
mcpBrokerAddrFlag string
mcpRoutePublicHost string
mcpRoutePrivateHost string
mcpRouterKey string
cacheConnectionStringFlag string
mcpConfigFile string
jwtSigningKeyFlag string
sessionDurationInMins int64
brokerWriteTimeoutSecs int64
managerTickerIntervalSecs int64
loglevel int
logFormat string
enforceToolFilteringFlag bool
invalidToolPolicyFlag string
discoveryToolThresholdFlag int
)
func main() {
flag.StringVar(
&mcpRouterAddrFlag,
"mcp-router-address",
"0.0.0.0:50051",
"The address for MCP router",
)
flag.StringVar(
&mcpBrokerAddrFlag,
"mcp-broker-public-address",
"0.0.0.0:8080",
"The public address for MCP broker",
)
flag.StringVar(
&mcpRoutePublicHost,
"mcp-gateway-public-host",
"",
"The public host the MCP Gateway is exposing MCP servers on. The gateway router will always set the :authority header to this value to ensure the broker component cannot be bypassed.",
)
flag.StringVar(
&mcpRoutePrivateHost,
"mcp-gateway-private-host",
"mcp-gateway-istio.gateway-system.svc.cluster.local:8080",
"The private host the MCP Gateway. The gateway router will use this to hairpin request to initialize MCP servers etc.",
)
// TODO ick not sure how to describe this
flag.StringVar(
&mcpRouterKey,
"mcp-router-key",
goenv.GetDefault("MCP_ROUTER_API_KEY", "secret-api-key"),
"this key is used to allow the router to send request through the gateway and be trusted by the router",
)
flag.StringVar(
&mcpConfigFile,
"mcp-gateway-config",
"./config/mcp-system/config.yaml",
"where to locate the mcp server config",
)
flag.IntVar(
&loglevel,
"log-level",
int(slog.LevelInfo),
"set the log level 0=info, 4=warn , 8=error and -4=debug",
)
flag.StringVar(&jwtSigningKeyFlag,
"session-signing-key",
goenv.GetDefault("JWT_SESSION_SIGNING_KEY", ""),
"JWT signing key for session tokens (env: JWT_SESSION_SIGNING_KEY)",
)
//"redis://redis.mcp-system.svc.cluster.local:6379
flag.StringVar(&cacheConnectionStringFlag,
"cache-connection-string",
goenv.GetDefault("CACHE_CONNECTION_STRING", ""),
"redis based cache connection string redis://<user>:<pass>@localhost:6379/<db> (env: CACHE_CONNECTION_STRING). If not set defaults to in memory storage",
)
flag.StringVar(&logFormat, "log-format", "txt", "switch to json logs with --log-format=json")
flag.Int64Var(&sessionDurationInMins, "session-length", 60*24, "default session length with the gateway in minutes. Default 24h")
flag.Int64Var(&brokerWriteTimeoutSecs, "mcp-broker-write-timeout", 0, "HTTP write timeout in seconds for the broker. Default 0 (disabled) for SSE notification support. Set > 0 to enable timeout.")
flag.Int64Var(&managerTickerIntervalSecs, "mcp-check-interval", 60, "interval in seconds for MCP manager backend health checks. Default 60 seconds.")
flag.BoolVar(&enforceToolFilteringFlag, "enforce-tool-filtering", false, "when enabled an x-authorized-tools header will be needed to return any tools")
flag.StringVar(&invalidToolPolicyFlag, "invalid-tool-policy", "FilterOut", "policy for upstream tools with invalid schemas: FilterOut (default) or RejectServer")
flag.IntVar(&discoveryToolThresholdFlag, "discovery-tool-threshold", 10, "when total registered tools exceeds this threshold, new sessions only see discovery meta-tools until select_tools is called. Set 0 to always require discovery.")
flag.Parse()
loggerOpts := &slog.HandlerOptions{}
switch loglevel {
case 0:
loggerOpts.Level = slog.LevelInfo
case 8:
loggerOpts.Level = slog.LevelError
case -4:
loggerOpts.Level = slog.LevelDebug
default:
loggerOpts.Level = slog.LevelDebug
}
jsonFormat := logFormat == "json"
logger = mcpotel.NewTracingLogger(os.Stdout, loggerOpts, jsonFormat, nil)
ctx := context.Background()
otelShutdown, loggerProvider, err := mcpotel.SetupOTelSDK(ctx, gitSHA, dirty, version, logger)
if err != nil {
logger.Error("failed to setup OpenTelemetry", "error", err)
}
if loggerProvider != nil {
logger = mcpotel.NewTracingLogger(os.Stdout, loggerOpts, jsonFormat, loggerProvider)
logger.Info("Logger upgraded with OTLP export")
}
var redisClient *redis.Client
if cacheConnectionStringFlag != "" {
logger.Info("cache using external redis store")
redisOpt, err := redis.ParseURL(cacheConnectionStringFlag)
if err != nil {
panic("failed to parse redis connection string: " + err.Error())
}
redisClient = redis.NewClient(redisOpt)
if err := redisClient.Ping(ctx).Err(); err != nil {
panic("failed to connect to redis: " + err.Error())
}
}
sessionCache, err := session.NewCache(session.WithRedisClient(redisClient))
if err != nil {
panic("failed to setup session cache: " + err.Error())
}
var jwtSessionMgr *session.JWTManager
if jwtSigningKeyFlag == "" {
panic("JWT_SESSION_SIGNING_KEY is required but not set. " +
"When running via the controller, this is managed automatically. " +
"For standalone use, set the JWT_SESSION_SIGNING_KEY environment variable.")
}
jwtmgr, err := session.NewJWTManager(jwtSigningKeyFlag, sessionDurationInMins, logger, sessionCache)
if err != nil {
panic("failed to setup jwt manager " + err.Error())
}
jwtSessionMgr = jwtmgr
sessionTTL := time.Duration(sessionDurationInMins) * time.Minute
elicitationMap, err := idmap.New(idmap.WithRedisClient(redisClient), idmap.WithEntryTTL(sessionTTL))
if err != nil {
panic("failed to setup elicitation map: " + err.Error())
}
invalidToolPolicy := mcpv1alpha1.InvalidToolPolicy(invalidToolPolicyFlag)
if invalidToolPolicy != mcpv1alpha1.InvalidToolPolicyFilterOut && invalidToolPolicy != mcpv1alpha1.InvalidToolPolicyRejectServer {
panic("--invalid-tool-policy must be FilterOut or RejectServer")
}
managerTickerInterval := time.Duration(managerTickerIntervalSecs) * time.Second
if managerTickerInterval <= 0 {
panic("flag mcp-check-interval cannot be 0 or less seconds")
}
mcpBroker := broker.NewBroker(logger.With("component", "broker"),
broker.WithEnforceToolFilter(enforceToolFilteringFlag),
broker.WithTrustedHeadersPublicKey(os.Getenv("TRUSTED_HEADER_PUBLIC_KEY")),
broker.WithManagerTickerInterval(managerTickerInterval),
broker.WithInvalidToolPolicy(invalidToolPolicy),
broker.WithDiscoveryToolThreshold(discoveryToolThresholdFlag),
)
brokerServer, mcpServer := setUpHTTPServer(mcpBrokerAddrFlag, mcpBroker, jwtSessionMgr, brokerWriteTimeoutSecs)
routerGRPCServer, router := setUpRouter(mcpBroker, logger, jwtSessionMgr, sessionCache, elicitationMap)
mcpConfig.RegisterObserver(router)
mcpConfig.RegisterObserver(mcpBroker)
if mcpRoutePublicHost == "" {
panic("--mcp-gateway-public-host cannot be empty. The mcp gateway needs to be informed of what public host to expect requests from so it can ensure routing and session mgmt happens. Set --mcp-gateway-public-host")
}
mcpConfig.MCPGatewayExternalHostname = mcpRoutePublicHost
mcpConfig.MCPGatewayInternalHostname = mcpRoutePrivateHost
mcpConfig.RouterAPIKey = mcpRouterKey
// Only load config and run broker/router in standalone mode
mutex.Lock()
// will panic if fails
LoadConfig(mcpConfigFile)
mutex.Unlock()
mcpConfig.Notify(ctx)
viper.WatchConfig()
// set up our change event handler
viper.OnConfigChange(func(in fsnotify.Event) {
logger.Info("OnConfigChange mcp servers config changed ", "config file", in.Name)
mutex.Lock()
defer mutex.Unlock()
LoadConfig(mcpConfigFile)
logger.Info("OnConfigChange: notifying observers of config change")
mcpConfig.Notify(ctx)
})
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt)
grpcAddr := mcpRouterAddrFlag
lc := net.ListenConfig{}
lis, err := lc.Listen(ctx, "tcp", grpcAddr)
if err != nil {
log.Fatalf("[grpc] listen error: %v", err)
}
go func() {
logger.Info("[grpc] starting MCP Router", "listening", grpcAddr)
log.Fatal(routerGRPCServer.Serve(lis))
}()
go func() {
logger.Info("[http] starting MCP Broker (public)", "listening", brokerServer.Addr)
if err := brokerServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("[http] Cannot start public broker: %v", err)
}
}()
<-stop
// handle shutdown
logger.Info("shutting down MCP Broker and MCP Router")
shutdownCtx, shutdownRelease := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownRelease()
if err := otelShutdown(shutdownCtx); err != nil {
logger.Error("OpenTelemetry shutdown error", "error", err)
}
if err := brokerServer.Shutdown(shutdownCtx); err != nil {
log.Fatalf("HTTP shutdown error: %v", err)
}
if err := mcpServer.Shutdown(shutdownCtx); err != nil {
log.Printf("MCP shutdown error: %v; ignoring", err)
}
routerGRPCServer.GracefulStop()
if redisClient != nil {
if err := redisClient.Close(); err != nil {
logger.Error("redis close error", "error", err)
}
}
}
func setUpHTTPServer(address string, mcpBroker broker.MCPBroker, sessionManager *session.JWTManager, writeTimeoutSecs int64) (*http.Server, *server.StreamableHTTPServer) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
_, _ = fmt.Fprint(w, "Hello, World! BTW, the MCP server is on /mcp")
})
oauthHandler := broker.ProtectedResourceHandler{Logger: logger}
mux.HandleFunc("/.well-known/oauth-protected-resource", oauthHandler.Handle)
// WriteTimeout of 0 (disabled) is important for SSE connections (GET /mcp).
// SSE streams notifications indefinitely - any write timeout would kill the connection.
writeTimeout := time.Duration(writeTimeoutSecs) * time.Second
httpSrv := &http.Server{
Addr: address,
Handler: mux,
ReadTimeout: 5 * time.Second,
WriteTimeout: writeTimeout,
}
streamableHTTPOpts := []server.StreamableHTTPOption{
server.WithStreamableHTTPServer(httpSrv),
}
if sessionManager != nil {
logger.Info("jwt session manager configured")
streamableHTTPOpts = append(streamableHTTPOpts, server.WithSessionIdManager(sessionManager))
}
streamableHTTPServer := server.NewStreamableHTTPServer(mcpBroker.MCPServer(), streamableHTTPOpts...)
// Allow direct connections with MCP Inspector
mux.HandleFunc("OPTIONS /mcp", func(w http.ResponseWriter, r *http.Request) {
logger.Debug("Handling OPTIONS", "Mcp-Session-Id", r.Header.Get("Mcp-Session-Id"))
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/status", mcpBroker.HandleStatusRequest)
mux.HandleFunc("/status/", mcpBroker.HandleStatusRequest)
mux.Handle("/mcp", streamableHTTPServer)
return httpSrv, streamableHTTPServer
}
func setUpRouter(broker broker.MCPBroker, logger *slog.Logger, jwtManager *session.JWTManager, sessionCache *session.Cache, elicitationMap idmap.Map) (*grpc.Server, *mcpRouter.ExtProcServer) {
grpcSrv := grpc.NewServer()
server := &mcpRouter.ExtProcServer{
RoutingConfig: mcpConfig,
Logger: logger.With("component", "router"),
JWTManager: jwtManager,
InitForClient: clients.Initialize,
SessionCache: sessionCache,
ElicitationMap: elicitationMap,
Broker: broker, // TODO we shouldn't need a handle to broker in the router
}
extProcV3.RegisterExternalProcessorServer(grpcSrv, server)
return grpcSrv, server
}
// config
func LoadConfig(path string) {
viper.SetConfigFile(path)
logger.Debug("loading config", "path", viper.ConfigFileUsed())
err := viper.ReadInConfig()
if err != nil {
log.Fatalf("Error reading config file: %s", err)
}
// reset the servers to avoid old configs being written to
mcpConfig.Servers = []*config.MCPServer{}
err = viper.UnmarshalKey("servers", &mcpConfig.Servers)
if err != nil {
log.Fatalf("Unable to decode server config into struct: %s", err)
}
mcpConfig.VirtualServers = []*config.VirtualServer{}
// Load virtualServers if present - this is optional
if viper.IsSet("virtualServers") {
err = viper.UnmarshalKey("virtualServers", &mcpConfig.VirtualServers)
if err != nil {
log.Fatal("Failed to parse virtualServers configuration", "error", err)
}
} else {
logger.Debug("No virtualServers section found in configuration")
}
logger.Debug("config successfully loaded", "# servers", len(mcpConfig.Servers))
for _, s := range mcpConfig.Servers {
logger.Debug(
"server config",
"server name",
s.Name,
"server prefix",
s.ToolPrefix,
"enabled",
s.Enabled,
"backend url",
s.URL,
"routable host",
s.Hostname,
"category",
s.Category,
"hint",
s.Hint,
)
}
}