diff --git a/cmd/internal/flags.go b/cmd/internal/flags.go index 1c13b23d37ba..9c222a929d71 100644 --- a/cmd/internal/flags.go +++ b/cmd/internal/flags.go @@ -64,7 +64,7 @@ func ServeFlags(flags *pflag.FlagSet, opts *ToolboxOptions) { flags.BoolVar(&opts.Cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.") flags.BoolVar(&opts.Cfg.UI, "ui", false, "Launches the Toolbox UI web server.") flags.BoolVar(&opts.Cfg.EnableAPI, "enable-api", false, "Enable the /api endpoint.") - + flags.StringVar(&opts.Cfg.ToolboxUrl, "toolbox-url", "", "Specifies the Toolbox URL. Used as the resource field in the MCP PRM file when MCP Auth is enabled. Falls back to TOOLBOX_URL environment variable.") flags.StringSliceVar(&opts.Cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.") flags.StringSliceVar(&opts.Cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.") } diff --git a/cmd/root.go b/cmd/root.go index 04ba8199bee5..85f768cd94f3 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -37,6 +37,7 @@ import ( "github.com/googleapis/genai-toolbox/cmd/internal/serve" "github.com/googleapis/genai-toolbox/cmd/internal/skills" "github.com/googleapis/genai-toolbox/internal/auth" + "github.com/googleapis/genai-toolbox/internal/auth/generic" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server" @@ -450,6 +451,21 @@ func run(cmd *cobra.Command, opts *internal.ToolboxOptions) error { return err } + // Validate ToolboxUrl if MCP Auth is enabled + for _, authSvc := range opts.Cfg.AuthServiceConfigs { + if genCfg, ok := authSvc.(generic.Config); ok && genCfg.McpEnabled { + if opts.Cfg.ToolboxUrl == "" { + opts.Cfg.ToolboxUrl = os.Getenv("TOOLBOX_URL") + } + if opts.Cfg.ToolboxUrl == "" { + errMsg := fmt.Errorf("MCP Auth is enabled but Toolbox URL is missing. Please provide it via --toolbox-url flag or TOOLBOX_URL environment variable") + opts.Logger.ErrorContext(ctx, errMsg.Error()) + return errMsg + } + break + } + } + // start server s, err := server.NewServer(ctx, opts.Cfg) if err != nil { diff --git a/docs/en/documentation/configuration/toolbox_mcp_auth.md b/docs/en/documentation/configuration/toolbox_mcp_auth.md new file mode 100644 index 000000000000..a6653a25c807 --- /dev/null +++ b/docs/en/documentation/configuration/toolbox_mcp_auth.md @@ -0,0 +1,121 @@ +--- +title: "Toolbox with MCP Authorization" +type: docs +weight: 4 +description: > + How to set up and configure Toolbox with [MCP Authorization](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization). +--- + +## Overview + +Toolbox supports integration with Model Context Protocol (MCP) clients by acting as a Resource Server that implements OAuth 2.1 authorization. This enables Toolbox to validate JWT-based Bearer tokens before processing requests for resources or tool executions. + +This guide details the specific configuration steps required to deploy Toolbox with MCP Auth enabled. + +## Step 1: Configure the `generic` Auth Service + +Update your `tools.yaml` file to use a `generic` authorization service with `mcpEnabled` set to `true`. This instructs Toolbox to intercept requests on the `/mcp` routes and validate Bearer tokens using the JWKS (JSON Web Key Set) fetched from your OIDC provider endpoint (`authorizationServer`). + +```yaml +kind: authServices +name: my-mcp-auth +type: generic +mcpEnabled: true +authorizationServer: "https://accounts.google.com" # Your authorization server URL +audience: "your-mcp-audience" # Matches the `aud` claim in the JWT +scopesRequired: + - "mcp:tools" +``` + +When `mcpEnabled` is true, Toolbox also provisions the `/.well-known/oauth-protected-resource` Protected Resource Metadata (PRM) endpoint automatically using the `authorizationServer`. + +## Step 2: Deployment + +Deploying Toolbox with MCP auth requires defining the `TOOLBOX_URL` that the deployed service will use, as this URL must be included as the `resource` field in the PRM returned to the client. + +You can set this either through the `TOOLBOX_URL` environment variable or the `--toolbox-url` command-line flag during deployment. + +### Local Deployment + +To run Toolbox locally with MCP auth enabled, simply export the `TOOLBOX_URL` referencing your local port before running the binary: + +```bash +export TOOLBOX_URL="http://127.0.0.1:5000" +./toolbox --tools-file tools.yaml +``` + +If you prefer to use the `--toolbox-url` flag explicitly: + +```bash +./toolbox --tools-file tools.yaml --toolbox-url "http://127.0.0.1:5000" +``` + +### Cloud Run Deployment + +```bash +export IMAGE="us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest" + +# Pass your target Cloud Run URL to the `--toolbox-url` flag +gcloud run deploy toolbox \ + --image $IMAGE \ + --service-account toolbox-identity \ + --region us-central1 \ + --set-secrets "/app/tools.yaml=tools:latest" \ + --args="--tools-file=/app/tools.yaml","--address=0.0.0.0","--port=8080","--toolbox-url=${CLOUD_RUN_TOOLBOX_URL}" +``` + +### Alternative: Manual PRM File Override + +If you strictly need to define your own Protected Resource Metadata instead of auto-generating it from the `AuthService` config, you can use the `--mcp-prm-file ` flag. + +1. Create a `prm.json` containing the RFC-9207 compliant metadata. Note that the `resource` field must match the `TOOLBOX_URL`: + ```json + { + "resource": "https://toolbox-service-123456789-uc.a.run.app", + "authorization_servers": ["https://your-auth-server.example.com"], + "scopes_supported": ["mcp:tools"], + "bearer_methods_supported": ["header"] + } + ``` +2. Set the `--mcp-prm-file` flag to the path of the PRM file. + + - If you are using local deployment, you can just provide the path to the file directly: + ```bash + ./toolbox --tools-file tools.yaml --mcp-prm-file prm.json + ``` + - If you are using Cloud Run, upload it to GCP Secret Manager and Attach the secret to the Cloud Run deployment and provide the flag. + ```bash + gcloud secrets create prm_file --data-file=prm.json + + gcloud run deploy toolbox \ + # ... previous args + --set-secrets "/app/tools.yaml=tools:latest,/app/prm.json=prm_file:latest" \ + --args="--tools-file=/app/tools.yaml","--mcp-prm-file=/app/prm.json","--port=8080" + ``` + +## Step 3: Connecting to the Secure MCP Endpoint + +Once the Cloud Run instance is deployed, your MCP client must obtain a valid JWT token from your authorization server (the `authorizationServer` in `tools.yaml`). + +The client should provide this JWT via the standard HTTP `Authorization` header when connecting to the Streamable HTTP or SSE endpoint (`/mcp`): + +```bash +{ + "mcpServers": { + "toolbox-secure": { + "type": "http", + "url": "https://toolbox-service-123456789-uc.a.run.app/mcp", + "headers": { + "Authorization": "Bearer " + } + } + } +} +``` +Important: The token provided in the Authorization header must be a JWT token (issued by the auth server you configured previously), not a Google Cloud Run access token. + +Toolbox will intercept incoming connections, fetch the latest JWKS from your authorizationServer, and validate that the aud (audience), signature, and scopes on the JWT match the requirements defined by your mcpEnabled auth service. + +If your Cloud Run service also requires IAM authentication, you must pass the Cloud Run identity token using [Cloud Run's alternate auth header][cloud-run-alternate-auth-header] to avoid conflicting with Toolbox's internal authentication. + +[cloud-run-alternate-auth-header]: https://docs.cloud.google.com/run/docs/authenticating/service-to-service#acquire-token diff --git a/internal/auth/generic/generic_test.go b/internal/auth/generic/generic_test.go index 9a4f91f87a2b..0238eb1b6c6e 100644 --- a/internal/auth/generic/generic_test.go +++ b/internal/auth/generic/generic_test.go @@ -151,7 +151,6 @@ func TestGetClaimsFromHeader(t *testing.T) { } }, }, - { name: "wrong audience", setupHeader: func() http.Header { @@ -167,7 +166,6 @@ func TestGetClaimsFromHeader(t *testing.T) { wantError: true, errContains: "audience validation failed", }, - { name: "expired token", setupHeader: func() http.Header { diff --git a/internal/server/config.go b/internal/server/config.go index 77ac088e2d5f..402e9a68c940 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -72,6 +72,8 @@ type ServerConfig struct { UI bool // EnableAPI indicates if the /api endpoint is enabled. EnableAPI bool + // ToolboxUrl specifies the URL to advertise in the MCP PRM file as the resource field. + ToolboxUrl string // Specifies a list of origins permitted to access this server. AllowedOrigins []string // Specifies a list of hosts permitted to access this server. @@ -80,8 +82,6 @@ type ServerConfig struct { UserAgentMetadata []string // PollInterval sets the polling frequency for configuration file updates. PollInterval int - // ToolboxUrl specifies the Toolbox URL. Used as the resource field in the MCP PRM file when MCP Auth is enabled. - ToolboxUrl string } type logFormat string @@ -258,7 +258,6 @@ func UnmarshalYAMLAuthServiceConfig(ctx context.Context, name string, r map[stri if !ok { return nil, fmt.Errorf("missing 'type' field or it is not a string") } - dec, err := util.NewStrictDecoder(r) if err != nil { return nil, fmt.Errorf("error creating decoder: %s", err) diff --git a/internal/server/mcp.go b/internal/server/mcp.go index 17f359fee02c..22f77a8bf7f4 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -30,6 +30,7 @@ import ( "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/render" "github.com/google/uuid" + "github.com/googleapis/genai-toolbox/internal/auth/generic" "github.com/googleapis/genai-toolbox/internal/server/mcp" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" mcputil "github.com/googleapis/genai-toolbox/internal/server/mcp/util" @@ -760,3 +761,41 @@ func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVers return "", result, err } } + +type prmResponse struct { + Resource string `json:"resource"` + AuthorizationServers []string `json:"authorization_servers"` + ScopesSupported []string `json:"scopes_supported,omitempty"` + BearerMethodsSupported []string `json:"bearer_methods_supported"` +} + +// prmHandler generates the Protected Resource Metadata (PRM) file for MCP Authorization. +func prmHandler(s *Server, w http.ResponseWriter, r *http.Request) { + var server string + scopes := []string{} + for _, authSvc := range s.ResourceMgr.GetAuthServiceMap() { + cfg := authSvc.ToConfig() + if genCfg, ok := cfg.(generic.Config); ok { + if genCfg.McpEnabled { + server = genCfg.AuthorizationServer + if genCfg.ScopesRequired != nil { + scopes = genCfg.ScopesRequired + } + break + } + } + } + + res := prmResponse{ + Resource: s.toolboxUrl, + AuthorizationServers: []string{server}, + ScopesSupported: scopes, + BearerMethodsSupported: []string{"header"}, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(res); err != nil { + s.logger.ErrorContext(r.Context(), fmt.Sprintf("Failed to encode PRM response: %v", err)) + http.Error(w, "Failed to encode PRM response", http.StatusInternalServerError) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 53a0523e812d..48f815fb04ba 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -381,6 +381,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { instrumentation: instrumentation, sseManager: sseManager, ResourceMgr: resourceManager, + toolboxUrl: cfg.ToolboxUrl, } // cors @@ -410,6 +411,20 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { } r.Use(hostCheck(allowedHostsMap)) + // Host OAuth Protected Resource Metadata endpoint + mcpAuthEnabled := false + for _, authSvc := range s.ResourceMgr.GetAuthServiceMap() { + if genCfg, ok := authSvc.ToConfig().(generic.Config); ok && genCfg.McpEnabled { + mcpAuthEnabled = true + break + } + } + if mcpAuthEnabled { + r.Get("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, req *http.Request) { + prmHandler(s, w, req) + }) + } + // control plane mcpR, err := mcpRouter(s) if err != nil { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index ab809fc579d6..7ca27ebdbccf 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -16,15 +16,19 @@ package server_test import ( "context" + "encoding/json" "fmt" "io" "net/http" + "net/http/httptest" "os" + "reflect" "strings" "testing" "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/auth" + "github.com/googleapis/genai-toolbox/internal/auth/generic" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/log" "github.com/googleapis/genai-toolbox/internal/prompts" @@ -259,3 +263,124 @@ func TestNameValidation(t *testing.T) { }) } } + +func TestPRMEndpoint(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup telemetry and logging + otelShutdown, err := telemetry.SetupOTel(ctx, "0.0.0", "", false, "toolbox") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer func() { + if err := otelShutdown(ctx); err != nil { + t.Fatalf("unexpected error shutting down otel: %s", err) + } + }() + + testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + ctx = util.WithLogger(ctx, testLogger) + + instrumentation, err := telemetry.CreateTelemetryInstrumentation("0.0.0") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + ctx = util.WithInstrumentation(ctx, instrumentation) + + // Create a mock OIDC server to bypass JWKS discovery during init + mockOIDC := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"issuer": "http://%s", "jwks_uri": "http://%s/jwks"}`, r.Host, r.Host) + return + } + if r.URL.Path == "/jwks" { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"keys": []}`) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer mockOIDC.Close() + + // Configure the server + addr, port := "127.0.0.1", 5001 + cfg := server.ServerConfig{ + Version: "0.0.0", + Address: addr, + Port: port, + ToolboxUrl: "https://my-toolbox.example.com", + AllowedHosts: []string{"*"}, + AuthServiceConfigs: map[string]auth.AuthServiceConfig{ + "generic1": generic.Config{ + Name: "generic1", + Type: generic.AuthServiceType, + McpEnabled: true, + AuthorizationServer: mockOIDC.URL, // Injecting the mock server URL here + ScopesRequired: []string{"read", "write"}, + }, + }, + } + + // Initialize and start the server + s, err := server.NewServer(ctx, cfg) + if err != nil { + t.Fatalf("unable to initialize server: %v", err) + } + + if err := s.Listen(ctx); err != nil { + t.Fatalf("unable to start server: %v", err) + } + + errCh := make(chan error) + go func() { + defer close(errCh) + if err := s.Serve(ctx); err != nil && err != http.ErrServerClosed { + errCh <- err + } + }() + defer func() { + if err := s.Shutdown(ctx); err != nil { + t.Errorf("failed to cleanly shutdown server: %v", err) + } + }() + + // Test the PRM endpoint + url := fmt.Sprintf("http://%s:%d/.well-known/oauth-protected-resource", addr, port) + resp, err := http.Get(url) + if err != nil { + t.Fatalf("error when sending a request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unexpected error reading body: %s", err) + } + + var got map[string]any + if err := json.Unmarshal(body, &got); err != nil { + t.Fatalf("unexpected error unmarshalling body: %s", err) + } + + want := map[string]any{ + "resource": "https://my-toolbox.example.com", + "authorization_servers": []any{ + mockOIDC.URL, + }, + "scopes_supported": []any{"read", "write"}, + "bearer_methods_supported": []any{"header"}, + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("unexpected PRM response:\ngot %+v\nwant %+v", got, want) + } +}