Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
name: build
on: [push, pull_request]
name: Test
jobs:
test:
env:
GOPATH: ${{ github.workspace }}
GO111MODULE: on

on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
defaults:
run:
working-directory: ${{ env.GOPATH }}/src/github.com/${{ github.repository }}

jobs:
strategy:
matrix:
go-version: [1.22.x,1.23.x,1.24.x]
os: [ubuntu-latest, macos-latest]

build:
name: Build
runs-on: ubuntu-latest
steps:
runs-on: ${{ matrix.os }}

- name: Set up Go 1.x
uses: actions/setup-go@v2
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
go-version: ^1.14
id: go

- name: Check out code into the Go module directory
uses: actions/checkout@v2
path: ${{ env.GOPATH }}/src/github.com/${{ github.repository }}

- name: Get dependencies
run: go get -v -t -d ./...
- name: Install Go
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go-version }}

- name: Build
run: go build -v .
run: |
go build ./...

- name: Test
run: go test -v ./...
run: |
go test -v -race -run= ./...
37 changes: 37 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
SHELL = bash -o pipefail
TEST_FLAGS ?= -v -race

all:
@echo "make <cmd>"
@echo ""
@echo "commands:"
@echo ""
@echo " + Development:"
@echo " - build"
@echo " - test"
@echo " - todo"
@echo " - clean"
@echo ""
@echo ""


##
## Development
##
build:
go build ./...

clean:
go clean -cache -testcache

test: test-clean
GOGC=off go test $(TEST_FLAGS) -run=$(TEST) ./...

test-clean:
GOGC=off go clean -testcache

bench:
@go test -timeout=25m -bench=.

todo:
@git grep TODO -- './*' ':!./vendor/' ':!./Makefile' || :
3 changes: 2 additions & 1 deletion _example/simple/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"log/slog"
"net/http"
"time"

Expand All @@ -18,7 +19,7 @@ func main() {
w.Write([]byte("index"))
})

cached := stampede.Handler(512, 1*time.Second)
cached := stampede.Handler(slog.Default(), 512, 1*time.Second)

