Skip to content

Commit bfb8201

Browse files
committed
process the CORS options once
The CORS middleware resulting from handlers.CORS would unnecessarily re-process the CORS options every time it would be applied to a handler. With this change, the CORS options are processed only once, when the middleware is created.
1 parent 9c61bd8 commit bfb8201

File tree

1 file changed

+71
-71
lines changed

1 file changed

+71
-71
lines changed

cors.go

Lines changed: 71 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
type CORSOption func(*cors) error
1111

1212
type cors struct {
13-
h http.Handler
1413
allowedHeaders []string
1514
allowedMethods []string
1615
allowedOrigins []string
@@ -47,93 +46,95 @@ const (
4746
corsOriginMatchAll string = "*"
4847
)
4948

50-
func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
51-
origin := r.Header.Get(corsOriginHeader)
52-
if !ch.isOriginAllowed(origin) {
53-
if r.Method != corsOptionMethod || ch.ignoreOptions {
54-
ch.h.ServeHTTP(w, r)
55-
}
56-
57-
return
58-
}
59-
60-
if r.Method == corsOptionMethod {
61-
if ch.ignoreOptions {
62-
ch.h.ServeHTTP(w, r)
63-
return
64-
}
49+
func (ch *cors) wrap(h http.Handler) http.Handler {
50+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
51+
origin := r.Header.Get(corsOriginHeader)
52+
if !ch.isOriginAllowed(origin) {
53+
if r.Method != corsOptionMethod || ch.ignoreOptions {
54+
h.ServeHTTP(w, r)
55+
}
6556

66-
if _, ok := r.Header[corsRequestMethodHeader]; !ok {
67-
w.WriteHeader(http.StatusBadRequest)
6857
return
6958
}
7059

71-
method := r.Header.Get(corsRequestMethodHeader)
72-
if !ch.isMatch(method, ch.allowedMethods) {
73-
w.WriteHeader(http.StatusMethodNotAllowed)
74-
return
75-
}
60+
if r.Method == corsOptionMethod {
61+
if ch.ignoreOptions {
62+
h.ServeHTTP(w, r)
63+
return
64+
}
7665

77-
requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",")
78-
allowedHeaders := []string{}
79-
for _, v := range requestHeaders {
80-
canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
81-
if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) {
82-
continue
66+
if _, ok := r.Header[corsRequestMethodHeader]; !ok {
67+
w.WriteHeader(http.StatusBadRequest)
68+
return
8369
}
8470

85-
if !ch.isMatch(canonicalHeader, ch.allowedHeaders) {
86-
w.WriteHeader(http.StatusForbidden)
71+
method := r.Header.Get(corsRequestMethodHeader)
72+
if !ch.isMatch(method, ch.allowedMethods) {
73+
w.WriteHeader(http.StatusMethodNotAllowed)
8774
return
8875
}
8976

90-
allowedHeaders = append(allowedHeaders, canonicalHeader)
91-
}
77+
requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",")
78+
allowedHeaders := []string{}
79+
for _, v := range requestHeaders {
80+
canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
81+
if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) {
82+
continue
83+
}
9284

93-
if len(allowedHeaders) > 0 {
94-
w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ","))
95-
}
85+
if !ch.isMatch(canonicalHeader, ch.allowedHeaders) {
86+
w.WriteHeader(http.StatusForbidden)
87+
return
88+
}
9689

97-
if ch.maxAge > 0 {
98-
w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge))
99-
}
90+
allowedHeaders = append(allowedHeaders, canonicalHeader)
91+
}
92+
93+
if len(allowedHeaders) > 0 {
94+
w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ","))
95+
}
96+
97+
if ch.maxAge > 0 {
98+
w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge))
99+
}
100100

101-
if !ch.isMatch(method, defaultCorsMethods) {
102-
w.Header().Set(corsAllowMethodsHeader, method)
101+
if !ch.isMatch(method, defaultCorsMethods) {
102+
w.Header().Set(corsAllowMethodsHeader, method)
103+
}
104+
} else if len(ch.exposedHeaders) > 0 {
105+
w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ","))
103106
}
104-
} else if len(ch.exposedHeaders) > 0 {
105-
w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ","))
106-
}
107107

108-
if ch.allowCredentials {
109-
w.Header().Set(corsAllowCredentialsHeader, "true")
110-
}
108+
if ch.allowCredentials {
109+
w.Header().Set(corsAllowCredentialsHeader, "true")
110+
}
111111

112-
if len(ch.allowedOrigins) > 1 {
113-
w.Header().Set(corsVaryHeader, corsOriginHeader)
114-
}
112+
if len(ch.allowedOrigins) > 1 {
113+
w.Header().Set(corsVaryHeader, corsOriginHeader)
114+
}
115115

116-
returnOrigin := origin
117-
if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 {
118-
returnOrigin = "*"
119-
} else {
120-
for _, o := range ch.allowedOrigins {
121-
// A configuration of * is different than explicitly setting an allowed
122-
// origin. Returning arbitrary origin headers in an access control allow
123-
// origin header is unsafe and is not required by any use case.
124-
if o == corsOriginMatchAll {
125-
returnOrigin = "*"
126-
break
116+
returnOrigin := origin
117+
if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 {
118+
returnOrigin = "*"
119+
} else {
120+
for _, o := range ch.allowedOrigins {
121+
// A configuration of * is different than explicitly setting an allowed
122+
// origin. Returning arbitrary origin headers in an access control allow
123+
// origin header is unsafe and is not required by any use case.
124+
if o == corsOriginMatchAll {
125+
returnOrigin = "*"
126+
break
127+
}
127128
}
128129
}
129-
}
130-
w.Header().Set(corsAllowOriginHeader, returnOrigin)
130+
w.Header().Set(corsAllowOriginHeader, returnOrigin)
131131

132-
if r.Method == corsOptionMethod {
133-
w.WriteHeader(ch.optionStatusCode)
134-
return
135-
}
136-
ch.h.ServeHTTP(w, r)
132+
if r.Method == corsOptionMethod {
133+
w.WriteHeader(ch.optionStatusCode)
134+
return
135+
}
136+
h.ServeHTTP(w, r)
137+
})
137138
}
138139

139140
// CORS provides Cross-Origin Resource Sharing middleware.
@@ -155,10 +156,9 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
155156
// http.ListenAndServe(":8000", handlers.CORS()(r))
156157
// }
157158
func CORS(opts ...CORSOption) func(http.Handler) http.Handler {
159+
ch := parseCORSOptions(opts...)
158160
return func(h http.Handler) http.Handler {
159-
ch := parseCORSOptions(opts...)
160-
ch.h = h
161-
return ch
161+
return ch.wrap(h)
162162
}
163163
}
164164

0 commit comments

Comments
 (0)