Skip to content

Commit d455a56

Browse files
authored
Merge pull request #61 from basvanbeek/master
middleware: Improved http.Handler logic and added RequestSampler option
2 parents 8a54c36 + 8e1f1f4 commit d455a56

File tree

2 files changed

+306
-1
lines changed

2 files changed

+306
-1
lines changed

middleware/http/server.go

Lines changed: 241 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package http
22

33
import (
4+
"io"
45
"net/http"
56
"strconv"
67
"sync/atomic"
@@ -16,6 +17,7 @@ type handler struct {
1617
next http.Handler
1718
tagResponseSize bool
1819
defaultTags map[string]string
20+
requestSampler func(r *http.Request) bool
1921
}
2022

2123
// ServerOption allows Middleware to be optionally configured.
@@ -46,6 +48,14 @@ func SpanName(name string) ServerOption {
4648
}
4749
}
4850

51+
// RequestSampler allows one to set the sampling decision based on the details
52+
// found in the http.Request.
53+
func RequestSampler(sampleFunc func(r *http.Request) bool) ServerOption {
54+
return func(h *handler) {
55+
h.requestSampler = sampleFunc
56+
}
57+
}
58+
4959
// NewServerMiddleware returns a http.Handler middleware with Zipkin tracing.
5060
func NewServerMiddleware(t *zipkin.Tracer, options ...ServerOption) func(http.Handler) http.Handler {
5161
return func(next http.Handler) http.Handler {
@@ -67,6 +77,11 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
6777
// try to extract B3 Headers from upstream
6878
sc := h.tracer.Extract(b3.ExtractHTTP(r))
6979

80+
if h.requestSampler != nil && sc.Sampled == nil {
81+
sample := h.requestSampler(r)
82+
sc.Sampled = &sample
83+
}
84+
7085
remoteEndpoint, _ := zipkin.NewEndpoint("", r.RemoteAddr)
7186

7287
if len(h.name) == 0 {
@@ -114,7 +129,7 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
114129
}()
115130

116131
// call next http Handler func using our updated context.
117-
h.next.ServeHTTP(ri, r.WithContext(ctx))
132+
h.next.ServeHTTP(ri.wrap(), r.WithContext(ctx))
118133
}
119134

120135
// rwInterceptor intercepts the ResponseWriter so it can track response size
@@ -147,3 +162,228 @@ func (r *rwInterceptor) getStatusCode() int {
147162
func (r *rwInterceptor) getResponseSize() string {
148163
return strconv.FormatUint(atomic.LoadUint64(&r.size), 10)
149164
}
165+
166+
func (r *rwInterceptor) wrap() http.ResponseWriter {
167+
var (
168+
hj, i0 = r.w.(http.Hijacker)
169+
cn, i1 = r.w.(http.CloseNotifier)
170+
pu, i2 = r.w.(http.Pusher)
171+
fl, i3 = r.w.(http.Flusher)
172+
rf, i4 = r.w.(io.ReaderFrom)
173+
)
174+
175+
switch {
176+
case !i0 && !i1 && !i2 && !i3 && !i4:
177+
return struct {
178+
http.ResponseWriter
179+
}{r}
180+
case !i0 && !i1 && !i2 && !i3 && i4:
181+
return struct {
182+
http.ResponseWriter
183+
io.ReaderFrom
184+
}{r, rf}
185+
case !i0 && !i1 && !i2 && i3 && !i4:
186+
return struct {
187+
http.ResponseWriter
188+
http.Flusher
189+
}{r, fl}
190+
case !i0 && !i1 && !i2 && i3 && i4:
191+
return struct {
192+
http.ResponseWriter
193+
http.Flusher
194+
io.ReaderFrom
195+
}{r, fl, rf}
196+
case !i0 && !i1 && i2 && !i3 && !i4:
197+
return struct {
198+
http.ResponseWriter
199+
http.Pusher
200+
}{r, pu}
201+
case !i0 && !i1 && i2 && !i3 && i4:
202+
return struct {
203+
http.ResponseWriter
204+
http.Pusher
205+
io.ReaderFrom
206+
}{r, pu, rf}
207+
case !i0 && !i1 && i2 && i3 && !i4:
208+
return struct {
209+
http.ResponseWriter
210+
http.Pusher
211+
http.Flusher
212+
}{r, pu, fl}
213+
case !i0 && !i1 && i2 && i3 && i4:
214+
return struct {
215+
http.ResponseWriter
216+
http.Pusher
217+
http.Flusher
218+
io.ReaderFrom
219+
}{r, pu, fl, rf}
220+
case !i0 && i1 && !i2 && !i3 && !i4:
221+
return struct {
222+
http.ResponseWriter
223+
http.CloseNotifier
224+
}{r, cn}
225+
case !i0 && i1 && !i2 && !i3 && i4:
226+
return struct {
227+
http.ResponseWriter
228+
http.CloseNotifier
229+
io.ReaderFrom
230+
}{r, cn, rf}
231+
case !i0 && i1 && !i2 && i3 && !i4:
232+
return struct {
233+
http.ResponseWriter
234+
http.CloseNotifier
235+
http.Flusher
236+
}{r, cn, fl}
237+
case !i0 && i1 && !i2 && i3 && i4:
238+
return struct {
239+
http.ResponseWriter
240+
http.CloseNotifier
241+
http.Flusher
242+
io.ReaderFrom
243+
}{r, cn, fl, rf}
244+
case !i0 && i1 && i2 && !i3 && !i4:
245+
return struct {
246+
http.ResponseWriter
247+
http.CloseNotifier
248+
http.Pusher
249+
}{r, cn, pu}
250+
case !i0 && i1 && i2 && !i3 && i4:
251+
return struct {
252+
http.ResponseWriter
253+
http.CloseNotifier
254+
http.Pusher
255+
io.ReaderFrom
256+
}{r, cn, pu, rf}
257+
case !i0 && i1 && i2 && i3 && !i4:
258+
return struct {
259+
http.ResponseWriter
260+
http.CloseNotifier
261+
http.Pusher
262+
http.Flusher
263+
}{r, cn, pu, fl}
264+
case !i0 && i1 && i2 && i3 && i4:
265+
return struct {
266+
http.ResponseWriter
267+
http.CloseNotifier
268+
http.Pusher
269+
http.Flusher
270+
io.ReaderFrom
271+
}{r, cn, pu, fl, rf}
272+
case i0 && !i1 && !i2 && !i3 && !i4:
273+
return struct {
274+
http.ResponseWriter
275+
http.Hijacker
276+
}{r, hj}
277+
case i0 && !i1 && !i2 && !i3 && i4:
278+
return struct {
279+
http.ResponseWriter
280+
http.Hijacker
281+
io.ReaderFrom
282+
}{r, hj, rf}
283+
case i0 && !i1 && !i2 && i3 && !i4:
284+
return struct {
285+
http.ResponseWriter
286+
http.Hijacker
287+
http.Flusher
288+
}{r, hj, fl}
289+
case i0 && !i1 && !i2 && i3 && i4:
290+
return struct {
291+
http.ResponseWriter
292+
http.Hijacker
293+
http.Flusher
294+
io.ReaderFrom
295+
}{r, hj, fl, rf}
296+
case i0 && !i1 && i2 && !i3 && !i4:
297+
return struct {
298+
http.ResponseWriter
299+
http.Hijacker
300+
http.Pusher
301+
}{r, hj, pu}
302+
case i0 && !i1 && i2 && !i3 && i4:
303+
return struct {
304+
http.ResponseWriter
305+
http.Hijacker
306+
http.Pusher
307+
io.ReaderFrom
308+
}{r, hj, pu, rf}
309+
case i0 && !i1 && i2 && i3 && !i4:
310+
return struct {
311+
http.ResponseWriter
312+
http.Hijacker
313+
http.Pusher
314+
http.Flusher
315+
}{r, hj, pu, fl}
316+
case i0 && !i1 && i2 && i3 && i4:
317+
return struct {
318+
http.ResponseWriter
319+
http.Hijacker
320+
http.Pusher
321+
http.Flusher
322+
io.ReaderFrom
323+
}{r, hj, pu, fl, rf}
324+
case i0 && i1 && !i2 && !i3 && !i4:
325+
return struct {
326+
http.ResponseWriter
327+
http.Hijacker
328+
http.CloseNotifier
329+
}{r, hj, cn}
330+
case i0 && i1 && !i2 && !i3 && i4:
331+
return struct {
332+
http.ResponseWriter
333+
http.Hijacker
334+
http.CloseNotifier
335+
io.ReaderFrom
336+
}{r, hj, cn, rf}
337+
case i0 && i1 && !i2 && i3 && !i4:
338+
return struct {
339+
http.ResponseWriter
340+
http.Hijacker
341+
http.CloseNotifier
342+
http.Flusher
343+
}{r, hj, cn, fl}
344+
case i0 && i1 && !i2 && i3 && i4:
345+
return struct {
346+
http.ResponseWriter
347+
http.Hijacker
348+
http.CloseNotifier
349+
http.Flusher
350+
io.ReaderFrom
351+
}{r, hj, cn, fl, rf}
352+
case i0 && i1 && i2 && !i3 && !i4:
353+
return struct {
354+
http.ResponseWriter
355+
http.Hijacker
356+
http.CloseNotifier
357+
http.Pusher
358+
}{r, hj, cn, pu}
359+
case i0 && i1 && i2 && !i3 && i4:
360+
return struct {
361+
http.ResponseWriter
362+
http.Hijacker
363+
http.CloseNotifier
364+
http.Pusher
365+
io.ReaderFrom
366+
}{r, hj, cn, pu, rf}
367+
case i0 && i1 && i2 && i3 && !i4:
368+
return struct {
369+
http.ResponseWriter
370+
http.Hijacker
371+
http.CloseNotifier
372+
http.Pusher
373+
http.Flusher
374+
}{r, hj, cn, pu, fl}
375+
case i0 && i1 && i2 && i3 && i4:
376+
return struct {
377+
http.ResponseWriter
378+
http.Hijacker
379+
http.CloseNotifier
380+
http.Pusher
381+
http.Flusher
382+
io.ReaderFrom
383+
}{r, hj, cn, pu, fl, rf}
384+
default:
385+
return struct {
386+
http.ResponseWriter
387+
}{r}
388+
}
389+
}

middleware/http/server_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,68 @@ func TestHTTPDefaultSpanName(t *testing.T) {
139139
t.Errorf("Expected span name %s, got %s", want, have)
140140
}
141141
}
142+
143+
func TestHTTPRequestSampler(t *testing.T) {
144+
var (
145+
spanRecorder = &recorder.ReporterRecorder{}
146+
httpRecorder = httptest.NewRecorder()
147+
requestBuf = bytes.NewBufferString("incoming data")
148+
methodType = "POST"
149+
httpHandlerFunc = http.HandlerFunc(httpHandler(200, nil, bytes.NewBufferString("")))
150+
)
151+
152+
samplers := [](func(r *http.Request) bool){
153+
nil,
154+
func(r *http.Request) bool { return true },
155+
func(r *http.Request) bool { return false },
156+
}
157+
158+
for _, sampler := range samplers {
159+
tr, _ := zipkin.NewTracer(spanRecorder, zipkin.WithLocalEndpoint(lep), zipkin.WithSampler(zipkin.AlwaysSample))
160+
161+
request, err := http.NewRequest(methodType, "/test", requestBuf)
162+
if err != nil {
163+
t.Fatalf("unable to create request")
164+
}
165+
166+
handler := mw.NewServerMiddleware(tr, mw.RequestSampler(sampler))(httpHandlerFunc)
167+
168+
handler.ServeHTTP(httpRecorder, request)
169+
170+
spans := spanRecorder.Flush()
171+
172+
sampledSpans := 0
173+
if sampler == nil || sampler(request) {
174+
sampledSpans = 1
175+
}
176+
177+
if want, have := sampledSpans, len(spans); want != have {
178+
t.Errorf("Expected %d spans, got %d", want, have)
179+
}
180+
}
181+
182+
for _, sampler := range samplers {
183+
tr, _ := zipkin.NewTracer(spanRecorder, zipkin.WithLocalEndpoint(lep), zipkin.WithSampler(zipkin.NeverSample))
184+
185+
request, err := http.NewRequest(methodType, "/test", requestBuf)
186+
if err != nil {
187+
t.Fatalf("unable to create request")
188+
}
189+
190+
handler := mw.NewServerMiddleware(tr, mw.RequestSampler(sampler))(httpHandlerFunc)
191+
192+
handler.ServeHTTP(httpRecorder, request)
193+
194+
spans := spanRecorder.Flush()
195+
196+
sampledSpans := 0
197+
if sampler != nil && sampler(request) {
198+
sampledSpans = 1
199+
}
200+
201+
if want, have := sampledSpans, len(spans); want != have {
202+
t.Errorf("Expected %d spans, got %d", want, have)
203+
}
204+
}
205+
206+
}

0 commit comments

Comments
 (0)