4
4
from ddt import data , ddt
5
5
from jwkest import JWKESTException
6
6
from jwkest .jws import JWS
7
- from jwkest .jwt import JWT
7
+ from jwkest .jwt import JWT , BadSyntax
8
8
from mock import MagicMock , ANY , patch , call
9
9
from launchkey .transports import JOSETransport , RequestsTransport
10
10
from launchkey .transports .base import APIResponse , APIErrorResponse
70
70
"kid" : faux_kid
71
71
}
72
72
73
+ transport_request_headers = {"X-IOV-KEY-ID" : faux_kid }
74
+
73
75
74
76
class TestJOSETransport3rdParty (unittest .TestCase ):
75
77
76
78
def setUp (self ):
77
79
self ._transport = JOSETransport ()
78
80
self ._transport .get = MagicMock (return_value = MagicMock (spec = APIResponse ))
79
- public_key = APIResponse (valid_private_key , {} , 200 )
81
+ public_key = APIResponse (valid_public_key , transport_request_headers , 200 )
80
82
self ._transport .get .return_value = public_key
81
83
self ._transport ._server_time_difference = 0 , time ()
82
84
@@ -152,7 +154,7 @@ class TestJWKESTSupportedAlgs(unittest.TestCase):
152
154
def setUp (self ):
153
155
self ._transport = JOSETransport ()
154
156
self ._transport .get = MagicMock (return_value = MagicMock (spec = APIResponse ))
155
- public_key = APIResponse (valid_private_key , {} , 200 )
157
+ public_key = APIResponse (valid_private_key , transport_request_headers , 200 )
156
158
self ._transport .get .return_value = public_key
157
159
self ._transport ._server_time_difference = 0 , time ()
158
160
@@ -265,7 +267,7 @@ class TestJOSETransportJWTResponse(unittest.TestCase):
265
267
def setUp (self ):
266
268
self ._transport = JOSETransport ()
267
269
self ._transport .get = MagicMock (return_value = MagicMock (spec = APIResponse ))
268
- public_key = APIResponse (valid_private_key , {} , 200 )
270
+ public_key = APIResponse (valid_public_key , transport_request_headers , 200 )
269
271
self ._transport .get .return_value = public_key
270
272
self ._transport ._server_time_difference = 0 , time ()
271
273
self .issuer = "svc"
@@ -421,7 +423,7 @@ class TestJOSETransportJWTRequest(unittest.TestCase):
421
423
def setUp (self ):
422
424
self ._transport = JOSETransport ()
423
425
self ._transport .get = MagicMock (return_value = MagicMock (spec = APIResponse ))
424
- public_key = APIResponse (valid_private_key , {} , 200 )
426
+ public_key = APIResponse (valid_private_key , transport_request_headers , 200 )
425
427
self ._transport .get .return_value = public_key
426
428
self ._transport ._server_time_difference = 0 , time ()
427
429
self .issuer = "svc"
@@ -808,7 +810,7 @@ def setUp(self):
808
810
}
809
811
810
812
self ._requests_transport = MagicMock (spec = RequestsTransport )
811
- self ._requests_transport .get .return_value = APIResponse (valid_private_key , {} , 200 )
813
+ self ._requests_transport .get .return_value = APIResponse (valid_public_key , transport_request_headers , 200 )
812
814
self ._transport = JOSETransport (http_client = self ._requests_transport )
813
815
self ._import_rsa_key_patch = patch (
814
816
"launchkey.transports.jose_auth.import_rsa_key" ,
@@ -850,8 +852,7 @@ def test_key_is_retrieved_by_id_when_key_changed(self):
850
852
"kid" : "jwt2keyid"
851
853
}
852
854
853
- self ._jwt_patch .return_value .unpack .side_effect = [jwt1 , jwt1 , jwt2 , jwt2 ]
854
-
855
+ self ._jwt_patch .return_value .unpack .side_effect = [jwt1 , jwt2 ]
855
856
self ._transport .verify_jwt_response (MagicMock (), self .jti , ANY , None )
856
857
self ._transport .verify_jwt_response (MagicMock (), self .jti , ANY , None )
857
858
self ._requests_transport .get .assert_has_calls ([
@@ -865,15 +866,15 @@ def test_key_retrieved_is_used_to_verify_payload(self, rsa_key_patch):
865
866
self ._requests_transport .get .return_value .data = valid_public_key
866
867
self ._transport .verify_jwt_response (MagicMock (), self .jti , ANY , None )
867
868
868
- # Verify that verify_compact is called one time with key created by our jwkest key patch
869
- self ._jws_patch .return_value .verify_compact .assert_called_once_with (ANY , keys = [rsa_key_patch .return_value ])
870
-
871
- # Assert that the jwkest key patch is built using the import_rsa_key patch return value and the key id
872
- # from the header
873
- rsa_key_patch .assert_called_with (key = self ._import_rsa_key_patch .return_value , kid = faux_kid )
869
+ # Verify that verify_compact is called one time with key created
870
+ # by our jwkest key patch
871
+ self ._jws_patch .return_value .verify_compact \
872
+ .assert_called_once_with (ANY , keys = [rsa_key_patch .return_value ])
874
873
875
- # Verify that we used the correct key to retrieve the key id from the header
876
- self ._requests_transport .get .return_value .headers .get .assert_called_with ("X-IOV-JWT" )
874
+ # Assert that the jwkest key patch is built using the import_rsa_key
875
+ # patch return value and the key id from the header
876
+ rsa_key_patch .assert_called_with (
877
+ key = self ._import_rsa_key_patch .return_value , kid = faux_kid )
877
878
878
879
def test_raises_when_kid_header_is_missing (self ):
879
880
headers_without_kid = {"alg" : "RS512" , "typ" : "JWT" }
@@ -883,6 +884,31 @@ def test_raises_when_kid_header_is_missing(self):
883
884
with self .assertRaises (JWTValidationFailure ):
884
885
self ._transport .verify_jwt_response (MagicMock (), self .jti , ANY , None )
885
886
887
+ def test_raises_when_jwt_unpack_returns_badsyntax (self ):
888
+ self ._jwt_patch .return_value .unpack .side_effect = BadSyntax ("test" , "error" )
889
+ with self .assertRaises (UnexpectedAPIResponse ):
890
+ self ._transport .verify_jwt_response (
891
+ MagicMock (), self .jti , ANY , None )
892
+
893
+ def test_raises_when_jwt_unpack_returns_valueerror (self ):
894
+ self ._jwt_patch .return_value .unpack .side_effect = ValueError ()
895
+ with self .assertRaises (UnexpectedAPIResponse ):
896
+ self ._transport .verify_jwt_response (
897
+ MagicMock (), self .jti , ANY , None )
898
+
899
+ def test_raises_when_jwt_unpack_returns_indexerror (self ):
900
+ self ._jwt_patch .return_value .unpack .side_effect = IndexError ()
901
+ with self .assertRaises (UnexpectedAPIResponse ):
902
+ self ._transport .verify_jwt_response (
903
+ MagicMock (), self .jti , ANY , None )
904
+
905
+ def test_raises_when_kid_header_is_missing_from_http_headers (self ):
906
+ self ._requests_transport .get .return_value = APIResponse (
907
+ valid_public_key , {}, 200 )
908
+ with self .assertRaises (UnexpectedAPIResponse ):
909
+ self ._transport .verify_jwt_response (
910
+ MagicMock (), self .jti , ANY , None )
911
+
886
912
def test_raises_when_kid_header_is_malformed (self ):
887
913
headers_with_kid_of_wrong_type = {"alg" : "RS512" , "typ" : "JWT" , "kid" : 1234 }
888
914
jwt = MagicMock ()
0 commit comments