Skip to content

Commit b0e1f50

Browse files
committed
feat(middleware): add match wrappers
1 parent 936a01a commit b0e1f50

File tree

3 files changed

+85
-7
lines changed

3 files changed

+85
-7
lines changed

middleware/match.go

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package middleware
2+
3+
import (
4+
"regexp"
5+
6+
"github.com/ogen-go/ogen/internal/xmaps"
7+
)
8+
9+
// OperationID calls the next middleware if request operation ID matches the given operationID.
10+
func OperationID(m Middleware, operationID ...string) Middleware {
11+
switch len(operationID) {
12+
case 0:
13+
return justCallNext
14+
case 1:
15+
val := operationID[0]
16+
return func(req Request, next Next) (Response, error) {
17+
if req.OperationID == val {
18+
return m(req, next)
19+
}
20+
return next(req)
21+
}
22+
default:
23+
set := xmaps.BuildSet(operationID...)
24+
return func(req Request, next Next) (Response, error) {
25+
if _, ok := set[req.OperationID]; ok {
26+
return m(req, next)
27+
}
28+
return next(req)
29+
}
30+
}
31+
}
32+
33+
// OperationName calls the next middleware if request operation name matches the given operationName.
34+
func OperationName(m Middleware, operationName ...string) Middleware {
35+
switch len(operationName) {
36+
case 0:
37+
return justCallNext
38+
case 1:
39+
val := operationName[0]
40+
return func(req Request, next Next) (Response, error) {
41+
if req.OperationName == val {
42+
return m(req, next)
43+
}
44+
return next(req)
45+
}
46+
default:
47+
set := xmaps.BuildSet(operationName...)
48+
return func(req Request, next Next) (Response, error) {
49+
if _, ok := set[req.OperationName]; ok {
50+
return m(req, next)
51+
}
52+
return next(req)
53+
}
54+
}
55+
}
56+
57+
// PathRegex calls the next middleware if request path matches the given regex.
58+
func PathRegex(re *regexp.Regexp, m Middleware) Middleware {
59+
if re == nil {
60+
return justCallNext
61+
}
62+
63+
return func(req Request, next Next) (Response, error) {
64+
if re.MatchString(req.Raw.URL.Path) {
65+
return m(req, next)
66+
}
67+
return next(req)
68+
}
69+
}
70+
71+
// BodyType calls the next middleware if request body type matches the given type.
72+
func BodyType[T any](m Middleware) Middleware {
73+
return func(req Request, next Next) (Response, error) {
74+
if _, ok := req.Body.(T); ok {
75+
return m(req, next)
76+
}
77+
return next(req)
78+
}
79+
}

middleware/middleware.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,14 @@ type (
3535
Middleware func(req Request, next Next) (Response, error)
3636
)
3737

38+
func justCallNext(req Request, next Next) (Response, error) {
39+
return next(req)
40+
}
41+
3842
// ChainMiddlewares chains middlewares into a single middleware, which will be executed in the order they are passed.
3943
func ChainMiddlewares(m ...Middleware) Middleware {
4044
if len(m) == 0 {
41-
return func(req Request, next Next) (Response, error) {
42-
return next(req)
43-
}
45+
return justCallNext
4446
}
4547
tail := ChainMiddlewares(m[1:]...)
4648
return func(req Request, next Next) (Response, error) {

middleware/middleware_test.go

+1-4
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,11 @@ func TestChainMiddlewares(t *testing.T) {
5858

5959
func BenchmarkChainMiddlewares(b *testing.B) {
6060
const N = 20
61-
noop := func(req Request, next Next) (Response, error) {
62-
return next(req)
63-
}
6461

6562
var (
6663
chain = ChainMiddlewares(func() (r []Middleware) {
6764
for i := 0; i < N; i++ {
68-
r = append(r, noop)
65+
r = append(r, justCallNext)
6966
}
7067
return r
7168
}()...)

0 commit comments

Comments
 (0)