Skip to content

Commit cd7e067

Browse files
Add middleware/interceptors (#41)
Taken over by @bergundy (thanks @Quinn-With-Two-Ns for kicking this off). 💥 BREAKING CHANGE 💥 The experimental `WithHandlerContext` method signature was changed to accept the associated `HandlerInfo`. Do not merge until the Temporal SDK work that depends on this is ready for review.
1 parent 036065c commit cd7e067

File tree

5 files changed

+320
-83
lines changed

5 files changed

+320
-83
lines changed

README.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,52 @@ result, _ := nexus.StartOperation(ctx, client, operation, MyInput{Field: "value"
422422
fmt.Println("got result with backlinks", result.Links)
423423
```
424424

425+
### Middleware
426+
427+
The ServiceRegistry supports middleware registration via the `Use` method. The registry's handler will invoke every
428+
registered middleware in registration order. Typical use cases for middleware include global enforcement of
429+
authorization and logging.
430+
431+
Middleware is implemented as a function that takes the current context and the next handler in the invocation chain and
432+
returns a new handler to invoke. The function can pass through the given handler or return an error to abort the
433+
execution. The registered middleware function has access to common handler information such as the current service,
434+
operation, and request headers. To get access to more specific handler method information, such as inputs and operation
435+
tokens, wrap the given handler.
436+
437+
**Example**
438+
439+
```go
440+
type loggingOperation struct {
441+
nexus.UnimplementedOperation[any, any] // All OperationHandlers must embed this.
442+
next nexus.OperationHandler[any, any]
443+
}
444+
445+
func (lo *loggingOperation) Start(ctx context.Context, input any, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) {
446+
log.Println("starting operation", ExtractHandlerInfo(ctx).Operation)
447+
return lo.next.Start(ctx, input, options)
448+
}
449+
450+
func (lo *loggingOperation) GetResult(ctx context.Context, token string, options nexus.GetOperationResultOptions) (any, error) {
451+
log.Println("getting result for operation", ExtractHandlerInfo(ctx).Operation)
452+
return lo.next.GetResult(ctx, token, options)
453+
}
454+
455+
func (lo *loggingOperation) Cancel(ctx context.Context, token string, options nexus.CancelOperationOptions) error {
456+
log.Printf("canceling operation", ExtractHandlerInfo(ctx).Operation)
457+
return lo.next.Cancel(ctx, token, options)
458+
}
459+
460+
func (lo *loggingOperation) GetInfo(ctx context.Context, token string, options nexus.GetOperationInfoOptions) (*nexus.OperationInfo, error) {
461+
log.Println("getting info for operation", ExtractHandlerInfo(ctx).Operation)
462+
return lo.next.GetInfo(ctx, token, options)
463+
}
464+
465+
registry.Use(func(ctx context.Context, next nexus.OperationHandler[any, any]) (nexus.OperationHandler[any, any], error) {
466+
// Optionally call ExtractHandlerInfo(ctx) here.
467+
return &loggingOperation{next: next}, nil
468+
})
469+
```
470+
425471
## Contributing
426472

427473
### Prerequisites

nexus/handler_context_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
)
1010

1111
func TestHandlerContext(t *testing.T) {
12-
ctx := nexus.WithHandlerContext(context.Background())
12+
ctx := nexus.WithHandlerContext(context.Background(), nexus.HandlerInfo{Operation: "test"})
1313
require.True(t, nexus.IsHandlerContext(ctx))
1414
initial := []nexus.Link{{Type: "foo"}, {Type: "bar"}}
1515
nexus.AddHandlerLinks(ctx, initial...)
@@ -18,4 +18,5 @@ func TestHandlerContext(t *testing.T) {
1818
require.Equal(t, append(initial, additional), nexus.HandlerLinks(ctx))
1919
nexus.SetHandlerLinks(ctx, initial...)
2020
require.Equal(t, initial, nexus.HandlerLinks(ctx))
21+
require.Equal(t, nexus.HandlerInfo{Operation: "test"}, nexus.ExtractHandlerInfo(ctx))
2122
}

nexus/operation.go

Lines changed: 123 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,21 @@ type RegisterableOperation interface {
6464
//
6565
// Operation implementations must embed the [UnimplementedOperation].
6666
//
67-
// All Operation methods can return a [HandlerError] to fail requests with a custom [HandlerErrorType] and structured [Failure].
68-
// Arbitrary errors from handler methods are turned into [HandlerErrorTypeInternal],their details are logged and hidden
69-
// from the caller.
67+
// See [OperationHandler] for more information.
7068
type Operation[I, O any] interface {
7169
RegisterableOperation
7270
OperationReference[I, O]
71+
OperationHandler[I, O]
72+
}
7373

74+
// OperationHandler is the interface for the core operation methods. OperationHandler implementations must embed
75+
// [UnimplementedOperation].
76+
//
77+
// All Operation methods can return a [HandlerError] to fail requests with a custom [HandlerErrorType] and structured [Failure].
78+
// Arbitrary errors from handler methods are turned into [HandlerErrorTypeInternal], when using the Nexus SDK's
79+
// HTTP handler, their details are logged and hidden from the caller. Other handler implementations may expose internal
80+
// error information to callers.
81+
type OperationHandler[I, O any] interface {
7482
// Start handles requests for starting an operation. Return [HandlerStartOperationResultSync] to respond
7583
// successfully - inline, or [HandlerStartOperationResultAsync] to indicate that an asynchronous operation was
7684
// started. Return an [OperationError] to indicate that an operation completed as failed or
@@ -101,6 +109,8 @@ type Operation[I, O any] interface {
101109
// ignored by the underlying operation implemention.
102110
// 2. idempotent - implementors should ignore duplicate cancelations for the same operation.
103111
Cancel(ctx context.Context, token string, options CancelOperationOptions) error
112+
113+
mustEmbedUnimplementedOperation()
104114
}
105115

106116
type syncOperation[I, O any] struct {
@@ -186,13 +196,26 @@ func (s *Service) Operation(name string) RegisterableOperation {
186196
return s.operations[name]
187197
}
188198

199+
// MiddlewareFunc is a function which receives an OperationHandler and returns another OperationHandler.
200+
// If the middleware wants to stop the chain before any handler is called, it can return an error.
201+
//
202+
// To get [HandlerInfo] for the current handler, call [ExtractHandlerInfo] with the given context.
203+
//
204+
// NOTE: Experimental
205+
type MiddlewareFunc func(ctx context.Context, next OperationHandler[any, any]) (OperationHandler[any, any], error)
206+
189207
// A ServiceRegistry registers services and constructs a [Handler] that dispatches operations requests to those services.
190208
type ServiceRegistry struct {
191-
services map[string]*Service
209+
services map[string]*Service
210+
middleware []MiddlewareFunc
192211
}
193212

213+
// NewServiceRegistry constructs an empty [ServiceRegistry].
194214
func NewServiceRegistry() *ServiceRegistry {
195-
return &ServiceRegistry{services: make(map[string]*Service)}
215+
return &ServiceRegistry{
216+
services: make(map[string]*Service),
217+
middleware: make([]MiddlewareFunc, 0),
218+
}
196219
}
197220

198221
// Register one or more service.
@@ -218,6 +241,15 @@ func (r *ServiceRegistry) Register(services ...*Service) error {
218241
return nil
219242
}
220243

244+
// Use registers one or more middleware to be applied to all operation method invocations across all registered
245+
// services. Middleware is applied in registration order. If called multiple times, newly registered middleware will be
246+
// applied after any previously registered ones.
247+
//
248+
// NOTE: Experimental
249+
func (s *ServiceRegistry) Use(middleware ...MiddlewareFunc) {
250+
s.middleware = append(s.middleware, middleware...)
251+
}
252+
221253
// NewHandler creates a [Handler] that dispatches requests to registered operations based on their name.
222254
func (r *ServiceRegistry) NewHandler() (Handler, error) {
223255
if len(r.services) == 0 {
@@ -229,76 +261,64 @@ func (r *ServiceRegistry) NewHandler() (Handler, error) {
229261
}
230262
}
231263

232-
return &registryHandler{services: r.services}, nil
264+
return &registryHandler{services: r.services, middlewares: r.middleware}, nil
233265
}
234266

235267
type registryHandler struct {
236268
UnimplementedHandler
237269

238-
services map[string]*Service
270+
services map[string]*Service
271+
middlewares []MiddlewareFunc
239272
}
240273

241-
// CancelOperation implements Handler.
242-
func (r *registryHandler) CancelOperation(ctx context.Context, service, operation string, token string, options CancelOperationOptions) error {
243-
s, ok := r.services[service]
274+
func (r *registryHandler) operationHandler(ctx context.Context) (OperationHandler[any, any], error) {
275+
options := ExtractHandlerInfo(ctx)
276+
s, ok := r.services[options.Service]
244277
if !ok {
245-
return HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
278+
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", options.Service)
246279
}
247-
h, ok := s.operations[operation]
280+
h, ok := s.operations[options.Operation]
248281
if !ok {
249-
return HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
282+
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", options.Operation)
250283
}
251284

252-
// NOTE: We could avoid reflection here if we put the Cancel method on RegisterableOperation but it doesn't seem
253-
// worth it since we need reflection for the generic methods.
254-
m, _ := reflect.TypeOf(h).MethodByName("Cancel")
255-
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
256-
if values[0].IsNil() {
257-
return nil
285+
var handler OperationHandler[any, any]
286+
handler = &rootOperationHandler{h: h}
287+
for i := len(r.middlewares) - 1; i >= 0; i-- {
288+
var err error
289+
handler, err = r.middlewares[i](ctx, handler)
290+
if err != nil {
291+
return nil, err
292+
}
258293
}
259-
return values[0].Interface().(error)
294+
return handler, nil
260295
}
261296

262-
// GetOperationInfo implements Handler.
263-
func (r *registryHandler) GetOperationInfo(ctx context.Context, service, operation string, token string, options GetOperationInfoOptions) (*OperationInfo, error) {
264-
s, ok := r.services[service]
265-
if !ok {
266-
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
267-
}
268-
h, ok := s.operations[operation]
269-
if !ok {
270-
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
271-
}
272-
273-
// NOTE: We could avoid reflection here if we put the Cancel method on RegisterableOperation but it doesn't seem
274-
// worth it since we need reflection for the generic methods.
275-
m, _ := reflect.TypeOf(h).MethodByName("GetInfo")
276-
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
277-
if !values[1].IsNil() {
278-
return nil, values[1].Interface().(error)
297+
// CancelOperation implements Handler.
298+
func (r *registryHandler) CancelOperation(ctx context.Context, service, operation, token string, options CancelOperationOptions) error {
299+
h, err := r.operationHandler(ctx)
300+
if err != nil {
301+
return err
279302
}
280-
ret := values[0].Interface()
281-
return ret.(*OperationInfo), nil
303+
return h.Cancel(ctx, token, options)
282304
}
283305

284-
// GetOperationResult implements Handler.
285-
func (r *registryHandler) GetOperationResult(ctx context.Context, service, operation string, token string, options GetOperationResultOptions) (any, error) {
286-
s, ok := r.services[service]
287-
if !ok {
288-
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
289-
}
290-
h, ok := s.operations[operation]
291-
if !ok {
292-
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
306+
// operationHandlerInfo implements Handler.
307+
func (r *registryHandler) GetOperationInfo(ctx context.Context, service, operation, token string, options GetOperationInfoOptions) (*OperationInfo, error) {
308+
h, err := r.operationHandler(ctx)
309+
if err != nil {
310+
return nil, err
293311
}
312+
return h.GetInfo(ctx, token, options)
313+
}
294314

295-
m, _ := reflect.TypeOf(h).MethodByName("GetResult")
296-
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
297-
if !values[1].IsNil() {
298-
return nil, values[1].Interface().(error)
315+
// operationHandlerResult implements Handler.
316+
func (r *registryHandler) GetOperationResult(ctx context.Context, service, operation, token string, options GetOperationResultOptions) (any, error) {
317+
h, err := r.operationHandler(ctx)
318+
if err != nil {
319+
return nil, err
299320
}
300-
ret := values[0].Interface()
301-
return ret, nil
321+
return h.GetResult(ctx, token, options)
302322
}
303323

304324
// StartOperation implements Handler.
@@ -307,29 +327,72 @@ func (r *registryHandler) StartOperation(ctx context.Context, service, operation
307327
if !ok {
308328
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
309329
}
310-
h, ok := s.operations[operation]
330+
ro, ok := s.operations[operation]
311331
if !ok {
312332
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
313333
}
314334

315-
m, _ := reflect.TypeOf(h).MethodByName("Start")
335+
h, err := r.operationHandler(ctx)
336+
if err != nil {
337+
return nil, err
338+
}
339+
m, _ := reflect.TypeOf(ro).MethodByName("Start")
316340
inputType := m.Type.In(2)
317341
iptr := reflect.New(inputType).Interface()
318342
if err := input.Consume(iptr); err != nil {
319343
// TODO: log the error? Do we need to accept a logger for this single line?
320344
return nil, HandlerErrorf(HandlerErrorTypeBadRequest, "invalid input")
321345
}
322-
i := reflect.ValueOf(iptr).Elem()
346+
return h.Start(ctx, reflect.ValueOf(iptr).Elem().Interface(), options)
347+
}
323348

324-
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), i, reflect.ValueOf(options)})
349+
type rootOperationHandler struct {
350+
UnimplementedOperation[any, any]
351+
h RegisterableOperation
352+
}
353+
354+
func (r *rootOperationHandler) Cancel(ctx context.Context, token string, options CancelOperationOptions) error {
355+
// NOTE: We could avoid reflection here if we put the Cancel method on RegisterableOperation but it doesn't seem
356+
// worth it since we need reflection for the generic methods.
357+
m, _ := reflect.TypeOf(r.h).MethodByName("Cancel")
358+
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
359+
if values[0].IsNil() {
360+
return nil
361+
}
362+
return values[0].Interface().(error)
363+
}
364+
365+
func (r *rootOperationHandler) GetInfo(ctx context.Context, token string, options GetOperationInfoOptions) (*OperationInfo, error) {
366+
// NOTE: We could avoid reflection here if we put the GetInfo method on RegisterableOperation but it doesn't
367+
// seem worth it since we need reflection for the generic methods.
368+
m, _ := reflect.TypeOf(r.h).MethodByName("GetInfo")
369+
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
325370
if !values[1].IsNil() {
326371
return nil, values[1].Interface().(error)
327372
}
328373
ret := values[0].Interface()
329-
return ret.(HandlerStartOperationResult[any]), nil
374+
return ret.(*OperationInfo), nil
375+
}
376+
377+
func (r *rootOperationHandler) GetResult(ctx context.Context, token string, options GetOperationResultOptions) (any, error) {
378+
m, _ := reflect.TypeOf(r.h).MethodByName("GetResult")
379+
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
380+
if !values[1].IsNil() {
381+
return nil, values[1].Interface().(error)
382+
}
383+
ret := values[0].Interface()
384+
return ret, nil
330385
}
331386

332-
var _ Handler = &registryHandler{}
387+
func (r *rootOperationHandler) Start(ctx context.Context, input any, options StartOperationOptions) (HandlerStartOperationResult[any], error) {
388+
m, _ := reflect.TypeOf(r.h).MethodByName("Start")
389+
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(input), reflect.ValueOf(options)})
390+
if !values[1].IsNil() {
391+
return nil, values[1].Interface().(error)
392+
}
393+
ret := values[0].Interface()
394+
return ret.(HandlerStartOperationResult[any]), nil
395+
}
333396

334397
// ExecuteOperation is the type safe version of [HTTPClient.ExecuteOperation].
335398
// It accepts input of type I and returns output of type O, removing the need to consume the [LazyValue] returned by the

0 commit comments

Comments
 (0)