diff --git a/Sources/Grodt/Application/routes.swift b/Sources/Grodt/Application/routes.swift index 40823be..ea2d7a7 100644 --- a/Sources/Grodt/Application/routes.swift +++ b/Sources/Grodt/Application/routes.swift @@ -26,8 +26,16 @@ func routes(_ app: Application) async throws { let transactionChangedHandler = TransactionChangedHandler(portfolioRepository: PostgresPortfolioRepository(database: app.db), historicalPerformanceUpdater: portfolioPerformanceUpdater) + let globalRateLimiter = RateLimiterMiddleware(maxRequests: 100, perSeconds: 60) + let loginRateLimiter = RateLimiterMiddleware(maxRequests: 3, perSeconds: 60) + + app.middleware.use(app.sessions.middleware) + app.middleware.use(globalRateLimiter) + try app.group("") { routeBuilder in - try routeBuilder.register(collection: UserController(dtoMapper: loginResponseDTOMapper)) + try routeBuilder + .grouped(loginRateLimiter) + .register(collection: UserController(dtoMapper: loginResponseDTOMapper)) } let tokenAuthMiddleware = UserToken.authenticator() diff --git a/Sources/Grodt/BusinessLogic/RateLimiterMiddleware/ClientRequestStore.swift b/Sources/Grodt/BusinessLogic/RateLimiterMiddleware/ClientRequestStore.swift new file mode 100644 index 0000000..eda3313 --- /dev/null +++ b/Sources/Grodt/BusinessLogic/RateLimiterMiddleware/ClientRequestStore.swift @@ -0,0 +1,49 @@ +import Vapor + +actor ClientRequestStore { + private struct ClientData { + var requestCount: Int + var windowStart: Date + } + + private var clients: [String: ClientData] = [:] + private let timeWindow: TimeInterval + private let expirationDuration: TimeInterval + private var lastCleanupTime: Date + private let cleanupInterval: TimeInterval + + init(timeWindow: TimeInterval, expirationDuration: TimeInterval, cleanupInterval: TimeInterval) { + self.timeWindow = timeWindow + self.expirationDuration = expirationDuration + self.cleanupInterval = cleanupInterval + self.lastCleanupTime = Date() + } + + func incrementRequestCount(for clientID: String, at time: Date) async -> Int { + // Perform cleanup periodically + if time.timeIntervalSince(lastCleanupTime) > cleanupInterval { + cleanUpExpiredClients(currentTime: time) + lastCleanupTime = time + } + + var clientData = clients[clientID] ?? ClientData(requestCount: 0, windowStart: time) + + if time.timeIntervalSince(clientData.windowStart) > timeWindow { + // Start a new time window + clientData.requestCount = 1 + clientData.windowStart = time + } else { + // Increment request count within the current time window + clientData.requestCount += 1 + } + + clients[clientID] = clientData + return clientData.requestCount + } + + private func cleanUpExpiredClients(currentTime: Date) { + clients = clients.filter { _, data in + currentTime.timeIntervalSince(data.windowStart) < expirationDuration + } + } +} diff --git a/Sources/Grodt/BusinessLogic/RateLimiterMiddleware/RateLimiterMiddleware.swift b/Sources/Grodt/BusinessLogic/RateLimiterMiddleware/RateLimiterMiddleware.swift new file mode 100644 index 0000000..46d022a --- /dev/null +++ b/Sources/Grodt/BusinessLogic/RateLimiterMiddleware/RateLimiterMiddleware.swift @@ -0,0 +1,55 @@ +import Vapor + +struct RateLimiterMiddleware: AsyncMiddleware { + private let maxRequests: Int + private let store: ClientRequestStore + + init(maxRequests: Int, perSeconds timeWindow: TimeInterval) { + self.maxRequests = maxRequests + + let expirationDuration = timeWindow * 3 + let cleanupInterval = timeWindow + + self.store = ClientRequestStore( + timeWindow: timeWindow, + expirationDuration: expirationDuration, + cleanupInterval: cleanupInterval + ) + } + + func respond(to request: Request, chainingTo next: AsyncResponder) async throws -> Response { + let clientID = extractClientIdentifier(from: request) + let currentTime = Date() + + let requestCount = await store.incrementRequestCount(for: clientID, at: currentTime) + + if requestCount > maxRequests { + throw Abort(.tooManyRequests, reason: "Too many requests. Please try again later.") + } + + return try await next.respond(to: request) + } + + private func extractClientIdentifier(from request: Request) -> String { + if let userID = request.auth.get(User.self)?.id?.uuidString { + return "user:\(userID)" + } else if let sessionID = request.session.id?.string { + return "session:\(sessionID)" + } else { + return extractClientIP(from: request) + } + } + + private func extractClientIP(from request: Request) -> String { + // Try to get the client IP from the X-Forwarded-For header + if let forwardedFor = request.headers["X-Forwarded-For"].first { + // The X-Forwarded-For header can contain multiple IPs, the first one is the client's IP + if let clientIP = forwardedFor.split(separator: ",").first?.trimmingCharacters(in: .whitespaces) { + return clientIP + } + } + + // Fallback to the remote address + return request.remoteAddress?.ipAddress ?? "unknown" + } +}