Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/mattn/go-isatty v0.0.16
github.com/mattn/go-runewidth v0.0.14
github.com/valyala/fasthttp v1.41.0
golang.org/x/sync v0.10.0
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7Fw
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220906165146-f3363e06e74c/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
Expand Down
187 changes: 167 additions & 20 deletions middleware/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,33 @@ import (

"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"golang.org/x/sync/singleflight"
)

// timestampUpdatePeriod is the period which is used to check the cache expiration.
// It should not be too long to provide more or less acceptable expiration error, and in the same
// time it should not be too short to avoid overwhelming of the system
// timestampUpdatePeriod is the period that is used to check the cache expiration.
// It should not be too long to provide more or less acceptable expiration error, and,
// at the same time, it should not be too short to avoid overwhelming the system.
const timestampUpdatePeriod = 300 * time.Millisecond

// loadResult holds the response data returned from a singleflight load so waiters
// can apply it to their context without running the handler.
type loadResult struct {
Body []byte
Status int
Ctype []byte
Cencoding []byte
Headers map[string][]byte
Exp uint64
}

// cache status
// unreachable: when cache is bypass, or invalid
// hit: cache is served
// miss: do not have cache record
const (
// cacheUnreachable: when cache was bypassed or is invalid
cacheUnreachable = "unreachable"
cacheHit = "hit"
cacheMiss = "miss"
// cacheHit: cache served
cacheHit = "hit"
// cacheMiss: no cache record for the given key
cacheMiss = "miss"
)

// directives
Expand All @@ -43,11 +55,14 @@ var ignoreHeaders = map[string]interface{}{
"Trailers": nil,
"Transfer-Encoding": nil,
"Upgrade": nil,
"Content-Type": nil, // already stored explicitely by the cache manager
"Content-Encoding": nil, // already stored explicitely by the cache manager
"Content-Type": nil, // already stored explicitly by the cache manager
"Content-Encoding": nil, // already stored explicitly by the cache manager
}

// New creates a new middleware handler
// New creates a new middleware handler. When Config.SingleFlight is true, concurrent
// cache misses for the same key are coalesced (single-flight): only one request runs
// the handler and populates the cache; others wait and share the result, preventing
// cache stampede. Recommend SingleFlight: true for high-concurrency deployments.
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
Expand All @@ -63,12 +78,13 @@ func New(config ...Config) fiber.Handler {
// Cache settings
mux = &sync.RWMutex{}
timestamp = uint64(time.Now().Unix())
sf singleflight.Group
)
// Create manager to simplify storage operations ( see manager.go )
// Create a manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Create indexed heap for tracking expirations ( see heap.go )
// Create an indexed heap to track expirations ( see heap.go )
heap := &indexedHeap{}
// count stored bytes (sizes of response bodies)
// Count bytes stored (sizes of response bodies)
var storedBytes uint = 0

// Update timestamp in the configured interval
Expand All @@ -79,22 +95,24 @@ func New(config ...Config) fiber.Handler {
}
}()

// Delete key from both manager and storage
// Delete a key from both manager and storage
deleteKey := func(dkey string) {
manager.delete(dkey)
// External storage saves body data with different key
// External storage saves body data with a different key
if cfg.Storage != nil {
manager.delete(dkey + "_body")
}
}

