1- from fastapi import Request , Response
2- from starlette .middleware .base import BaseHTTPMiddleware , RequestResponseEndpoint
3- from loguru import logger
4- from time import time
1+ from collections .abc import Sequence
2+ from datetime import datetime
53from sys import getsizeof
6- from typing import List
4+ from time import time
5+ from typing import Final
6+
7+ from fastapi import FastAPI , Request , Response
8+ from loguru import logger
9+ from starlette .middleware .base import BaseHTTPMiddleware , RequestResponseEndpoint
10+
711
8- #add any endpoints we don't want logged here, e.g "/example/endpoint"
9- excluded_endpoints : List [str ] = []
1012class LoggerMiddleware (BaseHTTPMiddleware ):
1113 """Middleware that logs the request and response"""
1214
15+ def __init__ (self , app : FastAPI , excluded_endpoints : Sequence [str ] = ()) -> None :
16+ super ().__init__ (app )
17+ self .excluded_endpoints = excluded_endpoints
18+
1319 async def dispatch (self , request : Request , call_next : RequestResponseEndpoint ) -> Response :
20+ """Logs the request and response"""
21+ request_time = datetime .now ().strftime ("%Y-%m-%d %H:%M:%S.%f" )[:- 3 ]
1422 start_time = time ()
1523 response = await call_next (request )
1624 process_time = time () - start_time
17-
18- if request .url .path in excluded_endpoints :
25+
26+ if request .url .path in self . excluded_endpoints :
1927 return response
20-
28+
2129 request_body = await request .body ()
2230 request_size = getsizeof (request_body )
2331 # TODO: update this based on userID header name
2432 request_user_id = request .headers .get ("user_id" , "Anonymous" )
2533 request_params = dict (request .query_params )
2634
27- logger .info (f"REQUEST | Method: { request .method } | URL: { request .url .path } | User id: { request_user_id } | Params: { request_params } | Size: { request_size } bytes." )
28-
29- if response .status_code >= 500 :
35+ logger .info (
36+ " | " .join (
37+ [
38+ f"REQUEST | Method: { request .method } " ,
39+ f"URL: { request .url .path } " ,
40+ f"User id: { request_user_id } " ,
41+ f"Params: { request_params } " ,
42+ f"Time: { request_time } " ,
43+ f"Size: { request_size } bytes." ,
44+ ]
45+ )
46+ )
47+
48+ http_status_code_error_server : Final [int ] = 500
49+ http_status_code_error : Final [int ] = 400
50+
51+ if response .status_code >= http_status_code_error_server :
3052 logger_severity = logger .critical
31- elif response .status_code >= 400 :
53+ elif response .status_code >= http_status_code_error :
3254 logger_severity = logger .error
3355 else :
3456 logger_severity = logger .info
35-
36- response_body = b'' .join ([chunk async for chunk in response .body_iterator ])
57+
58+ response_body = b"" .join ([chunk async for chunk in response .body_iterator ])
3759 response_size = getsizeof (response_body )
38-
39- logger_severity (f"RESPONSE | Status: { response .status_code } | Response: { response_body .decode (errors = 'ignore' )} | Size: { response_size } bytes | Time Elasped: { process_time :.3f} ." )
40-
60+
61+ logger_severity (
62+ " | " .join (
63+ [
64+ f"RESPONSE | Status: { response .status_code } " ,
65+ f"Response: { response_body .decode (errors = 'ignore' )} " ,
66+ f"Size: { response_size } bytes" ,
67+ f"Time Elasped: { process_time :.3f} ." ,
68+ ]
69+ )
70+ )
71+
4172 return Response (
42- content = response_body ,
43- status_code = response .status_code ,
44- headers = dict (response .headers ),
73+ content = response_body ,
74+ status_code = response .status_code ,
75+ headers = dict (response .headers ),
4576 media_type = response .media_type ,
4677 )
47-
0 commit comments