Skip to content

Commit d0cd970

Browse files
committed
fix/security headers now adds cache-control and pragma
1 parent 3e82d1b commit d0cd970

3 files changed

Lines changed: 70 additions & 0 deletions

File tree

app/middleware/security_headers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,19 @@
1010
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
1111
"""Ensure every response carries mandatory security headers."""
1212

13+
_NO_STORE_EXACT_PATHS = frozenset(
14+
{
15+
"/auth/csrf",
16+
"/auth/login",
17+
"/auth/token",
18+
"/auth/logout",
19+
"/auth/reauth",
20+
"/auth/otp/verify/login",
21+
"/auth/otp/verify/action",
22+
"/auth/oauth/google/callback",
23+
"/auth/saml/callback",
24+
}
25+
)
1326
_DEFAULT_CONTENT_SECURITY_POLICY = (
1427
"default-src 'none'; frame-ancestors 'none'; base-uri 'none'; form-action 'self'"
1528
)
@@ -39,6 +52,9 @@ def _headers_for_path(cls, path: str) -> dict[str, str]:
3952
if path.startswith("/docs")
4053
else cls._DEFAULT_CONTENT_SECURITY_POLICY
4154
)
55+
if path in cls._NO_STORE_EXACT_PATHS:
56+
headers["Cache-Control"] = "no-store"
57+
headers["Pragma"] = "no-cache"
4258
return headers
4359

4460
async def dispatch(self, request: Request, call_next) -> Response:

tests/integration/test_auth_router_real.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
from app.core.jwt import get_jwt_service
1010

1111

