Skip to content

Commit 4177ab4

Browse files
vhespanhagaby
andauthored
🔥 feat: Add support for context.Context in keyauth middleware (#3287)
* feat(middleware): add support to context.Context in keyauth middleware pretty straightforward option to use context.Context instead of just fiber.Ctx, tests added accordingly. * fix(middleware): include import that was missing from previous commit * fix(middleware): include missing import * Replace logger with panic * Update keyauth_test.go * Update keyauth_test.go --------- Co-authored-by: Juan Calderon-Perez <[email protected]>
1 parent 208b9e3 commit 4177ab4

File tree

2 files changed

+75
-29
lines changed

2 files changed

+75
-29
lines changed

middleware/keyauth/keyauth.go

+17-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
package keyauth
33

44
import (
5+
"context"
56
"errors"
67
"fmt"
78
"net/url"
@@ -59,7 +60,10 @@ func New(config ...Config) fiber.Handler {
5960
valid, err := cfg.Validator(c, key)
6061

6162
if err == nil && valid {
63+
// Store in both Locals and Context
6264
c.Locals(tokenKey, key)
65+
ctx := context.WithValue(c.Context(), tokenKey, key)
66+
c.SetContext(ctx)
6367
return cfg.SuccessHandler(c)
6468
}
6569
return cfg.ErrorHandler(c, err)
@@ -68,12 +72,20 @@ func New(config ...Config) fiber.Handler {
6872

6973
// TokenFromContext returns the bearer token from the request context.
7074
// returns an empty string if the token does not exist
71-
func TokenFromContext(c fiber.Ctx) string {
72-
token, ok := c.Locals(tokenKey).(string)
73-
if !ok {
74-
return ""
75+
func TokenFromContext(c any) string {
76+
switch ctx := c.(type) {
77+
case context.Context:
78+
if token, ok := ctx.Value(tokenKey).(string); ok {
79+
return token
80+
}
81+
case fiber.Ctx:
82+
if token, ok := ctx.Locals(tokenKey).(string); ok {
83+
return token
84+
}
85+
default:
86+
panic("unsupported context type, expected fiber.Ctx or context.Context")
7587
}
76-
return token
88+
return ""
7789
}
7890

7991
// MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found

middleware/keyauth/keyauth_test.go

+58-24
Original file line numberDiff line numberDiff line change
@@ -503,33 +503,67 @@ func Test_TokenFromContext_None(t *testing.T) {
503503
}
504504

505505
func Test_TokenFromContext(t *testing.T) {
506-
app := fiber.New()
507-
// Wire up keyauth middleware to set TokenFromContext now
508-
app.Use(New(Config{
509-
KeyLookup: "header:Authorization",
510-
AuthScheme: "Basic",
511-
Validator: func(_ fiber.Ctx, key string) (bool, error) {
512-
if key == CorrectKey {
513-
return true, nil
514-
}
515-
return false, ErrMissingOrMalformedAPIKey
516-
},
517-
}))
518-
// Define a test handler that checks TokenFromContext
519-
app.Get("/", func(c fiber.Ctx) error {
520-
return c.SendString(TokenFromContext(c))
506+
// Test that TokenFromContext returns the correct token
507+
t.Run("fiber.Ctx", func(t *testing.T) {
508+
app := fiber.New()
509+
app.Use(New(Config{
510+
KeyLookup: "header:Authorization",
511+
AuthScheme: "Basic",
512+
Validator: func(_ fiber.Ctx, key string) (bool, error) {
513+
if key == CorrectKey {
514+
return true, nil
515+
}
516+
return false, ErrMissingOrMalformedAPIKey
517+
},
518+
}))
519+
app.Get("/", func(c fiber.Ctx) error {
520+
return c.SendString(TokenFromContext(c))
521+
})
522+
523+
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
524+
req.Header.Add("Authorization", "Basic "+CorrectKey)
525+
res, err := app.Test(req)
526+
require.NoError(t, err)
527+
528+
body, err := io.ReadAll(res.Body)
529+
require.NoError(t, err)
530+
require.Equal(t, CorrectKey, string(body))
521531
})
522532

523-
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
524-
req.Header.Add("Authorization", "Basic "+CorrectKey)
525-
// Send
526-
res, err := app.Test(req)
527-
require.NoError(t, err)
533+
t.Run("context.Context", func(t *testing.T) {
534+
app := fiber.New()
535+
app.Use(New(Config{
536+
KeyLookup: "header:Authorization",
537+
AuthScheme: "Basic",
538+
Validator: func(_ fiber.Ctx, key string) (bool, error) {
539+
if key == CorrectKey {
540+
return true, nil
541+
}
542+
return false, ErrMissingOrMalformedAPIKey
543+
},
544+
}))
545+
// Verify that TokenFromContext works with context.Context
546+
app.Get("/", func(c fiber.Ctx) error {
547+
ctx := c.Context()
548+
token := TokenFromContext(ctx)
549+
return c.SendString(token)
550+
})
528551

529-
// Read the response body into a string
530-
body, err := io.ReadAll(res.Body)
531-
require.NoError(t, err)
532-
require.Equal(t, CorrectKey, string(body))
552+
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
553+
req.Header.Add("Authorization", "Basic "+CorrectKey)
554+
res, err := app.Test(req)
555+
require.NoError(t, err)
556+
557+
body, err := io.ReadAll(res.Body)
558+
require.NoError(t, err)
559+
require.Equal(t, CorrectKey, string(body))
560+
})
561+
562+
t.Run("invalid context type", func(t *testing.T) {
563+
require.Panics(t, func() {
564+
_ = TokenFromContext("invalid")
565+
})
566+
})
533567
}
534568

535569
func Test_AuthSchemeToken(t *testing.T) {

0 commit comments

Comments
 (0)