@@ -14,9 +14,11 @@ import (
14
14
"github.com/didip/tollbooth/v5/limiter"
15
15
jwt "github.com/golang-jwt/jwt/v5"
16
16
"github.com/stretchr/testify/assert"
17
+ "github.com/stretchr/testify/mock"
17
18
"github.com/stretchr/testify/require"
18
19
"github.com/stretchr/testify/suite"
19
20
"github.com/supabase/auth/internal/conf"
21
+ "github.com/supabase/auth/internal/storage"
20
22
)
21
23
22
24
const (
@@ -443,3 +445,66 @@ func (ts *MiddlewareTestSuite) TestLimitHandler() {
443
445
ts .API .limitHandler (lmt ).handler (okHandler ).ServeHTTP (w , req )
444
446
require .Equal (ts .T (), http .StatusTooManyRequests , w .Code )
445
447
}
448
+
449
+ type MockCleanup struct {
450
+ mock.Mock
451
+ }
452
+
453
+ func (m * MockCleanup ) Clean (db * storage.Connection ) (int , error ) {
454
+ m .Called (db )
455
+ return 0 , nil
456
+ }
457
+
458
+ func (ts * MiddlewareTestSuite ) TestDatabaseCleanup () {
459
+ testHandler := func (statusCode int ) http.HandlerFunc {
460
+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
461
+ w .WriteHeader (statusCode )
462
+ b , _ := json .Marshal (map [string ]interface {}{"message" : "ok" })
463
+ w .Write ([]byte (b ))
464
+ })
465
+ }
466
+
467
+ cases := []struct {
468
+ desc string
469
+ statusCode int
470
+ method string
471
+ }{
472
+ {
473
+ desc : "Run cleanup successfully" ,
474
+ statusCode : http .StatusOK ,
475
+ method : http .MethodPost ,
476
+ },
477
+ {
478
+ desc : "Skip cleanup if GET" ,
479
+ statusCode : http .StatusOK ,
480
+ method : http .MethodGet ,
481
+ },
482
+ {
483
+ desc : "Skip cleanup if 3xx" ,
484
+ statusCode : http .StatusSeeOther ,
485
+ method : http .MethodPost ,
486
+ },
487
+ {
488
+ desc : "Skip cleanup if 4xx" ,
489
+ statusCode : http .StatusBadRequest ,
490
+ method : http .MethodPost ,
491
+ },
492
+ {
493
+ desc : "Skip cleanup if 5xx" ,
494
+ statusCode : http .StatusInternalServerError ,
495
+ method : http .MethodPost ,
496
+ },
497
+ }
498
+
499
+ mockCleanup := new (MockCleanup )
500
+ mockCleanup .On ("Clean" , mock .Anything ).Return (0 , nil )
501
+ for _ , c := range cases {
502
+ ts .Run ("DatabaseCleanup" , func () {
503
+ req := httptest .NewRequest (c .method , "http://localhost" , nil )
504
+ w := httptest .NewRecorder ()
505
+ ts .API .databaseCleanup (mockCleanup )(testHandler (c .statusCode )).ServeHTTP (w , req )
506
+ require .Equal (ts .T (), c .statusCode , w .Code )
507
+ })
508
+ }
509
+ mockCleanup .AssertNumberOfCalls (ts .T (), "Clean" , 1 )
510
+ }
0 commit comments