Skip to content

Commit 73e123e

Browse files
Add middleware
1 parent 00ebe23 commit 73e123e

File tree

2 files changed

+229
-47
lines changed

2 files changed

+229
-47
lines changed

nexus/operation.go

Lines changed: 149 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,25 @@ import (
1515
// )}
1616
type NoValue *struct{}
1717

18+
// OperationOptions contains the general options for an operation, across different handlers.
19+
type OperationOptions struct {
20+
// ServiceName is the name of the service that contains the operation.
21+
ServiceName string
22+
// OperationName is the name of the operation.
23+
OperationName string
24+
// Header contains the request header fields either received by the server or to be sent by the client.
25+
//
26+
// Header will always be non empty in server methods and can be optionally set in the client API.
27+
//
28+
// Header values set here will overwrite any SDK-provided values for the same key.
29+
//
30+
// Header keys with the "content-" prefix are reserved for [Serializer] headers and should not be set in the
31+
// client API; they are not available to server [Handler] and [Operation] implementations.
32+
Header Header
33+
}
34+
35+
type MiddlewareFunc func(OperationOptions, Operation[any, any]) (Operation[any, any], error)
36+
1837
// OperationReference provides a typed interface for invoking operations. Every [Operation] is also an
1938
// [OperationReference]. Callers may create references using [NewOperationReference] when the implementation is not
2039
// available.
@@ -184,11 +203,15 @@ func (s *Service) Operation(name string) RegisterableOperation {
184203

185204
// A ServiceRegistry registers services and constructs a [Handler] that dispatches operations requests to those services.
186205
type ServiceRegistry struct {
187-
services map[string]*Service
206+
services map[string]*Service
207+
middleware []MiddlewareFunc
188208
}
189209

190210
func NewServiceRegistry() *ServiceRegistry {
191-
return &ServiceRegistry{services: make(map[string]*Service)}
211+
return &ServiceRegistry{
212+
services: make(map[string]*Service),
213+
middleware: make([]MiddlewareFunc, 0),
214+
}
192215
}
193216

194217
// Register one or more service.
@@ -214,6 +237,11 @@ func (r *ServiceRegistry) Register(services ...*Service) error {
214237
return nil
215238
}
216239

240+
// Use registers middleware to be applied to all operations. Middleware is called for each operation invocation.
241+
func (s *ServiceRegistry) Use(middleWare ...MiddlewareFunc) {
242+
s.middleware = append(s.middleware, middleWare...)
243+
}
244+
217245
// NewHandler creates a [Handler] that dispatches requests to registered operations based on their name.
218246
func (r *ServiceRegistry) NewHandler() (Handler, error) {
219247
if len(r.services) == 0 {
@@ -225,76 +253,148 @@ func (r *ServiceRegistry) NewHandler() (Handler, error) {
225253
}
226254
}
227255

228-
return &registryHandler{services: r.services}, nil
256+
return &registryHandler{services: r.services, middlewares: r.middleware}, nil
229257
}
230258

231259
type registryHandler struct {
232260
UnimplementedHandler
233261

234-
services map[string]*Service
262+
services map[string]*Service
263+
middlewares []MiddlewareFunc
235264
}
236265

237-
// CancelOperation implements Handler.
238-
func (r *registryHandler) CancelOperation(ctx context.Context, service, operation string, token string, options CancelOperationOptions) error {
239-
s, ok := r.services[service]
266+
func (r *registryHandler) getOperation(options OperationOptions) (Operation[any, any], error) {
267+
s, ok := r.services[options.ServiceName]
240268
if !ok {
241-
return HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
269+
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", options.ServiceName)
242270
}
243-
h, ok := s.operations[operation]
271+
h, ok := s.operations[options.OperationName]
244272
if !ok {
245-
return HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
273+
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", options.OperationName)
246274
}
247275

276+
var handler Operation[any, any]
277+
handler = &rootOperationHandler{h: h}
278+
if h != nil && len(r.middlewares) > 0 {
279+
for i := len(r.middlewares) - 1; i >= 0; i-- {
280+
var err error
281+
handler, err = r.middlewares[i](options, handler)
282+
if err != nil {
283+
return nil, err
284+
}
285+
}
286+
}
287+
return handler, nil
288+
}
289+
290+
type rootOperationHandler struct {
291+
h RegisterableOperation
292+
}
293+
294+
func (r *rootOperationHandler) Cancel(ctx context.Context, token string, options CancelOperationOptions) error {
248295
// NOTE: We could avoid reflection here if we put the Cancel method on RegisterableOperation but it doesn't seem
249296
// worth it since we need reflection for the generic methods.
250-
m, _ := reflect.TypeOf(h).MethodByName("Cancel")
251-
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
297+
m, _ := reflect.TypeOf(r.h).MethodByName("Cancel")
298+
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
252299
if values[0].IsNil() {
253300
return nil
254301
}
255302
return values[0].Interface().(error)
256303
}
257304

258-
// GetOperationInfo implements Handler.
259-
func (r *registryHandler) GetOperationInfo(ctx context.Context, service, operation string, token string, options GetOperationInfoOptions) (*OperationInfo, error) {
260-
s, ok := r.services[service]
261-
if !ok {
262-
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
263-
}
264-
h, ok := s.operations[operation]
265-
if !ok {
266-
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
267-
}
268-
305+
func (r *rootOperationHandler) GetInfo(ctx context.Context, token string, options GetOperationInfoOptions) (*OperationInfo, error) {
269306
// NOTE: We could avoid reflection here if we put the Cancel method on RegisterableOperation but it doesn't seem
270307
// worth it since we need reflection for the generic methods.
271-
m, _ := reflect.TypeOf(h).MethodByName("GetInfo")
272-
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
308+
m, _ := reflect.TypeOf(r.h).MethodByName("GetInfo")
309+
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
273310
if !values[1].IsNil() {
274311
return nil, values[1].Interface().(error)
275312
}
276313
ret := values[0].Interface()
277314
return ret.(*OperationInfo), nil
278315
}
279316

280-
// GetOperationResult implements Handler.
281-
func (r *registryHandler) GetOperationResult(ctx context.Context, service, operation string, token string, options GetOperationResultOptions) (any, error) {
282-
s, ok := r.services[service]
283-
if !ok {
284-
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
285-
}
286-
h, ok := s.operations[operation]
287-
if !ok {
288-
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
317+
func (r *rootOperationHandler) GetResult(ctx context.Context, token string, options GetOperationResultOptions) (interface{}, error) {
318+
m, _ := reflect.TypeOf(r.h).MethodByName("GetResult")
319+
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
320+
if !values[1].IsNil() {
321+
return nil, values[1].Interface().(error)
289322
}
323+
ret := values[0].Interface()
324+
return ret, nil
325+
}
290326

291-
m, _ := reflect.TypeOf(h).MethodByName("GetResult")
292-
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), reflect.ValueOf(token), reflect.ValueOf(options)})
327+
func (r *rootOperationHandler) Start(ctx context.Context, input interface{}, options StartOperationOptions) (HandlerStartOperationResult[interface{}], error) {
328+
m, _ := reflect.TypeOf(r.h).MethodByName("Start")
329+
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h), reflect.ValueOf(ctx), reflect.ValueOf(input), reflect.ValueOf(options)})
293330
if !values[1].IsNil() {
294331
return nil, values[1].Interface().(error)
295332
}
296333
ret := values[0].Interface()
297-
return ret, nil
334+
return ret.(HandlerStartOperationResult[any]), nil
335+
}
336+
337+
func (r *rootOperationHandler) InputType() reflect.Type {
338+
m, _ := reflect.TypeOf(r.h).MethodByName("InputType")
339+
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h)})
340+
ret := values[0].Interface()
341+
return ret.(reflect.Type)
342+
}
343+
344+
func (r *rootOperationHandler) OutputType() reflect.Type {
345+
m, _ := reflect.TypeOf(r.h).MethodByName("OutputType")
346+
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h)})
347+
ret := values[0].Interface()
348+
return ret.(reflect.Type)
349+
}
350+
351+
func (r *rootOperationHandler) inferType(input, output any) {}
352+
353+
func (r *rootOperationHandler) Name() string {
354+
return r.h.Name()
355+
}
356+
357+
func (r *rootOperationHandler) mustEmbedUnimplementedOperation() {}
358+
359+
var _ Operation[any, any] = &rootOperationHandler{}
360+
361+
// CancelOperation implements Handler.
362+
func (r *registryHandler) CancelOperation(ctx context.Context, service, operation string, operationID string, options CancelOperationOptions) error {
363+
h, err := r.getOperation(OperationOptions{
364+
ServiceName: service,
365+
OperationName: operation,
366+
Header: options.Header,
367+
})
368+
if err != nil {
369+
return err
370+
}
371+
return h.Cancel(ctx, operationID, options)
372+
}
373+
374+
// GetOperationInfo implements Handler.
375+
func (r *registryHandler) GetOperationInfo(ctx context.Context, service, operation string, operationID string, options GetOperationInfoOptions) (*OperationInfo, error) {
376+
h, err := r.getOperation(OperationOptions{
377+
ServiceName: service,
378+
OperationName: operation,
379+
Header: options.Header,
380+
})
381+
if err != nil {
382+
return nil, err
383+
}
384+
return h.GetInfo(ctx, operationID, options)
385+
}
386+
387+
// GetOperationResult implements Handler.
388+
func (r *registryHandler) GetOperationResult(ctx context.Context, service, operation string, operationID string, options GetOperationResultOptions) (any, error) {
389+
h, err := r.getOperation(OperationOptions{
390+
ServiceName: service,
391+
OperationName: operation,
392+
Header: options.Header,
393+
})
394+
if err != nil {
395+
return nil, err
396+
}
397+
return h.GetResult(ctx, operationID, options)
298398
}
299399

