Skip to content

Commit ef1bfcb

Browse files
Quinn-With-Two-Nsbergundy
authored andcommitted
Add middleware
1 parent 036065c commit ef1bfcb

File tree

5 files changed

+321
-83
lines changed

5 files changed

+321
-83
lines changed

README.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,52 @@ result, _ := nexus.StartOperation(ctx, client, operation, MyInput{Field: "value"
422422
fmt.Println("got result with backlinks", result.Links)
423423
```
424424

425+
### Middleware
426+
427+
The ServiceRegistry supports middleware registration via the `Use` method. The registry's handler will invoke every
428+
registered middleware in registration order. Typical use cases for middleware include global enforcement of
429+
authorization and logging.
430+
431+
Middleware is implemented as a function that takes the current context and the next handler in the invocation chain and
432+
returns a new handler to invoke. The function can pass through the given handler or return an error to abort the
433+
execution. The registered middleware function has access to common handler information such as the current service,
434+
operation and request headers. To get access to more specific handler method information, such as inputs and operation
435+
tokens, wrap the given handler.
436+
437+
**Example**
438+
439+
```go
440+
type loggingOperation struct {
441+
nexus.UnimplementedOperation[any, any] // All OperationHandlers must embed this.
442+
next nexus.OperationHandler[any, any]
443+
}
444+
445+
func (lo *loggingOperation) Start(ctx context.Context, input any, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) {
446+
log.Println("starting operation", ExtractHandlerInfo(ctx).Operation)
447+
return lo.next.Start(ctx, input, options)
448+
}
449+
450+
func (lo *loggingOperation) GetResult(ctx context.Context, token string, options nexus.GetOperationResultOptions) (any, error) {
451+
log.Println("getting result for operation", ExtractHandlerInfo(ctx).Operation)
452+
return lo.next.GetResult(ctx, token, options)
453+
}
454+
455+
func (lo *loggingOperation) Cancel(ctx context.Context, token string, options nexus.CancelOperationOptions) error {
456+
log.Printf("canceling operation", ExtractHandlerInfo(ctx).Operation)
457+
return lo.next.Cancel(ctx, token, options)
458+
}
459+
460+
func (lo *loggingOperation) GetInfo(ctx context.Context, token string, options nexus.GetOperationInfoOptions) (*nexus.OperationInfo, error) {
461+
log.Println("getting info for operation", ExtractHandlerInfo(ctx).Operation)
462+
return lo.next.GetInfo(ctx, token, options)
463+
}
464+
465+
registry.Use(func(ctx context.Context, next nexus.OperationHandler[any, any]) (nexus.OperationHandler[any, any], error) {
466+
// Optionally call ExtractHandlerInfo(ctx) here.
467+
return &loggingOperation{next: next}, nil
468+
})
469+
```
470+
425471
## Contributing
426472

427473
### Prerequisites

nexus/handler_context_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
)
1010

1111
func TestHandlerContext(t *testing.T) {
12-
ctx := nexus.WithHandlerContext(context.Background())
12+
ctx := nexus.WithHandlerContext(context.Background(), nexus.HandlerInfo{Operation: "test"})
1313
require.True(t, nexus.IsHandlerContext(ctx))
1414
initial := []nexus.Link{{Type: "foo"}, {Type: "bar"}}
1515
nexus.AddHandlerLinks(ctx, initial...)
@@ -18,4 +18,5 @@ func TestHandlerContext(t *testing.T) {
1818
require.Equal(t, append(initial, additional), nexus.HandlerLinks(ctx))
1919
nexus.SetHandlerLinks(ctx, initial...)
2020
require.Equal(t, initial, nexus.HandlerLinks(ctx))
21+
require.Equal(t, nexus.HandlerInfo{Operation: "test"}, nexus.ExtractHandlerInfo(ctx))
2122
}

nexus/operation.go

Lines changed: 124 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ import (
1515
// )}
1616
type NoValue *struct{}
1717

18+
// MiddlewareFunc is a function which receives an OperationHandler and returns another OperationHandler.
19+
// If the middleware wants to stop the chain before any handler is called, it can return an error.
20+
//
21+
// To get [HandlerInfo] for the current handler, call [ExtractHandlerInfo] with the given context.
22+
//
23+
// NOTE: Experimental
24+
type MiddlewareFunc func(ctx context.Context, next OperationHandler[any, any]) (OperationHandler[any, any], error)
25+
1826
// OperationReference provides a typed interface for invoking operations. Every [Operation] is also an
1927
// [OperationReference]. Callers may create references using [NewOperationReference] when the implementation is not
2028
// available.
@@ -64,13 +72,20 @@ type RegisterableOperation interface {
6472
//
6573
// Operation implementations must embed the [UnimplementedOperation].
6674
//
67-
// All Operation methods can return a [HandlerError] to fail requests with a custom [HandlerErrorType] and structured [Failure].
68-
// Arbitrary errors from handler methods are turned into [HandlerErrorTypeInternal],their details are logged and hidden
69-
// from the caller.
75+
// See [OperationHandler] for more information.
7076
type Operation[I, O any] interface {
7177
RegisterableOperation
7278
OperationReference[I, O]
79+
OperationHandler[I, O]
80+
}
7381

82+
// OperationHandler is the interface for the core operation methods. OperationHandler implementations must embed
83+
// [UnimplementedOperation].
84+
//
85+
// All Operation methods can return a [HandlerError] to fail requests with a custom [HandlerErrorType] and structured [Failure].
86+
// Arbitrary errors from handler methods are turned into [HandlerErrorTypeInternal],their details are logged and hidden
87+
// from the caller.
88+
type OperationHandler[I, O any] interface {
7489
// Start handles requests for starting an operation. Return [HandlerStartOperationResultSync] to respond
7590
// successfully - inline, or [HandlerStartOperationResultAsync] to indicate that an asynchronous operation was
7691
// started. Return an [OperationError] to indicate that an operation completed as failed or
@@ -101,6 +116,8 @@ type Operation[I, O any] interface {
101116
// ignored by the underlying operation implemention.
102117
// 2. idempotent - implementors should ignore duplicate cancelations for the same operation.
103118
Cancel(ctx context.Context, token string, options CancelOperationOptions) error
119+
120+
mustEmbedUnimplementedOperation()
104121
}
105122

106123
type syncOperation[I, O any] struct {
@@ -188,11 +205,15 @@ func (s *Service) Operation(name string) RegisterableOperation {
188205

189206
// A ServiceRegistry registers services and constructs a [Handler] that dispatches operations requests to those services.
190207
type ServiceRegistry struct {
191-
services map[string]*Service
208+
services map[string]*Service
209+
middleware []MiddlewareFunc
192210
}
193211

194212
func NewServiceRegistry() *ServiceRegistry {
195-
return &ServiceRegistry{services: make(map[string]*Service)}
213+
return &ServiceRegistry{
214+
services: make(map[string]*Service),
215+
middleware: make([]MiddlewareFunc, 0),
216+
}
196217
}
197218

198219
// Register one or more service.
@@ -218,6 +239,15 @@ func (r *ServiceRegistry) Register(services ...*Service) error {
218239
return nil
219240
}
220241

242+
// Use registers one or more middleware to be applied to all operation method invocations across all registered
243+
// services. Middleware is applied in registration order. If called multiple times, newly registered middleware will be
244+
// applied after any previously registered ones.
245+
//
246+
// NOTE: Experimental
247+
func (s *ServiceRegistry) Use(middleware ...MiddlewareFunc) {
248+
s.middleware = append(s.middleware, middleware...)
249+
}
250+
221251
// NewHandler creates a [Handler] that dispatches requests to registered operations based on their name.
222252
func (r *ServiceRegistry) NewHandler() (Handler, error) {
223253
if len(r.services) == 0 {
@@ -229,76 +259,67 @@ func (r *ServiceRegistry) NewHandler() (Handler, error) {
229259
}
230260
}
231261

232-
return &registryHandler{services: r.services}, nil
262+
return &registryHandler{services: r.services, middlewares: r.middleware}, nil
233263
}
234264

235265
type registryHandler struct {
236266
UnimplementedHandler
237267

238-
services map[string]*Service
268+
services map[string]*Service
269+
middlewares []MiddlewareFunc
239270
}
240271

241-
// CancelOperation implements Handler.
242-
func (r *registryHandler) CancelOperation(ctx context.Context, service, operation string, token string, options CancelOperationOptions) error {
243-
s, ok := r.services[service]
272+
func (r *registryHandler) operationHandler(ctx context.Context) (OperationHandler[any, any], error) {
273+
options := ExtractHandlerInfo(ctx)
274+
s, ok := r.services[options.Service]
244275
if !ok {
245-
return HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
276+
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", options.Service)
246277
}
247-
h, ok := s.operations[operation]
278+
h, ok := s.operations[options.Operation]
248279
if !ok {
249-
return HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
280+
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", options.Operation)
250281
}
251282

252-
// NOTE: We could avoid reflection here if we put the Cancel method on RegisterableOperation but it doesn't seem
253-
// worth it since we need reflection for the generic methods.
254-
m, _ := reflect.TypeOf(h).MethodByName("Cancel")
255-
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
256-
if values[0].IsNil() {
257-
return nil
283+
var handler OperationHandler[any, any]
284+
handler = &rootOperationHandler{h: h}
285+
// TODO: h == nil should probably be checked at registration time.
286+
if h != nil && len(r.middlewares) > 0 {
287+
for i := len(r.middlewares) - 1; i >= 0; i-- {
288+
var err error
289+
handler, err = r.middlewares[i](ctx, handler)
290+
if err != nil {
291+
return nil, err
292+
}
293+
}
258294
}
259-
return values[0].Interface().(error)
295+
return handler, nil
260296
}
261297

262-
// GetOperationInfo implements Handler.
263-
func (r *registryHandler) GetOperationInfo(ctx context.Context, service, operation string, token string, options GetOperationInfoOptions) (*OperationInfo, error) {
264-
s, ok := r.services[service]
265-
if !ok {
266-
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
267-
}
268-
h, ok := s.operations[operation]
269-
if !ok {
270-
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
271-
}
272-
273-
// NOTE: We could avoid reflection here if we put the Cancel method on RegisterableOperation but it doesn't seem
274-
// worth it since we need reflection for the generic methods.
275-
m, _ := reflect.TypeOf(h).MethodByName("GetInfo")
276-
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
277-
if !values[1].IsNil() {
278-
return nil, values[1].Interface().(error)
298+
// CancelOperation implements Handler.
299+
func (r *registryHandler) CancelOperation(ctx context.Context, service, operation, token string, options CancelOperationOptions) error {
300+
h, err := r.operationHandler(ctx)
301+
if err != nil {
302+
return err
279303
}
280-
ret := values[0].Interface()
281-
return ret.(*OperationInfo), nil
304+
return h.Cancel(ctx, token, options)
282305
}
283306

284-
// GetOperationResult implements Handler.
285-
func (r *registryHandler) GetOperationResult(ctx context.Context, service, operation string, token string, options GetOperationResultOptions) (any, error) {
286-
s, ok := r.services[service]
287-
if !ok {
288-
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
289-
}
290-
h, ok := s.operations[operation]
291-
if !ok {
292-
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
307+
// operationHandlerInfo implements Handler.
308+
func (r *registryHandler) GetOperationInfo(ctx context.Context, service, operation, token string, options GetOperationInfoOptions) (*OperationInfo, error) {
309+
h, err := r.operationHandler(ctx)
310+
if err != nil {
311+
return nil, err
293312
}
313+
return h.GetInfo(ctx, token, options)
314+
}
294315

295-
m, _ := reflect.TypeOf(h).MethodByName("GetResult")
296-
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
297-
if !values[1].IsNil() {
298-
return nil, values[1].Interface().(error)
316+
// operationHandlerResult implements Handler.
317+
func (r *registryHandler) GetOperationResult(ctx context.Context, service, operation, token string, options GetOperationResultOptions) (any, error) {
318+
h, err := r.operationHandler(ctx)
319+
if err != nil {
320+
return nil, err
299321
}
300-
ret := values[0].Interface()
301-
return ret, nil
322+
return h.GetResult(ctx, token, options)
302323
}
303324

304325
// StartOperation implements Handler.
@@ -307,29 +328,72 @@ func (r *registryHandler) StartOperation(ctx context.Context, service, operation
307328
if !ok {
308329
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
309330
}
310-
h, ok := s.operations[operation]
331+
ro, ok := s.operations[operation]
311332
if !ok {
312333
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
313334
}
314335

315-
m, _ := reflect.TypeOf(h).MethodByName("Start")
336+
h, err := r.operationHandler(ctx)
337+
if err != nil {
338+
return nil, err
339+
}
340+
m, _ := reflect.TypeOf(ro).MethodByName("Start")
316341
inputType := m.Type.In(2)
317342
iptr := reflect.New(inputType).Interface()
318343
if err := input.Consume(iptr); err != nil {
319344
// TODO: log the error? Do we need to accept a logger for this single line?
320345
return nil, HandlerErrorf(HandlerErrorTypeBadRequest, "invalid input")
321346
}
322-
i := reflect.ValueOf(iptr).Elem()
347+
return h.Start(ctx, reflect.ValueOf(iptr).Elem().Interface(), options)
348+
}
323349

324-
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), i, reflect.ValueOf(options)})
350+
type rootOperationHandler struct {
351+
UnimplementedOperation[any, any]
352+
h RegisterableOperation
353+
}
354+
355+
func (r *rootOperationHandler) Cancel(ctx context.Context, token string, options CancelOperationOptions) error {
356+
// NOTE: We could avoid reflection here if we put the Cancel method on RegisterableOperation but it doesn't seem
357+
// worth it since we need reflection for the generic methods.
358+
m, _ := reflect.TypeOf(r.h).MethodByName("Cancel")
359+
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
360+
if values[0].IsNil() {
361+
return nil
362+
}
363+
return values[0].Interface().(error)
364+
}
365+
366+
func (r *rootOperationHandler) GetInfo(ctx context.Context, token string, options GetOperationInfoOptions) (*OperationInfo, error) {
367+
// NOTE: We could avoid reflection here if we put the Cancel method on RegisterableOperation but it doesn't seem
368+
// worth it since we need reflection for the generic methods.
369+
m, _ := reflect.TypeOf(r.h).MethodByName("GetInfo")
370+
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
325371
if !values[1].IsNil() {
326372
return nil, values[1].Interface().(error)
327373
}
328374
ret := values[0].Interface()
329-
return ret.(HandlerStartOperationResult[any]), nil
375+
return ret.(*OperationInfo), nil
330376
}
331377

332-
var _ Handler = &registryHandler{}
378+
func (r *rootOperationHandler) GetResult(ctx context.Context, token string, options GetOperationResultOptions) (any, error) {
379+
m, _ := reflect.TypeOf(r.h).MethodByName("GetResult")
380+
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
381+
if !values[1].IsNil() {
382+
return nil, values[1].Interface().(error)
383+
}
384+
ret := values[0].Interface()
385+
return ret, nil
386+
}
387+
388+
func (r *rootOperationHandler) Start(ctx context.Context, input any, options StartOperationOptions) (HandlerStartOperationResult[any], error) {
389+
m, _ := reflect.TypeOf(r.h).MethodByName("Start")
390+
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(input), reflect.ValueOf(options)})
391+
if !values[1].IsNil() {
392+
return nil, values[1].Interface().(error)
393+
}
394+
ret := values[0].Interface()
395+
return ret.(HandlerStartOperationResult[any]), nil
396+
}
333397

334398
// ExecuteOperation is the type safe version of [HTTPClient.ExecuteOperation].
335399
// It accepts input of type I and returns output of type O, removing the need to consume the [LazyValue] returned by the

0 commit comments

Comments
 (0)