Skip to content

Commit 5e63e2c

Browse files
Run middleware for OPTIONS fallbacks
1 parent 98d99d5 commit 5e63e2c

2 files changed

Lines changed: 85 additions & 0 deletions

File tree

middleware/cors_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,35 @@ func TestCorsHeaders(t *testing.T) {
581581
}
582582
}
583583

584+
func TestCORSWithConfig_GroupPreflightWithoutOptionsRoute(t *testing.T) {
585+
e := echo.New()
586+
g := e.Group("/myroute", CORSWithConfig(CORSConfig{
587+
AllowOrigins: []string{"https://example.com"},
588+
AllowHeaders: []string{
589+
echo.HeaderOrigin,
590+
echo.HeaderContentType,
591+
echo.HeaderAccept,
592+
echo.HeaderAuthorization,
593+
},
594+
}))
595+
g.GET("", func(c *echo.Context) error {
596+
return c.String(http.StatusOK, "OK")
597+
})
598+
599+
req := httptest.NewRequest(http.MethodOptions, "/myroute", nil)
600+
req.Header.Set(echo.HeaderOrigin, "https://example.com")
601+
req.Header.Set(echo.HeaderAccessControlRequestMethod, http.MethodGet)
602+
rec := httptest.NewRecorder()
603+
604+
e.ServeHTTP(rec, req)
605+
606+
assert.Equal(t, http.StatusNoContent, rec.Code)
607+
assert.Equal(t, "https://example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
608+
assert.Equal(t, "OPTIONS, GET", rec.Header().Get(echo.HeaderAccessControlAllowMethods))
609+
assert.Equal(t, "Origin,Content-Type,Accept,Authorization", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
610+
assert.Equal(t, "OPTIONS, GET", rec.Header().Get(echo.HeaderAllow))
611+
}
612+
584613
func Test_allowOriginFunc(t *testing.T) {
585614
returnTrue := func(c *echo.Context, origin string) (string, bool, error) {
586615
return origin, true, nil

router.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ const (
142142
type routeMethod struct {
143143
*RouteInfo
144144
handler HandlerFunc
145+
middlewares []MiddlewareFunc
145146
orgRouteInfo RouteInfo
146147
}
147148

@@ -298,6 +299,54 @@ func (m *routeMethods) updateAllowHeader() {
298299
m.allowHeader = buf.String()
299300
}
300301

302+
func (m *routeMethods) optionsFallbackHandler(requestedMethod string) *routeMethod {
303+
if requestedMethod != "" {
304+
if h := m.find(requestedMethod, true); h != nil {
305+
return h
306+
}
307+
}
308+
if m.connect != nil {
309+
return m.connect
310+
}
311+
if m.delete != nil {
312+
return m.delete
313+
}
314+
if m.get != nil {
315+
return m.get
316+
}
317+
if m.head != nil {
318+
return m.head
319+
}
320+
if m.options != nil {
321+
return m.options
322+
}
323+
if m.patch != nil {
324+
return m.patch
325+
}
326+
if m.post != nil {
327+
return m.post
328+
}
329+
if m.propfind != nil {
330+
return m.propfind
331+
}
332+
if m.put != nil {
333+
return m.put
334+
}
335+
if m.trace != nil {
336+
return m.trace
337+
}
338+
if m.report != nil {
339+
return m.report
340+
}
341+
if m.any != nil {
342+
return m.any
343+
}
344+
for _, r := range m.anyOther {
345+
return r
346+
}
347+
return nil
348+
}
349+
301350
func (m *routeMethods) isHandler() bool {
302351
return m.get != nil ||
303352
m.post != nil ||
@@ -488,6 +537,7 @@ func (r *DefaultRouter) Add(route Route) (RouteInfo, error) {
488537
rm := routeMethod{
489538
RouteInfo: &RouteInfo{Method: method, Path: originalPath, Parameters: paramNames, Name: route.Name},
490539
handler: h,
540+
middlewares: append([]MiddlewareFunc(nil), route.Middlewares...),
491541
orgRouteInfo: ri,
492542
}
493543
r.insert(paramKind, path[:i], method, rm)
@@ -503,6 +553,7 @@ func (r *DefaultRouter) Add(route Route) (RouteInfo, error) {
503553
rm := routeMethod{
504554
RouteInfo: &RouteInfo{Method: method, Path: originalPath, Parameters: paramNames, Name: route.Name},
505555
handler: h,
556+
middlewares: append([]MiddlewareFunc(nil), route.Middlewares...),
506557
orgRouteInfo: ri,
507558
}
508559
r.insert(anyKind, path[:i+1], method, rm)
@@ -516,6 +567,7 @@ func (r *DefaultRouter) Add(route Route) (RouteInfo, error) {
516567
rm := routeMethod{
517568
RouteInfo: &RouteInfo{Method: method, Path: originalPath, Parameters: paramNames, Name: route.Name},
518569
handler: h,
570+
middlewares: append([]MiddlewareFunc(nil), route.Middlewares...),
519571
orgRouteInfo: ri,
520572
}
521573
r.insert(staticKind, path, method, rm)
@@ -1011,6 +1063,10 @@ func (r *DefaultRouter) Route(c *Context) HandlerFunc {
10111063
rHandler = r.methodNotAllowedHandler
10121064
if req.Method == http.MethodOptions {
10131065
rHandler = r.optionsMethodHandler
1066+
requestedMethod := req.Header.Get(HeaderAccessControlRequestMethod)
1067+
if fallbackMethod := currentNode.methods.optionsFallbackHandler(requestedMethod); fallbackMethod != nil {
1068+
rHandler = applyMiddleware(rHandler, fallbackMethod.middlewares...)
1069+
}
10141070
}
10151071
}
10161072
}

0 commit comments

Comments
 (0)