Skip to content

Commit c2eef54

Browse files
authored
Allow middleware output to be buffered and discarded (#521)
* Add log func to HTTP * Add log func to GRPC * Add tests * Enhance tests * Respond to peer feedback * Fix lint error
1 parent 3bbae3e commit c2eef54

File tree

4 files changed

+177
-11
lines changed

4 files changed

+177
-11
lines changed

log/grpc.go

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type (
2727
iserr func(codes.Code) bool
2828
disableCallLogging bool
2929
disableCallID bool
30+
logFunc func(ctx context.Context, keyvals ...Fielder)
3031
}
3132
)
3233

@@ -43,6 +44,12 @@ func UnaryServerInterceptor(logCtx context.Context, opts ...GRPCLogOption) grpc.
4344
for _, opt := range opts {
4445
opt(o)
4546
}
47+
48+
logFunc := Print
49+
if o.logFunc != nil {
50+
logFunc = o.logFunc
51+
}
52+
4653
return func(
4754
ctx context.Context,
4855
req interface{},
@@ -59,7 +66,7 @@ func UnaryServerInterceptor(logCtx context.Context, opts ...GRPCLogOption) grpc.
5966
then := time.Now()
6067
svcKV := KV{K: GRPCServiceKey, V: path.Dir(info.FullMethod)[1:]}
6168
methKV := KV{K: GRPCMethodKey, V: path.Base(info.FullMethod)}
62-
Print(ctx, KV{MessageKey, "start"}, svcKV, methKV)
69+
logFunc(ctx, KV{MessageKey, "start"}, svcKV, methKV)
6370

6471
res, err := handler(ctx, req)
6572

@@ -72,7 +79,7 @@ func UnaryServerInterceptor(logCtx context.Context, opts ...GRPCLogOption) grpc.
7279
Error(ctx, err, svcKV, methKV, statKV, codeKV, durKV)
7380
return res, err
7481
}
75-
Print(ctx, KV{MessageKey, "end"}, svcKV, methKV, codeKV, durKV)
82+
logFunc(ctx, KV{MessageKey, "end"}, svcKV, methKV, codeKV, durKV)
7683
return res, err
7784
}
7885
}
@@ -87,6 +94,12 @@ func StreamServerInterceptor(logCtx context.Context, opts ...GRPCLogOption) grpc
8794
for _, opt := range opts {
8895
opt(o)
8996
}
97+
98+
logFunc := Print
99+
if o.logFunc != nil {
100+
logFunc = o.logFunc
101+
}
102+
90103
return func(
91104
srv interface{},
92105
stream grpc.ServerStream,
@@ -104,7 +117,7 @@ func StreamServerInterceptor(logCtx context.Context, opts ...GRPCLogOption) grpc
104117
then := time.Now()
105118
svcKV := KV{K: GRPCServiceKey, V: path.Dir(info.FullMethod)[1:]}
106119
methKV := KV{K: GRPCMethodKey, V: path.Base(info.FullMethod)}
107-
Print(ctx, KV{MessageKey, "start"}, svcKV, methKV)
120+
logFunc(ctx, KV{MessageKey, "start"}, svcKV, methKV)
108121

109122
err := handler(srv, stream)
110123

@@ -117,7 +130,7 @@ func StreamServerInterceptor(logCtx context.Context, opts ...GRPCLogOption) grpc
117130
Error(ctx, err, svcKV, methKV, statKV, codeKV, durKV)
118131
return err
119132
}
120-
Print(ctx, KV{MessageKey, "end"}, svcKV, methKV, codeKV, durKV)
133+
logFunc(ctx, KV{MessageKey, "end"}, svcKV, methKV, codeKV, durKV)
121134
return err
122135
}
123136
}
@@ -129,6 +142,12 @@ func UnaryClientInterceptor(opts ...GRPCLogOption) grpc.UnaryClientInterceptor {
129142
for _, opt := range opts {
130143
opt(o)
131144
}
145+
146+
logFunc := Print
147+
if o.logFunc != nil {
148+
logFunc = o.logFunc
149+
}
150+
132151
return func(
133152
ctx context.Context,
134153
fullmethod string,
@@ -140,7 +159,7 @@ func UnaryClientInterceptor(opts ...GRPCLogOption) grpc.UnaryClientInterceptor {
140159
then := time.Now()
141160
svcKV := KV{K: GRPCServiceKey, V: path.Dir(fullmethod)[1:]}
142161
methKV := KV{K: GRPCMethodKey, V: path.Base(fullmethod)}
143-
Print(ctx, KV{K: MessageKey, V: "start"}, svcKV, methKV)
162+
logFunc(ctx, KV{K: MessageKey, V: "start"}, svcKV, methKV)
144163

145164
err := invoker(ctx, fullmethod, req, reply, cc, opts...)
146165

@@ -153,7 +172,7 @@ func UnaryClientInterceptor(opts ...GRPCLogOption) grpc.UnaryClientInterceptor {
153172
Error(ctx, err, svcKV, methKV, statKV, codeKV, durKV)
154173
return err
155174
}
156-
Print(ctx, KV{K: MessageKey, V: "end"}, svcKV, methKV, codeKV, durKV)
175+
logFunc(ctx, KV{K: MessageKey, V: "end"}, svcKV, methKV, codeKV, durKV)
157176
return err
158177
}
159178
}
@@ -165,6 +184,12 @@ func StreamClientInterceptor(opts ...GRPCLogOption) grpc.StreamClientInterceptor
165184
for _, opt := range opts {
166185
opt(o)
167186
}
187+
188+
logFunc := Print
189+
if o.logFunc != nil {
190+
logFunc = o.logFunc
191+
}
192+
168193
return func(
169194
ctx context.Context,
170195
desc *grpc.StreamDesc,
@@ -176,7 +201,7 @@ func StreamClientInterceptor(opts ...GRPCLogOption) grpc.StreamClientInterceptor
176201
then := time.Now()
177202
svcKV := KV{K: GRPCServiceKey, V: path.Dir(fullmethod)[1:]}
178203
methKV := KV{K: GRPCMethodKey, V: path.Base(fullmethod)}
179-
Print(ctx, KV{K: MessageKey, V: "start"}, svcKV, methKV)
204+
logFunc(ctx, KV{K: MessageKey, V: "start"}, svcKV, methKV)
180205

181206
stream, err := streamer(ctx, desc, cc, fullmethod, opts...)
182207

@@ -189,7 +214,7 @@ func StreamClientInterceptor(opts ...GRPCLogOption) grpc.StreamClientInterceptor
189214
Error(ctx, err, svcKV, methKV, statKV, codeKV, durKV)
190215
return stream, err
191216
}
192-
Print(ctx, KV{K: MessageKey, V: "end"}, svcKV, methKV, codeKV, durKV)
217+
logFunc(ctx, KV{K: MessageKey, V: "end"}, svcKV, methKV, codeKV, durKV)
193218
return stream, err
194219
}
195220
}
@@ -218,6 +243,14 @@ func WithDisableCallID() GRPCLogOption {
218243
}
219244
}
220245

