Skip to content

Commit 25fdfd5

Browse files
committed
feat: scope middleware to groups in Use() and MiddlewareGroup() functions
Enhance middleware handling to support scoping within `Group()`, `NamedGroup()`, and `MiddlewareGroup()`, ensuring middleware applies only to routes within the defined group. Add tests to verify expected behavior.
1 parent c005aec commit 25fdfd5

2 files changed

Lines changed: 94 additions & 4 deletions

File tree

pkg/teapot/router.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type Router struct {
2525
optimizedHandlers *[]*optimizedHandler // for finalization optimization
2626
finalized bool
2727
debugLog bool // enable debug logging for auto-promotion
28+
isGroupContext bool // true when inside Group()/NamedGroup()/MiddlewareGroup() — Use() appends to r.middlewares instead of r.mux
2829

2930
// Homing support for late propagation
3031
parents []parentRouter // parent routers to notify of new routes
@@ -451,6 +452,7 @@ func (r *Router) MiddlewareGroup(fn func(r *Router), middlewares ...func(http.Ha
451452
middlewares: append(append([]func(http.Handler) http.Handler{}, r.middlewares...), middlewares...), // Parent + new
452453
optimizedHandlers: r.optimizedHandlers,
453454
finalized: r.finalized,
455+
isGroupContext: true, // Use() inside groups appends to r.middlewares, not r.mux
454456
}
455457

456458
fn(subRouter)
@@ -470,6 +472,7 @@ func (r *Router) NamedGroup(pattern, namePrefix string, fn func(r *Router)) {
470472
middlewares: append([]func(http.Handler) http.Handler{}, r.middlewares...), // Copy parent middlewares
471473
optimizedHandlers: r.optimizedHandlers,
472474
debugLog: r.debugLog,
475+
isGroupContext: true, // Use() inside groups appends to r.middlewares, not r.mux
473476
}
474477

475478
// Trim trailing dot if namePrefix is empty
@@ -521,11 +524,24 @@ func (r *Router) Route(pattern string, fn func(r *Router)) {
521524
})
522525
}
523526

524-
// Use adds global middleware to the router
527+
// Use adds middleware to the router.
528+
//
529+
// At the root level, middleware is registered globally with chi and applies to all routes.
530+
//
531+
// Inside a Group(), NamedGroup(), or MiddlewareGroup() block, middleware is scoped to
532+
// routes registered within that group. It is appended to the group's middleware chain
533+
// and copied to each route at registration time.
534+
//
535+
// Use() must be called before registering routes or nested groups for the middleware
536+
// to take effect on those routes (same convention as chi).
525537
func (r *Router) Use(middlewares ...func(http.Handler) http.Handler) {
526-
// Only add to chi.Mux for truly global middleware
527-
// Don't add to r.middlewares as that would duplicate with route-specific
528-
r.mux.Use(middlewares...)
538+
if r.isGroupContext {
539+
// Group context: store in r.middlewares so it's copied to routes at registration
540+
r.middlewares = append(r.middlewares, middlewares...)
541+
} else {
542+
// Root/Route context: delegate to chi for global middleware
543+
r.mux.Use(middlewares...)
544+
}
529545
}
530546

531547
// MountNamed is like Mount, but allows specifying a name prefix for the sub-router's routes.

pkg/teapot/router_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,3 +1060,77 @@ func TestMiddlewareGroupAfterFinalization(t *testing.T) {
10601060
asserts.True(routeNames["after1"])
10611061
asserts.True(routeNames["after2"])
10621062
}
1063+
1064+
// Test: r.Use() inside Group() scopes middleware to the group, not globally
1065+
func TestGroupUseIsScoped(t *testing.T) {
1066+
r := teapot.New()
1067+
1068+
var log []string
1069+
groupMiddleware := func(next http.Handler) http.Handler {
1070+
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
1071+
log = append(log, "group:"+req.URL.Path)
1072+
next.ServeHTTP(w, req)
1073+
})
1074+
}
1075+
1076+
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1077+
w.WriteHeader(http.StatusOK)
1078+
})
1079+
1080+
r.Group("/api", func(r *teapot.Router) {
1081+
r.Use(groupMiddleware)
1082+
r.GET("/users", handler)
1083+
})
1084+
r.GET("/public", handler)
1085+
1086+
srv := httptest.NewServer(r)
1087+
defer srv.Close()
1088+
1089+
// Hit the public route — middleware must NOT run
1090+
log = nil
1091+
resp, _ := http.Get(srv.URL + "/public")
1092+
_ = resp.Body.Close()
1093+
assert.Empty(t, log, "group middleware must not fire for /public")
1094+
1095+
// Hit the group route — middleware MUST run
1096+
log = nil
1097+
resp, _ = http.Get(srv.URL + "/api/users")
1098+
_ = resp.Body.Close()
1099+
assert.Equal(t, []string{"group:/api/users"}, log, "group middleware must fire for /api/users")
1100+
}
1101+
1102+
// Test: r.Use() inside MiddlewareGroup() scopes middleware to that group's routes
1103+
func TestMiddlewareGroupUseIsScoped(t *testing.T) {
1104+
r := teapot.New()
1105+
1106+
var log []string
1107+
mw := func(next http.Handler) http.Handler {
1108+
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
1109+
log = append(log, "mw:"+req.URL.Path)
1110+
next.ServeHTTP(w, req)
1111+
})
1112+
}
1113+
1114+
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1115+
w.WriteHeader(http.StatusOK)
1116+
})
1117+
1118+
r.MiddlewareGroup(func(r *teapot.Router) {
1119+
r.Use(mw)
1120+
r.GET("/protected", handler)
1121+
})
1122+
r.GET("/open", handler)
1123+
1124+
srv := httptest.NewServer(r)
1125+
defer srv.Close()
1126+
1127+
log = nil
1128+
resp, _ := http.Get(srv.URL + "/open")
1129+
_ = resp.Body.Close()
1130+
assert.Empty(t, log, "mw must not fire for /open")
1131+
1132+
log = nil
1133+
resp, _ = http.Get(srv.URL + "/protected")
1134+
_ = resp.Body.Close()
1135+
assert.Equal(t, []string{"mw:/protected"}, log, "mw must fire for /protected")
1136+
}

0 commit comments

Comments
 (0)