Skip to content

Commit 234baee

Browse files
committed
update getClaimsFromHeader
1 parent 8788beb commit 234baee

File tree

3 files changed

+126
-89
lines changed

3 files changed

+126
-89
lines changed

internal/auth/generic/generic.go

Lines changed: 5 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -121,16 +121,14 @@ func (a AuthService) GetName() string {
121121

122122
// Verifies generic JWT access token inside the Authorization header
123123
func (a AuthService) GetClaimsFromHeader(ctx context.Context, h http.Header) (map[string]any, error) {
124-
authHeader := h.Get("Authorization")
125-
if authHeader == "" {
126-
return nil, nil // Return nil, nil if no authorization header is found
124+
if a.McpEnabled {
125+
return nil, nil
127126
}
128127

129-
parts := strings.Split(authHeader, " ")
130-
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
131-
return nil, fmt.Errorf("authorization header format must be Bearer {token}")
128+
tokenString := h.Get(a.Name + "_token")
129+
if tokenString == "" {
130+
return nil, nil
132131
}
133-
tokenString := parts[1]
134132

135133
// Parse and verify the token signature
136134
token, err := jwt.Parse(tokenString, a.kf.Keyfunc)
@@ -161,45 +159,10 @@ func (a AuthService) GetClaimsFromHeader(ctx context.Context, h http.Header) (ma
161159
}
162160
}
163161

164-
// Some IDPs use 'client_id' instead of 'aud' or put it as a single string, checking that if aud not found or not matched
165-
if !isAudValid {
166-
if clientIDClaim, ok := claims["client_id"].(string); ok && clientIDClaim == a.Audience {
167-
isAudValid = true
168-
} else if audStr, ok := claims["aud"].(string); ok && audStr == a.Audience {
169-
isAudValid = true
170-
}
171-
}
172-
173162
if !isAudValid {
174163
return nil, fmt.Errorf("audience validation failed: expected %s, got %v", a.Audience, aud)
175164
}
176165

177-
// Validate 'scope' claim against ScopesRequired
178-
if len(a.ScopesRequired) > 0 {
179-
var tokenScopes []string
180-
181-
switch s := claims["scope"].(type) {
182-
case string:
183-
tokenScopes = strings.Split(s, " ") // space-separated string is common
184-
case []interface{}:
185-
for _, v := range s {
186-
if str, ok := v.(string); ok {
187-
tokenScopes = append(tokenScopes, str)
188-
}
189-
}
190-
}
191-
192-
scopeMap := make(map[string]bool)
193-
for _, s := range tokenScopes {
194-
scopeMap[s] = true
195-
}
196-
197-
for _, requiredScope := range a.ScopesRequired {
198-
if !scopeMap[requiredScope] {
199-
return nil, fmt.Errorf("missing required scope: %s", requiredScope)
200-
}
201-
}
202-
}
203166

204167
// Return claims dynamically
205168
claimsMap := make(map[string]any)

internal/auth/generic/generic_test.go

Lines changed: 6 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ func TestGetClaimsFromHeader(t *testing.T) {
9595
Name: "test-generic-auth",
9696
Type: "generic",
9797
Audience: "my-audience",
98-
McpEnabled: true,
98+
McpEnabled: false,
9999
AuthorizationServerURL: server.URL,
100100
ScopesRequired: []string{"read:files"},
101101
}
@@ -129,7 +129,7 @@ func TestGetClaimsFromHeader(t *testing.T) {
129129
"exp": time.Now().Add(time.Hour).Unix(),
130130
})
131131
header := http.Header{}
132-
header.Set("Authorization", "Bearer "+token)
132+
header.Set("test-generic-auth_token", token)
133133
return header
134134
},
135135
wantError: false,
@@ -151,16 +151,7 @@ func TestGetClaimsFromHeader(t *testing.T) {
151151
}
152152
},
153153
},
154-
{
155-
name: "invalid formatting",
156-
setupHeader: func() http.Header {
157-
header := http.Header{}
158-
header.Set("Authorization", "Token something")
159-
return header
160-
},
161-
wantError: true,
162-
errContains: "authorization header format must be Bearer {token}",
163-
},
154+
164155
{
165156
name: "wrong audience",
166157
setupHeader: func() http.Header {
@@ -170,41 +161,13 @@ func TestGetClaimsFromHeader(t *testing.T) {
170161
"exp": time.Now().Add(time.Hour).Unix(),
171162
})
172163
header := http.Header{}
173-
header.Set("Authorization", "Bearer "+token)
164+
header.Set("test-generic-auth_token", token)
174165
return header
175166
},
176167
wantError: true,
177168
errContains: "audience validation failed",
178169
},
179-
{
180-
name: "missing required scope",
181-
setupHeader: func() http.Header {
182-
token := generateValidToken(t, privateKey, keyID, jwt.MapClaims{
183-
"aud": "my-audience",
184-
"scope": "some:other_scope",
185-
"exp": time.Now().Add(time.Hour).Unix(),
186-
})
187-
header := http.Header{}
188-
header.Set("Authorization", "Bearer "+token)
189-
return header
190-
},
191-
wantError: true,
192-
errContains: "missing required scope: read:files",
193-
},
194-
{
195-
name: "client_id used instead of aud (valid)",
196-
setupHeader: func() http.Header {
197-
token := generateValidToken(t, privateKey, keyID, jwt.MapClaims{
198-
"client_id": "my-audience",
199-
"scope": []interface{}{"read:files"}, // Testing slice type scopes
200-
"exp": time.Now().Add(time.Hour).Unix(),
201-
})
202-
header := http.Header{}
203-
header.Set("Authorization", "Bearer "+token)
204-
return header
205-
},
206-
wantError: false,
207-
},
170+
208171
{
209172
name: "expired token",
210173
setupHeader: func() http.Header {
@@ -214,7 +177,7 @@ func TestGetClaimsFromHeader(t *testing.T) {
214177
"exp": time.Now().Add(-1 * time.Hour).Unix(),
215178
})
216179
header := http.Header{}
217-
header.Set("Authorization", "Bearer "+token)
180+
header.Set("test-generic-auth_token", token)
218181
return header
219182
},
220183
wantError: true,
@@ -231,8 +194,6 @@ func TestGetClaimsFromHeader(t *testing.T) {
231194
if err == nil {
232195
t.Fatalf("expected error, got nil")
233196
}
234-
// We don't check for exact prefix because jwt library errors can be complex,
235-
// check for substring or simple failure instead.
236197
if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) {
237198
t.Errorf("expected error containing %q, got: %v", tc.errContains, err)
238199
}

