Skip to content

Commit 9d12d4c

Browse files
add pat rate limiting
1 parent 719283c commit 9d12d4c

File tree

4 files changed

+484
-2
lines changed

4 files changed

+484
-2
lines changed

management/server/http/handler.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@ import (
44
"context"
55
"fmt"
66
"net/http"
7+
"os"
8+
"strconv"
9+
"time"
710

811
"github.com/gorilla/mux"
912
"github.com/rs/cors"
13+
log "github.com/sirupsen/logrus"
1014

1115
"github.com/netbirdio/management-integrations/integrations"
1216

@@ -38,7 +42,12 @@ import (
3842
"github.com/netbirdio/netbird/management/server/telemetry"
3943
)
4044

41-
const apiPrefix = "/api"
45+
const (
46+
apiPrefix = "/api"
47+
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
48+
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
49+
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
50+
)
4251

4352
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
4453
func NewAPIHandler(
@@ -58,11 +67,42 @@ func NewAPIHandler(
5867
settingsManager settings.Manager,
5968
) (http.Handler, error) {
6069

70+
var rateLimitingConfig *middleware.RateLimiterConfig
71+
if os.Getenv(rateLimitingEnabledKey) == "true" {
72+
rpm := 6
73+
if v := os.Getenv(rateLimitingRPMKey); v != "" {
74+
value, err := strconv.Atoi(v)
75+
if err != nil {
76+
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
77+
} else {
78+
rpm = value
79+
}
80+
}
81+
82+
burst := 500
83+
if v := os.Getenv(rateLimitingBurstKey); v != "" {
84+
value, err := strconv.Atoi(v)
85+
if err != nil {
86+
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
87+
} else {
88+
burst = value
89+
}
90+
}
91+
92+
rateLimitingConfig = &middleware.RateLimiterConfig{
93+
RequestsPerMinute: float64(rpm),
94+
Burst: burst,
95+
CleanupInterval: 6 * time.Hour,
96+
LimiterTTL: 24 * time.Hour,
97+
}
98+
}
99+
61100
authMiddleware := middleware.NewAuthMiddleware(
62101
authManager,
63102
accountManager.GetAccountIDFromUserAuth,
64103
accountManager.SyncUserJWTGroups,
65104
accountManager.GetUserFromUserAuth,
105+
rateLimitingConfig,
66106
)
67107

68108
corsMiddleware := cors.AllowAll()

management/server/http/middleware/auth_middleware.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type AuthMiddleware struct {
2929
ensureAccount EnsureAccountFunc
3030
getUserFromUserAuth GetUserFromUserAuthFunc
3131
syncUserJWTGroups SyncUserJWTGroupsFunc
32+
rateLimiter *APIRateLimiter
3233
}
3334

3435
// NewAuthMiddleware instance constructor
@@ -37,12 +38,19 @@ func NewAuthMiddleware(
3738
ensureAccount EnsureAccountFunc,
3839
syncUserJWTGroups SyncUserJWTGroupsFunc,
3940
getUserFromUserAuth GetUserFromUserAuthFunc,
41+
rateLimiterConfig *RateLimiterConfig,
4042
) *AuthMiddleware {
43+
var rateLimiter *APIRateLimiter
44+
if rateLimiterConfig != nil {
45+
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
46+
}
47+
4148
return &AuthMiddleware{
4249
authManager: authManager,
4350
ensureAccount: ensureAccount,
4451
syncUserJWTGroups: syncUserJWTGroups,
4552
getUserFromUserAuth: getUserFromUserAuth,
53+
rateLimiter: rateLimiter,
4654
}
4755
}
4856

@@ -145,6 +153,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
145153
return r, fmt.Errorf("error extracting token: %w", err)
146154
}
147155

156+
if m.rateLimiter != nil {
157+
if !m.rateLimiter.Allow(token) {
158+
return r, fmt.Errorf("too many requests")
159+
}
160+
}
161+
148162
ctx := r.Context()
149163
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
150164
if err != nil {

0 commit comments

Comments
 (0)