Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion _example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ import (
"github.com/lestrrat-go/jwx/v2/jwt"
)

type dynamicTokenAuth struct {
keySet []byte
}

func (d *dynamicTokenAuth) JWTAuth() (*jwtauth.JWTAuth, error) {
keySet, err := jwtauth.NewKeySet(d.keySet)
if err != nil {
return nil, err
}
return keySet, nil
}

var tokenAuth *jwtauth.JWTAuth

func init() {
Expand All @@ -76,7 +88,8 @@ func init() {
// For debugging/example purposes, we generate and print
// a sample jwt token with claims `user_id:123` here:
_, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123})
fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString)
fmt.Printf("DEBUG: a sample jwt for /admin is %s\n\n", tokenString)
fmt.Printf("DEBUG: a sample jwt for /rotate is %s\n\n", sampleJWTRotate)
}

func main() {
Expand Down Expand Up @@ -105,6 +118,23 @@ func router() http.Handler {
})
})

r.Group(func(r chi.Router) {
dynamicTokenAuth := dynamicTokenAuth{keySet: keySet}
// Seek, verify and validate JWT tokens based on keys returned by the callback function
r.Use(jwtauth.VerifierDynamic(dynamicTokenAuth.JWTAuth))

// Handle valid / invalid tokens. In this example, we use
// the provided authenticator middleware, but you can write your
// own very easily, look at the Authenticator method in jwtauth.go
// and tweak it, its not scary.
r.Use(jwtauth.Authenticator)

r.Get("/rotate", func(w http.ResponseWriter, r *http.Request) {
_, claims, _ := jwtauth.FromContext(r.Context())
w.Write([]byte(fmt.Sprintf("protected area. hi %v", claims["user_id"])))
})
})

// Public routes
r.Group(func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -114,3 +144,20 @@ func router() http.Handler {

return r
}

var (
keySet = []byte(`{
"keys": [
{
"kty": "RSA",
"alg": "RS256",
"kid": "kid",
"use": "sig",
"n": "rgzO_v14UXJ33MvccKI8aIw3YpknVJbRB-m1z1X4j3gaTmmzmb7_naEd1TOKhF6Z1BGupvAKhCs8uHtp5e1PCrp52kzrjv7nqQfDpdppPZmKpwf-OD_lVgLLuCljB71mX9w7T5vI_WiVknuNhm48y0TJQNslpDZum4E2e0BLKUDRKKlo25foGoDuQN535_Xso861U8KsA80jX37BJplQ6IHewV_bbe04NYTVqaFcmLaZCAzh2f8L1h4xt76Y0xF_u8FXt2-rgcWlz17CtZzxC8ZXNI_92pX8CY5LY2eQf_B_n5Rhd5TQvEIdoI1GNBrcKUI9pMeEC4pErcOGgKGH7w",
"e": "AQAB"
}
]
}`)

sampleJWTRotate = `eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImtpZCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.APC4bUOmfbcXjBnZnmyiGBpXqlboTB4Qbh_sqJrgSU5AEQlwzjvDJ79eBlty8h6kfq3i5ffy87s-g82ZoRsHqMjwCIvTOVnoEyDgVu68s9lE32uaA0cc2-hbA13DIBsyIUGjehh9c3h93BrUoUr7n0CHgoKgx2OEw1Bq8vm4EqvmFGF-mr_0qi32uudPy3I15SyP1NJfU0ogQEFUdDHww3c8omDmrTPiGlWZAl9AiBMroDu0nq3UOtC4d5Se-361NEGiZ9J_kHcVWGdoMwsi5KEB0Uf3wAfXK3wcXeRu1pTXYKOV3X3g_2ss6mh65bNMsSx-MZUnQv5v6qZMOxMBUA`
)
21 changes: 21 additions & 0 deletions jwtauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package jwtauth

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"

"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
)

Expand All @@ -17,6 +20,7 @@ type JWTAuth struct {
verifyKey interface{} // public-key, only used by RSA and ECDSA algorithms
verifier jwt.ParseOption
validateOptions []jwt.ValidateOption
keySet jwk.Set
}

var (
Expand Down Expand Up @@ -50,6 +54,19 @@ func New(alg string, signKey interface{}, verifyKey interface{}, validateOptions
return ja
}

func NewKeySet(keySet []byte) (*JWTAuth, error) {
ks := jwk.NewSet()
err := json.Unmarshal(keySet, &ks)
if err != nil {
return nil, err
}

ja := &JWTAuth{keySet: ks}
ja.verifier = jwt.WithKeySet(ks)

return ja, nil
}

// Verifier http middleware handler will verify a JWT string from a http request.
//
// Verifier will search for a JWT token in a http request, in the order:
Expand Down Expand Up @@ -120,6 +137,10 @@ func VerifyToken(ja *JWTAuth, tokenString string) (jwt.Token, error) {
}

func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) {
if ja.keySet != nil {
return nil, "", fmt.Errorf("encode not supported")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm working on rebasing and adding more documentation. I don't understand why encode isn't supported in this case and I think we need a better error message for it. Can you explain what this is for?

Copy link
Copy Markdown
Author

@davidallendj davidallendj Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's the wrong error message probably from some copy-pasta. I think the message should just reflect that the JWKS isn't set so nothing can be encoded here.

Edit: This may have been something added by the original author, so I'm not 100% sure of the intent here, so I'm speculating.

Copy link
Copy Markdown

@alexlovelltroy alexlovelltroy Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on my understanding, I'm considering updating Encode like this:

// Encode generates a JWT token string with the provided claims.
// It returns the encoded token as a string, along with the token object and any error encountered.
// If the JWTAuth instance has a key set, encoding is not supported and an error is returned.
func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) {
	if ja.keySet != nil {
		return nil, "", fmt.Errorf("encoding is not supported with key set")
	}

	t = jwt.New()
	for k, v := range claims {
		if err := t.Set(k, v); err != nil {
			return nil, "", err
		}
	}
	payload, err := ja.sign(t)
	if err != nil {
		return nil, "", err
	}
	tokenString = string(payload)
	return
}

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine to me. I was reading the code wrong and was thinking the ja.keySet was an error. I'll have to go back and look at the entire PR to recall the context why this was done.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm refactoring it a bit now that I understand it better.

}

