Skip to content

Commit 75edc99

Browse files
authored
Add rate limiting (#5)
1 parent ea3f8bb commit 75edc99

File tree

3 files changed

+113
-1
lines changed

3 files changed

+113
-1
lines changed

Sources/Grodt/Application/routes.swift

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,16 @@ func routes(_ app: Application) async throws {
2626
let transactionChangedHandler = TransactionChangedHandler(portfolioRepository: PostgresPortfolioRepository(database: app.db),
2727
historicalPerformanceUpdater: portfolioPerformanceUpdater)
2828

29+
let globalRateLimiter = RateLimiterMiddleware(maxRequests: 100, perSeconds: 60)
30+
let loginRateLimiter = RateLimiterMiddleware(maxRequests: 3, perSeconds: 60)
31+
32+
app.middleware.use(app.sessions.middleware)
33+
app.middleware.use(globalRateLimiter)
34+
2935
try app.group("") { routeBuilder in
30-
try routeBuilder.register(collection: UserController(dtoMapper: loginResponseDTOMapper))
36+
try routeBuilder
37+
.grouped(loginRateLimiter)
38+
.register(collection: UserController(dtoMapper: loginResponseDTOMapper))
3139
}
3240

3341
let tokenAuthMiddleware = UserToken.authenticator()
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import Vapor
2+
3+
actor ClientRequestStore {
4+
private struct ClientData {
5+
var requestCount: Int
6+
var windowStart: Date
7+
}
8+
9+
private var clients: [String: ClientData] = [:]
10+
private let timeWindow: TimeInterval
11+
private let expirationDuration: TimeInterval
12+
private var lastCleanupTime: Date
13+
private let cleanupInterval: TimeInterval
14+
15+
init(timeWindow: TimeInterval, expirationDuration: TimeInterval, cleanupInterval: TimeInterval) {
16+
self.timeWindow = timeWindow
17+
self.expirationDuration = expirationDuration
18+
self.cleanupInterval = cleanupInterval
19+
self.lastCleanupTime = Date()
20+
}
21+
22+
func incrementRequestCount(for clientID: String, at time: Date) async -> Int {
23+
// Perform cleanup periodically
24+
if time.timeIntervalSince(lastCleanupTime) > cleanupInterval {
25+
cleanUpExpiredClients(currentTime: time)
26+
lastCleanupTime = time
27+
}
28+
29+
var clientData = clients[clientID] ?? ClientData(requestCount: 0, windowStart: time)
30+
31+
if time.timeIntervalSince(clientData.windowStart) > timeWindow {
32+
// Start a new time window
33+
clientData.requestCount = 1
34+
clientData.windowStart = time
35+
} else {
36+
// Increment request count within the current time window
37+
clientData.requestCount += 1
38+
}
39+
40+
clients[clientID] = clientData
41+
return clientData.requestCount
42+
}
43+
44+
private func cleanUpExpiredClients(currentTime: Date) {
45+
clients = clients.filter { _, data in
46+
currentTime.timeIntervalSince(data.windowStart) < expirationDuration
47+
}
48+
}
49+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import Vapor
2+
3+
struct RateLimiterMiddleware: AsyncMiddleware {
4+
private let maxRequests: Int
5+
private let store: ClientRequestStore
6+
7+
init(maxRequests: Int, perSeconds timeWindow: TimeInterval) {
8+
self.maxRequests = maxRequests
9+
10+
let expirationDuration = timeWindow * 3
11+
let cleanupInterval = timeWindow
12+
13+
self.store = ClientRequestStore(
14+
timeWindow: timeWindow,
15+
expirationDuration: expirationDuration,
16+
cleanupInterval: cleanupInterval
17+
)
18+
}
19+
20+
func respond(to request: Request, chainingTo next: AsyncResponder) async throws -> Response {
21+
let clientID = extractClientIdentifier(from: request)
22+
let currentTime = Date()
23+
24+
let requestCount = await store.incrementRequestCount(for: clientID, at: currentTime)
25+
26+
if requestCount > maxRequests {
27+
throw Abort(.tooManyRequests, reason: "Too many requests. Please try again later.")
28+
}
29+
30+
return try await next.respond(to: request)
31+
}
32+
33+
private func extractClientIdentifier(from request: Request) -> String {
34+
if let userID = request.auth.get(User.self)?.id?.uuidString {
35+
return "user:\(userID)"
36+
} else if let sessionID = request.session.id?.string {
37+
return "session:\(sessionID)"
38+
} else {
39+
return extractClientIP(from: request)
40+
}
41+
}
42+
43+
private func extractClientIP(from request: Request) -> String {
44+
// Try to get the client IP from the X-Forwarded-For header
45+
if let forwardedFor = request.headers["X-Forwarded-For"].first {
46+
// The X-Forwarded-For header can contain multiple IPs, the first one is the client's IP
47+
if let clientIP = forwardedFor.split(separator: ",").first?.trimmingCharacters(in: .whitespaces) {
48+
return clientIP
49+
}
50+
}
51+
52+
// Fallback to the remote address
53+
return request.remoteAddress?.ipAddress ?? "unknown"
54+
}
55+
}

0 commit comments

Comments
 (0)