@@ -139,3 +139,68 @@ func TestHTTPDefaultSpanName(t *testing.T) {
139139 t .Errorf ("Expected span name %s, got %s" , want , have )
140140 }
141141}
142+
143+ func TestHTTPRequestSampler (t * testing.T ) {
144+ var (
145+ spanRecorder = & recorder.ReporterRecorder {}
146+ httpRecorder = httptest .NewRecorder ()
147+ requestBuf = bytes .NewBufferString ("incoming data" )
148+ methodType = "POST"
149+ httpHandlerFunc = http .HandlerFunc (httpHandler (200 , nil , bytes .NewBufferString ("" )))
150+ )
151+
152+ samplers := [](func (r * http.Request ) bool ){
153+ nil ,
154+ func (r * http.Request ) bool { return true },
155+ func (r * http.Request ) bool { return false },
156+ }
157+
158+ for _ , sampler := range samplers {
159+ tr , _ := zipkin .NewTracer (spanRecorder , zipkin .WithLocalEndpoint (lep ), zipkin .WithSampler (zipkin .AlwaysSample ))
160+
161+ request , err := http .NewRequest (methodType , "/test" , requestBuf )
162+ if err != nil {
163+ t .Fatalf ("unable to create request" )
164+ }
165+
166+ handler := mw .NewServerMiddleware (tr , mw .RequestSampler (sampler ))(httpHandlerFunc )
167+
168+ handler .ServeHTTP (httpRecorder , request )
169+
170+ spans := spanRecorder .Flush ()
171+
172+ sampledSpans := 0
173+ if sampler == nil || sampler (request ) {
174+ sampledSpans = 1
175+ }
176+
177+ if want , have := sampledSpans , len (spans ); want != have {
178+ t .Errorf ("Expected %d spans, got %d" , want , have )
179+ }
180+ }
181+
182+ for _ , sampler := range samplers {
183+ tr , _ := zipkin .NewTracer (spanRecorder , zipkin .WithLocalEndpoint (lep ), zipkin .WithSampler (zipkin .NeverSample ))
184+
185+ request , err := http .NewRequest (methodType , "/test" , requestBuf )
186+ if err != nil {
187+ t .Fatalf ("unable to create request" )
188+ }
189+
190+ handler := mw .NewServerMiddleware (tr , mw .RequestSampler (sampler ))(httpHandlerFunc )
191+
192+ handler .ServeHTTP (httpRecorder , request )
193+
194+ spans := spanRecorder .Flush ()
195+
196+ sampledSpans := 0
197+ if sampler != nil && sampler (request ) {
198+ sampledSpans = 1
199+ }
200+
201+ if want , have := sampledSpans , len (spans ); want != have {
202+ t .Errorf ("Expected %d spans, got %d" , want , have )
203+ }
204+ }
205+
206+ }
0 commit comments