246+
// WithCallLogFunc returns a HTTP middleware option that configures the logger to use
247+
// the given log function instead of log.Print() as default.
248+
func WithCallLogFunc(logFunc func(ctx context.Context, keyvals ...Fielder)) GRPCLogOption {
249+
return func(o *grpcOptions) {
250+
o.logFunc = logFunc
251+
}
252+
}
253+
221254
func defaultGRPCOptions() *grpcOptions {
222255
return &grpcOptions{
223256
iserr: func(c codes.Code) bool {

log/grpc_test.go

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,94 @@ func TestStreamClientInterceptor(t *testing.T) {
257257
}
258258
}
259259

260-
func logUnaryMethod(ctx context.Context, _ *testsvc.Fields) (*testsvc.Fields, error) {
260+
func TestWithCallLogFunc(t *testing.T) {
261+
var loggedKeyvals []Fielder
262+
customLogFunc := func(ctx context.Context, keyvals ...Fielder) {
263+
loggedKeyvals = append(loggedKeyvals, keyvals...)
264+
}
265+
266+
var buf bytes.Buffer
267+
ctx := Context(context.Background(), WithOutput(&buf), WithFormat(FormatJSON))
268+
handler := UnaryServerInterceptor(ctx, WithCallLogFunc(customLogFunc))
269+
cli, stop := testsvc.SetupGRPC(t,
270+
testsvc.WithServerOptions(grpc.UnaryInterceptor(handler)),
271+
testsvc.WithUnaryFunc(silentUnaryMethod))
272+
273+
_, err := cli.GRPCMethod(context.Background(), &testsvc.Fields{})
274+
require.NoError(t, err)
275+
276+
// Verify matching start/end messages were logged
277+
assert.Contains(t, loggedKeyvals, KV{K: MessageKey, V: "start"})
278+
assert.Contains(t, loggedKeyvals, KV{K: MessageKey, V: "end"})
279+
280+
// Verify that nothing was written to the buffer since we're using custom log func
281+
assert.Empty(t, buf.String())
282+
283+
stop()
284+
}
285+
286+
func TestWithCallLogFuncStream(t *testing.T) {
287+
var loggedKeyvals []Fielder
288+
customLogFunc := func(ctx context.Context, keyvals ...Fielder) {
289+
loggedKeyvals = append(loggedKeyvals, keyvals...)
290+
}
291+
292+
var buf bytes.Buffer
293+
ctx := Context(context.Background(), WithOutput(&buf), WithFormat(FormatJSON))
294+
handler := StreamServerInterceptor(ctx, WithCallLogFunc(customLogFunc))
295+
cli, stop := testsvc.SetupGRPC(t,
296+
testsvc.WithServerOptions(grpc.StreamInterceptor(handler)),
297+
testsvc.WithStreamFunc(dummyStreamMethod()))
298+
299+
stream, err := cli.GRPCStream(context.Background())
300+
require.NoError(t, err)
301+
err = stream.Send(&testsvc.Fields{})
302+
require.NoError(t, err)
303+
err = stream.Close()
304+
require.NoError(t, err)
305+
stop()
306+
307+
// Verify matching start/end messages were logged
308+
assert.Contains(t, loggedKeyvals, KV{K: MessageKey, V: "start"})
309+
assert.Contains(t, loggedKeyvals, KV{K: MessageKey, V: "end"})
310+
311+
// Verify that nothing was written to the buffer since we're using custom log func
312+
assert.Empty(t, buf.String())
313+
}
314+
315+
func TestWithCallLogFuncClient(t *testing.T) {
316+
var loggedKeyvals []Fielder
317+
customLogFunc := func(ctx context.Context, keyvals ...Fielder) {
318+
loggedKeyvals = append(loggedKeyvals, keyvals...)
319+
}
320+
321+
var buf bytes.Buffer
322+
ctx := Context(context.Background(), WithOutput(&buf), WithFormat(FormatJSON))
323+
324+
// Test unary client interceptor
325+
cli, stop := testsvc.SetupGRPC(t,
326+
testsvc.WithDialOptions(grpc.WithUnaryInterceptor(UnaryClientInterceptor(WithCallLogFunc(customLogFunc)))),
327+
testsvc.WithUnaryFunc(silentUnaryMethod))
328+
329+
_, err := cli.GRPCMethod(ctx, &testsvc.Fields{})
330+
require.NoError(t, err)
331+
332+
// Verify matching start/end messages were logged
333+
assert.Contains(t, loggedKeyvals, KV{K: MessageKey, V: "start"})
334+
assert.Contains(t, loggedKeyvals, KV{K: MessageKey, V: "end"})
335+
336+
// Verify that nothing was written to the buffer since we're using custom log func
337+
assert.Empty(t, buf.String())
338+
339+
stop()
340+
}
341+
342+
func logUnaryMethod(ctx context.Context, fields *testsvc.Fields) (*testsvc.Fields, error) {
261343
Print(ctx, KV{"key1", "value1"}, KV{"key2", "value2"})
344+
return silentUnaryMethod(ctx, fields)
345+
}
346+
347+
func silentUnaryMethod(_ context.Context, _ *testsvc.Fields) (*testsvc.Fields, error) {
262348
return &testsvc.Fields{}, nil
263349
}
264350

log/http.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type (
2727
pathFilters []*regexp.Regexp
2828
disableRequestLogging bool
2929
disableRequestID bool
30+
logFunc func(ctx context.Context, keyvals ...Fielder)
3031
}
3132

3233
httpClientOptions struct {
@@ -63,6 +64,12 @@ func HTTP(logCtx context.Context, opts ...HTTPLogOption) func(http.Handler) http
6364
o(&options)
6465
}
6566
}
67+
68+
logFunc := Print
69+
if options.logFunc != nil {
70+
logFunc = options.logFunc
71+
}
72+
6673
return func(h http.Handler) http.Handler {
6774
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
6875
for _, opt := range options.pathFilters {
@@ -82,7 +89,7 @@ func HTTP(logCtx context.Context, opts ...HTTPLogOption) func(http.Handler) http
8289
methKV := KV{K: HTTPMethodKey, V: req.Method}
8390
urlKV := KV{K: HTTPURLKey, V: req.URL.String()}
8491
fromKV := KV{K: HTTPFromKey, V: from(req)}
85-
Print(ctx, KV{K: MessageKey, V: "start"}, methKV, urlKV, fromKV)
92+
logFunc(ctx, KV{K: MessageKey, V: "start"}, methKV, urlKV, fromKV)
8693

8794
rw := &responseCapture{ResponseWriter: w}
8895
started := timeNow()
@@ -91,7 +98,7 @@ func HTTP(logCtx context.Context, opts ...HTTPLogOption) func(http.Handler) http
9198
statusKV := KV{K: HTTPStatusKey, V: rw.StatusCode}
9299
durKV := KV{K: HTTPDurationKey, V: timeSince(started).Milliseconds()}
93100
bytesKV := KV{K: HTTPBytesKey, V: rw.ContentLength}
94-
Print(ctx, KV{K: MessageKey, V: "end"}, methKV, urlKV, statusKV, durKV, bytesKV)
101+
logFunc(ctx, KV{K: MessageKey, V: "end"}, methKV, urlKV, statusKV, durKV, bytesKV)
95102
})
96103
}
97104
}
@@ -163,6 +170,14 @@ func WithDisableRequestID() HTTPLogOption {
163170
}
164171
}
165172

