Skip to content

Commit fd5ddc1

Browse files
committed
fix: correct route Name/Action lookup for late-registered sub-routers
1 parent 59705d5 commit fd5ddc1

2 files changed

Lines changed: 147 additions & 8 deletions

File tree

pkg/teapot/router.go

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,12 @@ func (r *Router) propagateRoute(pathPrefix, namePrefix string, rt *core.Route) {
606606
r.nameIndex[newRoute.Name] = newRoute
607607
}
608608

609+
// Also track in directRoutes and dispatchers for findMatchingRoute
610+
dispatcherKey := newRoute.Method + ":" + newRoute.ChiPattern
611+
if len(newRoute.QueryMatchers) == 0 {
612+
r.directRoutes[dispatcherKey] = newRoute
613+
}
614+
609615
// Propagate further up if this router also has parents
610616
for _, p := range r.parents {
611617
p.router.propagateRoute(p.pathPrefix, p.namePrefix, newRoute)
@@ -700,36 +706,78 @@ func (r *Router) SubRouter(prefix string) *Router {
700706
// findMatchingRoute manually matches a request against registered routes
701707
// This is used as a fallback when Chi's RouteContext isn't available (e.g., in global middleware)
702708
func (r *Router) findMatchingRoute(method, path string) *core.Route {
709+
type candidate struct {
710+
route *core.Route
711+
pattern string
712+
}
713+
var matches []candidate
714+
703715
// Check all direct routes
704716
for key, route := range r.directRoutes {
705717
if strings.HasPrefix(key, method+":") {
706718
pattern := strings.TrimPrefix(key, method+":")
707719
if r.matchPattern(pattern, path) {
708-
return route
720+
matches = append(matches, candidate{route: route, pattern: pattern})
709721
}
710722
}
711723
}
712724

713-
// Check dispatcher routes (return fallback route for query-multiplexed)
725+
// Check dispatcher routes
714726
for key, disp := range r.dispatchers {
715727
if strings.HasPrefix(key, method+":") {
716728
pattern := strings.TrimPrefix(key, method+":")
717729
if r.matchPattern(pattern, path) {
718-
// Return fallback route (no query matchers)
730+
// For dispatchers, we want the fallback route for name/action resolution in middleware
731+
// when we don't know the query params yet.
732+
var fallback *core.Route
719733
for _, rt := range disp.Routes {
720734
if len(rt.QueryMatchers) == 0 {
721-
return rt
735+
fallback = rt
736+
break
722737
}
723738
}
724-
// If no fallback, return first route
725-
if len(disp.Routes) > 0 {
726-
return disp.Routes[0]
739+
if fallback == nil && len(disp.Routes) > 0 {
740+
fallback = disp.Routes[0]
741+
}
742+
if fallback != nil {
743+
matches = append(matches, candidate{route: fallback, pattern: pattern})
744+
}
745+
}
746+
}
747+
}
748+
749+
if len(matches) == 0 {
750+
return nil
751+
}
752+
753+
// If multiple matches, prioritize literal matches (no {param} or *)
754+
if len(matches) > 1 {
755+
bestIdx := 0
756+
bestScore := -1
757+
758+
for i, match := range matches {
759+
score := 0
760+
if !strings.ContainsAny(match.pattern, "{}*") {
761+
score = 100 // Exact literal match
762+
} else {
763+
// Count literal parts (non-parameters)
764+
parts := strings.Split(match.pattern, "/")
765+
for _, p := range parts {
766+
if p != "" && !strings.ContainsAny(p, "{}*") {
767+
score++
768+
}
727769
}
728770
}
771+
772+
if score > bestScore {
773+
bestScore = score
774+
bestIdx = i
775+
}
729776
}
777+
return matches[bestIdx].route
730778
}
731779

732-
return nil
780+
return matches[0].route
733781
}
734782

735783
// matchPattern checks if a Chi pattern matches a path

tests/subrouter_resolution_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package tests
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
10+
"github.com/mallardduck/teapot-router/pkg/teapot"
11+
)
12+
13+
func TestSubRouterResolutionBug(t *testing.T) {
14+
t.Run("subrouter with parameter shadows literal", func(t *testing.T) {
15+
r := teapot.New()
16+
17+
// Global middleware that checks route name/action
18+
r.Use(teapot.RouteContextMiddleware(r))
19+
20+
var resolvedName string
21+
var resolvedAction string
22+
r.Use(func(next http.Handler) http.Handler {
23+
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
24+
next.ServeHTTP(w, req)
25+
resolvedName = teapot.GetRouteName(req)
26+
resolvedAction = teapot.GetAction(req)
27+
})
28+
})
29+
r.GET("/{greedy}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
30+
w.WriteHeader(http.StatusOK)
31+
}))
32+
33+
// Sub-router mounted at /api
34+
api := teapot.New()
35+
36+
// Route with parameter at first position
37+
api.GET("/{id}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
38+
w.WriteHeader(http.StatusOK)
39+
})).Name("api.show").Action("ShowAction")
40+
41+
// Literal route that should match /api/users
42+
api.GET("/users", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
43+
w.WriteHeader(http.StatusOK)
44+
})).Name("api.users").Action("UsersAction")
45+
46+
r.Mount("/api", api)
47+
r.Finalize()
48+
49+
// Request to /api/users
50+
req := httptest.NewRequest("GET", "/api/users", nil)
51+
rec := httptest.NewRecorder()
52+
r.ServeHTTP(rec, req)
53+
54+
assert.Equal(t, http.StatusOK, rec.Code)
55+
assert.Equal(t, "api.users", resolvedName, "Should resolve to api.users, not api.show")
56+
assert.Equal(t, "UsersAction", resolvedAction)
57+
})
58+
59+
t.Run("late registration with parameter shadows literal", func(t *testing.T) {
60+
r := teapot.New()
61+
62+
// Global middleware
63+
r.Use(teapot.RouteContextMiddleware(r))
64+
var resolvedName string
65+
r.Use(func(next http.Handler) http.Handler {
66+
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
67+
next.ServeHTTP(w, req)
68+
resolvedName = teapot.GetRouteName(req)
69+
})
70+
})
71+
72+
// Parameter route first
73+
r.GET("/{id}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
74+
w.WriteHeader(http.StatusOK)
75+
})).Name("show")
76+
77+
// Literal route later
78+
r.GET("/users", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
79+
w.WriteHeader(http.StatusOK)
80+
})).Name("users")
81+
82+
r.Finalize()
83+
84+
// Request to /users
85+
req := httptest.NewRequest("GET", "/users", nil)
86+
rec := httptest.NewRecorder()
87+
r.ServeHTTP(rec, req)
88+
89+
assert.Equal(t, "users", resolvedName)
90+
})
91+
}

0 commit comments

Comments
 (0)