diff --git a/go.mod b/go.mod index d9239e49194..04bdc15d210 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/mattn/go-isatty v0.0.16 github.com/mattn/go-runewidth v0.0.14 github.com/valyala/fasthttp v1.41.0 + golang.org/x/sync v0.10.0 golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab ) diff --git a/go.sum b/go.sum index b84535bd1f3..bb65ffb54b2 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,8 @@ github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7Fw golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220906165146-f3363e06e74c/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index f6db49e2cf1..993b31b734b 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -11,21 +11,33 @@ import ( "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/utils" + "golang.org/x/sync/singleflight" ) -// timestampUpdatePeriod is the period which is used to check the cache expiration. -// It should not be too long to provide more or less acceptable expiration error, and in the same -// time it should not be too short to avoid overwhelming of the system +// timestampUpdatePeriod is the period that is used to check the cache expiration. +// It should not be too long to provide more or less acceptable expiration error, and, +// at the same time, it should not be too short to avoid overwhelming the system. const timestampUpdatePeriod = 300 * time.Millisecond +// loadResult holds the response data returned from a singleflight load so waiters +// can apply it to their context without running the handler. +type loadResult struct { + Body []byte + Status int + Ctype []byte + Cencoding []byte + Headers map[string][]byte + Exp uint64 +} + // cache status -// unreachable: when cache is bypass, or invalid -// hit: cache is served -// miss: do not have cache record const ( + // cacheUnreachable: when cache was bypassed or is invalid cacheUnreachable = "unreachable" - cacheHit = "hit" - cacheMiss = "miss" + // cacheHit: cache served + cacheHit = "hit" + // cacheMiss: no cache record for the given key + cacheMiss = "miss" ) // directives @@ -43,11 +55,14 @@ var ignoreHeaders = map[string]interface{}{ "Trailers": nil, "Transfer-Encoding": nil, "Upgrade": nil, - "Content-Type": nil, // already stored explicitely by the cache manager - "Content-Encoding": nil, // already stored explicitely by the cache manager + "Content-Type": nil, // already stored explicitly by the cache manager + "Content-Encoding": nil, // already stored explicitly by the cache manager } -// New creates a new middleware handler +// New creates a new middleware handler. When Config.SingleFlight is true, concurrent +// cache misses for the same key are coalesced (single-flight): only one request runs +// the handler and populates the cache; others wait and share the result, preventing +// cache stampede. Recommend SingleFlight: true for high-concurrency deployments. func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) @@ -63,12 +78,13 @@ func New(config ...Config) fiber.Handler { // Cache settings mux = &sync.RWMutex{} timestamp = uint64(time.Now().Unix()) + sf singleflight.Group ) - // Create manager to simplify storage operations ( see manager.go ) + // Create a manager to simplify storage operations ( see manager.go ) manager := newManager(cfg.Storage) - // Create indexed heap for tracking expirations ( see heap.go ) + // Create an indexed heap to track expirations ( see heap.go ) heap := &indexedHeap{} - // count stored bytes (sizes of response bodies) + // Count bytes stored (sizes of response bodies) var storedBytes uint = 0 // Update timestamp in the configured interval @@ -79,22 +95,24 @@ func New(config ...Config) fiber.Handler { } }() - // Delete key from both manager and storage + // Delete a key from both manager and storage deleteKey := func(dkey string) { manager.delete(dkey) - // External storage saves body data with different key + // External storage saves body data with a different key if cfg.Storage != nil { manager.delete(dkey + "_body") } } - // Return new handler + // Return a new handler return func(c *fiber.Ctx) error { + // ------------------------------------------------------------------------- // Refrain from caching if hasRequestDirective(c, noStore) { return c.Next() } + // ------------------------------------------------------------------------- // Only cache selected methods var isExists bool for _, method := range cfg.Methods { @@ -108,6 +126,7 @@ func New(config ...Config) fiber.Handler { return c.Next() } + // ------------------------------------------------------------------------- // Get key from request // TODO(allocation optimization): try to minimize the allocation from 2 to 1 key := cfg.KeyGenerator(c) + "_" + c.Method() @@ -121,7 +140,7 @@ func New(config ...Config) fiber.Handler { // Get timestamp ts := atomic.LoadUint64(×tamp) - // Check if entry is expired + // Check if entry has expired if e.exp != 0 && ts >= e.exp { deleteKey(key) if cfg.MaxBytes > 0 { @@ -134,6 +153,7 @@ func New(config ...Config) fiber.Handler { if cfg.Storage != nil { e.body = manager.getRaw(key + "_body") } + // Set response headers from cache c.Response().SetBodyRaw(e.body) c.Response().SetStatusCode(e.status) @@ -146,6 +166,7 @@ func New(config ...Config) fiber.Handler { c.Response().Header.SetBytesV(k, v) } } + // Set Cache-Control header if enabled if cfg.CacheControl { maxAge := strconv.FormatUint(e.exp-ts, 10) @@ -163,7 +184,133 @@ func New(config ...Config) fiber.Handler { // make sure we're not blocking concurrent requests - do unlock mux.Unlock() - // Continue stack, return err to Fiber if exist + // ------------------------------------------------------------------------- + // Single-flight path (optional) + // Handle concurrent cache misses (single-flight) -> mitigate cache stampede + if cfg.SingleFlight { + // Single-flight: one request runs the handler and populates cache; others wait and share the result. + v, err, shared := sf.Do(key, func() (any, error) { + if err := c.Next(); err != nil { + return nil, err + } + + // Begin critical section: lock entry and timestamp + mux.Lock() + defer mux.Unlock() + ts := atomic.LoadUint64(×tamp) + e := manager.get(key) + bodySize := uint(len(c.Response().Body())) + + expiration := cfg.Expiration + if cfg.ExpirationGenerator != nil { + expiration = cfg.ExpirationGenerator(c, &cfg) + } + exp := ts + uint64(expiration.Seconds()) + res := loadResult{ + Body: utils.CopyBytes(c.Response().Body()), + Status: c.Response().StatusCode(), + Ctype: utils.CopyBytes(c.Response().Header.ContentType()), + Cencoding: utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding)), + Exp: exp, + } + + // Store response headers if enabled + if cfg.StoreResponseHeaders { + res.Headers = make(map[string][]byte) + c.Response().Header.VisitAll( + func(k []byte, v []byte) { + keyS := string(k) + if _, ok := ignoreHeaders[keyS]; !ok { + res.Headers[keyS] = utils.CopyBytes(v) + } + }, + ) + } + + // If middleware marks request for bypass, return result without caching. + if cfg.Next != nil && cfg.Next(c) { + return res, nil + } + // Skip caching if body won't fit into cache. + if cfg.MaxBytes > 0 && bodySize > cfg.MaxBytes { + return res, nil + } + // Evict oldest entries if cache is full. + if cfg.MaxBytes > 0 { + for storedBytes+bodySize > cfg.MaxBytes { + removedKey, size := heap.removeFirst() + deleteKey(removedKey) + storedBytes -= size + } + } + + // Overwrite pool entry with the new result. + e.body = res.Body + e.status = res.Status + e.ctype = res.Ctype + e.cencoding = res.Cencoding + e.headers = res.Headers + e.exp = res.Exp + + // Update cache size tracking if enabled. + if cfg.MaxBytes > 0 { + e.heapidx = heap.put(key, e.exp, bodySize) + storedBytes += bodySize + } + + // Store entry in external storage if enabled. + if cfg.Storage != nil { + manager.setRaw(key+"_body", e.body, expiration) + // Avoid body msgp encoding. + e.body = nil + manager.set(key, e, expiration) + manager.release(e) + } else { + // Store entry in memory. + manager.set(key, e, expiration) + } + return res, nil + }) + if err != nil { + return err + } + + // If result was shared (other request already populated cache), apply it to our context. + if shared { + // Waiter: apply shared result to our context + res := v.(loadResult) + c.Response().SetBodyRaw(res.Body) + c.Response().SetStatusCode(res.Status) + c.Response().Header.SetContentTypeBytes(res.Ctype) + + // Set content encoding if defined. + if len(res.Cencoding) > 0 { + c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, res.Cencoding) + } + + // Pass headers if defined. + if res.Headers != nil { + for k, v := range res.Headers { + c.Response().Header.SetBytesV(k, v) + } + } + + // Set Cache-Control header if enabled. + if cfg.CacheControl { + ts := atomic.LoadUint64(×tamp) + maxAge := strconv.FormatUint(res.Exp-ts, 10) + c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge) + } + } + + // Set cache status header. + c.Set(cfg.CacheHeader, cacheMiss) + return nil + } + + // Otherwise, the default non-single-flight path. + + // Continue stack, return err to Fiber if exists if err := c.Next(); err != nil { return err } @@ -248,7 +395,7 @@ func New(config ...Config) fiber.Handler { } } -// Check if request has directive +// Check if request has a directive. func hasRequestDirective(c *fiber.Ctx, directive string) bool { return strings.Contains(c.Get(fiber.HeaderCacheControl), directive) } diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index e2610238207..d9ec1ded8aa 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -11,6 +11,8 @@ import ( "net/http/httptest" "os" "strconv" + "sync" + "sync/atomic" "testing" "time" @@ -108,6 +110,77 @@ func Test_Cache(t *testing.T) { utils.AssertEqual(t, cachedBody, body) } +// Test_Cache_SingleFlight verifies that with SingleFlight enabled, concurrent +// misses for the same key result in exactly one handler invocation and all +// requesters receive the same response (stampede prevention). +func Test_Cache_SingleFlight(t *testing.T) { + t.Parallel() + + var handlerCalls int64 + app := fiber.New() + app.Use(New(Config{ + Expiration: 10 * time.Second, + SingleFlight: true, + KeyGenerator: func(c *fiber.Ctx) string { return "/singleflight" }, + })) + + app.Get("/singleflight", func(c *fiber.Ctx) error { + n := atomic.AddInt64(&handlerCalls, 1) + return c.SendString(fmt.Sprintf("ok-%d", n)) + }) + + // Cold cache: fire many concurrent requests for the same key. Only one + // handler run should occur; all requesters get the same body. + const concurrency = 50 + var wg sync.WaitGroup + bodies := make([][]byte, concurrency) + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + req := httptest.NewRequest("GET", "/singleflight", nil) + resp, err := app.Test(req) + if err != nil { + t.Errorf("request %d: %v", idx, err) + return + } + body, _ := io.ReadAll(resp.Body) + bodies[idx] = body + }(i) + } + wg.Wait() + + utils.AssertEqual(t, int64(1), atomic.LoadInt64(&handlerCalls), "handler should be invoked exactly once") + expectedBody := []byte("ok-1") + for i := 0; i < concurrency; i++ { + utils.AssertEqual(t, expectedBody, bodies[i], fmt.Sprintf("request %d body", i)) + } +} + +// Test_Cache_DefaultConfig_BackwardsCompatible ensures default config (SingleFlight false) +// keeps existing behavior: no coalescing; existing tests pass unchanged. +func Test_Cache_DefaultConfig_BackwardsCompatible(t *testing.T) { + t.Parallel() + + app := fiber.New() + app.Use(New()) // SingleFlight defaults to false + + app.Get("/", func(c *fiber.Ctx) error { + return c.SendString("default") + }) + + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + body, _ := io.ReadAll(resp.Body) + utils.AssertEqual(t, []byte("default"), body) + + resp2, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + body2, _ := io.ReadAll(resp2.Body) + utils.AssertEqual(t, []byte("default"), body2) + utils.AssertEqual(t, cacheHit, resp2.Header.Get("X-Cache")) +} + // go test -run Test_Cache_WithNoCacheRequestDirective func Test_Cache_WithNoCacheRequestDirective(t *testing.T) { t.Parallel() diff --git a/middleware/cache/config.go b/middleware/cache/config.go index 12f81e2ae8d..b62cad2e386 100644 --- a/middleware/cache/config.go +++ b/middleware/cache/config.go @@ -72,6 +72,20 @@ type Config struct { // // Default: []string{fiber.MethodGet, fiber.MethodHead} Methods []string + + // SingleFlight, when true, prevents cache stampede by coalescing concurrent + // misses for the same key: only one request runs the handler and populates + // the cache; others wait and share the result. Recommend true for high-concurrency setups. + // + // Default: false + SingleFlight bool + + // StaleWhileRevalidate, when > 0, allows serving stale responses for expired + // entries while one revalidation runs (one handler run per key). 0 disables. + // Full stale-while-revalidate behavior may be added in a follow-up. + // + // Default: 0 + StaleWhileRevalidate time.Duration } // ConfigDefault is the default config @@ -88,6 +102,8 @@ var ConfigDefault = Config{ Storage: nil, MaxBytes: 0, Methods: []string{fiber.MethodGet, fiber.MethodHead}, + SingleFlight: false, + StaleWhileRevalidate: 0, } // Helper function to set default values @@ -124,5 +140,6 @@ func configDefault(config ...Config) Config { if len(cfg.Methods) == 0 { cfg.Methods = ConfigDefault.Methods } + // SingleFlight and StaleWhileRevalidate have zero-value defaults; no need to set return cfg }