173+
// WithRequestLogFunc returns a HTTP middleware option that configures the logger to use
174+
// the given log function instead of log.Print() as default.
175+
func WithRequestLogFunc(logFunc func(ctx context.Context, keyvals ...Fielder)) HTTPLogOption {
176+
return func(o *httpLogOptions) {
177+
o.logFunc = logFunc
178+
}
179+
}
180+
166181
// RoundTrip executes the given HTTP request and logs the request and response. The
167182
// request context must be initialized with a clue logger.
168183
func (c *client) RoundTrip(req *http.Request) (resp *http.Response, err error) {

log/http_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,38 @@ func TestWithPathFilter(t *testing.T) {
173173
assert.Empty(t, buf.String())
174174
}
175175

176+
func TestWithRequestLogFunc(t *testing.T) {
177+
now := timeNow
178+
timeNow = func() time.Time { return time.Date(2022, time.January, 9, 20, 29, 45, 0, time.UTC) }
179+
defer func() { timeNow = now }()
180+
timeSince = func(_ time.Time) time.Duration { return 42 * time.Millisecond }
181+
defer func() { timeSince = time.Since }()
182+
shortID = func() string { return "test-request-id" }
183+
defer func() { shortID = randShortID }()
184+
185+
var loggedKeyvals []Fielder
186+
customLogFunc := func(ctx context.Context, keyvals ...Fielder) {
187+
loggedKeyvals = append(loggedKeyvals, keyvals...)
188+
}
189+
190+
var handler http.Handler = http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
191+
})
192+
var buf bytes.Buffer
193+
ctx := Context(context.Background(), WithOutput(&buf), WithFormat(FormatJSON))
194+
195+
handler = HTTP(ctx, WithRequestLogFunc(customLogFunc))(handler)
196+
197+
req, _ := http.NewRequest("GET", "http://example.com", nil)
198+
handler.ServeHTTP(nil, req)
199+
200+
// Verify matching start/end
201+
assert.Contains(t, loggedKeyvals, KV{K: MessageKey, V: "start"})
202+
assert.Contains(t, loggedKeyvals, KV{K: MessageKey, V: "end"})
203+
204+
// Verify that nothing was written to the buffer since we're using custom log func
205+
assert.Empty(t, buf.String())
206+
}
207+
176208
type errorClient struct {
177209
err error
178210
}

0 commit comments

Comments
 (0)