diff --git a/docs/api/log.md b/docs/api/log.md index b1457caf091..c015f833455 100644 --- a/docs/api/log.md +++ b/docs/api/log.md @@ -198,7 +198,84 @@ commonLogger := log.WithContext(ctx) commonLogger.Info("info") ``` -Context binding adds request-specific data for easier tracing. +Context binding adds request-specific data for easier tracing. The method accepts `fiber.Ctx`, `*fasthttp.RequestCtx`, or `context.Context`. When using standard `context.Context` instances (such as `c.Context()`), enable `PassLocalsToContext` in the app config so that values stored in `fiber.Ctx.Locals` are propagated through the context chain. + +### Automatic Context Fields + +Middleware that stores values in the request context can register extractors so that `log.WithContext` automatically includes those values in every log entry. The following middlewares register extractors when their `New()` constructor is called: + +| Middleware | Log Field | Description | +| ------------ | -------------- | -------------------------------- | +| `requestid` | `request-id` | Request identifier | +| `basicauth` | `username` | Authenticated username | +| `keyauth` | `api-key` | API key token (redacted) | +| `csrf` | `csrf-token` | CSRF token (redacted) | +| `session` | `session-id` | Session identifier (redacted) | + +```go +app.Use(requestid.New()) + +app.Get("/", func(c fiber.Ctx) error { + // Automatically includes request-id= in the log output + log.WithContext(c).Info("processing request") + return c.SendString("OK") +}) +``` + +**Example output:** + +```text +2026/03/17 12:00:00.123456 main.go:15: [Info] request-id=abc-123 processing request +``` + +The context fields (`request-id=abc-123`) are automatically prepended to the log message. You can use multiple middlewares and all their fields will be included: + +```go +app.Use(requestid.New()) +app.Use(basicauth.New(basicauth.Config{ + Users: map[string]string{"admin": "password"}, +})) + +app.Get("/", func(c fiber.Ctx) error { + log.WithContext(c).Info("user action") + return c.SendString("OK") +}) +``` + +**Example output:** + +```text +2026/03/17 12:00:00.123456 main.go:20: [Info] request-id=abc-123 username=admin user action +``` + +:::note +**Context fields and logger compatibility:** + +- Context fields are prepended to the log message in `key=value` format +- The fields are extracted once per log call and added before the message +- Works with any logger that implements the `AllLogger` interface +- For JSON or structured logging, use the `Logw` methods (e.g., `log.WithContext(c).Infow("message", "key", "value")`) which preserve field structure +- Context fields are always included when using `log.WithContext()`, regardless of how many times you call the logger in a handler + +::: + +### Custom Context Extractors + +Use `log.RegisterContextExtractor` to register your own extractors. Each extractor receives the bound context and returns a field name, value, and success flag: + +```go +log.RegisterContextExtractor(func(ctx any) (string, any, bool) { + // Use fiber.ValueFromContext to extract from any supported context type + if traceID, ok := fiber.ValueFromContext[string](ctx, traceIDKey); ok && traceID != "" { + return "trace-id", traceID, true + } + return "", nil, false +}) +``` + +:::note +`RegisterContextExtractor` may be called at any time, including while your application is handling requests and emitting logs. Registrations are safe to perform concurrently with logging. In practice, register extractors during program initialization (e.g. in an `init` function or middleware constructor) so that they are in place before requests are processed. +::: ## Logger diff --git a/docs/whats_new.md b/docs/whats_new.md index 56ee957b69d..a22d67a3a9a 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -1215,6 +1215,24 @@ app.Use(logger.New(logger.Config{ })) ``` +### Context-Aware Logging + +`log.WithContext` now automatically includes context fields extracted by middleware. The method accepts `fiber.Ctx`, `*fasthttp.RequestCtx`, or `context.Context`, making it flexible and consistent with Fiber's context handling throughout the framework. + +Middleware such as `requestid`, `basicauth`, `keyauth`, `csrf`, and `session` register extractors when their `New()` constructor is called. When you pass a context to `log.WithContext`, registered fields are prepended to every log entry. + +```go +app.Use(requestid.New()) + +app.Get("/", func(c fiber.Ctx) error { + // Output: [Info] request-id=abc-123 processing request + log.WithContext(c).Info("processing request") + return c.SendString("OK") +}) +``` + +Custom extractors can be registered via `log.RegisterContextExtractor`. See [Log API docs](./api/log.md#custom-context-extractors) for details. + ## 📦 Storage Interface The storage interface has been updated to include new subset of methods with `WithContext` suffix. These methods allow you to pass a context to the storage operations, enabling better control over timeouts and cancellation if needed. This is particularly useful when storage implementations used outside of the Fiber core, such as in background jobs or long-running tasks. diff --git a/log/default.go b/log/default.go index 971644b0b5f..d7723d3c32c 100644 --- a/log/default.go +++ b/log/default.go @@ -1,7 +1,6 @@ package log import ( - "context" "fmt" "io" "log" @@ -14,11 +13,30 @@ import ( var _ AllLogger[*log.Logger] = (*defaultLogger)(nil) type defaultLogger struct { + ctx any stdlog *log.Logger level Level depth int } +// writeContextFields appends extracted context key-value pairs to buf. +// Each pair is written as "key=value " (trailing space included). +func (l *defaultLogger) writeContextFields(buf *bytebufferpool.ByteBuffer) { + if l.ctx == nil { + return + } + extractors := loadContextExtractors() + for _, extractor := range extractors { + key, value, ok := extractor(l.ctx) + if ok && key != "" { + buf.WriteString(key) + buf.WriteByte('=') + buf.WriteString(utils.ToString(value)) + buf.WriteByte(' ') + } + } +} + // privateLog logs a message at a given level log the default logger. // when the level is fatal, it will exit the program. func (l *defaultLogger) privateLog(lv Level, fmtArgs []any) { @@ -28,6 +46,7 @@ func (l *defaultLogger) privateLog(lv Level, fmtArgs []any) { level := lv.toString() buf := bytebufferpool.Get() buf.WriteString(level) + l.writeContextFields(buf) fmt.Fprint(buf, fmtArgs...) _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error @@ -51,6 +70,7 @@ func (l *defaultLogger) privateLogf(lv Level, format string, fmtArgs []any) { level := lv.toString() buf := bytebufferpool.Get() buf.WriteString(level) + l.writeContextFields(buf) if len(fmtArgs) > 0 { _, _ = fmt.Fprintf(buf, format, fmtArgs...) @@ -78,8 +98,7 @@ func (l *defaultLogger) privateLogw(lv Level, format string, keysAndValues []any level := lv.toString() buf := bytebufferpool.Get() buf.WriteString(level) - - // Write format privateLog buffer + l.writeContextFields(buf) if format != "" { buf.WriteString(format) } @@ -220,12 +239,16 @@ func (l *defaultLogger) Panicw(msg string, keysAndValues ...any) { l.privateLogw(LevelPanic, msg, keysAndValues) } -// WithContext returns a logger that shares the underlying output but adjusts the call depth. -func (l *defaultLogger) WithContext(_ context.Context) CommonLogger { +// WithContext returns a logger that shares the underlying output but carries +// the provided context. Any registered ContextExtractor functions will be +// called at log time to prepend key-value fields extracted from the context. +// The ctx parameter can be fiber.Ctx, *fasthttp.RequestCtx, or context.Context. +func (l *defaultLogger) WithContext(ctx any) CommonLogger { return &defaultLogger{ stdlog: l.stdlog, level: l.level, depth: l.depth - 1, + ctx: ctx, } } diff --git a/log/default_test.go b/log/default_test.go index 0cd2c60cead..dbfec7cb714 100644 --- a/log/default_test.go +++ b/log/default_test.go @@ -118,6 +118,124 @@ func Test_CtxLogger(t *testing.T) { "[Panic] work panic\n", string(w.b)) } +type testContextKey struct{} + +func Test_WithContextExtractor(t *testing.T) { + // Save and restore global extractors using the mutex for correctness. + contextExtractorsMu.Lock() + saved := contextExtractors + contextExtractors = nil + contextExtractorsMu.Unlock() + + defer func() { + contextExtractorsMu.Lock() + contextExtractors = saved + contextExtractorsMu.Unlock() + }() + + RegisterContextExtractor(func(ctx any) (string, any, bool) { + if ctxTyped, ok := ctx.(context.Context); ok { + if v, ok := ctxTyped.Value(testContextKey{}).(string); ok && v != "" { + return "request-id", v, true + } + } + return "", nil, false + }) + + t.Run("Info with context field", func(t *testing.T) { + var buf bytes.Buffer + l := &defaultLogger{ + stdlog: log.New(&buf, "", 0), + level: LevelTrace, + depth: 4, + } + ctx := context.WithValue(context.Background(), testContextKey{}, "abc-123") + l.WithContext(ctx).Info("hello") + + require.Equal(t, "[Info] request-id=abc-123 hello\n", buf.String()) + }) + + t.Run("Infof with context field", func(t *testing.T) { + var buf bytes.Buffer + l := &defaultLogger{ + stdlog: log.New(&buf, "", 0), + level: LevelTrace, + depth: 4, + } + ctx := context.WithValue(context.Background(), testContextKey{}, "abc-123") + l.WithContext(ctx).Infof("hello %s", "world") + + require.Equal(t, "[Info] request-id=abc-123 hello world\n", buf.String()) + }) + + t.Run("Infow with context field", func(t *testing.T) { + var buf bytes.Buffer + l := &defaultLogger{ + stdlog: log.New(&buf, "", 0), + level: LevelTrace, + depth: 4, + } + ctx := context.WithValue(context.Background(), testContextKey{}, "abc-123") + l.WithContext(ctx).Infow("hello", "key", "value") + + require.Equal(t, "[Info] request-id=abc-123 hello key=value\n", buf.String()) + }) + + t.Run("no context field when value absent", func(t *testing.T) { + var buf bytes.Buffer + l := &defaultLogger{ + stdlog: log.New(&buf, "", 0), + level: LevelTrace, + depth: 4, + } + ctx := context.Background() + l.WithContext(ctx).Info("hello") + + require.Equal(t, "[Info] hello\n", buf.String()) + }) + + t.Run("no context field without WithContext", func(t *testing.T) { + var buf bytes.Buffer + l := &defaultLogger{ + stdlog: log.New(&buf, "", 0), + level: LevelTrace, + depth: 4, + } + l.Info("hello") + + require.Equal(t, "[Info] hello\n", buf.String()) + }) + + t.Run("empty key extractor is skipped", func(t *testing.T) { + // Save and restore extractors for this subtest using the mutex. + contextExtractorsMu.Lock() + savedInner := contextExtractors + contextExtractorsMu.Unlock() + + defer func() { + contextExtractorsMu.Lock() + contextExtractors = savedInner + contextExtractorsMu.Unlock() + }() + + // Add an extractor that returns ok=true but key="" + RegisterContextExtractor(func(_ any) (string, any, bool) { + return "", "should-not-appear", true + }) + + var buf bytes.Buffer + l := &defaultLogger{ + stdlog: log.New(&buf, "", 0), + level: LevelTrace, + depth: 4, + } + ctx := context.WithValue(context.Background(), testContextKey{}, "abc-123") + l.WithContext(ctx).Info("hello") + + require.Equal(t, "[Info] request-id=abc-123 hello\n", buf.String()) + }) +} + func Test_LogfKeyAndValues(t *testing.T) { tests := []struct { name string diff --git a/log/log.go b/log/log.go index df21e0f971d..7dd9a7be46d 100644 --- a/log/log.go +++ b/log/log.go @@ -1,20 +1,65 @@ package log import ( - "context" "fmt" "io" "log" "os" + "sync" ) +// ContextExtractor extracts a key-value pair from the given context for +// inclusion in log output when using WithContext. +// It returns the log field name, its value, and whether extraction succeeded. +// The ctx parameter can be fiber.Ctx, *fasthttp.RequestCtx, or context.Context. +type ContextExtractor func(ctx any) (string, any, bool) + +// contextExtractorsMu guards contextExtractors for concurrent registration +// and snapshot reads. +var contextExtractorsMu sync.RWMutex + +// contextExtractors holds all registered context field extractors. +// Use loadContextExtractors to obtain a safe snapshot for iteration. +var contextExtractors []ContextExtractor + +// loadContextExtractors returns an immutable snapshot of the registered +// extractors. The returned slice must not be modified. +func loadContextExtractors() []ContextExtractor { + contextExtractorsMu.RLock() + snapshot := contextExtractors + contextExtractorsMu.RUnlock() + return snapshot +} + +// RegisterContextExtractor registers a function that extracts a key-value pair +// from context for inclusion in log output when using WithContext. +// +// This function is safe to call concurrently with logging and with other +// registrations. All calls to RegisterContextExtractor should happen during +// program initialization (e.g. in an init function or middleware constructor) +// so that extractors are in place before requests are processed. +func RegisterContextExtractor(extractor ContextExtractor) { + if extractor == nil { + panic("log: RegisterContextExtractor called with nil extractor") + } + contextExtractorsMu.Lock() + // Copy-on-write: always allocate a new backing array so snapshots taken + // by concurrent readers remain stable. + n := len(contextExtractors) + next := make([]ContextExtractor, n+1) + copy(next, contextExtractors) + next[n] = extractor + contextExtractors = next + contextExtractorsMu.Unlock() +} + // baseLogger defines the minimal logger functionality required by the package. // It allows storing any logger implementation regardless of its generic type. type baseLogger interface { CommonLogger SetLevel(Level) SetOutput(io.Writer) - WithContext(ctx context.Context) CommonLogger + WithContext(ctx any) CommonLogger } var logger baseLogger = &defaultLogger{ @@ -84,7 +129,8 @@ type AllLogger[T any] interface { ConfigurableLogger[T] // WithContext returns a new logger with the given context. - WithContext(ctx context.Context) CommonLogger + // The ctx parameter can be fiber.Ctx, *fasthttp.RequestCtx, or context.Context. + WithContext(ctx any) CommonLogger } // Level defines the priority of a log message. diff --git a/middleware/basicauth/basicauth.go b/middleware/basicauth/basicauth.go index e8878e0f9e7..e3748f66791 100644 --- a/middleware/basicauth/basicauth.go +++ b/middleware/basicauth/basicauth.go @@ -4,10 +4,12 @@ import ( "encoding/base64" "errors" "strings" + "sync" "unicode" "unicode/utf8" "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" "golang.org/x/text/unicode/norm" ) @@ -23,11 +25,24 @@ const ( const basicScheme = "Basic" +// registerExtractor ensures the log context extractor for the authenticated +// username is registered exactly once. +var registerExtractor sync.Once + // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) + // Register a log context extractor so that log.WithContext(c) automatically + // includes the authenticated username when basicauth middleware is in use. + registerExtractor.Do(func() { + log.RegisterContextExtractor(func(ctx any) (string, any, bool) { + username := UsernameFromContext(ctx) + return "username", username, username != "" + }) + }) + var cerr base64.CorruptInputError // Return new handler diff --git a/middleware/basicauth/basicauth_test.go b/middleware/basicauth/basicauth_test.go index dc5a0ab8173..a41819b27f1 100644 --- a/middleware/basicauth/basicauth_test.go +++ b/middleware/basicauth/basicauth_test.go @@ -1,6 +1,7 @@ package basicauth import ( + "bytes" "crypto/sha256" "crypto/sha512" "encoding/base64" @@ -9,10 +10,12 @@ import ( "io" "net/http" "net/http/httptest" + "os" "strings" "testing" "github.com/gofiber/fiber/v3" + fiberlog "github.com/gofiber/fiber/v3/log" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" "golang.org/x/crypto/bcrypt" @@ -630,3 +633,32 @@ func Test_BasicAuth_HashVariants_Invalid(t *testing.T) { require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) } } + +func Test_BasicAuth_LogWithContext(t *testing.T) { + hashedJohn := sha256Hash("doe") + app := fiber.New() + app.Use(New(Config{ + Users: map[string]string{ + "john": hashedJohn, + }, + })) + + var logOutput bytes.Buffer + fiberlog.SetOutput(&logOutput) + defer fiberlog.SetOutput(os.Stderr) + + app.Get("/", func(c fiber.Ctx) error { + fiberlog.WithContext(c).Info("basicauth test") + return c.SendStatus(fiber.StatusOK) + }) + + creds := base64.StdEncoding.EncodeToString([]byte("john:doe")) + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds) + + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + require.Contains(t, logOutput.String(), "username=john") + require.Contains(t, logOutput.String(), "basicauth test") +} diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index 559d05ea331..55e527cc054 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -6,6 +6,7 @@ import ( "net/url" "slices" "strings" + "sync" "time" "github.com/gofiber/utils/v2" @@ -13,6 +14,7 @@ import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" + "github.com/gofiber/fiber/v3/log" ) var ( @@ -46,6 +48,10 @@ const ( handlerKey ) +// registerExtractor ensures the log context extractor for CSRF tokens is +// registered exactly once, regardless of how many times New() is called. +var registerExtractor sync.Once + // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config @@ -53,6 +59,21 @@ func New(config ...Config) fiber.Handler { redactKeys := !cfg.DisableValueRedaction + // Register a log context extractor so that log.WithContext(c) automatically + // includes a redacted CSRF token when the csrf middleware is in use. + // CSRF tokens are always redacted in log output regardless of DisableValueRedaction, + // because they are bearer secrets and must never appear in plain text in logs. + // An empty token (no middleware or middleware skipped) is omitted. + registerExtractor.Do(func() { + log.RegisterContextExtractor(func(ctx any) (string, any, bool) { + token := TokenFromContext(ctx) + if token == "" { + return "", nil, false + } + return "csrf-token", redactedKey, true + }) + }) + maskValue := func(value string) string { if redactKeys { return redactedKey diff --git a/middleware/csrf/csrf_test.go b/middleware/csrf/csrf_test.go index 5637edae570..7946666d7d5 100644 --- a/middleware/csrf/csrf_test.go +++ b/middleware/csrf/csrf_test.go @@ -1,17 +1,20 @@ package csrf import ( + "bytes" "context" "errors" "net" "net/http" "net/http/httptest" + "os" "strings" "testing" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" + fiberlog "github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3/middleware/session" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/require" @@ -2484,3 +2487,23 @@ func Test_CSRF_Extractors_ErrorTypes(t *testing.T) { }) } } + +func Test_CSRF_LogWithContext(t *testing.T) { + app := fiber.New() + app.Use(New()) + + var logOutput bytes.Buffer + fiberlog.SetOutput(&logOutput) + defer fiberlog.SetOutput(os.Stderr) + + app.Get("/", func(c fiber.Ctx) error { + fiberlog.WithContext(c).Info("csrf test") + return c.SendStatus(fiber.StatusOK) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + require.Contains(t, logOutput.String(), "csrf-token=[redacted]") + require.Contains(t, logOutput.String(), "csrf test") +} diff --git a/middleware/keyauth/keyauth.go b/middleware/keyauth/keyauth.go index 109c82d20c1..4eadd50e756 100644 --- a/middleware/keyauth/keyauth.go +++ b/middleware/keyauth/keyauth.go @@ -4,9 +4,11 @@ import ( "errors" "fmt" "strings" + "sync" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" + "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" ) @@ -22,11 +24,28 @@ const ( // ErrMissingOrMalformedAPIKey is returned when the API key is missing or invalid. var ErrMissingOrMalformedAPIKey = errors.New("missing or invalid API Key") +// registerExtractor ensures the log context extractor for API keys is +// registered exactly once, regardless of how many times New() is called. +var registerExtractor sync.Once + // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Init config cfg := configDefault(config...) + // Register a log context extractor so that log.WithContext(c) automatically + // includes a redacted API key when the keyauth middleware is in use. + // An empty token (no middleware or middleware skipped) is omitted. + registerExtractor.Do(func() { + log.RegisterContextExtractor(func(ctx any) (string, any, bool) { + token := TokenFromContext(ctx) + if token == "" { + return "", nil, false + } + return "api-key", redactValue(token), true + }) + }) + // Determine the auth schemes from the extractor chain. authSchemes := getAuthSchemes(cfg.Extractor) @@ -117,3 +136,13 @@ func getAuthSchemes(e extractors.Extractor) []string { } return schemes } + +// redactValue returns a masked version of a sensitive value for safe logging. +// It shows the first 4 characters followed by "****" for values longer than +// 8 characters, or "****" for shorter values. +func redactValue(s string) string { + if len(s) > 8 { + return s[:4] + "****" + } + return "****" +} diff --git a/middleware/keyauth/keyauth_test.go b/middleware/keyauth/keyauth_test.go index 20771e6615f..0842f67f2a1 100644 --- a/middleware/keyauth/keyauth_test.go +++ b/middleware/keyauth/keyauth_test.go @@ -1,12 +1,14 @@ package keyauth import ( + "bytes" "context" "errors" "io" "net/http" "net/http/httptest" "net/url" + "os" "strings" "testing" @@ -15,6 +17,7 @@ import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" + fiberlog "github.com/gofiber/fiber/v3/log" ) const CorrectKey = "correct-token_123./~+" @@ -1130,3 +1133,29 @@ func Test_New_ErrorURIAbsolute(t *testing.T) { }) }) } + +func Test_KeyAuth_LogWithContext(t *testing.T) { + app := fiber.New() + app.Use(New(Config{ + Validator: func(_ fiber.Ctx, key string) (bool, error) { + return key == CorrectKey, nil + }, + })) + + var logOutput bytes.Buffer + fiberlog.SetOutput(&logOutput) + defer fiberlog.SetOutput(os.Stderr) + + app.Get("/", func(c fiber.Ctx) error { + fiberlog.WithContext(c).Info("keyauth test") + return c.SendStatus(fiber.StatusOK) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.Header.Set("Authorization", "Bearer "+CorrectKey) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + require.Contains(t, logOutput.String(), "api-key=corr****") + require.Contains(t, logOutput.String(), "keyauth test") +} diff --git a/middleware/requestid/requestid.go b/middleware/requestid/requestid.go index d998670a5c8..d73e98ad9c5 100644 --- a/middleware/requestid/requestid.go +++ b/middleware/requestid/requestid.go @@ -1,7 +1,10 @@ package requestid import ( + "sync" + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" ) @@ -14,11 +17,25 @@ const ( requestIDKey contextKey = iota ) +// registerExtractor ensures the log context extractor for request IDs is +// registered exactly once, regardless of how many times New() is called. +var registerExtractor sync.Once + // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) + // Register a log context extractor so that log.WithContext(c) automatically + // includes the request ID when the requestid middleware is in use. + // An empty request ID (no middleware or middleware skipped) is omitted. + registerExtractor.Do(func() { + log.RegisterContextExtractor(func(ctx any) (string, any, bool) { + rid := FromContext(ctx) + return "request-id", rid, rid != "" + }) + }) + // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true diff --git a/middleware/requestid/requestid_test.go b/middleware/requestid/requestid_test.go index e6684d40d0a..b2e71602cd7 100644 --- a/middleware/requestid/requestid_test.go +++ b/middleware/requestid/requestid_test.go @@ -1,11 +1,14 @@ package requestid import ( + "bytes" "net/http" "net/http/httptest" + "os" "testing" "github.com/gofiber/fiber/v3" + fiberlog "github.com/gofiber/fiber/v3/log" "github.com/stretchr/testify/require" ) @@ -233,3 +236,60 @@ func Test_RequestID_FromContext_Types(t *testing.T) { require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } + +// Test_RequestID_LogWithContext_FiberCtx verifies that log.WithContext(c) +// automatically includes the request ID when a fiber.Ctx is passed. +func Test_RequestID_LogWithContext_FiberCtx(t *testing.T) { + reqID := "test-request-id-fiber" + + app := fiber.New() + app.Use(New(Config{ + Generator: func() string { + return reqID + }, + })) + + var logOutput bytes.Buffer + fiberlog.SetOutput(&logOutput) + defer fiberlog.SetOutput(os.Stderr) + + app.Get("/", func(c fiber.Ctx) error { + fiberlog.WithContext(c).Info("hello from handler") + return c.SendStatus(fiber.StatusOK) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + require.Contains(t, logOutput.String(), "request-id="+reqID) + require.Contains(t, logOutput.String(), "hello from handler") +} + +// Test_RequestID_LogWithContext_ContextContext verifies that log.WithContext +// works with a context.Context obtained via c.Context() when PassLocalsToContext +// is enabled. +func Test_RequestID_LogWithContext_ContextContext(t *testing.T) { + reqID := "test-request-id-context" + + app := fiber.New(fiber.Config{PassLocalsToContext: true}) + app.Use(New(Config{ + Generator: func() string { + return reqID + }, + })) + + var logOutput bytes.Buffer + fiberlog.SetOutput(&logOutput) + defer fiberlog.SetOutput(os.Stderr) + + app.Get("/", func(c fiber.Ctx) error { + fiberlog.WithContext(c.Context()).Info("hello via context.Context") + return c.SendStatus(fiber.StatusOK) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + require.Contains(t, logOutput.String(), "request-id="+reqID) + require.Contains(t, logOutput.String(), "hello via context.Context") +} diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 78c2ef05384..1fb7dca875a 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/log" ) // Middleware holds session data and configuration. @@ -36,6 +37,10 @@ var ( return &Middleware{} }, } + + // registerExtractor ensures the log context extractor for session IDs is + // registered exactly once. + registerExtractor sync.Once ) // New initializes session middleware with optional configuration. @@ -81,6 +86,24 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) { cfg.Store = NewStore(cfg) } + // Register a log context extractor so that log.WithContext(c) automatically + // includes a redacted session ID when the session middleware is in use. + // Session IDs are bearer secrets, so only the first 4 characters are logged + // to enable correlation without exposing the full token. + registerExtractor.Do(func() { + log.RegisterContextExtractor(func(ctx any) (string, any, bool) { + m := FromContext(ctx) + if m == nil || m.Session == nil { + return "", nil, false + } + id := m.Session.ID() + if id == "" { + return "", nil, false + } + return "session-id", redactSessionID(id), true + }) + }) + handler := func(c fiber.Ctx) error { if cfg.Next != nil && cfg.Next(c) { return c.Next() @@ -184,6 +207,18 @@ func FromContext(ctx any) *Middleware { return nil } +// redactSessionID returns a masked version of a session ID for safe logging. +// Session IDs are bearer secrets; for IDs longer than 8 characters, only the +// first 4 characters are retained and the remainder is masked so that log +// entries can still be correlated without exposing the full token. Shorter +// IDs are fully redacted. +func redactSessionID(id string) string { + if len(id) > 8 { + return id[:4] + "****" + } + return "****" +} + // Set sets a key-value pair in the session. // // Parameters: diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go index a48b6b1b49c..db3c7c31362 100644 --- a/middleware/session/middleware_test.go +++ b/middleware/session/middleware_test.go @@ -1,9 +1,11 @@ package session import ( + "bytes" "fmt" "net/http" "net/http/httptest" + "os" "sort" "strings" "sync" @@ -12,6 +14,7 @@ import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" + fiberlog "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" @@ -590,3 +593,56 @@ func Test_Session_Middleware_Store(t *testing.T) { h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) } + +func Test_Session_LogWithContext(t *testing.T) { + app := fiber.New() + app.Use(New()) + + var logOutput bytes.Buffer + fiberlog.SetOutput(&logOutput) + defer fiberlog.SetOutput(os.Stderr) + + var capturedID string + + app.Get("/", func(c fiber.Ctx) error { + sess := FromContext(c) + require.NotNil(t, sess) + capturedID = sess.Session.ID() + fiberlog.WithContext(c).Info("session test") + return c.SendStatus(fiber.StatusOK) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + // Session ID must appear in log but be redacted (only first 4 chars + ****) + logStr := logOutput.String() + expectedRedacted := redactSessionID(capturedID) + require.Contains(t, logStr, "session-id="+expectedRedacted, "redacted session ID must appear in log") + require.Contains(t, logStr, "session test") + require.NotContains(t, logStr, capturedID, "full session ID must not appear in log") +} + +func Test_redactSessionID(t *testing.T) { + t.Parallel() + + t.Run("long ID is redacted", func(t *testing.T) { + t.Parallel() + require.Equal(t, "abcd****", redactSessionID("abcdefghij")) + }) + + t.Run("short ID is fully redacted", func(t *testing.T) { + t.Parallel() + require.Equal(t, "****", redactSessionID("short")) + }) + + t.Run("exactly 8 chars is fully redacted", func(t *testing.T) { + t.Parallel() + require.Equal(t, "****", redactSessionID("12345678")) + }) + + t.Run("empty string is fully redacted", func(t *testing.T) { + t.Parallel() + require.Equal(t, "****", redactSessionID("")) + }) +}