Skip to content

Commit 1607d87

Browse files
authored
fix(middleware/cors): Categorize requests correctly (#2921)
* fix(middleware/cors): categorise requests correctly * test(middleware/cors): improve test coverage for request types * test(middleware/cors): Add subdomain matching tests * test(middleware/cors): parallel tests for CORS headers based on request type * test(middleware/cors): Add benchmark for CORS subdomain matching * test(middleware/cors): cover additiona test cases * refactor(middleware/cors): origin validation and normalization
1 parent 1aac6f6 commit 1607d87

File tree

3 files changed

+238
-32
lines changed

3 files changed

+238
-32
lines changed

middleware/cors/cors.go

+16-24
Original file line numberDiff line numberDiff line change
@@ -119,33 +119,23 @@ func New(config ...Config) fiber.Handler {
119119
allowSOrigins := []subdomain{}
120120
allowAllOrigins := false
121121

122-
// processOrigin processes an origin string, normalizes it and checks its validity
123-
// it will panic if the origin is invalid
124-
processOrigin := func(origin string) (string, bool) {
125-
trimmedOrigin := strings.TrimSpace(origin)
126-
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
127-
if !isValid {
128-
log.Warnf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin)
129-
panic("[CORS] Invalid origin provided in configuration")
130-
}
131-
return normalizedOrigin, true
132-
}
133-
134122
// Validate and normalize static AllowOrigins
135123
if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
136124
origins := strings.Split(cfg.AllowOrigins, ",")
137125
for _, origin := range origins {
138126
if i := strings.Index(origin, "://*."); i != -1 {
139-
normalizedOrigin, isValid := processOrigin(origin[:i+3] + origin[i+4:])
127+
trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:])
128+
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
140129
if !isValid {
141-
continue
130+
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
142131
}
143132
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
144133
allowSOrigins = append(allowSOrigins, sd)
145134
} else {
146-
normalizedOrigin, isValid := processOrigin(origin)
135+
trimmedOrigin := strings.TrimSpace(origin)
136+
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
147137
if !isValid {
148-
continue
138+
panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin)
149139
}
150140
allowOrigins = append(allowOrigins, normalizedOrigin)
151141
}
@@ -172,8 +162,9 @@ func New(config ...Config) fiber.Handler {
172162
// Get originHeader header
173163
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))
174164

175-
// If the request does not have an Origin header, the request is outside the scope of CORS
176-
if originHeader == "" {
165+
// If the request does not have Origin and Access-Control-Request-Method
166+
// headers, the request is outside the scope of CORS
167+
if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
177168
return c.Next()
178169
}
179170

@@ -211,8 +202,9 @@ func New(config ...Config) fiber.Handler {
211202
}
212203

213204
// Simple request
205+
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
214206
if c.Method() != fiber.MethodOptions {
215-
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
207+
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
216208
return c.Next()
217209
}
218210

@@ -233,14 +225,14 @@ func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, expos
233225

234226
if cfg.AllowCredentials {
235227
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
236-
if allowOrigin != "*" && allowOrigin != "" {
237-
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
238-
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
239-
} else if allowOrigin == "*" {
228+
if allowOrigin == "*" {
240229
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
241230
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
231+
} else if allowOrigin != "" {
232+
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
233+
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
242234
}
243-
} else if len(allowOrigin) > 0 {
235+
} else if allowOrigin != "" {
244236
// For non-credential requests, it's safe to set to '*' or specific origins
245237
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
246238
}

0 commit comments

Comments
 (0)