r.With(cached).Get("/cached", func(w http.ResponseWriter, r *http.Request) {
// processing..
Expand Down
7 changes: 4 additions & 3 deletions _example/with-key/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"log/slog"
"net/http"
"strings"
"time"
Expand Down Expand Up @@ -67,11 +68,11 @@ func main() {
})

// Include anything user specific, e.g. Authorization Token
customKeyFunc := func(r *http.Request) uint64 {
customKeyFunc := func(r *http.Request) (uint64, error) {
token := r.Header.Get("Authorization")
return stampede.StringToHash(r.Method, strings.ToLower(strings.ToLower(token)))
return stampede.StringToHash(r.Method, strings.ToLower(strings.ToLower(token))), nil
}
cached := stampede.HandlerWithKey(512, 1*time.Second, customKeyFunc)
cached := stampede.HandlerWithKey(slog.Default(), 512, 1*time.Second, customKeyFunc)

r.With(cached).Get("/me", func(w http.ResponseWriter, r *http.Request) {
// processing..
Expand Down
95 changes: 49 additions & 46 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,35 @@ package stampede

import (
"bytes"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
)

var stripOutHeaders = []string{
"Access-Control-Allow-Credentials",
"Access-Control-Allow-Headers",
"Access-Control-Allow-Methods",
"Access-Control-Allow-Origin",
"Access-Control-Expose-Headers",
"Access-Control-Max-Age",
"Access-Control-Request-Headers",
"Access-Control-Request-Method",
}

func Handler(cacheSize int, ttl time.Duration, paths ...string) func(next http.Handler) http.Handler {
defaultKeyFunc := func(r *http.Request) uint64 {
func Handler(logger *slog.Logger, cacheSize int, ttl time.Duration, paths ...string) func(next http.Handler) http.Handler {
defaultKeyFunc := func(r *http.Request) (uint64, error) {
// Read the request payload, and then setup buffer for future reader
var err error
var buf []byte
if r.Body != nil {
buf, _ = io.ReadAll(r.Body)
buf, err = io.ReadAll(r.Body)
if err != nil {
return 0, err
}
r.Body = io.NopCloser(bytes.NewBuffer(buf))
}

// Prepare cache key based on request URL path and the request data payload.
key := BytesToHash([]byte(strings.ToLower(r.URL.Path)), buf)
return key
return key, nil
}

return HandlerWithKey(cacheSize, ttl, defaultKeyFunc, paths...)
return HandlerWithKey(logger, cacheSize, ttl, defaultKeyFunc, paths...)
}

func HandlerWithKey(cacheSize int, ttl time.Duration, keyFunc func(r *http.Request) uint64, paths ...string) func(next http.Handler) http.Handler {
func HandlerWithKey(logger *slog.Logger, cacheSize int, ttl time.Duration, keyFunc CacheKeyFunc, paths ...string) func(next http.Handler) http.Handler {
// mapping of url paths that are cacheable by the stampede handler
pathMap := map[string]struct{}{}
for _, path := range paths {
Expand All @@ -51,7 +44,7 @@ func HandlerWithKey(cacheSize int, ttl time.Duration, keyFunc func(r *http.Reque
// executes, and the remaining handlers will use the response from
// the first request. The content thereafter will be cached for up to
// ttl time for subsequent requests for further caching.
h := stampede(cacheSize, ttl, keyFunc)
h := stampede(logger, cacheSize, ttl, keyFunc)

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -74,20 +67,28 @@ func HandlerWithKey(cacheSize int, ttl time.Duration, keyFunc func(r *http.Reque
}
}

func stampede(cacheSize int, ttl time.Duration, keyFunc func(r *http.Request) uint64) func(next http.Handler) http.Handler {
type CacheKeyFunc func(r *http.Request) (uint64, error)

func stampede(logger *slog.Logger, cacheSize int, ttl time.Duration, keyFunc CacheKeyFunc) func(next http.Handler) http.Handler {
cache := NewCacheKV[uint64, responseValue](cacheSize, ttl, ttl*2)

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

// cache key for the request
key := keyFunc(r)
key, err := keyFunc(r)
if err != nil {
logger.Warn("stampede: fail to compute cache key", "err", err)
next.ServeHTTP(w, r)
return
}

// mark the request that actually processes the response
first := false

// process request (single flight)
respVal, err := cache.GetFresh(r.Context(), key, func() (responseValue, error) {
// process request (single flight) – this will block all subsequent requests
// until the first request is processed
cachedVal, err := cache.GetFresh(r.Context(), key, func() (responseValue, error) {
first = true
buf := bytes.NewBuffer(nil)
ww := &responseWriter{ResponseWriter: w, tee: buf}
Expand All @@ -101,7 +102,7 @@ func stampede(cacheSize int, ttl time.Duration, keyFunc func(r *http.Request) ui

// the handler may not write header and body in some logic,
// while writing only the body, an attempt is made to write the default header (http.StatusOK)
skip: ww.IsHeaderWrong(),
skip: !ww.IsValid(),
}
return val, nil
})
Expand All @@ -112,35 +113,37 @@ func stampede(cacheSize int, ttl time.Duration, keyFunc func(r *http.Request) ui
return
}

// handle response for other listeners
// handle response for subsequent requests
if err != nil {
// TODO: perhaps just log error and execute standard handler..?
panic(fmt.Sprintf("stampede: fail to get value, %v", err))
logger.Error("stampede: fail to get value, serving standard request handler", "err", err)
next.ServeHTTP(w, r)
return
}

if respVal.skip {
// if the handler did not write a header, then serve the next handler
// a standard request handler
if cachedVal.skip {
next.ServeHTTP(w, r)
return
}

header := w.Header()

nextHeader:
for k := range respVal.headers {
for _, match := range stripOutHeaders {
// Prevent any header in stripOutHeaders to override the current
// value of that header. This is important when you don't want a
// header to affect all subsequent requests (for instance, when
// working with several CORS domains, you don't want the first domain
// to be recorded an to be printed in all responses)
if match == k {
continue nextHeader
}
// copy headers from the first request to the response writer
respHeader := w.Header()
for k, v := range cachedVal.headers {
// Prevent certain headers to override the current
// value of that header. This is important when you don't want a
// header to affect all subsequent requests (for instance, when
// working with several CORS domains, you don't want the first domain
// to be recorded an to be printed in all responses)
headerKey := strings.ToLower(k)
if strings.HasPrefix(headerKey, "access-control-") {
continue
}
header[k] = respVal.headers[k]
respHeader[k] = v
}

w.WriteHeader(respVal.status)
w.Write(respVal.body)
w.WriteHeader(cachedVal.status)
w.Write(cachedVal.body)
})
}
}
Expand Down Expand Up @@ -169,8 +172,8 @@ func (b *responseWriter) WriteHeader(code int) {
}
}

func (b *responseWriter) IsHeaderWrong() bool {
return !b.wroteHeader && (b.code < 100 || b.code > 999)
func (b *responseWriter) IsValid() bool {
return b.wroteHeader && (b.code >= 100 && b.code < 999)
}

func (b *responseWriter) Write(buf []byte) (int, error) {
Expand Down
21 changes: 9 additions & 12 deletions stampede_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"io"
"log"
"log/slog"
"net/http"
"net/http/httptest"
"runtime"
Expand All @@ -26,8 +27,6 @@ func TestGet(t *testing.T) {
// time.Sleep(1 * time.Second)

var wg sync.WaitGroup
numGoroutines := runtime.NumGoroutine()

n := 10
ctx := context.Background()

Expand Down Expand Up @@ -60,8 +59,6 @@ func TestGet(t *testing.T) {

// confirm same before/after num of goroutines
t.Logf("numGoroutines now %d", runtime.NumGoroutine())
assert.Equal(t, numGoroutines, runtime.NumGoroutine())

}
}

Expand Down Expand Up @@ -109,7 +106,7 @@ func TestHandler(t *testing.T) {
})
}

h := stampede.Handler(512, 1*time.Second)
h := stampede.Handler(slog.Default(), 512, 1*time.Second)

ts := httptest.NewServer(counter(recoverer(h(http.HandlerFunc(app)))))
defer ts.Close()
Expand All @@ -122,12 +119,12 @@ func TestHandler(t *testing.T) {
defer wg.Done()
resp, err := http.Get(ts.URL)
if err != nil {
t.Fatal(err)
panic(err)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
panic(err)
}
defer resp.Body.Close()

Expand Down Expand Up @@ -190,7 +187,7 @@ func TestBypassCORSHeaders(t *testing.T) {
atomic.AddUint64(&count, 1)
}

h := stampede.Handler(512, 1*time.Second)
h := stampede.Handler(slog.Default(), 512, 1*time.Second)
c := cors.New(cors.Options{
AllowedOrigins: domains,
AllowedMethods: []string{"GET"},
Expand All @@ -217,12 +214,12 @@ func TestBypassCORSHeaders(t *testing.T) {

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
panic(err)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
panic(err)
}
defer resp.Body.Close()

Expand Down Expand Up @@ -255,9 +252,9 @@ func TestBypassCORSHeaders(t *testing.T) {
}
}

func TestPanic(t *testing.T) {
func TestEmptyHandlerFunc(t *testing.T) {
mux := http.NewServeMux()
middleware := stampede.Handler(100, 1*time.Hour)
middleware := stampede.Handler(slog.Default(), 100, 1*time.Hour)
mux.Handle("/", middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
t.Log(r.Method, r.URL)
})))
Expand Down
Loading