diff --git a/cmd/internal/config.go b/cmd/internal/config.go index 2d28fe685dbf..2e6c3e4162e7 100644 --- a/cmd/internal/config.go +++ b/cmd/internal/config.go @@ -27,6 +27,7 @@ import ( "github.com/goccy/go-yaml" "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/auth/generic" "github.com/googleapis/genai-toolbox/internal/server" ) @@ -309,6 +310,18 @@ func mergeConfigs(files ...Config) (Config, error) { return Config{}, fmt.Errorf("resource conflicts detected:\n - %s\n\nPlease ensure each source, authService, tool, toolset and prompt has a unique name across all files", strings.Join(conflicts, "\n - ")) } + // Ensure only one authService has mcpEnabled = true + var mcpEnabledAuthServers []string + for name, authService := range merged.AuthServices { + // Only generic type has McpEnabled right now + if genericService, ok := authService.(generic.Config); ok && genericService.McpEnabled { + mcpEnabledAuthServers = append(mcpEnabledAuthServers, name) + } + } + if len(mcpEnabledAuthServers) > 1 { + return Config{}, fmt.Errorf("multiple authServices with mcpEnabled=true detected: %s. Only one MCP authorization server is currently supported", strings.Join(mcpEnabledAuthServers, ", ")) + } + return merged, nil } diff --git a/cmd/internal/config_test.go b/cmd/internal/config_test.go index fa977f819cae..9e0d42785412 100644 --- a/cmd/internal/config_test.go +++ b/cmd/internal/config_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/auth/generic" "github.com/googleapis/genai-toolbox/internal/auth/google" "github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini" "github.com/googleapis/genai-toolbox/internal/prebuiltconfigs" @@ -616,38 +617,48 @@ func TestParseConfig(t *testing.T) { type: google clientId: testing-id --- - kind: embeddingModel - name: gemini-model - type: gemini - model: gemini-embedding-001 - apiKey: some-key - dimension: 768 + kind: authService + name: my-generic-auth + type: generic + audience: testings + authorizationServer: https://testings + mcpEnabled: true + scopesRequired: + - read:files + - write:files --- - kind: tool - name: example_tool - type: postgres-sql - source: my-pg-instance - description: some description - statement: | - SELECT * FROM SQL_STATEMENT; - parameters: - - name: country - type: string - description: some description + kind: embeddingModel + name: gemini-model + type: gemini + model: gemini-embedding-001 + apiKey: some-key + dimension: 768 --- - kind: toolset - name: example_toolset - tools: - - example_tool + kind: tool + name: example_tool + type: postgres-sql + source: my-pg-instance + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + parameters: + - name: country + type: string + description: some description --- - kind: prompt - name: code_review - description: ask llm to analyze code quality - messages: - - content: "please review the following code for quality: {{.code}}" - arguments: - - name: code - description: the code to review + kind: toolset + name: example_toolset + tools: + - example_tool +--- + kind: prompt + name: code_review + description: ask llm to analyze code quality + messages: + - content: "please review the following code for quality: {{.code}}" + arguments: + - name: code + description: the code to review `, wantConfig: Config{ Sources: server.SourceConfigs{ @@ -669,6 +680,14 @@ func TestParseConfig(t *testing.T) { Type: google.AuthServiceType, ClientID: "testing-id", }, + "my-generic-auth": generic.Config{ + Name: "my-generic-auth", + Type: generic.AuthServiceType, + Audience: "testings", + McpEnabled: true, + AuthorizationServer: "https://testings", + ScopesRequired: []string{"read:files", "write:files"}, + }, }, EmbeddingModels: server.EmbeddingModelConfigs{ "gemini-model": gemini.Config{ @@ -2029,12 +2048,19 @@ func TestMergeConfigs(t *testing.T) { Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, Tools: server.ToolConfigs{"tool2": http.Config{Name: "tool2"}}, } + fileMcp1 := Config{ + AuthServices: server.AuthServiceConfigs{"generic1": generic.Config{Name: "generic1", McpEnabled: true}}, + } + fileMcp2 := Config{ + AuthServices: server.AuthServiceConfigs{"generic2": generic.Config{Name: "generic2", McpEnabled: true}}, + } testCases := []struct { - name string - files []Config - want Config - wantErr bool + name string + files []Config + want Config + wantErr bool + errString string }{ { name: "merge two distinct files", @@ -2054,6 +2080,12 @@ func TestMergeConfigs(t *testing.T) { files: []Config{file1, file2, fileWithConflicts}, wantErr: true, }, + { + name: "merge multiple mcp enabled generic", + files: []Config{fileMcp1, fileMcp2}, + wantErr: true, + errString: "multiple authServices with mcpEnabled=true detected", + }, { name: "merge single file", files: []Config{file1}, @@ -2094,7 +2126,9 @@ func TestMergeConfigs(t *testing.T) { if err == nil { t.Fatal("expected an error for conflicting files but got none") } - if !strings.Contains(err.Error(), "resource conflicts detected") { + if tc.errString != "" && !strings.Contains(err.Error(), tc.errString) { + t.Errorf("expected error %q, but got: %v", tc.errString, err) + } else if tc.errString == "" && !strings.Contains(err.Error(), "resource conflicts detected") { t.Errorf("expected conflict error, but got: %v", err) } } diff --git a/docs/en/resources/authServices/generic.md b/docs/en/resources/authServices/generic.md new file mode 100644 index 000000000000..c12adf90e041 --- /dev/null +++ b/docs/en/resources/authServices/generic.md @@ -0,0 +1,67 @@ +--- +title: "Generic OIDC Auth" +type: docs +weight: 2 +description: > + Use a Generic OpenID Connect (OIDC) provider for OAuth 2.0 flow and token + lifecycle. +--- + +## Getting Started + +The Generic Auth Service allows you to integrate with any OpenID Connect (OIDC) +compliant identity provider (IDP). It discovers the JWKS (JSON Web Key Set) URL +either through the provider's `/.well-known/openid-configuration` endpoint or +directly via the provided `authorizationServer`. + +To configure this auth service, you need to provide the `audience` (typically +your client ID or the intended audience for the token), the +`authorizationServer` of your identity provider, and optionally a list of +`scopesRequired` that must be present in the token's claims. + +## Behavior + +### Token Validation + +When a request is received, the service will: + +1. Extract the token from the `_token` header (e.g., + `my-generic-auth_token`). +2. Fetch the JWKS from the configured `authorizationServer` (caching it in the + background) to verify the token's signature. +3. Validate that the token is not expired and its signature is valid. +4. Verify that the `aud` (audience) claim matches the configured `audience`. + claim contains all required scopes. +5. Return the validated claims to be used for [Authenticated + Parameters][auth-params] or [Authorized Invocations][auth-invoke]. + +[auth-invoke]: ../tools/#authorized-invocations +[auth-params]: ../tools/#authenticated-parameters + +## Example + +```yaml +kind: authServices +name: my-generic-auth +type: generic +audience: ${YOUR_OIDC_AUDIENCE} +authorizationServer: https://your-idp.example.com +mcpEnabled: false +scopesRequired: + - read + - write +``` + +{{< notice tip >}} Use environment variable replacement with the format +${ENV_NAME} instead of hardcoding your secrets into the configuration file. +{{< /notice >}} + +## Reference + +| **field** | **type** | **required** | **description** | +| ------------------- | :------: | :----------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| type | string | true | Must be "generic". | +| audience | string | true | The expected audience (`aud` claim) in the JWT token. This ensures the token was minted specifically for your application. | +| authorizationServer | string | true | The base URL of your OIDC provider. The service will append `/.well-known/openid-configuration` to discover the JWKS URI. HTTP is allowed but logs a warning. | +| mcpEnabled | bool | false | Indicates if MCP endpoint authentication should be applied. Defaults to false. | +| scopesRequired | []string | false | A list of required scopes that must be present in the token's `scope` claim to be considered valid. | diff --git a/go.mod b/go.mod index 1384140396e3..26ae4a652153 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,8 @@ require ( github.com/ClickHouse/clickhouse-go/v2 v2.43.0 github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.31.0 + github.com/MicahParks/jwkset v0.11.0 + github.com/MicahParks/keyfunc/v3 v3.8.0 github.com/apache/cassandra-gocql-driver/v2 v2.0.0 github.com/cenkalti/backoff/v5 v5.0.3 github.com/cockroachdb/cockroach-go/v2 v2.4.3 @@ -36,6 +38,7 @@ require ( github.com/go-sql-driver/mysql v1.9.3 github.com/goccy/go-yaml v1.19.2 github.com/godror/godror v0.50.0 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.9.1 @@ -172,7 +175,6 @@ require ( github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect github.com/godror/knownpb v0.3.0 // indirect github.com/gofrs/flock v0.13.0 // indirect - github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect diff --git a/go.sum b/go.sum index 9dabb055df86..f59f628d384d 100644 --- a/go.sum +++ b/go.sum @@ -89,6 +89,10 @@ github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0 github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0/go.mod h1:vB2GH9GAYYJTO3mEn8oYwzEdhlayZIdQz6zdzgUIRvA= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 h1:0s6TxfCu2KHkkZPnBfsQ2y5qia0jl3MMrmBhu3nCOYk= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0/go.mod h1:Mf6O40IAyB9zR/1J8nGDDPirZQQPbYJni8Yisy7NTMc= +github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= +github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= +github.com/MicahParks/keyfunc/v3 v3.8.0 h1:Hx2dgIjAXGk9slakM6rV9BOeaWDPEXXZ4Us8guNBfds= +github.com/MicahParks/keyfunc/v3 v3.8.0/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= diff --git a/internal/auth/generic/generic.go b/internal/auth/generic/generic.go new file mode 100644 index 000000000000..1f60ceb384e9 --- /dev/null +++ b/internal/auth/generic/generic.go @@ -0,0 +1,282 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package generic + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/MicahParks/keyfunc/v3" + "github.com/golang-jwt/jwt/v5" + "github.com/googleapis/genai-toolbox/internal/auth" +) + +const AuthServiceType string = "generic" + +// validate interface +var _ auth.AuthServiceConfig = Config{} + +// Auth service configuration +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Audience string `yaml:"audience" validate:"required"` + McpEnabled bool `yaml:"mcpEnabled"` + AuthorizationServer string `yaml:"authorizationServer" validate:"required"` + ScopesRequired []string `yaml:"scopesRequired"` +} + +// Returns the auth service type +func (cfg Config) AuthServiceConfigType() string { + return AuthServiceType +} + +// Initialize a generic auth service +func (cfg Config) Initialize() (auth.AuthService, error) { + // Discover the JWKS URL from the OIDC configuration endpoint + jwksURL, err := discoverJWKSURL(cfg.AuthorizationServer) + if err != nil { + return nil, fmt.Errorf("failed to discover JWKS URL: %w", err) + } + + // Create the keyfunc to fetch and cache the JWKS in the background + kf, err := keyfunc.NewDefault([]string{jwksURL}) + if err != nil { + return nil, fmt.Errorf("failed to create keyfunc from JWKS URL %s: %w", jwksURL, err) + } + + a := &AuthService{ + Config: cfg, + kf: kf, + } + return a, nil +} + +func discoverJWKSURL(AuthorizationServer string) (string, error) { + u, err := url.Parse(AuthorizationServer) + if err != nil { + return "", fmt.Errorf("invalid auth URL") + } + if u.Scheme != "https" { + log.Printf("WARNING: HTTP instead of HTTPS is being used for AuthorizationServer: %s", AuthorizationServer) + } + + oidcConfigURL, err := url.JoinPath(AuthorizationServer, ".well-known/openid-configuration") + if err != nil { + return "", err + } + + // HTTP Client + client := &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + MaxIdleConns: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 5 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + // Prevent redirect loops or redirects to internal sites + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err := client.Get(oidcConfigURL) + if err != nil { + return "", fmt.Errorf("failed to fetch OIDC config: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected status: %d", resp.StatusCode) + } + + // Limit read size to 1MB to prevent memory exhaustion + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return "", err + } + + var config struct { + JWKSURI string `json:"jwks_uri"` + } + if err := json.Unmarshal(body, &config); err != nil { + return "", err + } + + if config.JWKSURI == "" { + return "", fmt.Errorf("jwks_uri not found in config") + } + + // Sanitize the resulting JWKS URI before returning it + parsedJWKS, err := url.Parse(config.JWKSURI) + if err != nil { + return "", fmt.Errorf("invalid jwks_uri detected") + } + if parsedJWKS.Scheme != "https" { + log.Printf("WARNING: HTTP instead of HTTPS is being used for JWKS URI: %s", config.JWKSURI) + } + + return config.JWKSURI, nil +} + +var _ auth.AuthService = AuthService{} + +// struct used to store auth service info +type AuthService struct { + Config + kf keyfunc.Keyfunc +} + +// Returns the auth service type +func (a AuthService) AuthServiceType() string { + return AuthServiceType +} + +func (a AuthService) ToConfig() auth.AuthServiceConfig { + return a.Config +} + +// Returns the name of the auth service +func (a AuthService) GetName() string { + return a.Name +} + +// Verifies generic JWT access token inside the Authorization header +func (a AuthService) GetClaimsFromHeader(ctx context.Context, h http.Header) (map[string]any, error) { + if a.McpEnabled { + return nil, nil + } + + tokenString := h.Get(a.Name + "_token") + if tokenString == "" { + return nil, nil + } + + // Parse and verify the token signature + token, err := jwt.Parse(tokenString, a.kf.Keyfunc) + if err != nil { + return nil, fmt.Errorf("failed to parse and verify JWT token: %w", err) + } + + if !token.Valid { + return nil, fmt.Errorf("invalid JWT token") + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, fmt.Errorf("invalid JWT claims format") + } + + // Validate 'aud' (audience) claim + aud, err := claims.GetAudience() + if err != nil { + return nil, fmt.Errorf("could not parse audience from token: %w", err) + } + + isAudValid := false + for _, audItem := range aud { + if audItem == a.Audience { + isAudValid = true + break + } + } + + if !isAudValid { + return nil, fmt.Errorf("audience validation failed: expected %s, got %v", a.Audience, aud) + } + + return claims, nil +} + +// MCPAuthError represents an error during MCP authentication validation. +type MCPAuthError struct { + Code int + Message string + ScopesRequired []string +} + +func (e *MCPAuthError) Error() string { return e.Message } + +// ValidateMCPAuth handles MCP auth token validation +func (a AuthService) ValidateMCPAuth(ctx context.Context, h http.Header) error { + tokenString := h.Get("Authorization") + if tokenString == "" { + return &MCPAuthError{Code: http.StatusUnauthorized, Message: "missing access token", ScopesRequired: a.ScopesRequired} + } + + headerParts := strings.Split(tokenString, " ") + if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" { + return &MCPAuthError{Code: http.StatusUnauthorized, Message: "authorization header must be in the format 'Bearer '", ScopesRequired: a.ScopesRequired} + } + + token, err := jwt.Parse(headerParts[1], a.kf.Keyfunc) + if err != nil || !token.Valid { + return &MCPAuthError{Code: http.StatusUnauthorized, Message: "invalid or expired token", ScopesRequired: a.ScopesRequired} + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return &MCPAuthError{Code: http.StatusUnauthorized, Message: "invalid JWT claims format", ScopesRequired: a.ScopesRequired} + } + + // Validate audience + aud, err := claims.GetAudience() + if err != nil { + return &MCPAuthError{Code: http.StatusUnauthorized, Message: "could not parse audience from token", ScopesRequired: a.ScopesRequired} + } + + isAudValid := false + for _, audItem := range aud { + if audItem == a.Audience { + isAudValid = true + break + } + } + + if !isAudValid { + return &MCPAuthError{Code: http.StatusUnauthorized, Message: "audience validation failed", ScopesRequired: a.ScopesRequired} + } + + // Check scopes + if len(a.ScopesRequired) > 0 { + scopeClaim, ok := claims["scope"].(string) + if !ok { + return &MCPAuthError{Code: http.StatusForbidden, Message: "insufficient scopes", ScopesRequired: a.ScopesRequired} + } + + tokenScopes := strings.Split(scopeClaim, " ") + scopeMap := make(map[string]bool) + for _, s := range tokenScopes { + scopeMap[s] = true + } + + for _, requiredScope := range a.ScopesRequired { + if !scopeMap[requiredScope] { + return &MCPAuthError{Code: http.StatusForbidden, Message: "insufficient scopes", ScopesRequired: a.ScopesRequired} + } + } + } + + return nil +} diff --git a/internal/auth/generic/generic_test.go b/internal/auth/generic/generic_test.go new file mode 100644 index 000000000000..9a4f91f87a2b --- /dev/null +++ b/internal/auth/generic/generic_test.go @@ -0,0 +1,210 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package generic + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/MicahParks/jwkset" + "github.com/golang-jwt/jwt/v5" +) + +func generateRSAPrivateKey(t *testing.T) *rsa.PrivateKey { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to create RSA private key: %v", err) + } + return key +} + +func setupJWKSMockServer(t *testing.T, key *rsa.PrivateKey, keyID string) *httptest.Server { + t.Helper() + + jwksHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": "https://example.com", + "jwks_uri": "http://" + r.Host + "/jwks", + }) + return + } + + if r.URL.Path == "/jwks" { + options := jwkset.JWKOptions{ + Metadata: jwkset.JWKMetadataOptions{ + KID: keyID, + }, + } + jwk, err := jwkset.NewJWKFromKey(key.Public(), options) + if err != nil { + t.Fatalf("failed to create JWK: %v", err) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "keys": []jwkset.JWKMarshal{jwk.Marshal()}, + }) + return + } + + http.NotFound(w, r) + }) + + return httptest.NewServer(jwksHandler) +} + +func generateValidToken(t *testing.T, key *rsa.PrivateKey, keyID string, claims jwt.MapClaims) string { + t.Helper() + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = keyID + signedString, err := token.SignedString(key) + if err != nil { + t.Fatalf("failed to sign token: %v", err) + } + return signedString +} + +func TestGetClaimsFromHeader(t *testing.T) { + privateKey := generateRSAPrivateKey(t) + keyID := "test-key-id" + server := setupJWKSMockServer(t, privateKey, keyID) + defer server.Close() + + cfg := Config{ + Name: "test-generic-auth", + Type: "generic", + Audience: "my-audience", + McpEnabled: false, + AuthorizationServer: server.URL, + ScopesRequired: []string{"read:files"}, + } + + authService, err := cfg.Initialize() + if err != nil { + t.Fatalf("failed to initialize auth service: %v", err) + } + + genericAuth, ok := authService.(*AuthService) + if !ok { + t.Fatalf("expected *AuthService, got %T", authService) + } + + ctx := context.Background() + + tests := []struct { + name string + setupHeader func() http.Header + wantError bool + errContains string + validate func(claims map[string]any) + }{ + { + name: "valid token", + setupHeader: func() http.Header { + token := generateValidToken(t, privateKey, keyID, jwt.MapClaims{ + "aud": "my-audience", + "scope": "read:files write:files", + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + }) + header := http.Header{} + header.Set("test-generic-auth_token", token) + return header + }, + wantError: false, + validate: func(claims map[string]any) { + if sub, ok := claims["sub"].(string); !ok || sub != "test-user" { + t.Errorf("expected sub=test-user, got %v", claims["sub"]) + } + }, + }, + { + name: "no header", + setupHeader: func() http.Header { + return http.Header{} + }, + wantError: false, + validate: func(claims map[string]any) { + if claims != nil { + t.Errorf("expected nil claims on missing header, got %v", claims) + } + }, + }, + + { + name: "wrong audience", + setupHeader: func() http.Header { + token := generateValidToken(t, privateKey, keyID, jwt.MapClaims{ + "aud": "wrong-audience", + "scope": "read:files", + "exp": time.Now().Add(time.Hour).Unix(), + }) + header := http.Header{} + header.Set("test-generic-auth_token", token) + return header + }, + wantError: true, + errContains: "audience validation failed", + }, + + { + name: "expired token", + setupHeader: func() http.Header { + token := generateValidToken(t, privateKey, keyID, jwt.MapClaims{ + "aud": "my-audience", + "scope": "read:files", + "exp": time.Now().Add(-1 * time.Hour).Unix(), + }) + header := http.Header{} + header.Set("test-generic-auth_token", token) + return header + }, + wantError: true, + errContains: "token has invalid claims: token is expired", // Custom JWT err string + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + header := tc.setupHeader() + claims, err := genericAuth.GetClaimsFromHeader(ctx, header) + + if tc.wantError { + if err == nil { + t.Fatalf("expected error, got nil") + } + if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("expected error containing %q, got: %v", tc.errContains, err) + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc.validate != nil { + tc.validate(claims) + } + } + }) + } +} diff --git a/internal/server/config.go b/internal/server/config.go index eb02760d86fb..77ac088e2d5f 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -23,6 +23,7 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/auth" + "github.com/googleapis/genai-toolbox/internal/auth/generic" "github.com/googleapis/genai-toolbox/internal/auth/google" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini" @@ -79,6 +80,8 @@ type ServerConfig struct { UserAgentMetadata []string // PollInterval sets the polling frequency for configuration file updates. PollInterval int + // ToolboxUrl specifies the Toolbox URL. Used as the resource field in the MCP PRM file when MCP Auth is enabled. + ToolboxUrl string } type logFormat string @@ -255,18 +258,28 @@ func UnmarshalYAMLAuthServiceConfig(ctx context.Context, name string, r map[stri if !ok { return nil, fmt.Errorf("missing 'type' field or it is not a string") } - if resourceType != google.AuthServiceType { - return nil, fmt.Errorf("%s is not a valid type of auth service", resourceType) - } + dec, err := util.NewStrictDecoder(r) if err != nil { return nil, fmt.Errorf("error creating decoder: %s", err) } - actual := google.Config{Name: name} - if err := dec.DecodeContext(ctx, &actual); err != nil { - return nil, fmt.Errorf("unable to parse as %s: %w", name, err) + + switch resourceType { + case google.AuthServiceType: + actual := google.Config{Name: name} + if err := dec.DecodeContext(ctx, &actual); err != nil { + return nil, fmt.Errorf("unable to parse as %s: %w", name, err) + } + return actual, nil + case generic.AuthServiceType: + actual := generic.Config{Name: name} + if err := dec.DecodeContext(ctx, &actual); err != nil { + return nil, fmt.Errorf("unable to parse as %s: %w", name, err) + } + return actual, nil + default: + return nil, fmt.Errorf("%s is not a valid type of auth service", resourceType) } - return actual, nil } func UnmarshalYAMLEmbeddingModelConfig(ctx context.Context, name string, r map[string]any) (embeddingmodels.EmbeddingModelConfig, error) { diff --git a/internal/server/mcp.go b/internal/server/mcp.go index 70e6b6e8b96c..17f359fee02c 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -332,6 +332,7 @@ func mcpRouter(s *Server) (chi.Router, error) { r.Use(middleware.AllowContentType("application/json", "application/json-rpc", "application/jsonrequest")) r.Use(middleware.StripSlashes) r.Use(render.SetContentType(render.ContentTypeJSON)) + r.Use(mcpAuthMiddleware(s)) r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) }) r.Get("/", func(w http.ResponseWriter, r *http.Request) { methodNotAllowed(s, w, r) }) diff --git a/internal/server/server.go b/internal/server/server.go index 21188f2f3a4f..53a0523e812d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -16,6 +16,7 @@ package server import ( "context" + "errors" "fmt" "io" "net" @@ -30,6 +31,7 @@ import ( "github.com/go-chi/cors" "github.com/go-chi/httplog/v3" "github.com/googleapis/genai-toolbox/internal/auth" + "github.com/googleapis/genai-toolbox/internal/auth/generic" "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/log" "github.com/googleapis/genai-toolbox/internal/prompts" @@ -45,6 +47,7 @@ import ( // Server contains info for running an instance of Toolbox. Should be instantiated with NewServer(). type Server struct { version string + toolboxUrl string srv *http.Server listener net.Listener root chi.Router @@ -412,6 +415,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { if err != nil { return nil, err } + r.Mount("/mcp", mcpR) if cfg.EnableAPI { apiR, err := apiRouter(s) @@ -435,6 +439,49 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { return s, nil } +func mcpAuthMiddleware(s *Server) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Find McpEnabled auth service + var mcpSvc *generic.AuthService + for _, authSvc := range s.ResourceMgr.GetAuthServiceMap() { + if genSvc, ok := authSvc.(*generic.AuthService); ok && genSvc.McpEnabled { + mcpSvc = genSvc + break + } + } + + // MCP Auth not enabled + if mcpSvc == nil { + next.ServeHTTP(w, r) + return + } + + if err := mcpSvc.ValidateMCPAuth(r.Context(), r.Header); err != nil { + var mcpErr *generic.MCPAuthError + if errors.As(err, &mcpErr) { + switch mcpErr.Code { + case http.StatusUnauthorized: + scopesArg := "" + if len(mcpErr.ScopesRequired) > 0 { + scopesArg = fmt.Sprintf(`, scope="%s"`, strings.Join(mcpErr.ScopesRequired, " ")) + } + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"%s`, s.toolboxUrl+"/.well-known/oauth-protected-resource", scopesArg)) + http.Error(w, mcpErr.Message, http.StatusUnauthorized) + return + case http.StatusForbidden: + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer error="insufficient_scope", scope="%s", resource_metadata="%s", error_description="%s"`, strings.Join(mcpErr.ScopesRequired, " "), s.toolboxUrl+"/.well-known/oauth-protected-resource", mcpErr.Message)) + http.Error(w, mcpErr.Message, http.StatusForbidden) + return + } + } + } + + next.ServeHTTP(w, r) + }) + } +} + // Listen starts a listener for the given Server instance. func (s *Server) Listen(ctx context.Context) error { ctx, cancel := context.WithCancel(ctx) diff --git a/tests/auth/auth_integration_test.go b/tests/auth/auth_integration_test.go new file mode 100644 index 000000000000..78db4e4830b7 --- /dev/null +++ b/tests/auth/auth_integration_test.go @@ -0,0 +1,170 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "regexp" + "strings" + "testing" + "time" + + "github.com/MicahParks/jwkset" + "github.com/golang-jwt/jwt/v5" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/tests" +) + +// TestMcpAuth test for MCP Authorization +func TestMcpAuth(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + // Set up generic auth mock server + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to create RSA private key: %v", err) + } + jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": "https://example.com", + "jwks_uri": "http://" + r.Host + "/jwks", + }) + return + } + if r.URL.Path == "/jwks" { + options := jwkset.JWKOptions{ + Metadata: jwkset.JWKMetadataOptions{ + KID: "test-key-id", + }, + } + jwk, _ := jwkset.NewJWKFromKey(privateKey.Public(), options) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "keys": []jwkset.JWKMarshal{jwk.Marshal()}, + }) + return + } + http.NotFound(w, r) + })) + defer jwksServer.Close() + + toolsFile := map[string]any{ + "sources": map[string]any{}, + "authServices": map[string]any{ + "my-generic-auth": map[string]any{ + "type": "generic", + "audience": "test-audience", + "authorizationServer": jwksServer.URL, + "scopesRequired": []string{"read:files"}, + "mcpEnabled": true, + }, + }, + "tools": map[string]any{}, + } + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + api := "http://127.0.0.1:5000/mcp/sse" + + t.Run("401 Unauthorized without token", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, api, nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", resp.StatusCode) + } + authHeader := resp.Header.Get("WWW-Authenticate") + if !strings.Contains(authHeader, `resource_metadata="/.well-known/oauth-protected-resource"`) || !strings.Contains(authHeader, `scope="read:files"`) { + t.Fatalf("expected WWW-Authenticate header to contain resource_metadata and scope, got: %s", authHeader) + } + }) + + t.Run("403 Forbidden with insufficient scopes", func(t *testing.T) { + // Generate valid token but wrong scopes + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "aud": "test-audience", + "scope": "wrong:scope", + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + }) + token.Header["kid"] = "test-key-id" + signedString, _ := token.SignedString(privateKey) + + req, _ := http.NewRequest(http.MethodGet, api, nil) + req.Header.Add("Authorization", "Bearer "+signedString) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("expected 403, got %d", resp.StatusCode) + } + authHeader := resp.Header.Get("WWW-Authenticate") + if !strings.Contains(authHeader, `resource_metadata="/.well-known/oauth-protected-resource"`) || !strings.Contains(authHeader, `scope="read:files"`) || !strings.Contains(authHeader, `error="insufficient_scope"`) { + t.Fatalf("expected WWW-Authenticate header to contain error, scope, and resource_metadata, got: %s", authHeader) + } + }) + + t.Run("200 OK with valid token", func(t *testing.T) { + // Generate valid token with correct scopes + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "aud": "test-audience", + "scope": "read:files", + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + }) + token.Header["kid"] = "test-key-id" + signedString, _ := token.SignedString(privateKey) + + req, _ := http.NewRequest(http.MethodGet, api, nil) + req.Header.Add("Authorization", "Bearer "+signedString) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + }) +} diff --git a/tests/http/http_integration_test.go b/tests/http/http_integration_test.go index a3486561ad80..f7c18b89ab6e 100644 --- a/tests/http/http_integration_test.go +++ b/tests/http/http_integration_test.go @@ -17,6 +17,8 @@ package http import ( "bytes" "context" + "crypto/rand" + "crypto/rsa" "encoding/json" "fmt" "io" @@ -28,6 +30,8 @@ import ( "testing" "time" + "github.com/MicahParks/jwkset" + "github.com/golang-jwt/jwt/v5" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/googleapis/genai-toolbox/tests" @@ -307,9 +311,40 @@ func TestHttpToolEndpoints(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() + // Set up generic auth mock server + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to create RSA private key: %v", err) + } + jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": "https://example.com", + "jwks_uri": "http://" + r.Host + "/jwks", + }) + return + } + if r.URL.Path == "/jwks" { + options := jwkset.JWKOptions{ + Metadata: jwkset.JWKMetadataOptions{ + KID: "test-key-id", + }, + } + jwk, _ := jwkset.NewJWKFromKey(privateKey.Public(), options) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "keys": []jwkset.JWKMarshal{jwk.Marshal()}, + }) + return + } + http.NotFound(w, r) + })) + defer jwksServer.Close() + var args []string - toolsFile := getHTTPToolsConfig(sourceConfig, HttpToolType) + toolsFile := getHTTPToolsConfig(sourceConfig, HttpToolType, jwksServer.URL) cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) if err != nil { t.Fatalf("command initialization returned an error: %s", err) @@ -329,6 +364,70 @@ func TestHttpToolEndpoints(t *testing.T) { tests.RunToolInvokeTest(t, `"hello world"`, tests.DisableArrayTest()) runAdvancedHTTPInvokeTest(t) runQueryParamInvokeTest(t) + runGenericAuthInvokeTest(t, privateKey) +} + +func runGenericAuthInvokeTest(t *testing.T, privateKey *rsa.PrivateKey) { + // Generate valid token + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "aud": "test-audience", + "scope": "read:files", + "sub": "test-user", + "exp": time.Now().Add(time.Hour).Unix(), + }) + token.Header["kid"] = "test-key-id" + signedString, err := token.SignedString(privateKey) + if err != nil { + t.Fatalf("failed to sign token: %v", err) + } + + api := "http://127.0.0.1:5000/api/tool/my-auth-required-generic-tool/invoke" + + // Test without auth header (should fail) + t.Run("invoke generic auth tool without token", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, api, bytes.NewBuffer([]byte(`{}`))) + req.Header.Add("Content-type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + var body map[string]interface{} + _ = json.NewDecoder(resp.Body).Decode(&body) + errorStr, _ := body["error"].(string) + statusStr, _ := body["status"].(string) + if !strings.Contains(strings.ToLower(errorStr), "not authorized") && !strings.Contains(strings.ToLower(statusStr), "unauthorized") { + bodyBytes, _ := json.Marshal(body) + t.Fatalf("expected unauthorized error, got: %s", string(bodyBytes)) + } + }) + + // Test with valid token + t.Run("invoke generic auth tool with valid token", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, api, bytes.NewBuffer([]byte(`{}`))) + req.Header.Add("Content-type", "application/json") + req.Header.Add("my-generic-auth_token", signedString) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("expected status 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var body map[string]interface{} + _ = json.NewDecoder(resp.Body).Decode(&body) + got, ok := body["result"].(string) + if !ok || got != `"hello world"` { + bodyBytes, _ := json.Marshal(body) + t.Fatalf("unexpected result: %s", string(bodyBytes)) + } + }) } // runQueryParamInvokeTest runs the tool invoke endpoint for the query param test tool @@ -500,7 +599,7 @@ func runAdvancedHTTPInvokeTest(t *testing.T) { } // getHTTPToolsConfig returns a mock HTTP tool's config file -func getHTTPToolsConfig(sourceConfig map[string]any, toolType string) map[string]any { +func getHTTPToolsConfig(sourceConfig map[string]any, toolType string, jwksURL string) map[string]any { // Write config into a file and pass it to command otherSourceConfig := make(map[string]any) for k, v := range sourceConfig { @@ -524,6 +623,12 @@ func getHTTPToolsConfig(sourceConfig map[string]any, toolType string) map[string "type": "google", "clientId": clientID, }, + "my-generic-auth": map[string]any{ + "type": "generic", + "audience": "test-audience", + "authorizationServer": jwksURL, + "scopesRequired": []string{"read:files"}, + }, }, "tools": map[string]any{ "my-simple-tool": map[string]any{ @@ -603,6 +708,15 @@ func getHTTPToolsConfig(sourceConfig map[string]any, toolType string) map[string "requestBody": "{}", "authRequired": []string{"my-google-auth"}, }, + "my-auth-required-generic-tool": map[string]any{ + "type": toolType, + "source": "my-instance", + "method": "POST", + "path": "/tool0", + "description": "some description", + "requestBody": "{}", + "authRequired": []string{"my-generic-auth"}, + }, "my-advanced-tool": map[string]any{ "type": toolType, "source": "other-instance",