Skip to content

Commit 7552877

Browse files
Add NewKeySet method to JWTAuth
This commit adds support for KeySets through a new method `NewKeySet` to the `JWTAuth` struct. It includes tests and comments that seek to explain how it works inline. There's also an example in the _example directory that shows how to use and rotate a KeySet.
1 parent b5d850b commit 7552877

3 files changed

Lines changed: 284 additions & 11 deletions

File tree

_example/main.go

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,18 @@ import (
6868
"github.com/lestrrat-go/jwx/v2/jwt"
6969
)
7070

71+
type dynamicTokenAuth struct {
72+
keySet []byte
73+
}
74+
75+
func (d *dynamicTokenAuth) JWTAuth() (*jwtauth.JWTAuth, error) {
76+
keySet, err := jwtauth.NewKeySet(d.keySet)
77+
if err != nil {
78+
return nil, err
79+
}
80+
return keySet, nil
81+
}
82+
7183
var tokenAuth *jwtauth.JWTAuth
7284

7385
func init() {
@@ -76,7 +88,8 @@ func init() {
7688
// For debugging/example purposes, we generate and print
7789
// a sample jwt token with claims `user_id:123` here:
7890
_, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123})
79-
fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString)
91+
fmt.Printf("DEBUG: a sample jwt for /admin is %s\n\n", tokenString)
92+
fmt.Printf("DEBUG: a sample jwt for /rotate is %s\n\n", sampleJWTRotate)
8093
}
8194

8295
func main() {
@@ -105,6 +118,23 @@ func router() http.Handler {
105118
})
106119
})
107120

