Skip to content

Commit ea5ccb7

Browse files
committed
fix/centralizing anchor rendering in html link message
1 parent b2a6692 commit ea5ccb7

5 files changed

Lines changed: 148 additions & 48 deletions

File tree

app/config.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
_UNSET = object()
2525
_PENDING_SINGLETON_CLEANUPS: deque[tuple[str, Callable[[Any], Any], Any]] = deque()
2626
_PENDING_SINGLETON_CLEANUPS_LOCK = Lock()
27+
_SCHEDULED_SINGLETON_CLEANUPS: set[asyncio.Task[None]] = set()
28+
_SCHEDULED_SINGLETON_CLEANUPS_LOCK = Lock()
2729
_REGISTERED_SINGLETON_CLEARERS: list[Callable[[], None]] = []
2830
_REGISTERED_SINGLETON_CLEARERS_LOCK = Lock()
2931
_cleanup_logger = logging.getLogger(__name__)
@@ -410,6 +412,18 @@ async def _run_singleton_cleanup(name: str, cleanup: Callable[[Any], Any], value
410412
_cleanup_logger.exception("reloadable_singleton_cleanup_failed", extra={"name": name})
411413

412414

415+
def _track_singleton_cleanup_task(task: asyncio.Task[None]) -> None:
416+
"""Track one in-loop singleton cleanup task until it finishes."""
417+
with _SCHEDULED_SINGLETON_CLEANUPS_LOCK:
418+
_SCHEDULED_SINGLETON_CLEANUPS.add(task)
419+
420+
def _discard(completed_task: asyncio.Task[None]) -> None:
421+
with _SCHEDULED_SINGLETON_CLEANUPS_LOCK:
422+
_SCHEDULED_SINGLETON_CLEANUPS.discard(completed_task)
423+
424+
task.add_done_callback(_discard)
425+
426+
413427
def _schedule_singleton_cleanup(name: str, cleanup: Callable[[Any], Any], value: Any) -> None:
414428
"""Run cleanup immediately when possible or enqueue it for later draining."""
415429
try:
@@ -419,21 +433,50 @@ def _schedule_singleton_cleanup(name: str, cleanup: Callable[[Any], Any], value:
419433
_PENDING_SINGLETON_CLEANUPS.append((name, cleanup, value))
420434
return
421435

422-
loop.create_task(_run_singleton_cleanup(name, cleanup, value))
436+
task = loop.create_task(_run_singleton_cleanup(name, cleanup, value))
437+
_track_singleton_cleanup_task(task)
438+
439+
440+
async def _await_scheduled_singleton_cleanups() -> None:
441+
"""Await tracked cleanup tasks bound to the current running loop."""
442+
running_loop = asyncio.get_running_loop()
443+
444+
while True:
445+
with _SCHEDULED_SINGLETON_CLEANUPS_LOCK:
446+
tasks = [
447+
task
448+
for task in _SCHEDULED_SINGLETON_CLEANUPS
449+
if not task.done() and task.get_loop() is running_loop
450+
]
451+
452+
if not tasks:
453+
return
454+
455+
await asyncio.gather(*tasks)
423456

424457

425458
async def flush_pending_singleton_cleanups() -> None:
426-
"""Drain any deferred singleton cleanup work."""
459+
"""Drain deferred and in-flight singleton cleanup work."""
427460
while True:
428461
with _PENDING_SINGLETON_CLEANUPS_LOCK:
429-
if not _PENDING_SINGLETON_CLEANUPS:
430-
return
431462
pending = list(_PENDING_SINGLETON_CLEANUPS)
432463
_PENDING_SINGLETON_CLEANUPS.clear()
433464

434465
for name, cleanup, value in pending:
435466
await _run_singleton_cleanup(name, cleanup, value)
436467

468+
await _await_scheduled_singleton_cleanups()
469+
470+
with _PENDING_SINGLETON_CLEANUPS_LOCK:
471+
has_pending = bool(_PENDING_SINGLETON_CLEANUPS)
472+
with _SCHEDULED_SINGLETON_CLEANUPS_LOCK:
473+
has_scheduled = any(
474+
not task.done() and task.get_loop() is asyncio.get_running_loop()
475+
for task in _SCHEDULED_SINGLETON_CLEANUPS
476+
)
477+
if not has_pending and not has_scheduled:
478+
return
479+
437480

438481
def clear_reloadable_singletons() -> None:
439482
"""Clear all registered reloadable singleton caches."""

app/services/lifecycle_service.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import asyncio
6+
import html
67
import secrets
78
import smtplib
89
from dataclasses import dataclass
@@ -102,9 +103,9 @@ async def send_verification_email(self, to_email: str, verification_link: str) -
102103
to_email=to_email,
103104
subject=subject,
104105
body=f"Open this link to verify your account: {verification_link}",
105-
html_body=(
106-
"<p>Open this link to verify your account:</p>"
107-
f'<p><a href="{verification_link}">{verification_link}</a></p>'
106+
html_body=self._html_link_message(
107+
intro_text="Open this link to verify your account:",
108+
link=verification_link,
108109
),
109110
)
110111

@@ -115,9 +116,9 @@ async def send_password_reset_email(self, to_email: str, reset_link: str) -> Non
115116
to_email=to_email,
116117
subject="Reset your password",
117118
body=f"Open this link to reset your password: {reset_link}",
118-
html_body=(
119-
"<p>Open this link to reset your password:</p>"
120-
f'<p><a href="{reset_link}">{reset_link}</a></p>'
119+
html_body=self._html_link_message(
120+
intro_text="Open this link to reset your password:",
121+
link=reset_link,
121122
),
122123
)
123124

@@ -149,6 +150,13 @@ def _send_blocking(
149150
with smtplib.SMTP(self.host, self.port, timeout=10) as smtp:
150151
smtp.send_message(message)
151152

153+
@staticmethod
154+
def _html_link_message(*, intro_text: str, link: str) -> str:
155+
"""Build one HTML paragraph pair with a safely escaped anchor."""
156+
escaped_intro = html.escape(intro_text, quote=False)
157+
escaped_link = html.escape(link, quote=True)
158+
return f'<p>{escaped_intro}</p><p><a href="{escaped_link}">{escaped_link}</a></p>'
159+
152160

153161
class LifecycleService:
154162
"""Lifecycle orchestration for signup verification and resend flows."""

tests/integration/conftest.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -96,39 +96,6 @@ def _clear_dependency_caches() -> None:
9696
get_webhook_service.cache_clear()
9797

9898

99-
async def _close_async_client(client: Any) -> None:
100-
"""Close async client instances regardless of redis-py close API version."""
101-
close = getattr(client, "aclose", None)
102-
if callable(close):
103-
await close()
104-
return
105-
106-
close = getattr(client, "close", None)
107-
if callable(close):
108-
result = close()
109-
if hasattr(result, "__await__"):
110-
await result
111-
112-
113-
async def _dispose_async_singletons() -> None:
114-
"""Dispose loop-bound async resources before changing event loops."""
115-
from app.core.sessions import get_redis_client
116-
from app.db.session import dispose_engine, get_engine
117-
from app.middleware.rate_limit import get_rate_limit_redis_client
118-
119-
redis_client = get_redis_client() if get_redis_client.cache_info().currsize else None
120-
rate_limit_client = (
121-
get_rate_limit_redis_client() if get_rate_limit_redis_client.cache_info().currsize else None
122-
)
123-
124-
if redis_client is not None:
125-
await _close_async_client(redis_client)
126-
if rate_limit_client is not None and rate_limit_client is not redis_client:
127-
await _close_async_client(rate_limit_client)
128-
if get_engine.cache_info().currsize:
129-
await dispose_engine()
130-
131-
13299
def _redis_connection_url(redis: RedisContainer) -> str:
133100
"""Return a redis:// URL across testcontainers versions."""
134101
get_url = getattr(redis, "get_connection_url", None)
@@ -307,6 +274,7 @@ async def reset_state(
307274
) -> Iterator[None]:
308275
"""Clear DB tables and flush Redis; isolate async singletons per event loop."""
309276
del integration_env
277+
from app.config import shutdown_reloadable_singletons
310278
from app.core.sessions import get_redis_client
311279
from app.db.session import get_session_factory
312280
from app.models.api_key import APIKey
@@ -318,8 +286,7 @@ async def reset_state(
318286
from app.models.webhook_delivery import WebhookDelivery
319287
from app.models.webhook_endpoint import WebhookEndpoint
320288

321-
await _dispose_async_singletons()
322-
_clear_dependency_caches()
289+
await shutdown_reloadable_singletons()
323290

324291
session_factory = get_session_factory()
325292
async with session_factory() as session:
@@ -339,8 +306,7 @@ async def reset_state(
339306
try:
340307
yield
341308
finally:
342-
await _dispose_async_singletons()
343-
_clear_dependency_caches()
309+
await shutdown_reloadable_singletons()
344310

345311

346312
@pytest.fixture(scope="function")

tests/unit/test_lifecycle_service_edge_cases.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,59 @@ def send_message(self, message) -> None: # type: ignore[no-untyped-def]
237237
assert messages[1]["Subject"] == "Your password has been reset"
238238

239239

240+
@pytest.mark.asyncio
241+
async def test_mailhog_sender_escapes_links_in_html_email_parts(monkeypatch) -> None:
242+
"""Lifecycle emails HTML-escape links before interpolating them into anchor tags."""
243+
sender = MailhogVerificationEmailSender(
244+
host="mailhog", port=1025, email_from="from@example.com"
245+
)
246+
messages: list[object] = []
247+
248+
async def _fake_to_thread(func, **kwargs): # type: ignore[no-untyped-def]
249+
return func(**kwargs)
250+
251+
class _SMTP:
252+
def __init__(self, host: str, port: int, timeout: int) -> None:
253+
assert host == "mailhog"
254+
assert port == 1025
255+
assert timeout == 10
256+
257+
def __enter__(self) -> _SMTP:
258+
return self
259+
260+
def __exit__(self, exc_type, exc, tb) -> None: # type: ignore[no-untyped-def]
261+
return None
262+
263+
def send_message(self, message) -> None: # type: ignore[no-untyped-def]
264+
messages.append(message)
265+
266+
monkeypatch.setattr("app.services.lifecycle_service.asyncio.to_thread", _fake_to_thread)
267+
monkeypatch.setattr("app.services.lifecycle_service.smtplib.SMTP", _SMTP)
268+
269+
verification_link = 'https://verify.example/token?value="quoted"&x=<tag>'
270+
reset_link = 'https://reset.example/token?value="quoted"&x=<tag>'
271+
272+
await sender.send_verification_email("user@example.com", verification_link)
273+
await sender.send_password_reset_email("user@example.com", reset_link)
274+
275+
verification_html = messages[0].get_body(preferencelist=("html",)).get_content()
276+
reset_html = messages[1].get_body(preferencelist=("html",)).get_content()
277+
278+
assert (
279+
'href="https://verify.example/token?value=&quot;quoted&quot;&amp;x=&lt;tag&gt;"'
280+
in verification_html
281+
)
282+
assert (
283+
">https://verify.example/token?value=&quot;quoted&quot;&amp;x=&lt;tag&gt;<"
284+
in verification_html
285+
)
286+
assert (
287+
'href="https://reset.example/token?value=&quot;quoted&quot;&amp;x=&lt;tag&gt;"'
288+
in reset_html
289+
)
290+
assert ">https://reset.example/token?value=&quot;quoted&quot;&amp;x=&lt;tag&gt;<" in reset_html
291+
292+
240293
@pytest.mark.asyncio
241294
async def test_signup_verify_and_resend_cover_rollbacks_and_invalid_paths() -> None:
242295
"""Lifecycle signup/verify/resend cover rollback and invalid-token branches."""

tests/unit/test_reloadable_singletons.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import app.core.sessions as sessions_module
1010
import app.db.session as db_session_module
1111
import app.middleware.rate_limit as rate_limit_module
12-
from app.config import get_settings
12+
from app.config import flush_pending_singleton_cleanups, get_settings, reloadable_singleton
1313

1414

1515
class _AsyncRedisClientStub:
@@ -230,3 +230,33 @@ def _from_url(url: str, **kwargs: object) -> _AsyncRedisClientStub:
230230
assert second._access_token_ttl_seconds == 450
231231
assert first._refresh_token_ttl_seconds == 600
232232
assert second._refresh_token_ttl_seconds == 1200
233+
234+
235+
@pytest.mark.asyncio
236+
async def test_flush_pending_singleton_cleanups_waits_for_in_loop_cleanup_tasks() -> None:
237+
"""Flushing singleton cleanups should await tasks already scheduled on the active loop."""
238+
cleanup_started = asyncio.Event()
239+
release_cleanup = asyncio.Event()
240+
cleaned_values: list[object] = []
241+
value = object()
242+
243+
async def _cleanup(stale_value: object) -> None:
244+
cleanup_started.set()
245+
await release_cleanup.wait()
246+
cleaned_values.append(stale_value)
247+
248+
@reloadable_singleton(cleanup=_cleanup)
249+
def _get_singleton() -> object:
250+
return value
251+
252+
_get_singleton()
253+
_get_singleton.cache_clear() # type: ignore[attr-defined]
254+
await cleanup_started.wait()
255+
256+
flush_task = asyncio.create_task(flush_pending_singleton_cleanups())
257+
await asyncio.sleep(0)
258+
259+
assert not flush_task.done()
260+
release_cleanup.set()
261+
await flush_task
262+
assert cleaned_values == [value]

0 commit comments

Comments
 (0)