Skip to content

Add middleware/interceptors #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,52 @@ result, _ := nexus.StartOperation(ctx, client, operation, MyInput{Field: "value"
fmt.Println("got result with backlinks", result.Links)
```

### Middleware

The ServiceRegistry supports middleware registration via the `Use` method. The registry's handler will invoke every
registered middleware in registration order. Typical use cases for middleware include global enforcement of
authorization and logging.

Middleware is implemented as a function that takes the current context and the next handler in the invocation chain and
returns a new handler to invoke. The function can pass through the given handler or return an error to abort the
execution. The registered middleware function has access to common handler information such as the current service,
operation, and request headers. To get access to more specific handler method information, such as inputs and operation
tokens, wrap the given handler.

**Example**

```go
type loggingOperation struct {
nexus.UnimplementedOperation[any, any] // All OperationHandlers must embed this.
next nexus.OperationHandler[any, any]
}

func (lo *loggingOperation) Start(ctx context.Context, input any, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) {
log.Println("starting operation", ExtractHandlerInfo(ctx).Operation)
return lo.next.Start(ctx, input, options)
}

func (lo *loggingOperation) GetResult(ctx context.Context, token string, options nexus.GetOperationResultOptions) (any, error) {
log.Println("getting result for operation", ExtractHandlerInfo(ctx).Operation)
return lo.next.GetResult(ctx, token, options)
}

func (lo *loggingOperation) Cancel(ctx context.Context, token string, options nexus.CancelOperationOptions) error {
log.Printf("canceling operation", ExtractHandlerInfo(ctx).Operation)
return lo.next.Cancel(ctx, token, options)
}

func (lo *loggingOperation) GetInfo(ctx context.Context, token string, options nexus.GetOperationInfoOptions) (*nexus.OperationInfo, error) {
log.Println("getting info for operation", ExtractHandlerInfo(ctx).Operation)
return lo.next.GetInfo(ctx, token, options)
}

