Skip to content

Commit ba10e68

Browse files
authored
test(middleware/csrf): Fix Benchmark Tests (#2932)
* test(middleware/csrf): fix Benchmark_Middleware_CSRF_* * fix(middleware/csrf): update refererMatchesHost()
1 parent 1607d87 commit ba10e68

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

middleware/csrf/csrf.go

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"errors"
55
"net/url"
66
"reflect"
7+
"strings"
78
"time"
89

910
"github.com/gofiber/fiber/v2"
@@ -220,7 +221,7 @@ func isCsrfFromCookie(extractor interface{}) bool {
220221
// returns an error if the referer header is not present or is invalid
221222
// returns nil if the referer header is valid
222223
func refererMatchesHost(c *fiber.Ctx) error {
223-
referer := c.Get(fiber.HeaderReferer)
224+
referer := strings.ToLower(c.Get(fiber.HeaderReferer))
224225
if referer == "" {
225226
return ErrNoReferer
226227
}
@@ -230,9 +231,9 @@ func refererMatchesHost(c *fiber.Ctx) error {
230231
return ErrBadReferer
231232
}
232233

233-
if refererURL.Scheme+"://"+refererURL.Host != c.Protocol()+"://"+c.Hostname() {
234-
return ErrBadReferer
234+
if refererURL.Scheme == c.Protocol() && refererURL.Host == c.Hostname() {
235+
return nil
235236
}
236237

237-
return nil
238+
return ErrBadReferer
238239
}

middleware/csrf/csrf_test.go

+19-6
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,10 @@ func Benchmark_Middleware_CSRF_Check(b *testing.B) {
992992
return c.SendStatus(fiber.StatusTeapot)
993993
})
994994

995-
fctx := &fasthttp.RequestCtx{}
995+
app.Post("/", func(c *fiber.Ctx) error {
996+
return c.SendStatus(fiber.StatusTeapot)
997+
})
998+
996999
h := app.Handler()
9971000
ctx := &fasthttp.RequestCtx{}
9981001

@@ -1002,17 +1005,27 @@ func Benchmark_Middleware_CSRF_Check(b *testing.B) {
10021005
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
10031006
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
10041007

1008+
// Test Correct Referer POST
1009+
ctx.Request.Reset()
1010+
ctx.Response.Reset()
10051011
ctx.Request.Header.SetMethod(fiber.MethodPost)
1012+
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
1013+
ctx.Request.URI().SetScheme("https")
1014+
ctx.Request.URI().SetHost("example.com")
1015+
ctx.Request.Header.SetProtocol("https")
1016+
ctx.Request.Header.SetHost("example.com")
1017+
ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com")
10061018
ctx.Request.Header.Set(HeaderName, token)
1019+
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
10071020

10081021
b.ReportAllocs()
10091022
b.ResetTimer()
10101023

10111024
for n := 0; n < b.N; n++ {
1012-
h(fctx)
1025+
h(ctx)
10131026
}
10141027

1015-
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
1028+
utils.AssertEqual(b, fiber.StatusTeapot, ctx.Response.Header.StatusCode())
10161029
}
10171030

10181031
// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4
@@ -1024,7 +1037,6 @@ func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
10241037
return c.SendStatus(fiber.StatusTeapot)
10251038
})
10261039

1027-
fctx := &fasthttp.RequestCtx{}
10281040
h := app.Handler()
10291041
ctx := &fasthttp.RequestCtx{}
10301042

@@ -1034,8 +1046,9 @@ func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
10341046
b.ResetTimer()
10351047

10361048
for n := 0; n < b.N; n++ {
1037-
h(fctx)
1049+
h(ctx)
10381050
}
10391051

1040-
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
1052+
// Ensure the GET request returns a 418 status code
1053+
utils.AssertEqual(b, fiber.StatusTeapot, ctx.Response.Header.StatusCode())
10411054
}

0 commit comments

Comments
 (0)