Skip to content

Commit ad69ec5

Browse files
committed
Fix CORS header duplication under proxy chains
1 parent dac56bc commit ad69ec5

2 files changed

Lines changed: 89 additions & 12 deletions

File tree

middleware/cors.go

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,6 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
193193
res := c.Response()
194194
origin := req.Header.Get(echo.HeaderOrigin)
195195

196-
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
197-
198196
// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
199197
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
200198
// For simplicity we just consider method type and later `Origin` header.
@@ -217,8 +215,12 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
217215
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
218216
if origin == "" {
219217
if preflight { // req.Method=OPTIONS
218+
addVaryHeader(res.Header(), echo.HeaderOrigin)
220219
return c.NoContent(http.StatusNoContent)
221220
}
221+
res.Before(func() {
222+
addVaryHeader(res.Header(), echo.HeaderOrigin)
223+
})
222224
return next(c) // let non-browser calls through
223225
}
224226

@@ -239,30 +241,44 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
239241
// no CORS middleware should block non-preflight requests;
240242
// such requests should be let through. One reason is that not all requests that
241243
// carry an Origin header participate in the CORS protocol.
244+
res.Before(func() {
245+
addVaryHeader(res.Header(), echo.HeaderOrigin)
246+
})
242247
return next(c)
243248
}
244249

245250
// Origin existed and was allowed
246251

247-
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
248-
if config.AllowCredentials {
249-
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
250-
}
251-
252252
// Simple request will be let though
253253
if !preflight {
254-
if exposeHeaders != "" {
255-
res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
256-
}
254+
res.Before(func() {
255+
addVaryHeader(res.Header(), echo.HeaderOrigin)
256+
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
257+
if config.AllowCredentials {
258+
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
259+
} else {
260+
res.Header().Del(echo.HeaderAccessControlAllowCredentials)
261+
}
262+
if exposeHeaders != "" {
263+
res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
264+
}
265+
})
257266
return next(c)
258267
}
259268
// Below code is for Preflight (OPTIONS) request
260269
//
261270
// Preflight will end with c.NoContent(http.StatusNoContent) as we do not know if
262271
// at the end of handler chain is actual OPTIONS route or 404/405 route which
263272
// response code will confuse browsers
264-
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
265-
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
273+
addVaryHeader(res.Header(), echo.HeaderOrigin)
274+
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
275+
if config.AllowCredentials {
276+
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
277+
} else {
278+
res.Header().Del(echo.HeaderAccessControlAllowCredentials)
279+
}
280+
addVaryHeader(res.Header(), echo.HeaderAccessControlRequestMethod)
281+
addVaryHeader(res.Header(), echo.HeaderAccessControlRequestHeaders)
266282

267283
if !hasCustomAllowMethods && routerAllowMethods != "" {
268284
res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods)
@@ -298,3 +314,18 @@ func (config CORSConfig) defaultAllowOriginFunc(c *echo.Context, origin string)
298314
}
299315
return "", false, nil
300316
}
317+
318+
func addVaryHeader(h http.Header, value string) {
319+
if h.Get(echo.HeaderVary) == "" {
320+
h.Set(echo.HeaderVary, value)
321+
return
322+
}
323+
for _, v := range h.Values(echo.HeaderVary) {
324+
for _, part := range strings.Split(v, ",") {
325+
if strings.EqualFold(strings.TrimSpace(part), value) {
326+
return
327+
}
328+
}
329+
}
330+
h.Add(echo.HeaderVary, value)
331+
}

middleware/cors_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,3 +626,49 @@ func Test_allowOriginFunc(t *testing.T) {
626626
}
627627
}
628628
}
629+
630+
func TestCORSProxyChainedHeaders(t *testing.T) {
631+
e := echo.New()
632+
633+
// CORS middleware on the proxy
634+
cors := CORSWithConfig(CORSConfig{
635+
AllowOrigins: []string{"http://example.com"},
636+
})
637+
638+
// Proxy handler simulating upstream call that also returns CORS headers
639+
proxyHandler := func(c *echo.Context) error {
640+
// Mock upstream copying headers to response
641+
// This simulates the behavior of httputil.ReverseProxy which copies headers from upstream
642+
c.Response().Header().Add(echo.HeaderAccessControlAllowOrigin, "http://example.com")
643+
c.Response().Header().Add(echo.HeaderVary, echo.HeaderOrigin)
644+
c.Response().WriteHeader(http.StatusOK)
645+
return nil
646+
}
647+
648+
h := cors(proxyHandler)
649+
650+
req := httptest.NewRequest(http.MethodGet, "/", nil)
651+
req.Header.Set(echo.HeaderOrigin, "http://example.com")
652+
rec := httptest.NewRecorder()
653+
c := e.NewContext(req, rec)
654+
655+
err := h(c)
656+
assert.NoError(t, err)
657+
658+
// Verify that Access-Control-Allow-Origin is not duplicated
659+
acaoHeaders := rec.Header()[echo.HeaderAccessControlAllowOrigin]
660+
assert.Len(t, acaoHeaders, 1, "Access-Control-Allow-Origin should not be duplicated")
661+
assert.Equal(t, "http://example.com", acaoHeaders[0])
662+
663+
// Verify that Vary: Origin is not duplicated
664+
varyHeaders := rec.Header()[echo.HeaderVary]
665+
originCount := 0
666+
for _, v := range varyHeaders {
667+
for _, part := range strings.Split(v, ",") {
668+
if strings.EqualFold(strings.TrimSpace(part), echo.HeaderOrigin) {
669+
originCount++
670+
}
671+
}
672+
}
673+
assert.Equal(t, 1, originCount, "Vary Origin should not be duplicated")
674+
}

0 commit comments

Comments
 (0)