12+
def _assert_no_store_headers(headers: dict[str, str]) -> None:
13+
"""Assert auth-state responses are marked as non-cacheable."""
14+
assert headers.get("cache-control") == "no-store"
15+
assert headers.get("pragma") == "no-cache"
16+
17+
1218
@pytest.mark.asyncio
1319
async def test_auth_login_refresh_logout_happy_path(
1420
app_factory,
@@ -27,6 +33,7 @@ async def test_auth_login_refresh_logout_happy_path(
2733
json={"email": "alice@example.com", "password": "Password123!"},
2834
)
2935
assert login_response.status_code == 200
36+
_assert_no_store_headers(dict(login_response.headers))
3037
login_payload = login_response.json()
3138
assert login_payload["access_token"]
3239
assert login_payload["refresh_token"]
@@ -46,6 +53,7 @@ async def test_auth_login_refresh_logout_happy_path(
4653
json={"refresh_token": login_payload["refresh_token"]},
4754
)
4855
assert refresh_response.status_code == 200
56+
_assert_no_store_headers(dict(refresh_response.headers))
4957
refresh_payload = refresh_response.json()
5058
assert refresh_payload["refresh_token"] != login_payload["refresh_token"]
5159
refresh_access_claims = jwt_service.verify_token(
@@ -64,12 +72,14 @@ async def test_auth_login_refresh_logout_happy_path(
6472
headers={"authorization": f"Bearer {refresh_payload['access_token']}"},
6573
)
6674
assert logout_response.status_code == 204
75+
_assert_no_store_headers(dict(logout_response.headers))
6776

6877
refresh_after_logout = await client.post(
6978
"/auth/token",
7079
json={"refresh_token": refresh_payload["refresh_token"]},
7180
)
7281
assert refresh_after_logout.status_code == 401
82+
_assert_no_store_headers(dict(refresh_after_logout.headers))
7383
assert refresh_after_logout.json()["code"] == "session_expired"
7484

7585

@@ -88,6 +98,7 @@ async def test_auth_cookie_login_refresh_logout_happy_path(
8898
) as client:
8999
csrf_response = await client.get("/auth/csrf")
90100
assert csrf_response.status_code == 200
101+
_assert_no_store_headers(dict(csrf_response.headers))
91102
csrf_token = csrf_response.json()["csrf_token"]
92103

93104
login_response = await client.post(
@@ -99,6 +110,7 @@ async def test_auth_cookie_login_refresh_logout_happy_path(
99110
},
100111
)
101112
assert login_response.status_code == 200
113+
_assert_no_store_headers(dict(login_response.headers))
102114
assert login_response.json() == {
103115
"authenticated": True,
104116
"session_transport": "cookie",
@@ -122,6 +134,7 @@ async def test_auth_cookie_login_refresh_logout_happy_path(
122134
},
123135
)
124136
assert refresh_response.status_code == 200
137+
_assert_no_store_headers(dict(refresh_response.headers))
125138
assert refresh_response.json() == {
126139
"authenticated": True,
127140
"session_transport": "cookie",
@@ -146,6 +159,7 @@ async def test_auth_cookie_login_refresh_logout_happy_path(
146159
},
147160
)
148161
assert logout_response.status_code == 204
162+
_assert_no_store_headers(dict(logout_response.headers))
149163

150164
replacement_csrf = await client.get("/auth/csrf")
151165
assert replacement_csrf.status_code == 200

tests/integration/test_middleware_stack.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
"x-content-type-options": "nosniff",
2222
"strict-transport-security": "max-age=63072000; includeSubDomains; preload",
2323
}
24+
_NO_STORE_HEADERS = {
25+
"cache-control": "no-store",
26+
"pragma": "no-cache",
27+
}
2428

2529

2630
class _InMemoryRateLimitRedis:
@@ -100,6 +104,10 @@ async def server_error() -> None:
100104
async def login() -> dict[str, bool]:
101105
return {"ok": True}
102106

107+
@app.post("/auth/token")
108+
async def token() -> dict[str, bool]:
109+
return {"ok": True}
110+
103111
return app
104112

105113

@@ -109,6 +117,12 @@ def _assert_security_headers(headers: dict[str, str]) -> None:
109117
assert headers.get(header_name) == expected_value
110118

111119

120+
def _assert_no_store_headers(headers: dict[str, str]) -> None:
121+
"""Assert token-bearing responses are marked as non-cacheable."""
122+
for header_name, expected_value in _NO_STORE_HEADERS.items():
123+
assert headers.get(header_name) == expected_value
124+
125+
112126
@pytest.mark.asyncio
113127
async def test_headers_present_on_success_and_error_responses() -> None:
114128
"""Correlation ID and security headers are present on 2xx/4xx/5xx."""
@@ -128,10 +142,34 @@ async def test_headers_present_on_success_and_error_responses() -> None:
128142
assert client_error.status_code == 401
129143
assert client_error.headers.get("x-correlation-id")
130144
_assert_security_headers(dict(client_error.headers))
145+
for header_name in _NO_STORE_HEADERS:
146+
assert header_name not in client_error.headers
131147

132148
assert server_error.status_code == 500
133149
assert server_error.headers.get("x-correlation-id")
134150
_assert_security_headers(dict(server_error.headers))
151+
for header_name in _NO_STORE_HEADERS:
152+
assert header_name not in server_error.headers
153+
154+
155+
@pytest.mark.asyncio
156+
async def test_token_bearing_auth_routes_are_marked_no_store() -> None:
157+
"""Login and token routes include Cache-Control no-store and Pragma no-cache."""
158+
app = _build_test_app()
159+
160+
async with AsyncClient(
161+
transport=ASGITransport(app=app), base_url="http://testserver"
162+
) as client:
163+
login_response = await client.post("/auth/login")
164+
token_response = await client.post("/auth/token")
165+
166+
assert login_response.status_code == 200
167+
_assert_security_headers(dict(login_response.headers))
168+
_assert_no_store_headers(dict(login_response.headers))
169+
170+
assert token_response.status_code == 200
171+
_assert_security_headers(dict(token_response.headers))
172+
_assert_no_store_headers(dict(token_response.headers))
135173

136174

137175
@pytest.mark.asyncio
@@ -148,6 +186,8 @@ async def test_rate_limit_rejects_auth_login_with_required_error_code() -> None:
148186
assert first.status_code == 200
149187
assert second.status_code == 429
150188
assert second.json() == {"detail": "Rate limit exceeded.", "code": "rate_limited"}
189+
_assert_no_store_headers(dict(first.headers))
190+
_assert_no_store_headers(dict(second.headers))
151191

152192

153193
@pytest.mark.asyncio

0 commit comments

Comments
 (0)