4
4
# This source code is licensed under the terms described in the LICENSE file in
5
5
# the root directory of this source tree.
6
6
7
+ import base64
7
8
from unittest .mock import AsyncMock , patch
8
9
9
10
import pytest
16
17
AuthProviderConfig ,
17
18
AuthProviderType ,
18
19
TokenValidationResult ,
20
+ get_attributes_from_claims ,
19
21
)
20
22
21
23
@@ -435,7 +437,7 @@ async def mock_jwks_response(*args, **kwargs):
435
437
"kty" : "oct" ,
436
438
"alg" : "HS256" ,
437
439
"use" : "sig" ,
438
- "k" : "MTIzNDU2Nzg5MA" , # Base64-encoded "1234567890"
440
+ "k" : base64 . b64encode ( b"foobarbaz" ). decode (),
439
441
}
440
442
]
441
443
},
@@ -446,15 +448,14 @@ async def mock_jwks_response(*args, **kwargs):
446
448
def jwt_token_valid ():
447
449
from jose import jwt
448
450
449
- # correctly signed jwt token with "kid" in header
450
451
return jwt .encode (
451
452
{
452
453
"sub" : "my-user" ,
453
454
"groups" : ["group1" , "group2" ],
454
455
"scope" : "foo bar" ,
455
456
"aud" : "llama-stack" ,
456
457
},
457
- key = "1234567890 " ,
458
+ key = "foobarbaz " ,
458
459
algorithm = "HS256" ,
459
460
headers = {"kid" : "1234567890" },
460
461
)
@@ -467,4 +468,52 @@ def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid):
467
468
assert response .json () == {"message" : "Authentication successful" }
468
469
469
470
471
+ @patch ("httpx.AsyncClient.get" , new = mock_jwks_response )
472
+ def test_invalid_oauth2_authentication (oauth2_client , invalid_token ):
473
+ response = oauth2_client .get ("/test" , headers = {"Authorization" : f"Bearer { invalid_token } " })
474
+ assert response .status_code == 401
475
+ assert "Invalid JWT token" in response .json ()["error" ]["message" ]
476
+
477
+
478
+ def test_get_attributes_from_claims ():
479
+ claims = {
480
+ "sub" : "my-user" ,
481
+ "groups" : ["group1" , "group2" ],
482
+ "scope" : "foo bar" ,
483
+ "aud" : "llama-stack" ,
484
+ }
485
+ attributes = get_attributes_from_claims (claims , {"sub" : "roles" , "groups" : "teams" })
486
+ assert attributes .roles == ["my-user" ]
487
+ assert attributes .teams == ["group1" , "group2" ]
488
+
489
+ claims = {
490
+ "sub" : "my-user" ,
491
+ "tenant" : "my-tenant" ,
492
+ }
493
+ attributes = get_attributes_from_claims (claims , {"sub" : "roles" , "tenant" : "namespaces" })
494
+ assert attributes .roles == ["my-user" ]
495
+ assert attributes .namespaces == ["my-tenant" ]
496
+
497
+ claims = {
498
+ "sub" : "my-user" ,
499
+ "username" : "my-username" ,
500
+ "tenant" : "my-tenant" ,
501
+ "groups" : ["group1" , "group2" ],
502
+ "team" : "my-team" ,
503
+ }
504
+ attributes = get_attributes_from_claims (
505
+ claims ,
506
+ {
507
+ "sub" : "roles" ,
508
+ "tenant" : "namespaces" ,
509
+ "username" : "roles" ,
510
+ "team" : "teams" ,
511
+ "groups" : "teams" ,
512
+ },
513
+ )
514
+ assert set (attributes .roles ) == {"my-user" , "my-username" }
515
+ assert set (attributes .teams ) == {"my-team" , "group1" , "group2" }
516
+ assert attributes .namespaces == ["my-tenant" ]
517
+
518
+
470
519
# TODO: add more tests for oauth2 token provider
0 commit comments