@@ -10,7 +10,6 @@ import (
10
10
type CORSOption func (* cors ) error
11
11
12
12
type cors struct {
13
- h http.Handler
14
13
allowedHeaders []string
15
14
allowedMethods []string
16
15
allowedOrigins []string
@@ -47,93 +46,95 @@ const (
47
46
corsOriginMatchAll string = "*"
48
47
)
49
48
50
- func (ch * cors ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
51
- origin := r .Header .Get (corsOriginHeader )
52
- if ! ch .isOriginAllowed (origin ) {
53
- if r .Method != corsOptionMethod || ch .ignoreOptions {
54
- ch .h .ServeHTTP (w , r )
55
- }
56
-
57
- return
58
- }
59
-
60
- if r .Method == corsOptionMethod {
61
- if ch .ignoreOptions {
62
- ch .h .ServeHTTP (w , r )
63
- return
64
- }
49
+ func (ch * cors ) wrap (h http.Handler ) http.Handler {
50
+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
51
+ origin := r .Header .Get (corsOriginHeader )
52
+ if ! ch .isOriginAllowed (origin ) {
53
+ if r .Method != corsOptionMethod || ch .ignoreOptions {
54
+ h .ServeHTTP (w , r )
55
+ }
65
56
66
- if _ , ok := r .Header [corsRequestMethodHeader ]; ! ok {
67
- w .WriteHeader (http .StatusBadRequest )
68
57
return
69
58
}
70
59
71
- method := r . Header . Get ( corsRequestMethodHeader )
72
- if ! ch .isMatch ( method , ch . allowedMethods ) {
73
- w . WriteHeader ( http . StatusMethodNotAllowed )
74
- return
75
- }
60
+ if r . Method == corsOptionMethod {
61
+ if ch .ignoreOptions {
62
+ h . ServeHTTP ( w , r )
63
+ return
64
+ }
76
65
77
- requestHeaders := strings .Split (r .Header .Get (corsRequestHeadersHeader ), "," )
78
- allowedHeaders := []string {}
79
- for _ , v := range requestHeaders {
80
- canonicalHeader := http .CanonicalHeaderKey (strings .TrimSpace (v ))
81
- if canonicalHeader == "" || ch .isMatch (canonicalHeader , defaultCorsHeaders ) {
82
- continue
66
+ if _ , ok := r .Header [corsRequestMethodHeader ]; ! ok {
67
+ w .WriteHeader (http .StatusBadRequest )
68
+ return
83
69
}
84
70
85
- if ! ch .isMatch (canonicalHeader , ch .allowedHeaders ) {
86
- w .WriteHeader (http .StatusForbidden )
71
+ method := r .Header .Get (corsRequestMethodHeader )
72
+ if ! ch .isMatch (method , ch .allowedMethods ) {
73
+ w .WriteHeader (http .StatusMethodNotAllowed )
87
74
return
88
75
}
89
76
90
- allowedHeaders = append (allowedHeaders , canonicalHeader )
91
- }
77
+ requestHeaders := strings .Split (r .Header .Get (corsRequestHeadersHeader ), "," )
78
+ allowedHeaders := []string {}
79
+ for _ , v := range requestHeaders {
80
+ canonicalHeader := http .CanonicalHeaderKey (strings .TrimSpace (v ))
81
+ if canonicalHeader == "" || ch .isMatch (canonicalHeader , defaultCorsHeaders ) {
82
+ continue
83
+ }
92
84
93
- if len (allowedHeaders ) > 0 {
94
- w .Header ().Set (corsAllowHeadersHeader , strings .Join (allowedHeaders , "," ))
95
- }
85
+ if ! ch .isMatch (canonicalHeader , ch .allowedHeaders ) {
86
+ w .WriteHeader (http .StatusForbidden )
87
+ return
88
+ }
96
89
97
- if ch .maxAge > 0 {
98
- w .Header ().Set (corsMaxAgeHeader , strconv .Itoa (ch .maxAge ))
99
- }
90
+ allowedHeaders = append (allowedHeaders , canonicalHeader )
91
+ }
92
+
93
+ if len (allowedHeaders ) > 0 {
94
+ w .Header ().Set (corsAllowHeadersHeader , strings .Join (allowedHeaders , "," ))
95
+ }
96
+
97
+ if ch .maxAge > 0 {
98
+ w .Header ().Set (corsMaxAgeHeader , strconv .Itoa (ch .maxAge ))
99
+ }
100
100
101
- if ! ch .isMatch (method , defaultCorsMethods ) {
102
- w .Header ().Set (corsAllowMethodsHeader , method )
101
+ if ! ch .isMatch (method , defaultCorsMethods ) {
102
+ w .Header ().Set (corsAllowMethodsHeader , method )
103
+ }
104
+ } else if len (ch .exposedHeaders ) > 0 {
105
+ w .Header ().Set (corsExposeHeadersHeader , strings .Join (ch .exposedHeaders , "," ))
103
106
}
104
- } else if len (ch .exposedHeaders ) > 0 {
105
- w .Header ().Set (corsExposeHeadersHeader , strings .Join (ch .exposedHeaders , "," ))
106
- }
107
107
108
- if ch .allowCredentials {
109
- w .Header ().Set (corsAllowCredentialsHeader , "true" )
110
- }
108
+ if ch .allowCredentials {
109
+ w .Header ().Set (corsAllowCredentialsHeader , "true" )
110
+ }
111
111
112
- if len (ch .allowedOrigins ) > 1 {
113
- w .Header ().Set (corsVaryHeader , corsOriginHeader )
114
- }
112
+ if len (ch .allowedOrigins ) > 1 {
113
+ w .Header ().Set (corsVaryHeader , corsOriginHeader )
114
+ }
115
115
116
- returnOrigin := origin
117
- if ch .allowedOriginValidator == nil && len (ch .allowedOrigins ) == 0 {
118
- returnOrigin = "*"
119
- } else {
120
- for _ , o := range ch .allowedOrigins {
121
- // A configuration of * is different than explicitly setting an allowed
122
- // origin. Returning arbitrary origin headers in an access control allow
123
- // origin header is unsafe and is not required by any use case.
124
- if o == corsOriginMatchAll {
125
- returnOrigin = "*"
126
- break
116
+ returnOrigin := origin
117
+ if ch .allowedOriginValidator == nil && len (ch .allowedOrigins ) == 0 {
118
+ returnOrigin = "*"
119
+ } else {
120
+ for _ , o := range ch .allowedOrigins {
121
+ // A configuration of * is different than explicitly setting an allowed
122
+ // origin. Returning arbitrary origin headers in an access control allow
123
+ // origin header is unsafe and is not required by any use case.
124
+ if o == corsOriginMatchAll {
125
+ returnOrigin = "*"
126
+ break
127
+ }
127
128
}
128
129
}
129
- }
130
- w .Header ().Set (corsAllowOriginHeader , returnOrigin )
130
+ w .Header ().Set (corsAllowOriginHeader , returnOrigin )
131
131
132
- if r .Method == corsOptionMethod {
133
- w .WriteHeader (ch .optionStatusCode )
134
- return
135
- }
136
- ch .h .ServeHTTP (w , r )
132
+ if r .Method == corsOptionMethod {
133
+ w .WriteHeader (ch .optionStatusCode )
134
+ return
135
+ }
136
+ h .ServeHTTP (w , r )
137
+ })
137
138
}
138
139
139
140
// CORS provides Cross-Origin Resource Sharing middleware.
@@ -155,10 +156,9 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
155
156
// http.ListenAndServe(":8000", handlers.CORS()(r))
156
157
// }
157
158
func CORS (opts ... CORSOption ) func (http.Handler ) http.Handler {
159
+ ch := parseCORSOptions (opts ... )
158
160
return func (h http.Handler ) http.Handler {
159
- ch := parseCORSOptions (opts ... )
160
- ch .h = h
161
- return ch
161
+ return ch .wrap (h )
162
162
}
163
163
}
164
164
0 commit comments