@@ -13,6 +13,8 @@ import (
1313 "testing"
1414 "time"
1515
16+ "github.com/lestrrat-go/jwx/v2/jws"
17+
1618 "github.com/go-chi/chi/v5"
1719 "github.com/go-chi/jwtauth/v5"
1820 "github.com/lestrrat-go/jwx/v2/jwa"
@@ -41,6 +43,27 @@ MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALxo3PCjFw4QjgOX06QCJIJBnXXNiEYw
4143DLxxa5/7QyH6y77nCRQyJ3x3UwF9rUD0RCsp4sNdX5kOQ9PUyHyOtCUCAwEAAQ==
4244-----END PUBLIC KEY-----
4345`
46+
47+ KeySet = `{
48+ "keys": [
49+ {
50+ "kty": "RSA",
51+ "n": "vGjc8KMXDhCOA5fTpAIkgkGddc2IRjAMvHFrn_tDIfrLvucJFDInfHdTAX2tQPREKyniw11fmQ5D09TIfI60JQ",
52+ "e": "AQAB",
53+ "alg": "RS256",
54+ "kid": "1",
55+ "use": "sig"
56+ },
57+ {
58+ "kty": "RSA",
59+ "n": "foo",
60+ "e": "AQAB",
61+ "alg": "RS256",
62+ "kid": "2",
63+ "use": "sig"
64+ }
65+ ]
66+ }`
4467)
4568
4669func init () {
@@ -51,6 +74,59 @@ func init() {
5174// Tests
5275//
5376
77+ func TestNewKeySet (t * testing.T ) {
78+ _ , err := jwtauth .NewKeySet ([]byte ("not a valid key set" ))
79+ if err == nil {
80+ t .Fatal ("The error should not be nil" )
81+ }
82+
83+ _ , err = jwtauth .NewKeySet ([]byte (KeySet ))
84+ if err != nil {
85+ t .Fatalf (err .Error ())
86+ }
87+ }
88+
89+ func TestKeySetRSA (t * testing.T ) {
90+ privateKeyBlock , _ := pem .Decode ([]byte (PrivateKeyRS256String ))
91+
92+ privateKey , err := x509 .ParsePKCS1PrivateKey (privateKeyBlock .Bytes )
93+
94+ if err != nil {
95+ t .Fatalf (err .Error ())
96+ }
97+
98+ KeySetAuth , _ := jwtauth .NewKeySet ([]byte (KeySet ))
99+ claims := map [string ]interface {}{
100+ "key" : "val" ,
101+ "key2" : "val2" ,
102+ "key3" : "val3" ,
103+ }
104+
105+ signed := newJwtRSAToken (jwa .RS256 , privateKey , "1" , claims )
106+
107+ token , err := KeySetAuth .Decode (signed )
108+
109+ if err != nil {
110+ t .Fatalf ("Failed to decode token string %s\n " , err .Error ())
111+ }
112+
113+ tokenClaims , err := token .AsMap (context .Background ())
114+ if err != nil {
115+ t .Fatal (err .Error ())
116+ }
117+
118+ if ! reflect .DeepEqual (claims , tokenClaims ) {
119+ t .Fatalf ("The decoded claims don't match the original ones\n " )
120+ }
121+
122+ _ , _ , err = KeySetAuth .Encode (claims )
123+ if err .Error () != "no signing key provided" {
124+ t .Fatalf ("Expect error to equal %s. Found: %s." , "no signing key provided" , err .Error ())
125+ }
126+ fmt .Println (token .PrivateClaims ())
127+
128+ }
129+
54130func TestSimple (t * testing.T ) {
55131 r := chi .NewRouter ()
56132
@@ -279,20 +355,118 @@ func TestMore(t *testing.T) {
279355 }
280356}
281357
282- func TestEncodeClaims (t * testing.T ) {
358+ func TestKeySet (t * testing.T ) {
359+ privateKeyBlock , _ := pem .Decode ([]byte (PrivateKeyRS256String ))
360+ privateKey , err := x509 .ParsePKCS1PrivateKey (privateKeyBlock .Bytes )
361+ if err != nil {
362+ t .Fatalf (err .Error ())
363+ }
364+
365+ r := chi .NewRouter ()
366+
367+ keySet , err := jwtauth .NewKeySet ([]byte (KeySet ))
368+ if err != nil {
369+ t .Fatalf (err .Error ())
370+ }
371+
372+ // Protected routes
373+ r .Group (func (r chi.Router ) {
374+ r .Use (jwtauth .Verifier (keySet ))
375+
376+ authenticator := func (next http.Handler ) http.Handler {
377+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
378+ token , _ , err := jwtauth .FromContext (r .Context ())
379+
380+ if err != nil {
381+ http .Error (w , jwtauth .ErrorReason (err ).Error (), http .StatusUnauthorized )
382+ return
383+ }
384+
385+ if err := jwt .Validate (token ); err != nil {
386+ http .Error (w , jwtauth .ErrorReason (err ).Error (), http .StatusUnauthorized )
387+ return
388+ }
389+
390+ // Token is authenticated, pass it through
391+ next .ServeHTTP (w , r )
392+ })
393+ }
394+ r .Use (authenticator )
395+
396+ r .Get ("/admin" , func (w http.ResponseWriter , r * http.Request ) {
397+ _ , claims , err := jwtauth .FromContext (r .Context ())
398+
399+ if err != nil {
400+ w .Write ([]byte (fmt .Sprintf ("error! %v" , err )))
401+ return
402+ }
403+
404+ w .Write ([]byte (fmt .Sprintf ("protected, user:%v" , claims ["user_id" ])))
405+ })
406+ })
407+
408+ // Public routes
409+ r .Group (func (r chi.Router ) {
410+ r .Get ("/" , func (w http.ResponseWriter , r * http.Request ) {
411+ w .Write ([]byte ("welcome" ))
412+ })
413+ })
414+
415+ ts := httptest .NewServer (r )
416+ defer ts .Close ()
417+
418+ h := http.Header {}
419+ h .Set ("Authorization" , "BEARER " + newJwtRSAToken (jwa .RS256 , privateKey , "1" , map [string ]interface {}{"user_id" : 31337 , "exp" : jwtauth .ExpireIn (5 * time .Minute )}))
420+ if status , resp := testRequest (t , ts , "GET" , "/admin" , h , nil ); status != 200 || resp != "protected, user:31337" {
421+ t .Fatalf (resp )
422+ }
423+ }
424+
425+ func TestEncodeInvalidClaim (t * testing.T ) {
426+ ja := jwtauth .New ("HS256" , []byte ("secretpass" ), nil )
283427 claims := map [string ]interface {}{
284- "key1" : "val1" ,
285- "key2" : 2 ,
286- "key3" : time .Now (),
287- "key4" : []string {"1" , "2" },
428+ "key1" : "val1" ,
429+ "key2" : 2 ,
430+ "key3" : time .Now (),
431+ "key4" : []string {"1" , "2" },
432+ jwt .JwtIDKey : 1 , // This is invalid becasue it should be a string
288433 }
289- claims [jwt .JwtIDKey ] = 1
290- if _ , _ , err := TokenAuthHS256 .Encode (claims ); err == nil {
434+ _ , _ , err := ja .Encode (claims )
435+ if err == nil {
436+
291437 t .Fatal ("encoding invalid claims succeeded" )
292438 }
293- claims [jwt .JwtIDKey ] = "123"
294- if _ , _ , err := TokenAuthHS256 .Encode (claims ); err != nil {
295- t .Fatalf ("unexpected error encoding valid claims: %v" , err )
439+ }
440+ func TestEncode (t * testing.T ) {
441+ ja := jwtauth .New ("HS256" , []byte ("secretpass" ), nil )
442+
443+ claims := map [string ]interface {}{
444+ "sub" : "1234567890" ,
445+ "name" : "John Doe" ,
446+ "iat" : 1516239022 ,
447+ }
448+
449+ token , tokenString , err := ja .Encode (claims )
450+ if err != nil {
451+ t .Fatalf ("Failed to encode claims: %s" , err .Error ())
452+ }
453+
454+ if token == nil {
455+ t .Fatal ("Token should not be nil" )
456+ }
457+
458+ if tokenString == "" {
459+ t .Fatal ("Token string should not be empty" )
460+ }
461+
462+ // Verify the token string
463+ verifiedToken , err := ja .Decode (tokenString )
464+ if err != nil {
465+ t .Fatalf ("Failed to decode token string: %s" , err .Error ())
466+ }
467+
468+ if ! reflect .DeepEqual (token , verifiedToken ) {
469+ t .Fatal ("Decoded token does not match the original token" )
296470 }
297471}
298472
@@ -357,6 +531,29 @@ func newJwt512Token(secret []byte, claims ...map[string]interface{}) string {
357531 return string (tokenPayload )
358532}
359533
534+ func newJwtRSAToken (alg jwa.SignatureAlgorithm , secret interface {}, kid string , claims ... map [string ]interface {}) string {
535+ token := jwt .New ()
536+ if len (claims ) > 0 {
537+ for k , v := range claims [0 ] {
538+ token .Set (k , v )
539+ }
540+ }
541+
542+ headers := jws .NewHeaders ()
543+ if kid != "" {
544+ err := headers .Set ("kid" , kid )
545+ if err != nil {
546+ log .Fatal (err )
547+ }
548+ }
549+
550+ tokenPayload , err := jwt .Sign (token , jwt .WithKey (alg , secret , jws .WithProtectedHeaders (headers )))
551+ if err != nil {
552+ log .Fatal (err )
553+ }
554+ return string (tokenPayload )
555+ }
556+
360557func newAuthHeader (claims ... map [string ]interface {}) http.Header {
361558 h := http.Header {}
362559 h .Set ("Authorization" , "BEARER " + newJwtToken (TokenSecret , claims ... ))
0 commit comments