Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2724,9 +2724,8 @@ func TestSubrouterMatching(t *testing.T) {
r.Methods("POST").Subrouter().Methods("GET")
},
[]request{
{"matches before", newRequest("POST", "/"), none},
{"matches merged methods", newRequest("GET", "/"), stdOnly},
{"no match other", newRequest("HEAD", "/"), none},
{"matches override", newRequest("GET", "/"), none},
},
},
{
Expand Down
32 changes: 29 additions & 3 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,31 @@ func (m methodMatcher) Match(r *http.Request, match *RouteMatch) bool {
// Methods adds a matcher for HTTP methods.
// It accepts a sequence of one or more methods to be matched, e.g.:
// "GET", "POST", "PUT".
//
// Repeated calls merge into the same matcher instead of requiring every method
// to match independently.
func (r *Route) Methods(methods ...string) *Route {
for k, v := range methods {
methods[k] = strings.ToUpper(v)
}
for i, m := range r.matchers {
if existing, ok := m.(methodMatcher); ok {
r.matchers[i] = mergeMethodMatcher(existing, methods...)
return r
}
}
return r.addMatcher(methodMatcher(methods))
}

func mergeMethodMatcher(existing methodMatcher, methods ...string) methodMatcher {
for _, method := range methods {
if !matchInArray(existing, method) {
existing = append(existing, method)
}
}
return existing
}

// Path -----------------------------------------------------------------------

// Path adds a matcher for the URL path.
Expand Down Expand Up @@ -768,12 +786,20 @@ func (r *Route) GetMethods() ([]string, error) {
if r.err != nil {
return nil, r.err
}
var methods []string
for _, m := range r.matchers {
if methods, ok := m.(methodMatcher); ok {
return []string(methods), nil
if mm, ok := m.(methodMatcher); ok {
for _, method := range mm {
if !matchInArray(methods, method) {
methods = append(methods, method)
}
}
}
}
return nil, errors.New("mux: route doesn't have methods")
if len(methods) == 0 {
return nil, errors.New("mux: route doesn't have methods")
}
return methods, nil
}

// GetHostTemplate returns the template used to build the
Expand Down
38 changes: 38 additions & 0 deletions route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,41 @@ func TestRouteMetadata(t *testing.T) {
router.ServeHTTP(rw, req)
})
}

func TestRouteMethodsRepeatedCalls(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

route := NewRouter().NewRoute()
route.Path("/test").Handler(handler).Methods("PUT").Methods("PATCH")
route.Methods("OPTIONS")

methods, err := route.GetMethods()
if err != nil {
t.Fatalf("GetMethods failed: %v", err)
}
if len(methods) != 3 {
t.Fatalf("expected 3 methods, got %v", methods)
}

for _, method := range []string{"PUT", "PATCH", "OPTIONS"} {
req := newRequest(method, "/test")
match := &RouteMatch{}
if !route.Match(req, match) {
t.Fatalf("%s should match route", method)
}
if match.MatchErr != nil {
t.Fatalf("%s match returned error: %v", method, match.MatchErr)
}
}

req := newRequest("GET", "/test")
match := &RouteMatch{}
if route.Match(req, match) {
t.Fatal("GET should not match route")
}
if match.MatchErr != ErrMethodMismatch {
t.Fatalf("expected ErrMethodMismatch, got %v", match.MatchErr)
}
}