diff --git a/_example/main.go b/_example/main.go index a6d7559..082beb0 100644 --- a/_example/main.go +++ b/_example/main.go @@ -68,6 +68,18 @@ import ( "github.com/lestrrat-go/jwx/v2/jwt" ) +type dynamicTokenAuth struct { + keySet []byte +} + +func (d *dynamicTokenAuth) JWTAuth() (*jwtauth.JWTAuth, error) { + keySet, err := jwtauth.NewKeySet(d.keySet) + if err != nil { + return nil, err + } + return keySet, nil +} + var tokenAuth *jwtauth.JWTAuth func init() { @@ -76,7 +88,8 @@ func init() { // For debugging/example purposes, we generate and print // a sample jwt token with claims `user_id:123` here: _, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123}) - fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString) + fmt.Printf("DEBUG: a sample jwt for /admin is %s\n\n", tokenString) + fmt.Printf("DEBUG: a sample jwt for /rotate is %s\n\n", sampleJWTRotate) } func main() { @@ -105,6 +118,23 @@ func router() http.Handler { }) }) + r.Group(func(r chi.Router) { + dynamicTokenAuth := dynamicTokenAuth{keySet: keySet} + // Seek, verify and validate JWT tokens based on keys returned by the callback function + r.Use(jwtauth.VerifierDynamic(dynamicTokenAuth.JWTAuth)) + + // Handle valid / invalid tokens. In this example, we use + // the provided authenticator middleware, but you can write your + // own very easily, look at the Authenticator method in jwtauth.go + // and tweak it, its not scary. + r.Use(jwtauth.Authenticator) + + r.Get("/rotate", func(w http.ResponseWriter, r *http.Request) { + _, claims, _ := jwtauth.FromContext(r.Context()) + w.Write([]byte(fmt.Sprintf("protected area. hi %v", claims["user_id"]))) + }) + }) + // Public routes r.Group(func(r chi.Router) { r.Get("/", func(w http.ResponseWriter, r *http.Request) { @@ -114,3 +144,20 @@ func router() http.Handler { return r } + +var ( + keySet = []byte(`{ + "keys": [ + { + "kty": "RSA", + "alg": "RS256", + "kid": "kid", + "use": "sig", + "n": "rgzO_v14UXJ33MvccKI8aIw3YpknVJbRB-m1z1X4j3gaTmmzmb7_naEd1TOKhF6Z1BGupvAKhCs8uHtp5e1PCrp52kzrjv7nqQfDpdppPZmKpwf-OD_lVgLLuCljB71mX9w7T5vI_WiVknuNhm48y0TJQNslpDZum4E2e0BLKUDRKKlo25foGoDuQN535_Xso861U8KsA80jX37BJplQ6IHewV_bbe04NYTVqaFcmLaZCAzh2f8L1h4xt76Y0xF_u8FXt2-rgcWlz17CtZzxC8ZXNI_92pX8CY5LY2eQf_B_n5Rhd5TQvEIdoI1GNBrcKUI9pMeEC4pErcOGgKGH7w", + "e": "AQAB" + } + ] +}`) + + sampleJWTRotate = `eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImtpZCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.APC4bUOmfbcXjBnZnmyiGBpXqlboTB4Qbh_sqJrgSU5AEQlwzjvDJ79eBlty8h6kfq3i5ffy87s-g82ZoRsHqMjwCIvTOVnoEyDgVu68s9lE32uaA0cc2-hbA13DIBsyIUGjehh9c3h93BrUoUr7n0CHgoKgx2OEw1Bq8vm4EqvmFGF-mr_0qi32uudPy3I15SyP1NJfU0ogQEFUdDHww3c8omDmrTPiGlWZAl9AiBMroDu0nq3UOtC4d5Se-361NEGiZ9J_kHcVWGdoMwsi5KEB0Uf3wAfXK3wcXeRu1pTXYKOV3X3g_2ss6mh65bNMsSx-MZUnQv5v6qZMOxMBUA` +) diff --git a/jwtauth.go b/jwtauth.go index a03731a..4fd6eb2 100644 --- a/jwtauth.go +++ b/jwtauth.go @@ -2,12 +2,15 @@ package jwtauth import ( "context" + "encoding/json" "errors" + "fmt" "net/http" "strings" "time" "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" ) @@ -17,6 +20,7 @@ type JWTAuth struct { verifyKey interface{} // public-key, only used by RSA and ECDSA algorithms verifier jwt.ParseOption validateOptions []jwt.ValidateOption + keySet jwk.Set } var ( @@ -50,6 +54,19 @@ func New(alg string, signKey interface{}, verifyKey interface{}, validateOptions return ja } +func NewKeySet(keySet []byte) (*JWTAuth, error) { + ks := jwk.NewSet() + err := json.Unmarshal(keySet, &ks) + if err != nil { + return nil, err + } + + ja := &JWTAuth{keySet: ks} + ja.verifier = jwt.WithKeySet(ks) + + return ja, nil +} + // Verifier http middleware handler will verify a JWT string from a http request. // // Verifier will search for a JWT token in a http request, in the order: @@ -120,6 +137,10 @@ func VerifyToken(ja *JWTAuth, tokenString string) (jwt.Token, error) { } func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) { + if ja.keySet != nil { + return nil, "", fmt.Errorf("encode not supported") + } + t = jwt.New() for k, v := range claims { t.Set(k, v) diff --git a/jwtauth_test.go b/jwtauth_test.go index e2580cb..1f819c8 100644 --- a/jwtauth_test.go +++ b/jwtauth_test.go @@ -13,6 +13,8 @@ import ( "testing" "time" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/go-chi/chi/v5" "github.com/go-chi/jwtauth/v5" "github.com/lestrrat-go/jwx/v2/jwa" @@ -41,6 +43,27 @@ MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALxo3PCjFw4QjgOX06QCJIJBnXXNiEYw DLxxa5/7QyH6y77nCRQyJ3x3UwF9rUD0RCsp4sNdX5kOQ9PUyHyOtCUCAwEAAQ== -----END PUBLIC KEY----- ` + + KeySet = `{ + "keys": [ + { + "kty": "RSA", + "n": "vGjc8KMXDhCOA5fTpAIkgkGddc2IRjAMvHFrn_tDIfrLvucJFDInfHdTAX2tQPREKyniw11fmQ5D09TIfI60JQ", + "e": "AQAB", + "alg": "RS256", + "kid": "1", + "use": "sig" + }, + { + "kty": "RSA", + "n": "foo", + "e": "AQAB", + "alg": "RS256", + "kid": "2", + "use": "sig" + } + ] +}` ) func init() { @@ -51,6 +74,57 @@ func init() { // Tests // +func TestNewKeySet(t *testing.T) { + _, err := jwtauth.NewKeySet([]byte("not a valid key set")) + if err == nil { + t.Fatal("The error should not be nil") + } + + _, err = jwtauth.NewKeySet([]byte(KeySet)) + if err != nil { + t.Fatalf(err.Error()) + } +} + +func TestKeySetRSA(t *testing.T) { + privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String)) + + privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes) + + if err != nil { + t.Fatalf(err.Error()) + } + + KeySetAuth, _ := jwtauth.NewKeySet([]byte(KeySet)) + claims := map[string]interface{}{ + "key": "val", + "key2": "val2", + "key3": "val3", + } + + signed := newJwtRSAToken(jwa.RS256, privateKey, "1", claims) + + token, err := KeySetAuth.Decode(signed) + + if err != nil { + t.Fatalf("Failed to decode token string %s\n", err.Error()) + } + + tokenClaims, err := token.AsMap(context.Background()) + if err != nil { + t.Fatal(err.Error()) + } + + if !reflect.DeepEqual(claims, tokenClaims) { + t.Fatalf("The decoded claims don't match the original ones\n") + } + + _, _, err = KeySetAuth.Encode(claims) + if err.Error() != "encode not supported" { + t.Fatalf("Expect error to equal %s. Found: %s.", "encode not supported", err.Error()) + } +} + func TestSimple(t *testing.T) { r := chi.NewRouter() @@ -279,6 +353,73 @@ func TestMore(t *testing.T) { } } +func TestKeySet(t *testing.T) { + privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String)) + privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes) + if err != nil { + t.Fatalf(err.Error()) + } + + r := chi.NewRouter() + + keySet, err := jwtauth.NewKeySet([]byte(KeySet)) + if err != nil { + t.Fatalf(err.Error()) + } + + // Protected routes + r.Group(func(r chi.Router) { + r.Use(jwtauth.Verifier(keySet)) + + authenticator := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token, _, err := jwtauth.FromContext(r.Context()) + + if err != nil { + http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized) + return + } + + if err := jwt.Validate(token); err != nil { + http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized) + return + } + + // Token is authenticated, pass it through + next.ServeHTTP(w, r) + }) + } + r.Use(authenticator) + + r.Get("/admin", func(w http.ResponseWriter, r *http.Request) { + _, claims, err := jwtauth.FromContext(r.Context()) + + if err != nil { + w.Write([]byte(fmt.Sprintf("error! %v", err))) + return + } + + w.Write([]byte(fmt.Sprintf("protected, user:%v", claims["user_id"]))) + }) + }) + + // Public routes + r.Group(func(r chi.Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("welcome")) + }) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + h := http.Header{} + h.Set("Authorization", "BEARER "+newJwtRSAToken(jwa.RS256, privateKey, "1", map[string]interface{}{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)})) + if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" { + t.Fatalf(resp) + } +} + // // Test helper functions // @@ -340,6 +481,29 @@ func newJwt512Token(secret []byte, claims ...map[string]interface{}) string { return string(tokenPayload) } +func newJwtRSAToken(alg jwa.SignatureAlgorithm, secret interface{}, kid string, claims ...map[string]interface{}) string { + token := jwt.New() + if len(claims) > 0 { + for k, v := range claims[0] { + token.Set(k, v) + } + } + + headers := jws.NewHeaders() + if kid != "" { + err := headers.Set("kid", kid) + if err != nil { + log.Fatal(err) + } + } + + tokenPayload, err := jwt.Sign(token, jwt.WithKey(alg, secret, jws.WithProtectedHeaders(headers))) + if err != nil { + log.Fatal(err) + } + return string(tokenPayload) +} + func newAuthHeader(claims ...map[string]interface{}) http.Header { h := http.Header{} h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))