t = jwt.New()
for k, v := range claims {
t.Set(k, v)
Expand Down
164 changes: 164 additions & 0 deletions jwtauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"testing"
"time"

"github.com/lestrrat-go/jwx/v2/jws"

"github.com/go-chi/chi/v5"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
Expand Down Expand Up @@ -41,6 +43,27 @@ MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALxo3PCjFw4QjgOX06QCJIJBnXXNiEYw
DLxxa5/7QyH6y77nCRQyJ3x3UwF9rUD0RCsp4sNdX5kOQ9PUyHyOtCUCAwEAAQ==
-----END PUBLIC KEY-----
`

KeySet = `{
"keys": [
{
"kty": "RSA",
"n": "vGjc8KMXDhCOA5fTpAIkgkGddc2IRjAMvHFrn_tDIfrLvucJFDInfHdTAX2tQPREKyniw11fmQ5D09TIfI60JQ",
"e": "AQAB",
"alg": "RS256",
"kid": "1",
"use": "sig"
},
{
"kty": "RSA",
"n": "foo",
"e": "AQAB",
"alg": "RS256",
"kid": "2",
"use": "sig"
}
]
}`
)

func init() {
Expand All @@ -51,6 +74,57 @@ func init() {
// Tests
//

func TestNewKeySet(t *testing.T) {
_, err := jwtauth.NewKeySet([]byte("not a valid key set"))
if err == nil {
t.Fatal("The error should not be nil")
}

_, err = jwtauth.NewKeySet([]byte(KeySet))
if err != nil {
t.Fatalf(err.Error())
}
}

func TestKeySetRSA(t *testing.T) {
privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String))

privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)

if err != nil {
t.Fatalf(err.Error())
}

KeySetAuth, _ := jwtauth.NewKeySet([]byte(KeySet))
claims := map[string]interface{}{
"key": "val",
"key2": "val2",
"key3": "val3",
}

signed := newJwtRSAToken(jwa.RS256, privateKey, "1", claims)

token, err := KeySetAuth.Decode(signed)

if err != nil {
t.Fatalf("Failed to decode token string %s\n", err.Error())
}

tokenClaims, err := token.AsMap(context.Background())
if err != nil {
t.Fatal(err.Error())
}

if !reflect.DeepEqual(claims, tokenClaims) {
t.Fatalf("The decoded claims don't match the original ones\n")
}

_, _, err = KeySetAuth.Encode(claims)
if err.Error() != "encode not supported" {
t.Fatalf("Expect error to equal %s. Found: %s.", "encode not supported", err.Error())
}
}

func TestSimple(t *testing.T) {
r := chi.NewRouter()

Expand Down Expand Up @@ -279,6 +353,73 @@ func TestMore(t *testing.T) {
}
}

func TestKeySet(t *testing.T) {
privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String))
privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
if err != nil {
t.Fatalf(err.Error())
}

r := chi.NewRouter()

keySet, err := jwtauth.NewKeySet([]byte(KeySet))
if err != nil {
t.Fatalf(err.Error())
}

// Protected routes
r.Group(func(r chi.Router) {
r.Use(jwtauth.Verifier(keySet))

authenticator := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, _, err := jwtauth.FromContext(r.Context())

if err != nil {
http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized)
return
}

if err := jwt.Validate(token); err != nil {
http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized)
return
}

// Token is authenticated, pass it through
next.ServeHTTP(w, r)
})
}
r.Use(authenticator)

r.Get("/admin", func(w http.ResponseWriter, r *http.Request) {
_, claims, err := jwtauth.FromContext(r.Context())

if err != nil {
w.Write([]byte(fmt.Sprintf("error! %v", err)))
return
}

w.Write([]byte(fmt.Sprintf("protected, user:%v", claims["user_id"])))
})
})

// Public routes
r.Group(func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("welcome"))
})
})

ts := httptest.NewServer(r)
defer ts.Close()

h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtRSAToken(jwa.RS256, privateKey, "1", map[string]interface{}{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)}))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" {
t.Fatalf(resp)
}
}

//
// Test helper functions
//
Expand Down Expand Up @@ -340,6 +481,29 @@ func newJwt512Token(secret []byte, claims ...map[string]interface{}) string {
return string(tokenPayload)
}

func newJwtRSAToken(alg jwa.SignatureAlgorithm, secret interface{}, kid string, claims ...map[string]interface{}) string {
token := jwt.New()
if len(claims) > 0 {
for k, v := range claims[0] {
token.Set(k, v)
}
}

headers := jws.NewHeaders()
if kid != "" {
err := headers.Set("kid", kid)
if err != nil {
log.Fatal(err)
}
}

tokenPayload, err := jwt.Sign(token, jwt.WithKey(alg, secret, jws.WithProtectedHeaders(headers)))
if err != nil {
log.Fatal(err)
}
return string(tokenPayload)
}

func newAuthHeader(claims ...map[string]interface{}) http.Header {
h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))
Expand Down