121+
r.Group(func(r chi.Router) {
122+
dynamicTokenAuth := dynamicTokenAuth{keySet: keySet}
123+
// Seek, verify and validate JWT tokens based on keys returned by the callback function
124+
r.Use(jwtauth.VerifierDynamic(dynamicTokenAuth.JWTAuth))
125+
126+
// Handle valid / invalid tokens. In this example, we use
127+
// the provided authenticator middleware, but you can write your
128+
// own very easily, look at the Authenticator method in jwtauth.go
129+
// and tweak it, its not scary.
130+
r.Use(jwtauth.Authenticator)
131+
132+
r.Get("/rotate", func(w http.ResponseWriter, r *http.Request) {
133+
_, claims, _ := jwtauth.FromContext(r.Context())
134+
w.Write([]byte(fmt.Sprintf("protected area. hi %v", claims["user_id"])))
135+
})
136+
})
137+
108138
// Public routes
109139
r.Group(func(r chi.Router) {
110140
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
@@ -114,3 +144,20 @@ func router() http.Handler {
114144

115145
return r
116146
}
147+
148+
var (
149+
keySet = []byte(`{
150+
"keys": [
151+
{
152+
"kty": "RSA",
153+
"alg": "RS256",
154+
"kid": "kid",
155+
"use": "sig",
156+
"n": "rgzO_v14UXJ33MvccKI8aIw3YpknVJbRB-m1z1X4j3gaTmmzmb7_naEd1TOKhF6Z1BGupvAKhCs8uHtp5e1PCrp52kzrjv7nqQfDpdppPZmKpwf-OD_lVgLLuCljB71mX9w7T5vI_WiVknuNhm48y0TJQNslpDZum4E2e0BLKUDRKKlo25foGoDuQN535_Xso861U8KsA80jX37BJplQ6IHewV_bbe04NYTVqaFcmLaZCAzh2f8L1h4xt76Y0xF_u8FXt2-rgcWlz17CtZzxC8ZXNI_92pX8CY5LY2eQf_B_n5Rhd5TQvEIdoI1GNBrcKUI9pMeEC4pErcOGgKGH7w",
157+
"e": "AQAB"
158+
}
159+
]
160+
}`)
161+
162+
sampleJWTRotate = `eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImtpZCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.APC4bUOmfbcXjBnZnmyiGBpXqlboTB4Qbh_sqJrgSU5AEQlwzjvDJ79eBlty8h6kfq3i5ffy87s-g82ZoRsHqMjwCIvTOVnoEyDgVu68s9lE32uaA0cc2-hbA13DIBsyIUGjehh9c3h93BrUoUr7n0CHgoKgx2OEw1Bq8vm4EqvmFGF-mr_0qi32uudPy3I15SyP1NJfU0ogQEFUdDHww3c8omDmrTPiGlWZAl9AiBMroDu0nq3UOtC4d5Se-361NEGiZ9J_kHcVWGdoMwsi5KEB0Uf3wAfXK3wcXeRu1pTXYKOV3X3g_2ss6mh65bNMsSx-MZUnQv5v6qZMOxMBUA`
163+
)

jwtauth.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ package jwtauth
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
67
"net/http"
78
"strings"
89
"time"
910

1011
"github.com/lestrrat-go/jwx/v2/jwa"
12+
"github.com/lestrrat-go/jwx/v2/jwk"
1113
"github.com/lestrrat-go/jwx/v2/jwt"
1214
)
1315

@@ -17,6 +19,7 @@ type JWTAuth struct {
1719
verifyKey interface{} // public-key, only used by RSA and ECDSA algorithms
1820
verifier jwt.ParseOption
1921
validateOptions []jwt.ValidateOption
22+
keySet jwk.Set
2023
}
2124

2225
var (
@@ -50,6 +53,24 @@ func New(alg string, signKey interface{}, verifyKey interface{}, validateOptions
5053
return ja
5154
}
5255

56+
// NewKeySet initializes a new JWTAuth instance with the provided key set.
57+
// It takes a keySet parameter, which is a byte slice containing the key set in JSON format.
58+
// The function returns a pointer to JWTAuth and an error.
59+
// If the key set cannot be unmarshaled from the byte slice, an error is returned.
60+
// Otherwise, the JWTAuth instance is created with the unmarshaled key set and a verifier is set using the key set.
61+
func NewKeySet(keySet []byte) (*JWTAuth, error) {
62+
ks := jwk.NewSet()
63+
err := json.Unmarshal(keySet, &ks)
64+
if err != nil {
65+
return nil, err
66+
}
67+
68+
ja := &JWTAuth{keySet: ks}
69+
ja.verifier = jwt.WithKeySet(ks)
70+
71+
return ja, nil
72+
}
73+
5374
// Verifier http middleware handler will verify a JWT string from a http request.
5475
//
5576
// Verifier will search for a JWT token in a http request, in the order:
@@ -119,13 +140,21 @@ func VerifyToken(ja *JWTAuth, tokenString string) (jwt.Token, error) {
119140
return token, nil
120141
}
121142

143+
// Encode generates a JWT token string with the provided claims.
144+
// It returns the encoded token as a string, along with the token object and any error encountered.
122145
func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) {
123146
t = jwt.New()
124147
for k, v := range claims {
125148
if err := t.Set(k, v); err != nil {
126149
return nil, "", err
127150
}
128151
}
152+
// ja.sign() isn't going to work if ja.signKey is nil
153+
if ja.signKey == nil {
154+
// This generally means that you've called Encode on a KeySet
155+
// which can't be supported.
156+
return nil, "", errors.New("no signing key provided")
157+
}
129158
payload, err := ja.sign(t)
130159
if err != nil {
131160
return nil, "", err

jwtauth_test.go

Lines changed: 207 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"testing"
1414
"time"
1515

16+
"github.com/lestrrat-go/jwx/v2/jws"
17+
1618
"github.com/go-chi/chi/v5"
1719
"github.com/go-chi/jwtauth/v5"
1820
"github.com/lestrrat-go/jwx/v2/jwa"
@@ -41,6 +43,27 @@ MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALxo3PCjFw4QjgOX06QCJIJBnXXNiEYw
4143
DLxxa5/7QyH6y77nCRQyJ3x3UwF9rUD0RCsp4sNdX5kOQ9PUyHyOtCUCAwEAAQ==
4244
-----END PUBLIC KEY-----
4345
`
46+
47+
KeySet = `{
48+
"keys": [
49+
{
50+
"kty": "RSA",
51+
"n": "vGjc8KMXDhCOA5fTpAIkgkGddc2IRjAMvHFrn_tDIfrLvucJFDInfHdTAX2tQPREKyniw11fmQ5D09TIfI60JQ",
52+
"e": "AQAB",
53+
"alg": "RS256",
54+
"kid": "1",
55+
"use": "sig"
56+
},
57+
{
58+
"kty": "RSA",
59+
"n": "foo",
60+
"e": "AQAB",
61+
"alg": "RS256",
62+
"kid": "2",
63+
"use": "sig"
64+
}
65+
]
66+
}`
4467
)
4568

4669
func init() {
@@ -51,6 +74,59 @@ func init() {
5174
// Tests
5275
//
5376

77+
func TestNewKeySet(t *testing.T) {
78+
_, err := jwtauth.NewKeySet([]byte("not a valid key set"))
79+
if err == nil {
80+
t.Fatal("The error should not be nil")
81+
}
82+
83+
_, err = jwtauth.NewKeySet([]byte(KeySet))
84+
if err != nil {
85+
t.Fatalf(err.Error())
86+
}
87+
}
88+
89+
func TestKeySetRSA(t *testing.T) {
90+
privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String))
91+
92+
privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
93+
94+
if err != nil {
95+
t.Fatalf(err.Error())
96+
}
97+
98+
KeySetAuth, _ := jwtauth.NewKeySet([]byte(KeySet))
99+
claims := map[string]interface{}{
100+
"key": "val",
101+
"key2": "val2",
102+
"key3": "val3",
103+
}
104+
105+
signed := newJwtRSAToken(jwa.RS256, privateKey, "1", claims)
106+
107+
token, err := KeySetAuth.Decode(signed)
108+
109+
if err != nil {
110+
t.Fatalf("Failed to decode token string %s\n", err.Error())
111+
}
112+
113+
tokenClaims, err := token.AsMap(context.Background())
114+
if err != nil {
115+
t.Fatal(err.Error())
116+
}
117+
118+
if !reflect.DeepEqual(claims, tokenClaims) {
119+
t.Fatalf("The decoded claims don't match the original ones\n")
120+
}
121+
122+
_, _, err = KeySetAuth.Encode(claims)
123+
if err.Error() != "no signing key provided" {
124+
t.Fatalf("Expect error to equal %s. Found: %s.", "no signing key provided", err.Error())
125+
}
126+
fmt.Println(token.PrivateClaims())
127+
128+
}
129+
54130
func TestSimple(t *testing.T) {
55131
r := chi.NewRouter()
56132

@@ -279,20 +355,118 @@ func TestMore(t *testing.T) {
279355
}
280356
}
281357

282-
func TestEncodeClaims(t *testing.T) {
358+
func TestKeySet(t *testing.T) {
359+
privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String))
360+
privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
361+
if err != nil {
362+
t.Fatalf(err.Error())
363+
}
364+
365+
r := chi.NewRouter()
366+
367+
keySet, err := jwtauth.NewKeySet([]byte(KeySet))
368+
if err != nil {
369+
t.Fatalf(err.Error())
370+
}
371+
372+
// Protected routes
373+
r.Group(func(r chi.Router) {
374+
r.Use(jwtauth.Verifier(keySet))
375+
376+
authenticator := func(next http.Handler) http.Handler {
377+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
378+
token, _, err := jwtauth.FromContext(r.Context())
379+
380+
if err != nil {
381+
http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized)
382+
return
383+
}
384+
385+
if err := jwt.Validate(token); err != nil {
386+
http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized)
387+
return
388+
}
389+
390+
// Token is authenticated, pass it through
391+
next.ServeHTTP(w, r)
392+
})
393+
}
394+
r.Use(authenticator)
395+
396+
r.Get("/admin", func(w http.ResponseWriter, r *http.Request) {
397+
_, claims, err := jwtauth.FromContext(r.Context())
398+
399+
if err != nil {
400+
w.Write([]byte(fmt.Sprintf("error! %v", err)))
401+
return
402+
}
403+
404+
w.Write([]byte(fmt.Sprintf("protected, user:%v", claims["user_id"])))
405+
})
406+
})
407+
408+
// Public routes
409+
r.Group(func(r chi.Router) {
410+
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
411+
w.Write([]byte("welcome"))
412+
})
413+
})
414+
415+
ts := httptest.NewServer(r)
416+
defer ts.Close()
417+
418+
h := http.Header{}
419+
h.Set("Authorization", "BEARER "+newJwtRSAToken(jwa.RS256, privateKey, "1", map[string]interface{}{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)}))
420+
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" {
421+
t.Fatalf(resp)
422+
}
423+
}
424+
425+
func TestEncodeInvalidClaim(t *testing.T) {
426+
ja := jwtauth.New("HS256", []byte("secretpass"), nil)
283427
claims := map[string]interface{}{
284-
"key1": "val1",
285-
"key2": 2,
286-
"key3": time.Now(),
287-
"key4": []string{"1", "2"},
428+
"key1": "val1",
429+
"key2": 2,
430+
"key3": time.Now(),
431+
"key4": []string{"1", "2"},
432+
jwt.JwtIDKey: 1, // This is invalid becasue it should be a string
288433
}
289-
claims[jwt.JwtIDKey] = 1
290-
if _, _, err := TokenAuthHS256.Encode(claims); err == nil {
434+
_, _, err := ja.Encode(claims)
435+
if err == nil {
436+
291437
t.Fatal("encoding invalid claims succeeded")
292438
}
293-
claims[jwt.JwtIDKey] = "123"
294-
if _, _, err := TokenAuthHS256.Encode(claims); err != nil {
295-
t.Fatalf("unexpected error encoding valid claims: %v", err)
439+
}
440+
func TestEncode(t *testing.T) {
441+
ja := jwtauth.New("HS256", []byte("secretpass"), nil)
442+
443+
claims := map[string]interface{}{
444+
"sub": "1234567890",
445+
"name": "John Doe",
446+
"iat": 1516239022,
447+
}
448+
449+
token, tokenString, err := ja.Encode(claims)
450+
if err != nil {
451+
t.Fatalf("Failed to encode claims: %s", err.Error())
452+
}
453+
454+
if token == nil {
455+
t.Fatal("Token should not be nil")
456+
}
457+
458+
if tokenString == "" {
459+
t.Fatal("Token string should not be empty")
460+
}
461+
462+
// Verify the token string
463+
verifiedToken, err := ja.Decode(tokenString)
464+
if err != nil {
465+
t.Fatalf("Failed to decode token string: %s", err.Error())
466+
}
467+
468+
if !reflect.DeepEqual(token, verifiedToken) {
469+
t.Fatal("Decoded token does not match the original token")
296470
}
297471
}
298472

@@ -357,6 +531,29 @@ func newJwt512Token(secret []byte, claims ...map[string]interface{}) string {
357531
return string(tokenPayload)
358532
}
359533

534+
func newJwtRSAToken(alg jwa.SignatureAlgorithm, secret interface{}, kid string, claims ...map[string]interface{}) string {
535+
token := jwt.New()
536+
if len(claims) > 0 {
537+
for k, v := range claims[0] {
538+
token.Set(k, v)
539+
}
540+
}
541+
542+
headers := jws.NewHeaders()
543+
if kid != "" {
544+
err := headers.Set("kid", kid)
545+
if err != nil {
546+
log.Fatal(err)
547+
}
548+
}
549+
550+
tokenPayload, err := jwt.Sign(token, jwt.WithKey(alg, secret, jws.WithProtectedHeaders(headers)))
551+
if err != nil {
552+
log.Fatal(err)
553+
}
554+
return string(tokenPayload)
555+
}
556+
360557
func newAuthHeader(claims ...map[string]interface{}) http.Header {
361558
h := http.Header{}
362559
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))

0 commit comments

Comments
 (0)