Skip to content

Commit d1f86a6

Browse files
refactor middleware interface
1 parent 73e123e commit d1f86a6

File tree

2 files changed

+17
-34
lines changed

2 files changed

+17
-34
lines changed

nexus/operation.go

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type OperationOptions struct {
3232
Header Header
3333
}
3434

35-
type MiddlewareFunc func(OperationOptions, Operation[any, any]) (Operation[any, any], error)
35+
type MiddlewareFunc func(OperationOptions, OperationInvoker[any, any]) (OperationInvoker[any, any], error)
3636

3737
// OperationReference provides a typed interface for invoking operations. Every [Operation] is also an
3838
// [OperationReference]. Callers may create references using [NewOperationReference] when the implementation is not
@@ -89,7 +89,10 @@ type RegisterableOperation interface {
8989
type Operation[I, O any] interface {
9090
RegisterableOperation
9191
OperationReference[I, O]
92+
OperationInvoker[I, O]
93+
}
9294

95+
type OperationInvoker[I, O any] interface {
9396
// Start handles requests for starting an operation. Return [HandlerStartOperationResultSync] to respond
9497
// successfully - inline, or [HandlerStartOperationResultAsync] to indicate that an asynchronous operation was
9598
// started. Return an [OperationError] to indicate that an operation completed as failed or
@@ -263,7 +266,7 @@ type registryHandler struct {
263266
middlewares []MiddlewareFunc
264267
}
265268

266-
func (r *registryHandler) getOperation(options OperationOptions) (Operation[any, any], error) {
269+
func (r *registryHandler) getOperation(options OperationOptions) (OperationInvoker[any, any], error) {
267270
s, ok := r.services[options.ServiceName]
268271
if !ok {
269272
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "service %q not found", options.ServiceName)
@@ -273,7 +276,7 @@ func (r *registryHandler) getOperation(options OperationOptions) (Operation[any,
273276
return nil, HandlerErrorf(HandlerErrorTypeNotFound, "operation %q not found", options.OperationName)
274277
}
275278

276-
var handler Operation[any, any]
279+
var handler OperationInvoker[any, any]
277280
handler = &rootOperationHandler{h: h}
278281
if h != nil && len(r.middlewares) > 0 {
279282
for i := len(r.middlewares) - 1; i >= 0; i-- {
@@ -334,29 +337,7 @@ func (r *rootOperationHandler) Start(ctx context.Context, input interface{}, opt
334337
return ret.(HandlerStartOperationResult[any]), nil
335338
}
336339

337-
func (r *rootOperationHandler) InputType() reflect.Type {
338-
m, _ := reflect.TypeOf(r.h).MethodByName("InputType")
339-
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h)})
340-
ret := values[0].Interface()
341-
return ret.(reflect.Type)
342-
}
343-
344-
func (r *rootOperationHandler) OutputType() reflect.Type {
345-
m, _ := reflect.TypeOf(r.h).MethodByName("OutputType")
346-
values := m.Func.Call([]reflect.Value{reflect.ValueOf(r.h)})
347-
ret := values[0].Interface()
348-
return ret.(reflect.Type)
349-
}
350-
351-
func (r *rootOperationHandler) inferType(input, output any) {}
352-
353-
func (r *rootOperationHandler) Name() string {
354-
return r.h.Name()
355-
}
356-
357-
func (r *rootOperationHandler) mustEmbedUnimplementedOperation() {}
358-
359-
var _ Operation[any, any] = &rootOperationHandler{}
340+
var _ OperationInvoker[any, any] = &rootOperationHandler{}
360341

361342
// CancelOperation implements Handler.
362343
func (r *registryHandler) CancelOperation(ctx context.Context, service, operation string, operationID string, options CancelOperationOptions) error {

nexus/operation_test.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ func TestOperationInterceptor(t *testing.T) {
306306
}
307307

308308
func newAuthMiddleware(authKey string) MiddlewareFunc {
309-
return func(oo OperationOptions, uo Operation[any, any]) (Operation[any, any], error) {
309+
return func(oo OperationOptions, uo OperationInvoker[any, any]) (OperationInvoker[any, any], error) {
310310
if oo.Header.Get("authorization") != authKey {
311311
return nil, HandlerErrorf(HandlerErrorTypeUnauthorized, "unauthorized")
312312
}
@@ -315,34 +315,36 @@ func newAuthMiddleware(authKey string) MiddlewareFunc {
315315
}
316316

317317
type loggingOperation struct {
318-
Operation[any, any]
319-
output func(string)
318+
Operation OperationInvoker[any, any]
319+
name string
320+
output func(string)
320321
}
321322

322323
func (lo *loggingOperation) Start(ctx context.Context, input any, options StartOperationOptions) (HandlerStartOperationResult[any], error) {
323-
lo.output(fmt.Sprintf("starting operation %s", lo.Operation.Name()))
324+
lo.output(fmt.Sprintf("starting operation %s", lo.name))
324325
return lo.Operation.Start(ctx, input, options)
325326
}
326327

327328
func (lo *loggingOperation) GetResult(ctx context.Context, id string, options GetOperationResultOptions) (any, error) {
328-
lo.output(fmt.Sprintf("getting result for operation %s", lo.Operation.Name()))
329+
lo.output(fmt.Sprintf("getting result for operation %s", lo.name))
329330
return lo.Operation.GetResult(ctx, id, options)
330331
}
331332

332333
func (lo *loggingOperation) Cancel(ctx context.Context, id string, options CancelOperationOptions) error {
333-
lo.output(fmt.Sprintf("cancel operation %s", lo.Operation.Name()))
334+
lo.output(fmt.Sprintf("cancel operation %s", lo.name))
334335
return lo.Operation.Cancel(ctx, id, options)
335336
}
336337

337338
func (lo *loggingOperation) GetInfo(ctx context.Context, id string, options GetOperationInfoOptions) (*OperationInfo, error) {
338-
lo.output(fmt.Sprintf("getting info for operation %s", lo.Operation.Name()))
339+
lo.output(fmt.Sprintf("getting info for operation %s", lo.name))
339340
return lo.Operation.GetInfo(ctx, id, options)
340341
}
341342

342343
func newLoggingMiddleware(output func(string)) MiddlewareFunc {
343-
return func(oo OperationOptions, uo Operation[any, any]) (Operation[any, any], error) {
344+
return func(oo OperationOptions, uo OperationInvoker[any, any]) (OperationInvoker[any, any], error) {
344345
return &loggingOperation{
345346
uo,
347+
oo.OperationName,
346348
output,
347349
}, nil
348350
}

0 commit comments

Comments
 (0)