diff --git a/cors.go b/cors.go index 8af9c09..65a0e7d 100644 --- a/cors.go +++ b/cors.go @@ -10,7 +10,6 @@ import ( type CORSOption func(*cors) error type cors struct { - h http.Handler allowedHeaders []string allowedMethods []string allowedOrigins []string @@ -47,93 +46,95 @@ const ( corsOriginMatchAll string = "*" ) -func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { - origin := r.Header.Get(corsOriginHeader) - if !ch.isOriginAllowed(origin) { - if r.Method != corsOptionMethod || ch.ignoreOptions { - ch.h.ServeHTTP(w, r) - } - - return - } - - if r.Method == corsOptionMethod { - if ch.ignoreOptions { - ch.h.ServeHTTP(w, r) - return - } +func (ch *cors) wrap(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get(corsOriginHeader) + if !ch.isOriginAllowed(origin) { + if r.Method != corsOptionMethod || ch.ignoreOptions { + h.ServeHTTP(w, r) + } - if _, ok := r.Header[corsRequestMethodHeader]; !ok { - w.WriteHeader(http.StatusBadRequest) return } - method := r.Header.Get(corsRequestMethodHeader) - if !ch.isMatch(method, ch.allowedMethods) { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } + if r.Method == corsOptionMethod { + if ch.ignoreOptions { + h.ServeHTTP(w, r) + return + } - requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",") - allowedHeaders := []string{} - for _, v := range requestHeaders { - canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) - if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) { - continue + if _, ok := r.Header[corsRequestMethodHeader]; !ok { + w.WriteHeader(http.StatusBadRequest) + return } - if !ch.isMatch(canonicalHeader, ch.allowedHeaders) { - w.WriteHeader(http.StatusForbidden) + method := r.Header.Get(corsRequestMethodHeader) + if !ch.isMatch(method, ch.allowedMethods) { + w.WriteHeader(http.StatusMethodNotAllowed) return } - allowedHeaders = append(allowedHeaders, canonicalHeader) - } + requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",") + allowedHeaders := []string{} + for _, v := range requestHeaders { + canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) + if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) { + continue + } - if len(allowedHeaders) > 0 { - w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ",")) - } + if !ch.isMatch(canonicalHeader, ch.allowedHeaders) { + w.WriteHeader(http.StatusForbidden) + return + } - if ch.maxAge > 0 { - w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge)) - } + allowedHeaders = append(allowedHeaders, canonicalHeader) + } + + if len(allowedHeaders) > 0 { + w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ",")) + } + + if ch.maxAge > 0 { + w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge)) + } - if !ch.isMatch(method, defaultCorsMethods) { - w.Header().Set(corsAllowMethodsHeader, method) + if !ch.isMatch(method, defaultCorsMethods) { + w.Header().Set(corsAllowMethodsHeader, method) + } + } else if len(ch.exposedHeaders) > 0 { + w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ",")) } - } else if len(ch.exposedHeaders) > 0 { - w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ",")) - } - if ch.allowCredentials { - w.Header().Set(corsAllowCredentialsHeader, "true") - } + if ch.allowCredentials { + w.Header().Set(corsAllowCredentialsHeader, "true") + } - if len(ch.allowedOrigins) > 1 { - w.Header().Set(corsVaryHeader, corsOriginHeader) - } + if len(ch.allowedOrigins) > 1 { + w.Header().Set(corsVaryHeader, corsOriginHeader) + } - returnOrigin := origin - if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 { - returnOrigin = "*" - } else { - for _, o := range ch.allowedOrigins { - // A configuration of * is different than explicitly setting an allowed - // origin. Returning arbitrary origin headers in an access control allow - // origin header is unsafe and is not required by any use case. - if o == corsOriginMatchAll { - returnOrigin = "*" - break + returnOrigin := origin + if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 { + returnOrigin = "*" + } else { + for _, o := range ch.allowedOrigins { + // A configuration of * is different than explicitly setting an allowed + // origin. Returning arbitrary origin headers in an access control allow + // origin header is unsafe and is not required by any use case. + if o == corsOriginMatchAll { + returnOrigin = "*" + break + } } } - } - w.Header().Set(corsAllowOriginHeader, returnOrigin) + w.Header().Set(corsAllowOriginHeader, returnOrigin) - if r.Method == corsOptionMethod { - w.WriteHeader(ch.optionStatusCode) - return - } - ch.h.ServeHTTP(w, r) + if r.Method == corsOptionMethod { + w.WriteHeader(ch.optionStatusCode) + return + } + h.ServeHTTP(w, r) + }) } // CORS provides Cross-Origin Resource Sharing middleware. @@ -155,11 +156,7 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { // http.ListenAndServe(":8000", handlers.CORS()(r)) // } func CORS(opts ...CORSOption) func(http.Handler) http.Handler { - return func(h http.Handler) http.Handler { - ch := parseCORSOptions(opts...) - ch.h = h - return ch - } + return parseCORSOptions(opts...).wrap } func parseCORSOptions(opts ...CORSOption) *cors {