diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fd898d4..3bda270 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,32 +1,37 @@ -name: build +on: [push, pull_request] +name: Test +jobs: + test: + env: + GOPATH: ${{ github.workspace }} + GO111MODULE: on -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] + defaults: + run: + working-directory: ${{ env.GOPATH }}/src/github.com/${{ github.repository }} -jobs: + strategy: + matrix: + go-version: [1.22.x,1.23.x,1.24.x] + os: [ubuntu-latest, macos-latest] - build: - name: Build - runs-on: ubuntu-latest - steps: + runs-on: ${{ matrix.os }} - - name: Set up Go 1.x - uses: actions/setup-go@v2 + steps: + - name: Checkout code + uses: actions/checkout@v4 with: - go-version: ^1.14 - id: go - - - name: Check out code into the Go module directory - uses: actions/checkout@v2 + path: ${{ env.GOPATH }}/src/github.com/${{ github.repository }} - - name: Get dependencies - run: go get -v -t -d ./... + - name: Install Go + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go-version }} - name: Build - run: go build -v . + run: | + go build ./... - name: Test - run: go test -v ./... + run: | + go test -v -race -run= ./... diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..771cda1 --- /dev/null +++ b/Makefile @@ -0,0 +1,37 @@ +SHELL = bash -o pipefail +TEST_FLAGS ?= -v -race + +all: + @echo "make " + @echo "" + @echo "commands:" + @echo "" + @echo " + Development:" + @echo " - build" + @echo " - test" + @echo " - todo" + @echo " - clean" + @echo "" + @echo "" + + +## +## Development +## +build: + go build ./... + +clean: + go clean -cache -testcache + +test: test-clean + GOGC=off go test $(TEST_FLAGS) -run=$(TEST) ./... + +test-clean: + GOGC=off go clean -testcache + +bench: + @go test -timeout=25m -bench=. + +todo: + @git grep TODO -- './*' ':!./vendor/' ':!./Makefile' || : diff --git a/_example/simple/main.go b/_example/simple/main.go index 059ac39..d50cad5 100644 --- a/_example/simple/main.go +++ b/_example/simple/main.go @@ -1,6 +1,7 @@ package main import ( + "log/slog" "net/http" "time" @@ -18,7 +19,7 @@ func main() { w.Write([]byte("index")) }) - cached := stampede.Handler(512, 1*time.Second) + cached := stampede.Handler(slog.Default(), 512, 1*time.Second) r.With(cached).Get("/cached", func(w http.ResponseWriter, r *http.Request) { // processing.. diff --git a/_example/with-key/main.go b/_example/with-key/main.go index 0f49e2f..1305960 100644 --- a/_example/with-key/main.go +++ b/_example/with-key/main.go @@ -1,6 +1,7 @@ package main import ( + "log/slog" "net/http" "strings" "time" @@ -67,11 +68,11 @@ func main() { }) // Include anything user specific, e.g. Authorization Token - customKeyFunc := func(r *http.Request) uint64 { + customKeyFunc := func(r *http.Request) (uint64, error) { token := r.Header.Get("Authorization") - return stampede.StringToHash(r.Method, strings.ToLower(strings.ToLower(token))) + return stampede.StringToHash(r.Method, strings.ToLower(strings.ToLower(token))), nil } - cached := stampede.HandlerWithKey(512, 1*time.Second, customKeyFunc) + cached := stampede.HandlerWithKey(slog.Default(), 512, 1*time.Second, customKeyFunc) r.With(cached).Get("/me", func(w http.ResponseWriter, r *http.Request) { // processing.. diff --git a/http.go b/http.go index 89d2f64..45601fe 100644 --- a/http.go +++ b/http.go @@ -2,42 +2,35 @@ package stampede import ( "bytes" - "fmt" "io" + "log/slog" "net/http" "strings" "time" ) -var stripOutHeaders = []string{ - "Access-Control-Allow-Credentials", - "Access-Control-Allow-Headers", - "Access-Control-Allow-Methods", - "Access-Control-Allow-Origin", - "Access-Control-Expose-Headers", - "Access-Control-Max-Age", - "Access-Control-Request-Headers", - "Access-Control-Request-Method", -} - -func Handler(cacheSize int, ttl time.Duration, paths ...string) func(next http.Handler) http.Handler { - defaultKeyFunc := func(r *http.Request) uint64 { +func Handler(logger *slog.Logger, cacheSize int, ttl time.Duration, paths ...string) func(next http.Handler) http.Handler { + defaultKeyFunc := func(r *http.Request) (uint64, error) { // Read the request payload, and then setup buffer for future reader + var err error var buf []byte if r.Body != nil { - buf, _ = io.ReadAll(r.Body) + buf, err = io.ReadAll(r.Body) + if err != nil { + return 0, err + } r.Body = io.NopCloser(bytes.NewBuffer(buf)) } // Prepare cache key based on request URL path and the request data payload. key := BytesToHash([]byte(strings.ToLower(r.URL.Path)), buf) - return key + return key, nil } - return HandlerWithKey(cacheSize, ttl, defaultKeyFunc, paths...) + return HandlerWithKey(logger, cacheSize, ttl, defaultKeyFunc, paths...) } -func HandlerWithKey(cacheSize int, ttl time.Duration, keyFunc func(r *http.Request) uint64, paths ...string) func(next http.Handler) http.Handler { +func HandlerWithKey(logger *slog.Logger, cacheSize int, ttl time.Duration, keyFunc CacheKeyFunc, paths ...string) func(next http.Handler) http.Handler { // mapping of url paths that are cacheable by the stampede handler pathMap := map[string]struct{}{} for _, path := range paths { @@ -51,7 +44,7 @@ func HandlerWithKey(cacheSize int, ttl time.Duration, keyFunc func(r *http.Reque // executes, and the remaining handlers will use the response from // the first request. The content thereafter will be cached for up to // ttl time for subsequent requests for further caching. - h := stampede(cacheSize, ttl, keyFunc) + h := stampede(logger, cacheSize, ttl, keyFunc) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -74,20 +67,28 @@ func HandlerWithKey(cacheSize int, ttl time.Duration, keyFunc func(r *http.Reque } } -func stampede(cacheSize int, ttl time.Duration, keyFunc func(r *http.Request) uint64) func(next http.Handler) http.Handler { +type CacheKeyFunc func(r *http.Request) (uint64, error) + +func stampede(logger *slog.Logger, cacheSize int, ttl time.Duration, keyFunc CacheKeyFunc) func(next http.Handler) http.Handler { cache := NewCacheKV[uint64, responseValue](cacheSize, ttl, ttl*2) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // cache key for the request - key := keyFunc(r) + key, err := keyFunc(r) + if err != nil { + logger.Warn("stampede: fail to compute cache key", "err", err) + next.ServeHTTP(w, r) + return + } // mark the request that actually processes the response first := false - // process request (single flight) - respVal, err := cache.GetFresh(r.Context(), key, func() (responseValue, error) { + // process request (single flight) – this will block all subsequent requests + // until the first request is processed + cachedVal, err := cache.GetFresh(r.Context(), key, func() (responseValue, error) { first = true buf := bytes.NewBuffer(nil) ww := &responseWriter{ResponseWriter: w, tee: buf} @@ -101,7 +102,7 @@ func stampede(cacheSize int, ttl time.Duration, keyFunc func(r *http.Request) ui // the handler may not write header and body in some logic, // while writing only the body, an attempt is made to write the default header (http.StatusOK) - skip: ww.IsHeaderWrong(), + skip: !ww.IsValid(), } return val, nil }) @@ -112,35 +113,37 @@ func stampede(cacheSize int, ttl time.Duration, keyFunc func(r *http.Request) ui return } - // handle response for other listeners + // handle response for subsequent requests if err != nil { - // TODO: perhaps just log error and execute standard handler..? - panic(fmt.Sprintf("stampede: fail to get value, %v", err)) + logger.Error("stampede: fail to get value, serving standard request handler", "err", err) + next.ServeHTTP(w, r) + return } - if respVal.skip { + // if the handler did not write a header, then serve the next handler + // a standard request handler + if cachedVal.skip { + next.ServeHTTP(w, r) return } - header := w.Header() - - nextHeader: - for k := range respVal.headers { - for _, match := range stripOutHeaders { - // Prevent any header in stripOutHeaders to override the current - // value of that header. This is important when you don't want a - // header to affect all subsequent requests (for instance, when - // working with several CORS domains, you don't want the first domain - // to be recorded an to be printed in all responses) - if match == k { - continue nextHeader - } + // copy headers from the first request to the response writer + respHeader := w.Header() + for k, v := range cachedVal.headers { + // Prevent certain headers to override the current + // value of that header. This is important when you don't want a + // header to affect all subsequent requests (for instance, when + // working with several CORS domains, you don't want the first domain + // to be recorded an to be printed in all responses) + headerKey := strings.ToLower(k) + if strings.HasPrefix(headerKey, "access-control-") { + continue } - header[k] = respVal.headers[k] + respHeader[k] = v } - w.WriteHeader(respVal.status) - w.Write(respVal.body) + w.WriteHeader(cachedVal.status) + w.Write(cachedVal.body) }) } } @@ -169,8 +172,8 @@ func (b *responseWriter) WriteHeader(code int) { } } -func (b *responseWriter) IsHeaderWrong() bool { - return !b.wroteHeader && (b.code < 100 || b.code > 999) +func (b *responseWriter) IsValid() bool { + return b.wroteHeader && (b.code >= 100 && b.code < 999) } func (b *responseWriter) Write(buf []byte) (int, error) { diff --git a/stampede_test.go b/stampede_test.go index 44afd86..50f95fa 100644 --- a/stampede_test.go +++ b/stampede_test.go @@ -4,6 +4,7 @@ import ( "context" "io" "log" + "log/slog" "net/http" "net/http/httptest" "runtime" @@ -26,8 +27,6 @@ func TestGet(t *testing.T) { // time.Sleep(1 * time.Second) var wg sync.WaitGroup - numGoroutines := runtime.NumGoroutine() - n := 10 ctx := context.Background() @@ -60,8 +59,6 @@ func TestGet(t *testing.T) { // confirm same before/after num of goroutines t.Logf("numGoroutines now %d", runtime.NumGoroutine()) - assert.Equal(t, numGoroutines, runtime.NumGoroutine()) - } } @@ -109,7 +106,7 @@ func TestHandler(t *testing.T) { }) } - h := stampede.Handler(512, 1*time.Second) + h := stampede.Handler(slog.Default(), 512, 1*time.Second) ts := httptest.NewServer(counter(recoverer(h(http.HandlerFunc(app))))) defer ts.Close() @@ -122,12 +119,12 @@ func TestHandler(t *testing.T) { defer wg.Done() resp, err := http.Get(ts.URL) if err != nil { - t.Fatal(err) + panic(err) } body, err := io.ReadAll(resp.Body) if err != nil { - t.Fatal(err) + panic(err) } defer resp.Body.Close() @@ -190,7 +187,7 @@ func TestBypassCORSHeaders(t *testing.T) { atomic.AddUint64(&count, 1) } - h := stampede.Handler(512, 1*time.Second) + h := stampede.Handler(slog.Default(), 512, 1*time.Second) c := cors.New(cors.Options{ AllowedOrigins: domains, AllowedMethods: []string{"GET"}, @@ -217,12 +214,12 @@ func TestBypassCORSHeaders(t *testing.T) { resp, err := http.DefaultClient.Do(req) if err != nil { - t.Fatal(err) + panic(err) } body, err := io.ReadAll(resp.Body) if err != nil { - t.Fatal(err) + panic(err) } defer resp.Body.Close() @@ -255,9 +252,9 @@ func TestBypassCORSHeaders(t *testing.T) { } } -func TestPanic(t *testing.T) { +func TestEmptyHandlerFunc(t *testing.T) { mux := http.NewServeMux() - middleware := stampede.Handler(100, 1*time.Hour) + middleware := stampede.Handler(slog.Default(), 100, 1*time.Hour) mux.Handle("/", middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { t.Log(r.Method, r.URL) })))