registry.Use(func(ctx context.Context, next nexus.OperationHandler[any, any]) (nexus.OperationHandler[any, any], error) {
// Optionally call ExtractHandlerInfo(ctx) here.
return &loggingOperation{next: next}, nil
})
```

## Contributing

### Prerequisites
Expand Down
3 changes: 2 additions & 1 deletion nexus/handler_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func TestHandlerContext(t *testing.T) {
ctx := nexus.WithHandlerContext(context.Background())
ctx := nexus.WithHandlerContext(context.Background(), nexus.HandlerInfo{Operation: "test"})
require.True(t, nexus.IsHandlerContext(ctx))
initial := []nexus.Link{{Type: "foo"}, {Type: "bar"}}
nexus.AddHandlerLinks(ctx, initial...)
Expand All @@ -18,4 +18,5 @@ func TestHandlerContext(t *testing.T) {
require.Equal(t, append(initial, additional), nexus.HandlerLinks(ctx))
nexus.SetHandlerLinks(ctx, initial...)
require.Equal(t, initial, nexus.HandlerLinks(ctx))
require.Equal(t, nexus.HandlerInfo{Operation: "test"}, nexus.ExtractHandlerInfo(ctx))
}
183 changes: 123 additions & 60 deletions nexus/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,21 @@ type RegisterableOperation interface {
//
// Operation implementations must embed the [UnimplementedOperation].
//
// All Operation methods can return a [HandlerError] to fail requests with a custom [HandlerErrorType] and structured [Failure].
// Arbitrary errors from handler methods are turned into [HandlerErrorTypeInternal],their details are logged and hidden
// from the caller.
// See [OperationHandler] for more information.
type Operation[I, O any] interface {
RegisterableOperation
OperationReference[I, O]
OperationHandler[I, O]
}

// OperationHandler is the interface for the core operation methods. OperationHandler implementations must embed
// [UnimplementedOperation].
//
// All Operation methods can return a [HandlerError] to fail requests with a custom [HandlerErrorType] and structured [Failure].
// Arbitrary errors from handler methods are turned into [HandlerErrorTypeInternal], when using the Nexus SDK's
// HTTP handler, their details are logged and hidden from the caller. Other handler implementations may expose internal
// error information to callers.
type OperationHandler[I, O any] interface {
// Start handles requests for starting an operation. Return [HandlerStartOperationResultSync] to respond
// successfully - inline, or [HandlerStartOperationResultAsync] to indicate that an asynchronous operation was
// started. Return an [OperationError] to indicate that an operation completed as failed or
Expand Down Expand Up @@ -101,6 +109,8 @@ type Operation[I, O any] interface {
// ignored by the underlying operation implemention.
// 2. idempotent - implementors should ignore duplicate cancelations for the same operation.
Cancel(ctx context.Context, token string, options CancelOperationOptions) error

mustEmbedUnimplementedOperation()
}

type syncOperation[I, O any] struct {
Expand Down Expand Up @@ -186,13 +196,26 @@ func (s *Service) Operation(name string) RegisterableOperation {
return s.operations[name]
}

// MiddlewareFunc is a function which receives an OperationHandler and returns another OperationHandler.
// If the middleware wants to stop the chain before any handler is called, it can return an error.
//
// To get [HandlerInfo] for the current handler, call [ExtractHandlerInfo] with the given context.
//
// NOTE: Experimental
type MiddlewareFunc func(ctx context.Context, next OperationHandler[any, any]) (OperationHandler[any, any], error)

// A ServiceRegistry registers services and constructs a [Handler] that dispatches operations requests to those services.
type ServiceRegistry struct {
services map[string]*Service
services map[string]*Service
middleware []MiddlewareFunc
}

// NewServiceRegistry constructs an empty [ServiceRegistry].
func NewServiceRegistry() *ServiceRegistry {
return &ServiceRegistry{services: make(map[string]*Service)}
return &ServiceRegistry{
services: make(map[string]*Service),
middleware: make([]MiddlewareFunc, 0),
}
}

// Register one or more service.
Expand All @@ -218,6 +241,15 @@ func (r *ServiceRegistry) Register(services ...*Service) error {
return nil
}

// Use registers one or more middleware to be applied to all operation method invocations across all registered
// services. Middleware is applied in registration order. If called multiple times, newly registered middleware will be
// applied after any previously registered ones.
//
// NOTE: Experimental
func (s *ServiceRegistry) Use(middleware ...MiddlewareFunc) {
s.middleware = append(s.middleware, middleware...)
}

// NewHandler creates a [Handler] that dispatches requests to registered operations based on their name.
func (r *ServiceRegistry) NewHandler() (Handler, error) {
if len(r.services) == 0 {
Expand All @@ -229,76 +261,64 @@ func (r *ServiceRegistry) NewHandler() (Handler, error) {
}
}

return &registryHandler{services: r.services}, nil
return &registryHandler{services: r.services, middlewares: r.middleware}, nil
}

type registryHandler struct {
UnimplementedHandler

services map[string]*Service
services map[string]*Service
middlewares []MiddlewareFunc
}

// CancelOperation implements Handler.
func (r *registryHandler) CancelOperation(ctx context.Context, service, operation string, token string, options CancelOperationOptions) error {
s, ok := r.services[service]
func (r *registryHandler) operationHandler(ctx context.Context) (OperationHandler[any, any], error) {
options := ExtractHandlerInfo(ctx)
s, ok := r.services[options.Service]
if !ok {
return HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", options.Service)
}
h, ok := s.operations[operation]
h, ok := s.operations[options.Operation]
if !ok {
return HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", options.Operation)
}

// NOTE: We could avoid reflection here if we put the Cancel method on RegisterableOperation but it doesn't seem
// worth it since we need reflection for the generic methods.
m, _ := reflect.TypeOf(h).MethodByName("Cancel")
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
if values[0].IsNil() {
return nil
var handler OperationHandler[any, any]
handler = &rootOperationHandler{h: h}
for i := len(r.middlewares) - 1; i >= 0; i-- {
var err error
handler, err = r.middlewares[i](ctx, handler)
if err != nil {
return nil, err
}
}
return values[0].Interface().(error)
return handler, nil
}

// GetOperationInfo implements Handler.
func (r *registryHandler) GetOperationInfo(ctx context.Context, service, operation string, token string, options GetOperationInfoOptions) (*OperationInfo, error) {
s, ok := r.services[service]
if !ok {
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
}
h, ok := s.operations[operation]
if !ok {
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
}

// NOTE: We could avoid reflection here if we put the Cancel method on RegisterableOperation but it doesn't seem
// worth it since we need reflection for the generic methods.
m, _ := reflect.TypeOf(h).MethodByName("GetInfo")
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
if !values[1].IsNil() {
return nil, values[1].Interface().(error)
// CancelOperation implements Handler.
func (r *registryHandler) CancelOperation(ctx context.Context, service, operation, token string, options CancelOperationOptions) error {
h, err := r.operationHandler(ctx)
if err != nil {
return err
}
ret := values[0].Interface()
return ret.(*OperationInfo), nil
return h.Cancel(ctx, token, options)
}

// GetOperationResult implements Handler.
func (r *registryHandler) GetOperationResult(ctx context.Context, service, operation string, token string, options GetOperationResultOptions) (any, error) {
s, ok := r.services[service]
if !ok {
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
}
h, ok := s.operations[operation]
if !ok {
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
// operationHandlerInfo implements Handler.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: fix the docstring?

func (r *registryHandler) GetOperationInfo(ctx context.Context, service, operation, token string, options GetOperationInfoOptions) (*OperationInfo, error) {
h, err := r.operationHandler(ctx)
if err != nil {
return nil, err
}
return h.GetInfo(ctx, token, options)
}

m, _ := reflect.TypeOf(h).MethodByName("GetResult")
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
if !values[1].IsNil() {
return nil, values[1].Interface().(error)
// operationHandlerResult implements Handler.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here.

func (r *registryHandler) GetOperationResult(ctx context.Context, service, operation, token string, options GetOperationResultOptions) (any, error) {
h, err := r.operationHandler(ctx)
if err != nil {
return nil, err
}
ret := values[0].Interface()
return ret, nil
return h.GetResult(ctx, token, options)
}

// StartOperation implements Handler.
Expand All @@ -307,29 +327,72 @@ func (r *registryHandler) StartOperation(ctx context.Context, service, operation
if !ok {
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
}
h, ok := s.operations[operation]
ro, ok := s.operations[operation]
if !ok {
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
}

m, _ := reflect.TypeOf(h).MethodByName("Start")
h, err := r.operationHandler(ctx)
if err != nil {
return nil, err
}
m, _ := reflect.TypeOf(ro).MethodByName("Start")
inputType := m.Type.In(2)
iptr := reflect.New(inputType).Interface()
if err := input.Consume(iptr); err != nil {
// TODO: log the error? Do we need to accept a logger for this single line?
return nil, HandlerErrorf(HandlerErrorTypeBadRequest, "invalid input")
}
i := reflect.ValueOf(iptr).Elem()
return h.Start(ctx, reflect.ValueOf(iptr).Elem().Interface(), options)
}

values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), i, reflect.ValueOf(options)})
type rootOperationHandler struct {
UnimplementedOperation[any, any]
h RegisterableOperation
}

func (r *rootOperationHandler) Cancel(ctx context.Context, token string, options CancelOperationOptions) error {
// NOTE: We could avoid reflection here if we put the Cancel method on RegisterableOperation but it doesn't seem
// worth it since we need reflection for the generic methods.
m, _ := reflect.TypeOf(r.h).MethodByName("Cancel")
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
if values[0].IsNil() {
return nil
}
return values[0].Interface().(error)
}

func (r *rootOperationHandler) GetInfo(ctx context.Context, token string, options GetOperationInfoOptions) (*OperationInfo, error) {
// NOTE: We could avoid reflection here if we put the GetInfo method on RegisterableOperation but it doesn't
// seem worth it since we need reflection for the generic methods.
m, _ := reflect.TypeOf(r.h).MethodByName("GetInfo")
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
if !values[1].IsNil() {
return nil, values[1].Interface().(error)
}
ret := values[0].Interface()
return ret.(HandlerStartOperationResult[any]), nil
return ret.(*OperationInfo), nil
}

func (r *rootOperationHandler) GetResult(ctx context.Context, token string, options GetOperationResultOptions) (any, error) {
m, _ := reflect.TypeOf(r.h).MethodByName("GetResult")
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
if !values[1].IsNil() {
return nil, values[1].Interface().(error)
}
ret := values[0].Interface()
return ret, nil
}

var _ Handler = &registryHandler{}
func (r *rootOperationHandler) Start(ctx context.Context, input any, options StartOperationOptions) (HandlerStartOperationResult[any], error) {
m, _ := reflect.TypeOf(r.h).MethodByName("Start")
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(input), reflect.ValueOf(options)})
if !values[1].IsNil() {
return nil, values[1].Interface().(error)
}
ret := values[0].Interface()
return ret.(HandlerStartOperationResult[any]), nil
}

// ExecuteOperation is the type safe version of [HTTPClient.ExecuteOperation].
// It accepts input of type I and returns output of type O, removing the need to consume the [LazyValue] returned by the
Expand Down
Loading
Loading