diff --git a/packages/backend/src/helpers/__tests__/get-client-ip.test.ts b/packages/backend/src/helpers/__tests__/get-client-ip.test.ts new file mode 100644 index 0000000000..b70721205f --- /dev/null +++ b/packages/backend/src/helpers/__tests__/get-client-ip.test.ts @@ -0,0 +1,74 @@ +import type { Request } from 'express' +import { describe, expect, it } from 'vitest' + +import { getClientIp } from '../get-client-ip' + +describe('getClientIp', () => { + it('should return Cloudflare IP when cf-connecting-ip header is present', () => { + const mockReq = { + headers: { + 'cf-connecting-ip': '203.0.113.42', + }, + socket: { + remoteAddress: '192.168.1.1', + }, + } as unknown as Request + + expect(getClientIp(mockReq)).toBe('203.0.113.42') + }) + + it('should return socket remote address when no CF header', () => { + const mockReq = { + headers: {}, + socket: { + remoteAddress: '192.168.1.1', + }, + } as unknown as Request + + expect(getClientIp(mockReq)).toBe('192.168.1.1') + }) + + it('should trim and return first IP when multiple IPs in remote address', () => { + const mockReq = { + headers: {}, + socket: { + remoteAddress: '192.168.1.1, 10.0.0.1, 172.16.0.1', + }, + } as unknown as Request + + expect(getClientIp(mockReq)).toBe('192.168.1.1') + }) + + it('should return "unknown" when no IP information is available', () => { + const mockReq = { + headers: {}, + socket: {}, + } as unknown as Request + + expect(getClientIp(mockReq)).toBe('unknown') + }) + + it('should prioritize CF header over socket address', () => { + const mockReq = { + headers: { + 'cf-connecting-ip': '203.0.113.1', + }, + socket: { + remoteAddress: '192.168.1.100', + }, + } as unknown as Request + + expect(getClientIp(mockReq)).toBe('203.0.113.1') + }) + + it('should handle IPv6 addresses', () => { + const mockReq = { + headers: {}, + socket: { + remoteAddress: '2001:db8::1', + }, + } as unknown as Request + + expect(getClientIp(mockReq)).toBe('2001:db8::1') + }) +}) diff --git a/packages/backend/src/helpers/authentication.ts b/packages/backend/src/helpers/authentication.ts index ffcdd83b09..49393f88d2 100644 --- a/packages/backend/src/helpers/authentication.ts +++ b/packages/backend/src/helpers/authentication.ts @@ -8,6 +8,7 @@ import { getLoggedInUser, parseAdminToken, } from '@/helpers/auth' +import { getClientIp } from '@/helpers/get-client-ip' import { UnauthenticatedContext } from '@/types/express/context' export const setCurrentUserContext = async ({ @@ -64,11 +65,7 @@ const isAdminOperation = rule()( const rateLimitRule = createRateLimitRule({ identifyContext: (ctx: UnauthenticatedContext) => { - // get ip address of request in this order: cf-connecting-ip -> remoteAddress - const userIp = - (ctx.req.headers['cf-connecting-ip'] as string) || - ctx.req.socket.remoteAddress.split(',')[0].trim() - return userIp + return getClientIp(ctx.req) }, // recommended flag: https://github.com/teamplanes/graphql-rate-limit#enablebatchrequestcache enableBatchRequestCache: true, diff --git a/packages/backend/src/helpers/get-client-ip.ts b/packages/backend/src/helpers/get-client-ip.ts new file mode 100644 index 0000000000..c91af01826 --- /dev/null +++ b/packages/backend/src/helpers/get-client-ip.ts @@ -0,0 +1,23 @@ +import type { Request } from 'express' + +/** + * Get client IP address from request. + * Checks in this order: + * 1. Cloudflare header (cf-connecting-ip) + * 2. Socket remote address + * + * This is the same logic used in GraphQL authentication. + */ +export function getClientIp(req: Request): string { + const cfIp = req.headers['cf-connecting-ip'] as string + if (cfIp) { + return cfIp + } + + const remoteAddress = req.socket.remoteAddress + if (remoteAddress) { + return remoteAddress.split(',')[0].trim() + } + + return 'unknown' +} diff --git a/packages/backend/src/routes/api/__tests__/middleware/authentication.test.ts b/packages/backend/src/routes/api/__tests__/middleware/authentication.test.ts new file mode 100644 index 0000000000..55b4aa3bc6 --- /dev/null +++ b/packages/backend/src/routes/api/__tests__/middleware/authentication.test.ts @@ -0,0 +1,243 @@ +import type { NextFunction, Request, Response } from 'express' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import type { UnauthenticatedContext } from '@/types/express/context' + +import { + getAuthenticatedContext, + requireAuthentication, + setCurrentUserContext, +} from '../../middleware/authentication' + +const mocks = vi.hoisted(() => ({ + setGraphQLContext: vi.fn(), +})) + +vi.mock('@/helpers/authentication', () => ({ + setCurrentUserContext: mocks.setGraphQLContext, +})) + +describe('API Authentication Middleware', () => { + let mockReq: Partial + let mockRes: Partial + let mockNext: NextFunction + + beforeEach(() => { + mockReq = { + headers: {}, + context: undefined, + } as Partial + + mockRes = { + status: vi.fn().mockReturnThis(), + json: vi.fn(), + } as Partial + + mockNext = vi.fn() + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + describe('setCurrentUserContext', () => { + it('should call GraphQL context function and attach context to request', async () => { + const mockContext: UnauthenticatedContext = { + req: mockReq as Request, + res: mockRes as Response, + currentUser: { + id: 'test-user-id', + email: 'test@plumber.gov.sg', + } as any, + isAdminOperation: false, + } + + mocks.setGraphQLContext.mockResolvedValueOnce(mockContext) + + await setCurrentUserContext( + mockReq as Request, + mockRes as Response, + mockNext, + ) + + expect(mocks.setGraphQLContext).toHaveBeenCalledWith({ + req: mockReq, + res: mockRes, + }) + expect(mockReq.context).toEqual(mockContext) + expect(mockNext).toHaveBeenCalledOnce() + }) + + it('should handle admin operations', async () => { + const mockContext: UnauthenticatedContext = { + req: mockReq as Request, + res: mockRes as Response, + currentUser: { + id: 'admin-user-id', + email: 'admin@plumber.gov.sg', + } as any, + isAdminOperation: true, + } + + mocks.setGraphQLContext.mockResolvedValueOnce(mockContext) + + await setCurrentUserContext( + mockReq as Request, + mockRes as Response, + mockNext, + ) + + expect(mockReq.context?.isAdminOperation).toBe(true) + expect(mockNext).toHaveBeenCalledOnce() + }) + + it('should attach context even when user is null', async () => { + const mockContext: UnauthenticatedContext = { + req: mockReq as Request, + res: mockRes as Response, + currentUser: null, + isAdminOperation: false, + } + + mocks.setGraphQLContext.mockResolvedValueOnce(mockContext) + + await setCurrentUserContext( + mockReq as Request, + mockRes as Response, + mockNext, + ) + + expect(mockReq.context).toEqual(mockContext) + expect(mockReq.context?.currentUser).toBeNull() + expect(mockNext).toHaveBeenCalledOnce() + }) + }) + + describe('requireAuthentication', () => { + it('should call next() when user is authenticated', () => { + mockReq.context = { + req: mockReq as Request, + res: mockRes as Response, + currentUser: { + id: 'test-user-id', + email: 'test@plumber.gov.sg', + } as any, + isAdminOperation: false, + } + + requireAuthentication(mockReq as Request, mockRes as Response, mockNext) + + expect(mockNext).toHaveBeenCalledOnce() + expect(mockRes.status).not.toHaveBeenCalled() + expect(mockRes.json).not.toHaveBeenCalled() + }) + + it('should return 401 when context is undefined', () => { + mockReq.context = undefined + + requireAuthentication(mockReq as Request, mockRes as Response, mockNext) + + expect(mockRes.status).toHaveBeenCalledWith(401) + expect(mockRes.json).toHaveBeenCalledWith({ + error: 'Not Authorised!', + }) + expect(mockNext).not.toHaveBeenCalled() + }) + + it('should return 401 when currentUser is null', () => { + mockReq.context = { + req: mockReq as Request, + res: mockRes as Response, + currentUser: null, + isAdminOperation: false, + } + + requireAuthentication(mockReq as Request, mockRes as Response, mockNext) + + expect(mockRes.status).toHaveBeenCalledWith(401) + expect(mockRes.json).toHaveBeenCalledWith({ + error: 'Not Authorised!', + }) + expect(mockNext).not.toHaveBeenCalled() + }) + + it('should allow admin operations through', () => { + mockReq.context = { + req: mockReq as Request, + res: mockRes as Response, + currentUser: { + id: 'admin-user-id', + email: 'admin@plumber.gov.sg', + } as any, + isAdminOperation: true, + } + + requireAuthentication(mockReq as Request, mockRes as Response, mockNext) + + expect(mockNext).toHaveBeenCalledOnce() + expect(mockRes.status).not.toHaveBeenCalled() + }) + }) + + describe('getAuthenticatedContext', () => { + it('should return context when user is authenticated', () => { + const mockContext = { + req: mockReq as Request, + res: mockRes as Response, + currentUser: { + id: 'test-user-id', + email: 'test@plumber.gov.sg', + } as any, + isAdminOperation: false, + } + + mockReq.context = mockContext + + const result = getAuthenticatedContext(mockReq as Request) + + expect(result).toEqual(mockContext) + expect(result.currentUser).toBeDefined() + expect(result.currentUser.id).toBe('test-user-id') + }) + + it('should throw error when context is undefined', () => { + mockReq.context = undefined + + expect(() => getAuthenticatedContext(mockReq as Request)).toThrow( + 'User must be authenticated', + ) + }) + + it('should throw error when currentUser is null', () => { + mockReq.context = { + req: mockReq as Request, + res: mockRes as Response, + currentUser: null, + isAdminOperation: false, + } + + expect(() => getAuthenticatedContext(mockReq as Request)).toThrow( + 'User must be authenticated', + ) + }) + + it('should return context for admin users', () => { + const mockContext = { + req: mockReq as Request, + res: mockRes as Response, + currentUser: { + id: 'admin-user-id', + email: 'admin@plumber.gov.sg', + } as any, + isAdminOperation: true, + } + + mockReq.context = mockContext + + const result = getAuthenticatedContext(mockReq as Request) + + expect(result).toEqual(mockContext) + expect(result.isAdminOperation).toBe(true) + }) + }) +}) diff --git a/packages/backend/src/routes/api/__tests__/middleware/rate-limit.test.ts b/packages/backend/src/routes/api/__tests__/middleware/rate-limit.test.ts new file mode 100644 index 0000000000..a88f274c25 --- /dev/null +++ b/packages/backend/src/routes/api/__tests__/middleware/rate-limit.test.ts @@ -0,0 +1,201 @@ +import type { NextFunction, Request, Response } from 'express' +import { RateLimiterRes } from 'rate-limiter-flexible' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import { rateLimitApi } from '../../middleware/rate-limit' + +const mocks = vi.hoisted(() => ({ + rateLimiterRedis: { + consume: vi.fn(), + }, + logger: { + warn: vi.fn(), + error: vi.fn(), + }, + createRedisClient: vi.fn(), +})) + +vi.mock('rate-limiter-flexible', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + RateLimiterRedis: vi.fn(function () { + return mocks.rateLimiterRedis + }), + } +}) + +vi.mock('@/helpers/logger', () => ({ + default: mocks.logger, +})) + +vi.mock('@/config/redis', () => ({ + createRedisClient: mocks.createRedisClient, + REDIS_DB_INDEX: { + RATE_LIMIT: 'rate-limit', + }, +})) + +describe('Rate Limiting Middleware', () => { + let mockReq: Partial + let mockRes: Partial + let mockNext: NextFunction + + beforeEach(() => { + mockReq = { + headers: {}, + socket: { + remoteAddress: '127.0.0.1', + } as any, + context: { + currentUser: { + id: 'test-user-id', + email: 'test@plumber.gov.sg', + } as any, + isAdminOperation: false, + } as any, + } + + mockRes = { + status: vi.fn().mockReturnThis(), + json: vi.fn(), + } as Partial + + mockNext = vi.fn() + + vi.clearAllMocks() + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + describe('rateLimitApi', () => { + it('should allow request when under rate limit', async () => { + mocks.rateLimiterRedis.consume.mockResolvedValueOnce({}) + + await rateLimitApi(mockReq as Request, mockRes as Response, mockNext) + + expect(mocks.rateLimiterRedis.consume).toHaveBeenCalledWith( + 'test-user-id', + ) + expect(mockNext).toHaveBeenCalledOnce() + expect(mockRes.status).not.toHaveBeenCalled() + }) + + it('should use user ID for rate limiting when user is authenticated', async () => { + mocks.rateLimiterRedis.consume.mockResolvedValueOnce({}) + + await rateLimitApi(mockReq as Request, mockRes as Response, mockNext) + + expect(mocks.rateLimiterRedis.consume).toHaveBeenCalledWith( + 'test-user-id', + ) + expect(mockNext).toHaveBeenCalledOnce() + }) + + it('should use IP address when user is not authenticated', async () => { + mockReq.context = { + currentUser: null, + isAdminOperation: false, + } as any + mockReq.socket = { + remoteAddress: '192.168.1.1', + } as any + + mocks.rateLimiterRedis.consume.mockResolvedValueOnce({}) + + await rateLimitApi(mockReq as Request, mockRes as Response, mockNext) + + expect(mocks.rateLimiterRedis.consume).toHaveBeenCalledWith('192.168.1.1') + expect(mockNext).toHaveBeenCalledOnce() + }) + + it('should use Cloudflare IP when available', async () => { + mockReq.context = { + currentUser: null, + isAdminOperation: false, + } as any + mockReq.headers = { + 'cf-connecting-ip': '203.0.113.42', + } + + mocks.rateLimiterRedis.consume.mockResolvedValueOnce({}) + + await rateLimitApi(mockReq as Request, mockRes as Response, mockNext) + + expect(mocks.rateLimiterRedis.consume).toHaveBeenCalledWith( + '203.0.113.42', + ) + expect(mockNext).toHaveBeenCalledOnce() + }) + + it('should return 429 when rate limit is exceeded', async () => { + const rateLimitError = Object.assign(new RateLimiterRes(), { + msBeforeNext: 5000, + }) + + mocks.rateLimiterRedis.consume.mockRejectedValueOnce(rateLimitError) + + await rateLimitApi(mockReq as Request, mockRes as Response, mockNext) + + expect(mockRes.status).toHaveBeenCalledWith(429) + expect(mockRes.json).toHaveBeenCalledWith({ + error: 'Too many requests', + message: 'Rate limit exceeded. Please try again later.', + }) + expect(mockNext).not.toHaveBeenCalled() + }) + + it('should log rate limit violations', async () => { + const rateLimitError = Object.assign(new RateLimiterRes(), { + msBeforeNext: 3000, + }) + + mocks.rateLimiterRedis.consume.mockRejectedValueOnce(rateLimitError) + + await rateLimitApi(mockReq as Request, mockRes as Response, mockNext) + + expect(mocks.logger.warn).toHaveBeenCalledWith( + 'API endpoint rate limited', + expect.objectContaining({ + event: 'api-rate-limited', + userId: 'test-user-id', + remainingMs: 3000, + }), + ) + }) + + it('should handle errors gracefully and continue', async () => { + const genericError = new Error('Redis connection failed') + mocks.rateLimiterRedis.consume.mockRejectedValueOnce(genericError) + + await rateLimitApi(mockReq as Request, mockRes as Response, mockNext) + + expect(mocks.logger.error).toHaveBeenCalledWith( + 'Error in rate limiting middleware', + { error: genericError }, + ) + expect(mockNext).toHaveBeenCalledOnce() + expect(mockRes.status).not.toHaveBeenCalled() + }) + + it('should use IP fallback when no user is available', async () => { + mockReq.context = { + currentUser: null, + isAdminOperation: false, + } as any + mockReq.socket = { + remoteAddress: '10.0.0.1', + } as any + mockReq.headers = {} + + mocks.rateLimiterRedis.consume.mockResolvedValueOnce({}) + + await rateLimitApi(mockReq as Request, mockRes as Response, mockNext) + + expect(mocks.rateLimiterRedis.consume).toHaveBeenCalledWith('10.0.0.1') + expect(mockNext).toHaveBeenCalledOnce() + }) + }) +}) diff --git a/packages/backend/src/routes/api/index.ts b/packages/backend/src/routes/api/index.ts new file mode 100644 index 0000000000..c74e04a30b --- /dev/null +++ b/packages/backend/src/routes/api/index.ts @@ -0,0 +1,12 @@ +import { Router } from 'express' + +const router = Router() + +// Mount individual API routes + +// Future routes can be added here: +// router.use('/users', usersRouter) +// router.use('/analytics', analyticsRouter) +// etc. + +export default router diff --git a/packages/backend/src/routes/api/middleware/authentication.ts b/packages/backend/src/routes/api/middleware/authentication.ts new file mode 100644 index 0000000000..dfe8c8b0b7 --- /dev/null +++ b/packages/backend/src/routes/api/middleware/authentication.ts @@ -0,0 +1,51 @@ +import type { NextFunction, Request, Response } from 'express' + +import { setCurrentUserContext as setGraphQLContext } from '@/helpers/authentication' +import type Context from '@/types/express/context' + +/** + * Middleware to set the current user context on the request object. + * This reuses the same authentication logic as GraphQL mutations. + */ +export async function setCurrentUserContext( + req: Request, + res: Response, + next: NextFunction, +) { + // Reuse the GraphQL context creation logic + const context = await setGraphQLContext({ req, res }) + + // Attach context to the request object + req.context = context as Context + + next() +} + +/** + * Middleware to ensure the user is authenticated before allowing the request to proceed. + * Returns 401 if the user is not authenticated. + */ +export function requireAuthentication( + req: Request, + res: Response, + next: NextFunction, +) { + if (!req.context?.currentUser) { + res.status(401).json({ error: 'Not Authorised!' }) + return + } + + next() +} + +/** + * Type guard to ensure the request has an authenticated context. + * Use this in route handlers to get type-safe access to currentUser. + */ +export function getAuthenticatedContext(req: Request): Context { + if (!req.context?.currentUser) { + throw new Error('User must be authenticated') + } + + return req.context as Context +} diff --git a/packages/backend/src/routes/api/middleware/rate-limit.ts b/packages/backend/src/routes/api/middleware/rate-limit.ts new file mode 100644 index 0000000000..be557321b7 --- /dev/null +++ b/packages/backend/src/routes/api/middleware/rate-limit.ts @@ -0,0 +1,60 @@ +import type { NextFunction, Request, Response } from 'express' +import { RateLimiterRedis, RateLimiterRes } from 'rate-limiter-flexible' + +import { createRedisClient, REDIS_DB_INDEX } from '@/config/redis' +import { getClientIp } from '@/helpers/get-client-ip' +import logger from '@/helpers/logger' + +// Create rate limiter for API routes +// Allow 10 requests per minute per user/IP +// NOTE: this works now because we only have 1 route, +// it may need to be updated if we add more routes +const apiRateLimiter = new RateLimiterRedis({ + points: 10, // number of requests + duration: 60, // per 60 seconds (1 minute) + keyPrefix: 'api-rate', + storeClient: createRedisClient(REDIS_DB_INDEX.RATE_LIMIT), +}) + +/** + * Rate limiting middleware for API routes. + * Limits requests per user/IP to prevent abuse of expensive endpoints. + * + * Rate limit: 10 requests per minute per user/IP + * Uses the same IP identification logic as GraphQL authentication. + */ +export async function rateLimitApi( + req: Request, + res: Response, + next: NextFunction, +) { + // Use user ID if available, otherwise fall back to IP address + const userId = req.context?.currentUser?.id + const userIp = getClientIp(req) + + const rateLimitKey = userId || userIp + + try { + await apiRateLimiter.consume(rateLimitKey) + next() + } catch (error) { + if (error instanceof RateLimiterRes) { + logger.warn('API endpoint rate limited', { + event: 'api-rate-limited', + userId, + userIp, + remainingMs: error.msBeforeNext, + }) + + res.status(429).json({ + error: 'Too many requests', + message: 'Rate limit exceeded. Please try again later.', + }) + return + } + + // If it's not a rate limit error, log it and continue + logger.error('Error in rate limiting middleware', { error }) + next() + } +} diff --git a/packages/backend/src/routes/index.ts b/packages/backend/src/routes/index.ts index 524c651232..fc721120c6 100644 --- a/packages/backend/src/routes/index.ts +++ b/packages/backend/src/routes/index.ts @@ -2,10 +2,12 @@ import { Router } from 'express' import graphQLInstance from '@/helpers/graphql-instance' +import apiRouter from './api' import webhooksRouter from './webhooks' const router = Router() +router.use('/api', apiRouter) router.use('/graphql', graphQLInstance) router.use('/webhooks', webhooksRouter) diff --git a/packages/backend/src/types/express.d.ts b/packages/backend/src/types/express.d.ts new file mode 100644 index 0000000000..0f96fd537b --- /dev/null +++ b/packages/backend/src/types/express.d.ts @@ -0,0 +1,11 @@ +import type { UnauthenticatedContext } from './express/context' + +declare global { + namespace Express { + interface Request { + context?: UnauthenticatedContext + } + } +} + +export {} diff --git a/packages/frontend/vite.config.ts b/packages/frontend/vite.config.ts index bd1388c9b4..3dc89773bf 100644 --- a/packages/frontend/vite.config.ts +++ b/packages/frontend/vite.config.ts @@ -20,6 +20,11 @@ export default defineConfig({ changeOrigin: true, secure: false, }, + '/api': { + target: 'http://localhost:3000', + changeOrigin: true, + secure: false, + }, '/apps': { target: 'http://localhost:3000', changeOrigin: true,