tests/http/http_integration_test.go

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ package http
1717
import (
1818
"bytes"
1919
"context"
20+
"crypto/rand"
21+
"crypto/rsa"
2022
"encoding/json"
2123
"fmt"
2224
"io"
@@ -28,6 +30,8 @@ import (
2830
"testing"
2931
"time"
3032

33+
"github.com/MicahParks/jwkset"
34+
"github.com/golang-jwt/jwt/v5"
3135
"github.com/googleapis/genai-toolbox/internal/testutils"
3236
"github.com/googleapis/genai-toolbox/internal/util/parameters"
3337
"github.com/googleapis/genai-toolbox/tests"
@@ -307,9 +311,40 @@ func TestHttpToolEndpoints(t *testing.T) {
307311
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
308312
defer cancel()
309313

314+
// Set up generic auth mock server
315+
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
316+
if err != nil {
317+
t.Fatalf("failed to create RSA private key: %v", err)
318+
}
319+
jwksServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
320+
if r.URL.Path == "/.well-known/openid-configuration" {
321+
w.Header().Set("Content-Type", "application/json")
322+
_ = json.NewEncoder(w).Encode(map[string]interface{}{
323+
"issuer": "https://example.com",
324+
"jwks_uri": "http://" + r.Host + "/jwks",
325+
})
326+
return
327+
}
328+
if r.URL.Path == "/jwks" {
329+
options := jwkset.JWKOptions{
330+
Metadata: jwkset.JWKMetadataOptions{
331+
KID: "test-key-id",
332+
},
333+
}
334+
jwk, _ := jwkset.NewJWKFromKey(privateKey.Public(), options)
335+
w.Header().Set("Content-Type", "application/json")
336+
_ = json.NewEncoder(w).Encode(map[string]interface{}{
337+
"keys": []jwkset.JWKMarshal{jwk.Marshal()},
338+
})
339+
return
340+
}
341+
http.NotFound(w, r)
342+
}))
343+
defer jwksServer.Close()
344+
310345
var args []string
311346

