Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Sources/Grodt/Application/routes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,16 @@ func routes(_ app: Application) async throws {
let transactionChangedHandler = TransactionChangedHandler(portfolioRepository: PostgresPortfolioRepository(database: app.db),
historicalPerformanceUpdater: portfolioPerformanceUpdater)

let globalRateLimiter = RateLimiterMiddleware(maxRequests: 100, perSeconds: 60)
let loginRateLimiter = RateLimiterMiddleware(maxRequests: 3, perSeconds: 60)

app.middleware.use(app.sessions.middleware)
app.middleware.use(globalRateLimiter)

try app.group("") { routeBuilder in
try routeBuilder.register(collection: UserController(dtoMapper: loginResponseDTOMapper))
try routeBuilder
.grouped(loginRateLimiter)
.register(collection: UserController(dtoMapper: loginResponseDTOMapper))
}

let tokenAuthMiddleware = UserToken.authenticator()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import Vapor

actor ClientRequestStore {
private struct ClientData {
var requestCount: Int
var windowStart: Date
}

private var clients: [String: ClientData] = [:]
private let timeWindow: TimeInterval
private let expirationDuration: TimeInterval
private var lastCleanupTime: Date
private let cleanupInterval: TimeInterval

init(timeWindow: TimeInterval, expirationDuration: TimeInterval, cleanupInterval: TimeInterval) {
self.timeWindow = timeWindow
self.expirationDuration = expirationDuration
self.cleanupInterval = cleanupInterval
self.lastCleanupTime = Date()
}

func incrementRequestCount(for clientID: String, at time: Date) async -> Int {
// Perform cleanup periodically
if time.timeIntervalSince(lastCleanupTime) > cleanupInterval {
cleanUpExpiredClients(currentTime: time)
lastCleanupTime = time
}

var clientData = clients[clientID] ?? ClientData(requestCount: 0, windowStart: time)

if time.timeIntervalSince(clientData.windowStart) > timeWindow {
// Start a new time window
clientData.requestCount = 1
clientData.windowStart = time
} else {
// Increment request count within the current time window
clientData.requestCount += 1
}

clients[clientID] = clientData
return clientData.requestCount
}

private func cleanUpExpiredClients(currentTime: Date) {
clients = clients.filter { _, data in
currentTime.timeIntervalSince(data.windowStart) < expirationDuration
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import Vapor

struct RateLimiterMiddleware: AsyncMiddleware {
private let maxRequests: Int
private let store: ClientRequestStore

init(maxRequests: Int, perSeconds timeWindow: TimeInterval) {
self.maxRequests = maxRequests

let expirationDuration = timeWindow * 3
let cleanupInterval = timeWindow

self.store = ClientRequestStore(
timeWindow: timeWindow,
expirationDuration: expirationDuration,
cleanupInterval: cleanupInterval
)
}

func respond(to request: Request, chainingTo next: AsyncResponder) async throws -> Response {
let clientID = extractClientIdentifier(from: request)
let currentTime = Date()

let requestCount = await store.incrementRequestCount(for: clientID, at: currentTime)

if requestCount > maxRequests {
throw Abort(.tooManyRequests, reason: "Too many requests. Please try again later.")
}

return try await next.respond(to: request)
}

private func extractClientIdentifier(from request: Request) -> String {
if let userID = request.auth.get(User.self)?.id?.uuidString {
return "user:\(userID)"
} else if let sessionID = request.session.id?.string {
return "session:\(sessionID)"
} else {
return extractClientIP(from: request)
}
}

private func extractClientIP(from request: Request) -> String {
// Try to get the client IP from the X-Forwarded-For header
if let forwardedFor = request.headers["X-Forwarded-For"].first {
// The X-Forwarded-For header can contain multiple IPs, the first one is the client's IP
if let clientIP = forwardedFor.split(separator: ",").first?.trimmingCharacters(in: .whitespaces) {
return clientIP
}
}

// Fallback to the remote address
return request.remoteAddress?.ipAddress ?? "unknown"
}
}