300400
// StartOperation implements Handler.
@@ -303,26 +403,28 @@ func (r *registryHandler) StartOperation(ctx context.Context, service, operation
303403
if !ok {
304404
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", service)
305405
}
306-
h, ok := s.operations[operation]
406+
ro, ok := s.operations[operation]
307407
if !ok {
308408
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", operation)
309409
}
310410

311-
m, _ := reflect.TypeOf(h).MethodByName("Start")
411+
h, err := r.getOperation(OperationOptions{
412+
ServiceName: service,
413+
OperationName: operation,
414+
Header: options.Header,
415+
})
416+
if err != nil {
417+
return nil, err
418+
}
419+
420+
m, _ := reflect.TypeOf(ro).MethodByName("Start")
312421
inputType := m.Type.In(2)
313422
iptr := reflect.New(inputType).Interface()
314423
if err := input.Consume(iptr); err != nil {
315424
// TODO: log the error? Do we need to accept a logger for this single line?
316425
return nil, HandlerErrorf(HandlerErrorTypeBadRequest, "invalid input")
317426
}
318-
i := reflect.ValueOf(iptr).Elem()
319-
320-
values := m.Func.Call([]reflect.Value{reflect.ValueOf(h), reflect.ValueOf(ctx), i, reflect.ValueOf(options)})
321-
if !values[1].IsNil() {
322-
return nil, values[1].Interface().(error)
323-
}
324-
ret := values[0].Interface()
325-
return ret.(HandlerStartOperationResult[any]), nil
427+
return h.Start(ctx, reflect.ValueOf(iptr).Elem().Interface(), options)
326428
}
327429

