Skip to content

Commit c38d573

Browse files
Merge pull request #7 from karthikbhandary2/karthik
Ratelimiting added
2 parents 3db1d5f + 3b58366 commit c38d573

File tree

12 files changed

+209
-11
lines changed

12 files changed

+209
-11
lines changed

bin/build-errors.log

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

bin/main

50.1 KB
Binary file not shown.

cmd/api/api.go

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
package main
22

33
import (
4+
"errors"
45
"fmt"
56
"net/http"
7+
"os"
8+
"os/signal"
9+
"syscall"
610
"time"
711

812
"github.com/go-chi/chi/v5"
@@ -12,10 +16,12 @@ import (
1216
_ "github.com/karthikbhandary2/Social/docs"
1317
"github.com/karthikbhandary2/Social/internal/auth"
1418
"github.com/karthikbhandary2/Social/internal/mailer"
19+
"github.com/karthikbhandary2/Social/internal/ratelimiter"
1520
"github.com/karthikbhandary2/Social/internal/store"
1621
"github.com/karthikbhandary2/Social/internal/store/cache"
1722
"github.com/swaggo/http-swagger/v2"
1823
"go.uber.org/zap"
24+
"golang.org/x/net/context"
1925
)
2026

2127
type application struct {
@@ -25,6 +31,7 @@ type application struct {
2531
logger *zap.SugaredLogger
2632
mailer mailer.Client
2733
authenticator auth.Authenticator
34+
rateLimiter ratelimiter.Limiter
2835
}
2936

3037
type config struct {
@@ -36,6 +43,7 @@ type config struct {
3643
frontendURL string
3744
auth authConfig
3845
redisCfg redisConfig
46+
rateLimiter ratelimiter.Config
3947
}
4048

4149
type redisConfig struct {
@@ -94,10 +102,10 @@ func (app *application) mount() http.Handler {
94102
r.Use(middleware.Logger)
95103
r.Use(middleware.RealIP)
96104
r.Use(middleware.RequestID)
97-
r.Use(middleware.Timeout(60 * time.Second))
105+
r.Use(app.RateLimiterMiddleware)
98106

99107
r.Route("/v1", func(r chi.Router) {
100-
r.With(app.BasicAuthMiddleware()).Get("/health", app.healthCheckHandler)
108+
r.Get("/health", app.healthCheckHandler)
101109

102110
docsURL := fmt.Sprintf("%s/swagger/doc.json", app.config.addr)
103111
r.Get("/swagger/*", httpSwagger.Handler(httpSwagger.URL(docsURL)))
@@ -149,6 +157,30 @@ func (app *application) run(mux http.Handler) error {
149157
ReadTimeout: time.Second * 10, // It should be less than writes
150158
IdleTimeout: time.Minute,
151159
}
160+
161+
shutdown := make(chan error)
162+
go func() {
163+
quit := make(chan os.Signal, 1)
164+
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
165+
s := <-quit
166+
167+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
168+
defer cancel()
169+
170+
app.logger.Infow("signal caught", "signal", s.String())
171+
172+
shutdown <- srv.Shutdown(ctx)
173+
}()
152174
app.logger.Infow("server has started", "addr", app.config.addr, "env", app.config.env)
153-
return srv.ListenAndServe()
175+
176+
err := srv.ListenAndServe()
177+
if !errors.Is(err, http.ErrServerClosed) {
178+
return err
179+
}
180+
err = <-shutdown
181+
if err != nil {
182+
return err
183+
}
184+
app.logger.Infow("server has stopped", "addr", app.config.addr, "env", app.config.env)
185+
return nil
154186
}

cmd/api/api_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package main
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
"time"
8+
9+
"github.com/karthikbhandary2/Social/internal/ratelimiter"
10+
)
11+
12+
func TestRateLimiterMiddleware(t *testing.T) {
13+
cfg := config{
14+
rateLimiter: ratelimiter.Config{
15+
RequestsPerTimeFrame: 20,
16+
TimeFrame: time.Second * 5,
17+
Enabled: true,
18+
},
19+
addr: ":8080",
20+
}
21+
22+
app := newTestApplication(t, cfg)
23+
ts := httptest.NewServer(app.mount())
24+
defer ts.Close()
25+
26+
client := &http.Client{}
27+
mockIP := "192.168.1.1"
28+
marginOfError := 2
29+
30+
for i := 0; i < cfg.rateLimiter.RequestsPerTimeFrame+marginOfError; i++ {
31+
req, err := http.NewRequest("GET", ts.URL+"/v1/health", nil)
32+
if err != nil {
33+
t.Fatalf("could not create request: %v", err)
34+
}
35+
36+
req.Header.Set("X-Forwarded-For", mockIP)
37+
38+
resp, err := client.Do(req)
39+
if err != nil {
40+
t.Fatalf("could not send request: %v", err)
41+
}
42+
defer resp.Body.Close()
43+
44+
if i < cfg.rateLimiter.RequestsPerTimeFrame {
45+
if resp.StatusCode != http.StatusOK {
46+
t.Errorf("expected status OK; got %v", resp.Status)
47+
}
48+
} else {
49+
if resp.StatusCode != http.StatusTooManyRequests {
50+
t.Errorf("expected status Too Many Requests; got %v", resp.Status)
51+
}
52+
}
53+
}
54+
}

cmd/api/auth.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ import (
1010

1111
"github.com/golang-jwt/jwt/v5"
1212
"github.com/google/uuid"
13-
"github.com/karthikbhandary2/Social/internal/store"
1413
"github.com/karthikbhandary2/Social/internal/mailer"
14+
"github.com/karthikbhandary2/Social/internal/store"
1515
)
1616

1717
type RegisterUserPayload struct {

cmd/api/errors.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,11 @@ func (app *application) notFound(w http.ResponseWriter, r *http.Request, err err
3939
app.logger.Warnf("not found error", "method", r.Method,"path", r.URL.Path,"error",err)
4040
writeJSON(w, http.StatusNotFound, err.Error())
4141
}
42+
43+
func (app *application) rateLimitExceeded(w http.ResponseWriter, r *http.Request, retryAfter string) {
44+
app.logger.Warnw("rate limit exceeded", "method", r.Method, "path", r.URL.Path)
45+
46+
w.Header().Set("Retry-After", retryAfter)
47+
48+
writeJSONError(w, http.StatusTooManyRequests, "rate limit exceeded, retry after: "+retryAfter)
49+
}

cmd/api/main.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/karthikbhandary2/Social/internal/db"
99
"github.com/karthikbhandary2/Social/internal/env"
1010
"github.com/karthikbhandary2/Social/internal/mailer"
11+
"github.com/karthikbhandary2/Social/internal/ratelimiter"
1112
"github.com/karthikbhandary2/Social/internal/store"
1213
"github.com/karthikbhandary2/Social/internal/store/cache"
1314
"go.uber.org/zap"
@@ -72,6 +73,11 @@ func main() {
7273
iss: "social",
7374
},
7475
},
76+
rateLimiter: ratelimiter.Config{
77+
RequestsPerTimeFrame: env.GetInt("RATELIMITER_REQUESTS_COUNT", 20),
78+
TimeFrame: time.Second * 5,
79+
Enabled: env.GetBool("RATE_LIMITER_ENABLED", true),
80+
},
7581
}
7682

7783
//logger
@@ -93,6 +99,13 @@ func main() {
9399
rdb = cache.NewRedisClient(cfg.redisCfg.addr, cfg.redisCfg.pw, cfg.redisCfg.db)
94100
logger.Info("Redis connection pool established")
95101
}
102+
103+
//ratelimiter
104+
rateLimiter := ratelimiter.NewFixedWindowLimiter(
105+
cfg.rateLimiter.RequestsPerTimeFrame,
106+
cfg.rateLimiter.TimeFrame,
107+
)
108+
96109
store := store.NewStorage(db)
97110
cacheStorage := cache.NewRedisStorage(rdb)
98111
mailer := mailer.NewSendgrid(cfg.mail.sendGrid.apiKey, cfg.mail.fromEmail)
@@ -105,6 +118,7 @@ func main() {
105118
logger: logger,
106119
mailer: mailer,
107120
authenticator: jwtAuthenticator,
121+
rateLimiter: rateLimiter,
108122
}
109123

110124
mux := app.mount()

cmd/api/middleware.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,16 @@ func (app *application) getUser(ctx context.Context, userID int64) (*store.User,
144144
return user, nil
145145

146146

147+
}
148+
149+
func (app *application) RateLimiterMiddleware(next http.Handler) http.Handler {
150+
return http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
151+
if app.config.rateLimiter.Enabled {
152+
if allow, retryAfter := app.rateLimiter.Allow(r.RemoteAddr); !allow {
153+
app.rateLimitExceeded(w, r, retryAfter.String())
154+
return
155+
}
156+
}
157+
next.ServeHTTP(w, r)
158+
})
147159
}

cmd/api/test_utils.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,40 @@ import (
66
"testing"
77

88
"github.com/karthikbhandary2/Social/internal/auth"
9+
"github.com/karthikbhandary2/Social/internal/ratelimiter"
910
"github.com/karthikbhandary2/Social/internal/store"
1011
"github.com/karthikbhandary2/Social/internal/store/cache"
1112
"go.uber.org/zap"
1213
)
1314

14-
func newTestApplication(t *testing.T) *application {
15+
func newTestApplication(t *testing.T, cfg config) *application {
1516
t.Helper()
1617

18+
logger := zap.NewNop().Sugar()
19+
// Uncomment to enable logs
1720
// logger := zap.Must(zap.NewProduction()).Sugar()
18-
logger := zap.Must(zap.NewProduction()).Sugar()
1921
mockStore := store.NewMockStore()
2022
mockCacheStore := cache.NewMockStore()
23+
2124
testAuth := &auth.TestAuthenticator{}
25+
26+
// Rate limiter
27+
rateLimiter := ratelimiter.NewFixedWindowLimiter(
28+
cfg.rateLimiter.RequestsPerTimeFrame,
29+
cfg.rateLimiter.TimeFrame,
30+
)
31+
2232
return &application{
23-
logger: logger,
24-
store: mockStore,
25-
cacheStorage: mockCacheStore,
33+
logger: logger,
34+
store: mockStore,
35+
cacheStorage: mockCacheStore,
2636
authenticator: testAuth,
37+
config: cfg,
38+
rateLimiter: rateLimiter,
2739
}
2840
}
2941

42+
3043
func executeRequest(req *http.Request, mux http.Handler) *httptest.ResponseRecorder {
3144
rr := httptest.NewRecorder()
3245
mux.ServeHTTP(rr, req)

cmd/api/user_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@ import (
77
)
88

99
func TestGetUser(t *testing.T) {
10-
app := newTestApplication(t)
10+
withRedis := config{
11+
redisCfg: redisConfig{
12+
enabled: true,
13+
},
14+
}
15+
app := newTestApplication(t, withRedis)
1116
mux := app.mount()
1217
testToken, err := app.authenticator.GenerateToken(nil)
1318
if err != nil {

0 commit comments

Comments
 (0)