From 811732cd5084599a79724686d5d307a82bc01f12 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Wed, 19 Mar 2025 12:02:24 -0300 Subject: [PATCH] =?UTF-8?q?Revert=20"=F0=9F=94=A5=20feat:=20Add=20support?= =?UTF-8?q?=20for=20context.Context=20in=20keyauth=20middleware=20(#3287)"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 4177ab4086a97648553f34bcff2ff81a137d31f3. --- middleware/keyauth/keyauth.go | 22 ++------ middleware/keyauth/keyauth_test.go | 82 +++++++++--------------------- 2 files changed, 29 insertions(+), 75 deletions(-) diff --git a/middleware/keyauth/keyauth.go b/middleware/keyauth/keyauth.go index 54ecdbe513..e245ba4247 100644 --- a/middleware/keyauth/keyauth.go +++ b/middleware/keyauth/keyauth.go @@ -2,7 +2,6 @@ package keyauth import ( - "context" "errors" "fmt" "net/url" @@ -60,10 +59,7 @@ func New(config ...Config) fiber.Handler { valid, err := cfg.Validator(c, key) if err == nil && valid { - // Store in both Locals and Context c.Locals(tokenKey, key) - ctx := context.WithValue(c.Context(), tokenKey, key) - c.SetContext(ctx) return cfg.SuccessHandler(c) } return cfg.ErrorHandler(c, err) @@ -72,20 +68,12 @@ func New(config ...Config) fiber.Handler { // TokenFromContext returns the bearer token from the request context. // returns an empty string if the token does not exist -func TokenFromContext(c any) string { - switch ctx := c.(type) { - case context.Context: - if token, ok := ctx.Value(tokenKey).(string); ok { - return token - } - case fiber.Ctx: - if token, ok := ctx.Locals(tokenKey).(string); ok { - return token - } - default: - panic("unsupported context type, expected fiber.Ctx or context.Context") +func TokenFromContext(c fiber.Ctx) string { + token, ok := c.Locals(tokenKey).(string) + if !ok { + return "" } - return "" + return token } // MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found diff --git a/middleware/keyauth/keyauth_test.go b/middleware/keyauth/keyauth_test.go index 27c4e5a024..72c9d3c1b4 100644 --- a/middleware/keyauth/keyauth_test.go +++ b/middleware/keyauth/keyauth_test.go @@ -503,67 +503,33 @@ func Test_TokenFromContext_None(t *testing.T) { } func Test_TokenFromContext(t *testing.T) { - // Test that TokenFromContext returns the correct token - t.Run("fiber.Ctx", func(t *testing.T) { - app := fiber.New() - app.Use(New(Config{ - KeyLookup: "header:Authorization", - AuthScheme: "Basic", - Validator: func(_ fiber.Ctx, key string) (bool, error) { - if key == CorrectKey { - return true, nil - } - return false, ErrMissingOrMalformedAPIKey - }, - })) - app.Get("/", func(c fiber.Ctx) error { - return c.SendString(TokenFromContext(c)) - }) - - req := httptest.NewRequest(fiber.MethodGet, "/", nil) - req.Header.Add("Authorization", "Basic "+CorrectKey) - res, err := app.Test(req) - require.NoError(t, err) - - body, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.Equal(t, CorrectKey, string(body)) + app := fiber.New() + // Wire up keyauth middleware to set TokenFromContext now + app.Use(New(Config{ + KeyLookup: "header:Authorization", + AuthScheme: "Basic", + Validator: func(_ fiber.Ctx, key string) (bool, error) { + if key == CorrectKey { + return true, nil + } + return false, ErrMissingOrMalformedAPIKey + }, + })) + // Define a test handler that checks TokenFromContext + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(TokenFromContext(c)) }) - t.Run("context.Context", func(t *testing.T) { - app := fiber.New() - app.Use(New(Config{ - KeyLookup: "header:Authorization", - AuthScheme: "Basic", - Validator: func(_ fiber.Ctx, key string) (bool, error) { - if key == CorrectKey { - return true, nil - } - return false, ErrMissingOrMalformedAPIKey - }, - })) - // Verify that TokenFromContext works with context.Context - app.Get("/", func(c fiber.Ctx) error { - ctx := c.Context() - token := TokenFromContext(ctx) - return c.SendString(token) - }) - - req := httptest.NewRequest(fiber.MethodGet, "/", nil) - req.Header.Add("Authorization", "Basic "+CorrectKey) - res, err := app.Test(req) - require.NoError(t, err) - - body, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.Equal(t, CorrectKey, string(body)) - }) + req := httptest.NewRequest(fiber.MethodGet, "/", nil) + req.Header.Add("Authorization", "Basic "+CorrectKey) + // Send + res, err := app.Test(req) + require.NoError(t, err) - t.Run("invalid context type", func(t *testing.T) { - require.Panics(t, func() { - _ = TokenFromContext("invalid") - }) - }) + // Read the response body into a string + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, CorrectKey, string(body)) } func Test_AuthSchemeToken(t *testing.T) {