328430
var _ Handler = &registryHandler{}

nexus/operation_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,83 @@ func TestInputOutputType(t *testing.T) {
267267
require.True(t, reflect.TypeOf(3).AssignableTo(numberValidatorOperation.OutputType()))
268268
require.False(t, reflect.TypeOf("s").AssignableTo(numberValidatorOperation.OutputType()))
269269
}
270+
271+
func TestOperationInterceptor(t *testing.T) {
272+
registry := NewServiceRegistry()
273+
svc := NewService(testService)
274+
require.NoError(t, svc.Register(
275+
asyncNumberValidatorOperationInstance,
276+
))
277+
278+
var logger []string
279+
// Register the logging middleware after the auth middleware to ensure the auth middleware is called first.
280+
// any middleware that returns an error will prevent the operation from being called.
281+
registry.Use(newAuthMiddleware("auth-key"), newLoggingMiddleware(func(log string) {
282+
logger = append(logger, log)
283+
}))
284+
require.NoError(t, registry.Register(svc))
285+
286+
handler, err := registry.NewHandler()
287+
require.NoError(t, err)
288+
289+
ctx, client, teardown := setup(t, handler)
290+
defer teardown()
291+
292+
_, err = StartOperation(ctx, client, asyncNumberValidatorOperationInstance, 3, StartOperationOptions{})
293+
require.ErrorContains(t, err, "unauthorized")
294+
295+
authHeader := map[string]string{"authorization": "auth-key"}
296+
result, err := StartOperation(ctx, client, asyncNumberValidatorOperationInstance, 3, StartOperationOptions{
297+
Header: authHeader,
298+
})
299+
require.NoError(t, err)
300+
require.ErrorContains(t, result.Pending.Cancel(ctx, CancelOperationOptions{}), "unauthorized")
301+
require.NoError(t, result.Pending.Cancel(ctx, CancelOperationOptions{Header: authHeader}))
302+
// Assert the logger only contains calls from successful operations.
303+
require.Len(t, logger, 2)
304+
require.Contains(t, logger[0], "starting operation async-number-validator")
305+
require.Contains(t, logger[1], "cancel operation async-number-validator")
306+
}
307+
308+
func newAuthMiddleware(authKey string) MiddlewareFunc {
309+
return func(oo OperationOptions, uo Operation[any, any]) (Operation[any, any], error) {
310+
if oo.Header.Get("authorization") != authKey {
311+
return nil, HandlerErrorf(HandlerErrorTypeUnauthorized, "unauthorized")
312+
}
313+
return uo, nil
314+
}
315+
}
316+
317+
type loggingOperation struct {
318+
Operation[any, any]
319+
output func(string)
320+
}
321+
322+
func (lo *loggingOperation) Start(ctx context.Context, input any, options StartOperationOptions) (HandlerStartOperationResult[any], error) {
323+
lo.output(fmt.Sprintf("starting operation %s", lo.Operation.Name()))
324+
return lo.Operation.Start(ctx, input, options)
325+
}
326+
327+
func (lo *loggingOperation) GetResult(ctx context.Context, id string, options GetOperationResultOptions) (any, error) {
328+
lo.output(fmt.Sprintf("getting result for operation %s", lo.Operation.Name()))
329+
return lo.Operation.GetResult(ctx, id, options)
330+
}
331+
332+
func (lo *loggingOperation) Cancel(ctx context.Context, id string, options CancelOperationOptions) error {
333+
lo.output(fmt.Sprintf("cancel operation %s", lo.Operation.Name()))
334+
return lo.Operation.Cancel(ctx, id, options)
335+
}
336+
337+
func (lo *loggingOperation) GetInfo(ctx context.Context, id string, options GetOperationInfoOptions) (*OperationInfo, error) {
338+
lo.output(fmt.Sprintf("getting info for operation %s", lo.Operation.Name()))
339+
return lo.Operation.GetInfo(ctx, id, options)
340+
}
341+
342+
func newLoggingMiddleware(output func(string)) MiddlewareFunc {
343+
return func(oo OperationOptions, uo Operation[any, any]) (Operation[any, any], error) {
344+
return &loggingOperation{
345+
uo,
346+
output,
347+
}, nil
348+
}
349+
}

0 commit comments

Comments
 (0)