// Return new handler
// Return a new handler
return func(c *fiber.Ctx) error {
// -------------------------------------------------------------------------
// Refrain from caching
if hasRequestDirective(c, noStore) {
return c.Next()
}

// -------------------------------------------------------------------------
// Only cache selected methods
var isExists bool
for _, method := range cfg.Methods {
Expand All @@ -108,6 +126,7 @@ func New(config ...Config) fiber.Handler {
return c.Next()
}

// -------------------------------------------------------------------------
// Get key from request
// TODO(allocation optimization): try to minimize the allocation from 2 to 1
key := cfg.KeyGenerator(c) + "_" + c.Method()
Expand All @@ -121,7 +140,7 @@ func New(config ...Config) fiber.Handler {
// Get timestamp
ts := atomic.LoadUint64(&timestamp)

// Check if entry is expired
// Check if entry has expired
if e.exp != 0 && ts >= e.exp {
deleteKey(key)
if cfg.MaxBytes > 0 {
Expand All @@ -134,6 +153,7 @@ func New(config ...Config) fiber.Handler {
if cfg.Storage != nil {
e.body = manager.getRaw(key + "_body")
}

// Set response headers from cache
c.Response().SetBodyRaw(e.body)
c.Response().SetStatusCode(e.status)
Expand All @@ -146,6 +166,7 @@ func New(config ...Config) fiber.Handler {
c.Response().Header.SetBytesV(k, v)
}
}

// Set Cache-Control header if enabled
if cfg.CacheControl {
maxAge := strconv.FormatUint(e.exp-ts, 10)
Expand All @@ -163,7 +184,133 @@ func New(config ...Config) fiber.Handler {
// make sure we're not blocking concurrent requests - do unlock
mux.Unlock()

// Continue stack, return err to Fiber if exist
// -------------------------------------------------------------------------
// Single-flight path (optional)
// Handle concurrent cache misses (single-flight) -> mitigate cache stampede
if cfg.SingleFlight {
// Single-flight: one request runs the handler and populates cache; others wait and share the result.
v, err, shared := sf.Do(key, func() (any, error) {
if err := c.Next(); err != nil {
return nil, err
}

// Begin critical section: lock entry and timestamp
mux.Lock()
defer mux.Unlock()
ts := atomic.LoadUint64(&timestamp)
e := manager.get(key)
bodySize := uint(len(c.Response().Body()))

expiration := cfg.Expiration
if cfg.ExpirationGenerator != nil {
expiration = cfg.ExpirationGenerator(c, &cfg)
}
exp := ts + uint64(expiration.Seconds())
res := loadResult{
Body: utils.CopyBytes(c.Response().Body()),
Status: c.Response().StatusCode(),
Ctype: utils.CopyBytes(c.Response().Header.ContentType()),
Cencoding: utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding)),
Exp: exp,
}

// Store response headers if enabled
if cfg.StoreResponseHeaders {
res.Headers = make(map[string][]byte)
c.Response().Header.VisitAll(
func(k []byte, v []byte) {
keyS := string(k)
if _, ok := ignoreHeaders[keyS]; !ok {
res.Headers[keyS] = utils.CopyBytes(v)
}
},
)
}

// If middleware marks request for bypass, return result without caching.
if cfg.Next != nil && cfg.Next(c) {
return res, nil
}
// Skip caching if body won't fit into cache.
if cfg.MaxBytes > 0 && bodySize > cfg.MaxBytes {
return res, nil
}
// Evict oldest entries if cache is full.
if cfg.MaxBytes > 0 {
for storedBytes+bodySize > cfg.MaxBytes {
removedKey, size := heap.removeFirst()
deleteKey(removedKey)
storedBytes -= size
}
}

// Overwrite pool entry with the new result.
e.body = res.Body
e.status = res.Status
e.ctype = res.Ctype
e.cencoding = res.Cencoding
e.headers = res.Headers
e.exp = res.Exp

// Update cache size tracking if enabled.
if cfg.MaxBytes > 0 {
e.heapidx = heap.put(key, e.exp, bodySize)
storedBytes += bodySize
}

// Store entry in external storage if enabled.
if cfg.Storage != nil {
manager.setRaw(key+"_body", e.body, expiration)
// Avoid body msgp encoding.
e.body = nil
manager.set(key, e, expiration)
manager.release(e)
} else {
// Store entry in memory.
manager.set(key, e, expiration)
}
return res, nil
})
if err != nil {
return err
}

// If result was shared (other request already populated cache), apply it to our context.
if shared {
// Waiter: apply shared result to our context
res := v.(loadResult)
c.Response().SetBodyRaw(res.Body)
c.Response().SetStatusCode(res.Status)
c.Response().Header.SetContentTypeBytes(res.Ctype)

// Set content encoding if defined.
if len(res.Cencoding) > 0 {
c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, res.Cencoding)
}

// Pass headers if defined.
if res.Headers != nil {
for k, v := range res.Headers {
c.Response().Header.SetBytesV(k, v)
}
}

// Set Cache-Control header if enabled.
if cfg.CacheControl {
ts := atomic.LoadUint64(&timestamp)
maxAge := strconv.FormatUint(res.Exp-ts, 10)
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
}
}

// Set cache status header.
c.Set(cfg.CacheHeader, cacheMiss)
return nil
}

// Otherwise, the default non-single-flight path.

// Continue stack, return err to Fiber if exists
if err := c.Next(); err != nil {
return err
}
Expand Down Expand Up @@ -248,7 +395,7 @@ func New(config ...Config) fiber.Handler {
}
}

// Check if request has directive
// Check if request has a directive.
func hasRequestDirective(c *fiber.Ctx, directive string) bool {
return strings.Contains(c.Get(fiber.HeaderCacheControl), directive)
}
73 changes: 73 additions & 0 deletions middleware/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"net/http/httptest"
"os"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -108,6 +110,77 @@ func Test_Cache(t *testing.T) {
utils.AssertEqual(t, cachedBody, body)
}

// Test_Cache_SingleFlight verifies that with SingleFlight enabled, concurrent
// misses for the same key result in exactly one handler invocation and all
// requesters receive the same response (stampede prevention).
func Test_Cache_SingleFlight(t *testing.T) {
t.Parallel()

var handlerCalls int64
app := fiber.New()
app.Use(New(Config{
Expiration: 10 * time.Second,
SingleFlight: true,
KeyGenerator: func(c *fiber.Ctx) string { return "/singleflight" },
}))

app.Get("/singleflight", func(c *fiber.Ctx) error {
n := atomic.AddInt64(&handlerCalls, 1)
return c.SendString(fmt.Sprintf("ok-%d", n))
})

// Cold cache: fire many concurrent requests for the same key. Only one
// handler run should occur; all requesters get the same body.
const concurrency = 50
var wg sync.WaitGroup
bodies := make([][]byte, concurrency)
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
req := httptest.NewRequest("GET", "/singleflight", nil)
resp, err := app.Test(req)
if err != nil {
t.Errorf("request %d: %v", idx, err)
return
}
body, _ := io.ReadAll(resp.Body)
bodies[idx] = body
}(i)
}
wg.Wait()

utils.AssertEqual(t, int64(1), atomic.LoadInt64(&handlerCalls), "handler should be invoked exactly once")
expectedBody := []byte("ok-1")
for i := 0; i < concurrency; i++ {
utils.AssertEqual(t, expectedBody, bodies[i], fmt.Sprintf("request %d body", i))
}
}

// Test_Cache_DefaultConfig_BackwardsCompatible ensures default config (SingleFlight false)
// keeps existing behavior: no coalescing; existing tests pass unchanged.
func Test_Cache_DefaultConfig_BackwardsCompatible(t *testing.T) {
t.Parallel()

app := fiber.New()
app.Use(New()) // SingleFlight defaults to false

app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("default")
})

resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
utils.AssertEqual(t, nil, err)
body, _ := io.ReadAll(resp.Body)
utils.AssertEqual(t, []byte("default"), body)

resp2, err := app.Test(httptest.NewRequest("GET", "/", nil))
utils.AssertEqual(t, nil, err)
body2, _ := io.ReadAll(resp2.Body)
utils.AssertEqual(t, []byte("default"), body2)
utils.AssertEqual(t, cacheHit, resp2.Header.Get("X-Cache"))
}

// go test -run Test_Cache_WithNoCacheRequestDirective
func Test_Cache_WithNoCacheRequestDirective(t *testing.T) {
t.Parallel()
Expand Down
Loading