312-
toolsFile := getHTTPToolsConfig(sourceConfig, HttpToolType)
347+
toolsFile := getHTTPToolsConfig(sourceConfig, HttpToolType, jwksServer.URL)
313348
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
314349
if err != nil {
315350
t.Fatalf("command initialization returned an error: %s", err)
@@ -329,6 +364,69 @@ func TestHttpToolEndpoints(t *testing.T) {
329364
tests.RunToolInvokeTest(t, `"hello world"`, tests.DisableArrayTest())
330365
runAdvancedHTTPInvokeTest(t)
331366
runQueryParamInvokeTest(t)
367+
runGenericAuthInvokeTest(t, privateKey)
368+
}
369+
370+
func runGenericAuthInvokeTest(t *testing.T, privateKey *rsa.PrivateKey) {
371+
// Generate valid token
372+
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
373+
"aud": "test-audience",
374+
"scope": "read:files",
375+
"sub": "test-user",
376+
"exp": time.Now().Add(time.Hour).Unix(),
377+
})
378+
token.Header["kid"] = "test-key-id"
379+
signedString, err := token.SignedString(privateKey)
380+
if err != nil {
381+
t.Fatalf("failed to sign token: %v", err)
382+
}
383+
384+
api := "http://127.0.0.1:5000/api/tool/my-auth-required-generic-tool/invoke"
385+
386+
// Test without auth header (should fail)
387+
t.Run("invoke generic auth tool without token", func(t *testing.T) {
388+
req, _ := http.NewRequest(http.MethodPost, api, bytes.NewBuffer([]byte(`{}`)))
389+
req.Header.Add("Content-type", "application/json")
390+
resp, err := http.DefaultClient.Do(req)
391+
if err != nil {
392+
t.Fatalf("unable to send request: %s", err)
393+
}
394+
defer resp.Body.Close()
395+
396+
var body map[string]interface{}
397+
json.NewDecoder(resp.Body).Decode(&body)
398+
resultStr, _ := body["result"].(string)
399+
if !strings.Contains(resultStr, "unauthorized") && !strings.Contains(resultStr, "missing") {
400+
bodyBytes, _ := json.Marshal(body)
401+
t.Fatalf("expected unauthorized error, got: %s", string(bodyBytes))
402+
}
403+
})
404+
405+
// Test with valid token
406+
t.Run("invoke generic auth tool with valid token", func(t *testing.T) {
407+
req, _ := http.NewRequest(http.MethodPost, api, bytes.NewBuffer([]byte(`{}`)))
408+
req.Header.Add("Content-type", "application/json")
409+
req.Header.Add("my-generic-auth_token", signedString)
410+
411+
resp, err := http.DefaultClient.Do(req)
412+
if err != nil {
413+
t.Fatalf("unable to send request: %s", err)
414+
}
415+
defer resp.Body.Close()
416+
417+
if resp.StatusCode != http.StatusOK {
418+
bodyBytes, _ := io.ReadAll(resp.Body)
419+
t.Fatalf("expected status 200, got %d: %s", resp.StatusCode, string(bodyBytes))
420+
}
421+
422+
var body map[string]interface{}
423+
json.NewDecoder(resp.Body).Decode(&body)
424+
got, ok := body["result"].(string)
425+
if !ok || got != `"hello world"` {
426+
bodyBytes, _ := json.Marshal(body)
427+
t.Fatalf("unexpected result: %s", string(bodyBytes))
428+
}
429+
})
332430
}
333431

334432
// runQueryParamInvokeTest runs the tool invoke endpoint for the query param test tool
@@ -500,7 +598,7 @@ func runAdvancedHTTPInvokeTest(t *testing.T) {
500598
}
501599

502600
// getHTTPToolsConfig returns a mock HTTP tool's config file
503-
func getHTTPToolsConfig(sourceConfig map[string]any, toolType string) map[string]any {
601+
func getHTTPToolsConfig(sourceConfig map[string]any, toolType string, jwksURL string) map[string]any {
504602
// Write config into a file and pass it to command
505603
otherSourceConfig := make(map[string]any)
506604
for k, v := range sourceConfig {
@@ -519,6 +617,12 @@ func getHTTPToolsConfig(sourceConfig map[string]any, toolType string) map[string
519617
"type": "google",
520618
"clientId": tests.ClientId,
521619
},
620+
"my-generic-auth": map[string]any{
621+
"type": "generic",
622+
"audience": "test-audience",
623+
"authorizationServerUrl": jwksURL,
624+
"scopesRequired": []string{"read:files"},
625+
},
522626
},
523627
"tools": map[string]any{
524628
"my-simple-tool": map[string]any{
@@ -598,6 +702,15 @@ func getHTTPToolsConfig(sourceConfig map[string]any, toolType string) map[string
598702
"requestBody": "{}",
599703
"authRequired": []string{"my-google-auth"},
600704
},
705+
"my-auth-required-generic-tool": map[string]any{
706+
"type": toolType,
707+
"source": "my-instance",
708+
"method": "POST",
709+
"path": "/tool0",
710+
"description": "some description",
711+
"requestBody": "{}",
712+
"authRequired": []string{"my-generic-auth"},
713+
},
601714
"my-advanced-tool": map[string]any{
602715
"type": toolType,
603716
"source": "other-instance",

0 commit comments

Comments
 (0)