diff --git a/Dockerfile b/Dockerfile index 73da3eba..ad7bbb9d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,11 +3,13 @@ FROM golang:1.10 RUN mkdir -p /go/src/github.com/openfaas-incubator/of-watchdog WORKDIR /go/src/github.com/openfaas-incubator/of-watchdog -COPY vendor vendor -COPY config config -COPY executor executor -COPY metrics metrics -COPY main.go . +COPY vendor vendor +COPY config config +COPY executor executor +COPY metrics metrics +COPY concurrency-limiter concurrency-limiter +COPY metrics metrics +COPY main.go . # Run a gofmt and exclude all vendored code. RUN test -z "$(gofmt -l $(find . -type f -name '*.go' -not -path "./vendor/*"))" diff --git a/README.md b/README.md index 6ade7880..e8eb3a38 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,6 @@ Environmental variables: | `suppress_lock` | Yes | When set to `false` the watchdog will attempt to write a lockfile to /tmp/ for healthchecks. Default `false` | | `upstream_url` | Yes | `http` mode only - where to forward requests i.e. `127.0.0.1:5000` | | `buffer_http` | Yes | `http` mode only - buffers request body to memory before fowarding. Use if your upstream HTTP server does not accept `Transfer-Encoding: chunked` Default: `false` | - +| `max_inflight` | Yes | Limit the maximum number of requests in flight | > Note: the .lock file is implemented for health-checking, but cannot be disabled yet. You must create this file in /tmp/. diff --git a/concurrency-limiter/concurrency_limiter.go b/concurrency-limiter/concurrency_limiter.go new file mode 100644 index 00000000..30c75d89 --- /dev/null +++ b/concurrency-limiter/concurrency_limiter.go @@ -0,0 +1,67 @@ +package limiter + +import ( + "fmt" + "net/http" + "sync/atomic" +) + +type ConcurrencyLimiter struct { + backendHTTPHandler http.Handler + /* + We keep two counters here in order to make it so that we can know when a request has gone to completed + in the tests. We could wrap these up in a condvar, so there's no need to spinlock, but that seems overkill + for testing. + + This is effectively a very fancy semaphore built for optimistic concurrency only, and with spinlocks. If + you want to add timeouts here / pessimistic concurrency, signaling needs to be added and/or a condvar esque + sorta thing needs to be done to wake up waiters who are waiting post-spin. + + Otherwise, there's all sorts of futzing in order to make sure that the concurrency limiter handler + has completed + The math works on overflow: + var x, y uint64 + x = (1 << 64 - 1) + y = (1 << 64 - 1) + x++ + fmt.Println(x) + fmt.Println(y) + fmt.Println(x - y) + Prints: + 0 + 18446744073709551615 + 1 + */ + requestsStarted uint64 + requestsCompleted uint64 + + maxInflightRequests uint64 +} + +func (cl *ConcurrencyLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + requestsStarted := atomic.AddUint64(&cl.requestsStarted, 1) + completedRequested := atomic.LoadUint64(&cl.requestsCompleted) + if requestsStarted-completedRequested > cl.maxInflightRequests { + // This is a failure pathway, and we do not want to block on the write to finish + atomic.AddUint64(&cl.requestsCompleted, 1) + w.WriteHeader(http.StatusTooManyRequests) + fmt.Fprintf(w, "Concurrent request limit exceeded. Max concurrent requests: %d\n", cl.maxInflightRequests) + return + } + cl.backendHTTPHandler.ServeHTTP(w, r) + atomic.AddUint64(&cl.requestsCompleted, 1) +} + +// NewConcurrencyLimiter creates a handler which limits the active number of active, concurrent +// requests. If the concurrency limit is less than, or equal to 0, then it will just return the handler +// passed to it. +func NewConcurrencyLimiter(handler http.Handler, concurrencyLimit int) http.Handler { + if concurrencyLimit <= 0 { + return handler + } + + return &ConcurrencyLimiter{ + backendHTTPHandler: handler, + maxInflightRequests: uint64(concurrencyLimit), + } +} diff --git a/concurrency-limiter/concurrency_limiter_test.go b/concurrency-limiter/concurrency_limiter_test.go new file mode 100644 index 00000000..70146e1b --- /dev/null +++ b/concurrency-limiter/concurrency_limiter_test.go @@ -0,0 +1,204 @@ +package limiter + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" +) + +func makeFakeHandler(ctx context.Context, completeInFlightRequestChan chan struct{}) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + select { + case <-ctx.Done(): + w.WriteHeader(http.StatusServiceUnavailable) + case <-completeInFlightRequestChan: + w.WriteHeader(http.StatusOK) + } + } +} + +func doRRandRequest(ctx context.Context, wg *sync.WaitGroup, cl http.Handler) *httptest.ResponseRecorder { + // If wait for handler is true, it waits until the code is in the handler function + rr := httptest.NewRecorder() + // This should never fail unless we're out of memory or something + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + panic(err) + } + req = req.WithContext(ctx) + + wg.Add(1) + go func() { + // If this code path is meant to make it into the handler, we need a way to figure out if it's there or not + cl.ServeHTTP(rr, req) + // If the request was aborted, unblock any waiting goroutines + wg.Done() + }() + + return rr +} + +func TestConcurrencyLimitUnderLimit(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + completeInFlightRequestChan := make(chan struct{}) + handler := makeFakeHandler(ctx, completeInFlightRequestChan) + cl := NewConcurrencyLimiter(http.Handler(handler), 2) + + wg := &sync.WaitGroup{} + rr1 := doRRandRequest(ctx, wg, cl) + // This will "release" the request rr1 + completeInFlightRequestChan <- struct{}{} + + // This should never take more than the timeout + wg.Wait() + + // We want to access the response recorder directly, so we don't accidentally get an implicitly correct answer + if rr1.Code != http.StatusOK { + t.Fatalf("Want response code %d, got: %d", http.StatusOK, rr1.Code) + } +} + +func TestConcurrencyLimitAtLimit(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + completeInFlightRequestChan := make(chan struct{}) + handler := makeFakeHandler(ctx, completeInFlightRequestChan) + + cl := NewConcurrencyLimiter(http.Handler(handler), 2) + + wg := &sync.WaitGroup{} + rr1 := doRRandRequest(ctx, wg, cl) + rr2 := doRRandRequest(ctx, wg, cl) + + completeInFlightRequestChan <- struct{}{} + completeInFlightRequestChan <- struct{}{} + + wg.Wait() + + if rr1.Code != http.StatusOK { + t.Fatalf("Want response code %d, got: %d", http.StatusOK, rr1.Code) + } + if rr2.Code != http.StatusOK { + t.Fatalf("Want response code %d, got: %d", http.StatusOK, rr1.Code) + } +} + +func count(r *httptest.ResponseRecorder, code200s, code429s *int) { + switch r.Code { + case http.StatusTooManyRequests: + *code429s = *code429s + 1 + case http.StatusOK: + *code200s = *code200s + 1 + default: + panic(fmt.Sprintf("Unknown code: %d", r.Code)) + } +} + +func TestConcurrencyLimitOverLimit(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + completeInFlightRequestChan := make(chan struct{}, 3) + handler := makeFakeHandler(ctx, completeInFlightRequestChan) + + cl := NewConcurrencyLimiter(http.Handler(handler), 2).(*ConcurrencyLimiter) + + wg := &sync.WaitGroup{} + + rr1 := doRRandRequest(ctx, wg, cl) + rr2 := doRRandRequest(ctx, wg, cl) + for ctx.Err() == nil { + if requestsStarted := atomic.LoadUint64(&cl.requestsStarted); requestsStarted == 2 { + break + } + time.Sleep(time.Millisecond) + } + rr3 := doRRandRequest(ctx, wg, cl) + for ctx.Err() == nil { + if requestsStarted := atomic.LoadUint64(&cl.requestsStarted); requestsStarted == 3 { + break + } + time.Sleep(time.Millisecond) + } + completeInFlightRequestChan <- struct{}{} + completeInFlightRequestChan <- struct{}{} + completeInFlightRequestChan <- struct{}{} + + wg.Wait() + + code200s := 0 + code429s := 0 + count(rr1, &code200s, &code429s) + count(rr2, &code200s, &code429s) + count(rr3, &code200s, &code429s) + if code200s != 2 || code429s != 1 { + t.Fatalf("code 200s: %d, and code429s: %d", code200s, code429s) + } +} + +func TestConcurrencyLimitOverLimitAndRecover(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + completeInFlightRequestChan := make(chan struct{}, 4) + handler := makeFakeHandler(ctx, completeInFlightRequestChan) + cl := NewConcurrencyLimiter(http.Handler(handler), 2).(*ConcurrencyLimiter) + + wg := &sync.WaitGroup{} + + rr1 := doRRandRequest(ctx, wg, cl) + rr2 := doRRandRequest(ctx, wg, cl) + for ctx.Err() == nil { + if requestsStarted := atomic.LoadUint64(&cl.requestsStarted); requestsStarted == 2 { + break + } + time.Sleep(time.Millisecond) + } + // This will 429 + rr3 := doRRandRequest(ctx, wg, cl) + for ctx.Err() == nil { + if requestsStarted := atomic.LoadUint64(&cl.requestsStarted); requestsStarted == 3 { + break + } + time.Sleep(time.Millisecond) + } + completeInFlightRequestChan <- struct{}{} + completeInFlightRequestChan <- struct{}{} + completeInFlightRequestChan <- struct{}{} + // Although we could do another wg.Wait here, I don't think we should because + // it might provide a false sense of confidence + for ctx.Err() == nil { + if requestsCompleted := atomic.LoadUint64(&cl.requestsCompleted); requestsCompleted == 3 { + break + } + time.Sleep(time.Millisecond) + } + rr4 := doRRandRequest(ctx, wg, cl) + completeInFlightRequestChan <- struct{}{} + wg.Wait() + + code200s := 0 + code429s := 0 + count(rr1, &code200s, &code429s) + count(rr2, &code200s, &code429s) + count(rr3, &code200s, &code429s) + count(rr4, &code200s, &code429s) + + if code200s != 3 || code429s != 1 { + t.Fatalf("code 200s: %d, and code429s: %d", code200s, code429s) + } +} diff --git a/config/config.go b/config/config.go index f5e33290..82165ea5 100644 --- a/config/config.go +++ b/config/config.go @@ -28,6 +28,12 @@ type WatchdogConfig struct { // MetricsPort TCP port on which to serve HTTP Prometheus metrics MetricsPort int + + // MaxInflight limits the number of simultaneous + // requests that the watchdog allows concurrently. + // Any request which exceeds this limit will + // have an immediate response of 429. + MaxInflight int } // Process returns a string for the process and a slice for the arguments from the FunctionProcess. @@ -81,6 +87,7 @@ func New(env []string) (WatchdogConfig, error) { UpstreamURL: upstreamURL, BufferHTTPBody: getBool(envMap, "buffer_http"), MetricsPort: 8081, + MaxInflight: getInt(envMap, "max_inflight", 0), } if val := envMap["mode"]; len(val) > 0 { diff --git a/main.go b/main.go index d0cc23d0..95c089a5 100644 --- a/main.go +++ b/main.go @@ -15,10 +15,10 @@ import ( "syscall" "time" - "github.com/openfaas-incubator/of-watchdog/metrics" - + limiter "github.com/openfaas-incubator/of-watchdog/concurrency-limiter" "github.com/openfaas-incubator/of-watchdog/config" "github.com/openfaas-incubator/of-watchdog/executor" + "github.com/openfaas-incubator/of-watchdog/metrics" ) var ( @@ -45,7 +45,6 @@ func main() { httpMetrics := metrics.NewHttp() http.HandleFunc("/", metrics.InstrumentHandler(requestHandler, httpMetrics)) - http.HandleFunc("/_/health", makeHealthHandler()) metricsServer := metrics.MetricsServer{} @@ -129,7 +128,7 @@ func listenUntilShutdown(shutdownTimeout time.Duration, s *http.Server, suppress <-idleConnsClosed } -func buildRequestHandler(watchdogConfig config.WatchdogConfig) http.HandlerFunc { +func buildRequestHandler(watchdogConfig config.WatchdogConfig) http.Handler { var requestHandler http.HandlerFunc switch watchdogConfig.OperationalMode { @@ -150,6 +149,10 @@ func buildRequestHandler(watchdogConfig config.WatchdogConfig) http.HandlerFunc break } + if watchdogConfig.MaxInflight > 0 { + return limiter.NewConcurrencyLimiter(requestHandler, watchdogConfig.MaxInflight) + } + return requestHandler } diff --git a/metrics/metrics.go b/metrics/metrics.go index d5283d96..efdda8a6 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -59,7 +59,7 @@ func (m *MetricsServer) Serve(cancel chan bool) { // InstrumentHandler returns a handler which records HTTP requests // as they are made -func InstrumentHandler(next http.HandlerFunc, _http Http) http.HandlerFunc { +func InstrumentHandler(next http.Handler, _http Http) http.HandlerFunc { return promhttp.InstrumentHandlerCounter(_http.RequestsTotal, promhttp.InstrumentHandlerDuration(_http.RequestDurationHistogram, next)) }