diff --git a/h/models/notification.py b/h/models/notification.py index 6f37f03c912..32825dbb9ff 100644 --- a/h/models/notification.py +++ b/h/models/notification.py @@ -1,4 +1,4 @@ -import enum +from enum import StrEnum from sqlalchemy import ForeignKey, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -8,11 +8,20 @@ from h.models import helpers -class NotificationType(enum.StrEnum): +class NotificationType(StrEnum): MENTION = "mention" REPLY = "reply" +class EmailTag(StrEnum): + ACTIVATION = "activation" + FLAG_NOTIFICATION = "flag_notification" + REPLY_NOTIFICATION = "reply_notification" + RESET_PASSWORD = "reset_password" # noqa: S105 + MENTION_NOTIFICATION = "mention_notification" + TEST = "test" + + class Notification(Base, Timestamps): # pragma: no cover __tablename__ = "notification" diff --git a/h/services/__init__.py b/h/services/__init__.py index 88be8b3683c..c33f0dd22a6 100644 --- a/h/services/__init__.py +++ b/h/services/__init__.py @@ -16,9 +16,10 @@ from h.services.mention import MentionService from h.services.notification import NotificationService from h.services.subscription import SubscriptionService +from h.services.task_done import TaskDoneService -def includeme(config): # pragma: no cover +def includeme(config): # pragma: no cover # noqa: PLR0915 # Annotation related services config.register_service_factory( "h.services.annotation_delete.annotation_delete_service_factory", @@ -174,3 +175,6 @@ def includeme(config): # pragma: no cover "h.services.analytics.analytics_service_factory", name="analytics" ) config.register_service_factory("h.services.email.factory", iface=EmailService) + config.register_service_factory( + "h.services.task_done.factory", iface=TaskDoneService + ) diff --git a/h/services/email.py b/h/services/email.py index 7323cc1d26f..31a7fbb2ca9 100644 --- a/h/services/email.py +++ b/h/services/email.py @@ -1,9 +1,8 @@ # noqa: A005 import smtplib -from dataclasses import asdict, dataclass, field -from enum import StrEnum -from typing import Any +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta import pyramid_mailer import pyramid_mailer.message @@ -11,19 +10,14 @@ from pyramid_mailer import IMailer from sqlalchemy.orm import Session -from h.models import TaskDone +from h.models.notification import EmailTag +from h.services.task_done import TaskData, TaskDoneService from h.tasks.celery import get_task_logger logger = get_task_logger(__name__) - -class EmailTag(StrEnum): - ACTIVATION = "activation" - FLAG_NOTIFICATION = "flag_notification" - REPLY_NOTIFICATION = "reply_notification" - RESET_PASSWORD = "reset_password" # noqa: S105 - MENTION_NOTIFICATION = "mention_notification" - TEST = "test" +# Limit for the number of mention emails sent by a single user in a day to prevent abuse +DAILY_SENDER_MENTION_LIMIT = 100 @dataclass(frozen=True) @@ -49,27 +43,25 @@ def message(self) -> pyramid_mailer.message.Message: ) -@dataclass(frozen=True) -class TaskData: - tag: EmailTag - sender_id: int - recipient_ids: list[int] = field(default_factory=list) - extra: dict[str, Any] = field(default_factory=dict) - - @property - def formatted_extra(self) -> str: - return ", ".join(f"{k}={v!r}" for k, v in self.extra.items() if v is not None) - - class EmailService: """A service for sending emails.""" - def __init__(self, debug: bool, session: Session, mailer: IMailer) -> None: # noqa: FBT001 + def __init__( + self, + debug: bool, # noqa: FBT001 + session: Session, + mailer: IMailer, + task_done_service: TaskDoneService, + ) -> None: self._debug = debug self._session = session self._mailer = mailer + self._task_done_service = task_done_service def send(self, email_data: EmailData, task_data: TaskData) -> None: + if not self._allow_sending(task_data): + return + if self._debug: # pragma: no cover logger.info("emailing in debug mode: check the `mail/` directory") try: @@ -91,13 +83,33 @@ def send(self, email_data: EmailData, task_data: TaskData) -> None: separator, task_data.formatted_extra, ) - self._create_task_done(task_data) + self._task_done_service.create(task_data) + + def _allow_sending(self, task_data: TaskData) -> bool: + if ( + task_data.tag == EmailTag.MENTION_NOTIFICATION + and self._sender_limit_reached(task_data) + ): + logger.warning( + "Not sending email: tag=%r sender_id=%s recipient_ids=%s. Sender limit reached.", + task_data.tag, + task_data.sender_id, + task_data.recipient_ids, + ) + return False + return True - def _create_task_done(self, task_data: TaskData) -> None: - task_done = TaskDone(data=asdict(task_data)) - self._session.add(task_done) + def _sender_limit_reached(self, task_data: TaskData) -> bool: + after = datetime.now(UTC) - timedelta(days=1) + count = self._task_done_service.sender_mention_count(task_data.sender_id, after) + return count >= DAILY_SENDER_MENTION_LIMIT def factory(_context, request: Request) -> EmailService: mailer = pyramid_mailer.get_mailer(request) - return EmailService(debug=request.debug, session=request.db, mailer=mailer) + return EmailService( + debug=request.debug, + session=request.db, + mailer=mailer, + task_done_service=request.find_service(TaskDoneService), + ) diff --git a/h/services/task_done.py b/h/services/task_done.py new file mode 100644 index 00000000000..3f8016c9316 --- /dev/null +++ b/h/services/task_done.py @@ -0,0 +1,43 @@ +from dataclasses import asdict, dataclass, field +from datetime import datetime +from typing import Any + +from pyramid.request import Request +from sqlalchemy import select +from sqlalchemy.orm import Session +from sqlalchemy.sql.functions import count + +from h.models import TaskDone +from h.models.notification import EmailTag + + +@dataclass(frozen=True) +class TaskData: + tag: EmailTag + sender_id: int + recipient_ids: list[int] = field(default_factory=list) + extra: dict[str, Any] = field(default_factory=dict) + + @property + def formatted_extra(self) -> str: + return ", ".join(f"{k}={v!r}" for k, v in self.extra.items() if v is not None) + + +class TaskDoneService: + def __init__(self, session: Session) -> None: + self._session = session + + def create(self, task_data: TaskData) -> None: + task_done = TaskDone(data=asdict(task_data)) + self._session.add(task_done) + + def sender_mention_count(self, sender_id: str, after: datetime) -> int: + stmt = select(count(TaskDone.id)).where( + TaskDone.data["sender_id"].astext == str(sender_id), + TaskDone.created >= after, + ) + return self._session.execute(stmt).scalar_one() + + +def factory(_context, request: Request) -> TaskDoneService: + return TaskDoneService(session=request.db) diff --git a/tests/common/factories/task_done.py b/tests/common/factories/task_done.py index 541d59d837f..2a3a951d16a 100644 --- a/tests/common/factories/task_done.py +++ b/tests/common/factories/task_done.py @@ -15,3 +15,8 @@ class Meta: def expires_at(self, _create, extracted, **_kwargs): if extracted: self.expires_at = extracted + + @post_generation + def created(self, _create, extracted, **_kwargs): + if extracted: + self.created = extracted diff --git a/tests/common/fixtures/services.py b/tests/common/fixtures/services.py index dc0d0ebd5cf..585d2b4f60b 100644 --- a/tests/common/fixtures/services.py +++ b/tests/common/fixtures/services.py @@ -39,6 +39,7 @@ from h.services.organization import OrganizationService from h.services.search_index import SearchIndexService from h.services.subscription import SubscriptionService +from h.services.task_done import TaskDoneService from h.services.url_migration import URLMigrationService from h.services.user import UserService from h.services.user_delete import UserDeleteService @@ -85,6 +86,7 @@ "queue_service", "search_index", "subscription_service", + "task_done_service", "url_migration_service", "user_delete_service", "user_password_service", @@ -340,3 +342,8 @@ def feature_service(mock_service): @pytest.fixture def email_service(mock_service): return mock_service(EmailService) + + +@pytest.fixture +def task_done_service(mock_service): + return mock_service(TaskDoneService) diff --git a/tests/unit/h/services/email_test.py b/tests/unit/h/services/email_test.py index debc35c7bf7..bfba37d8d90 100644 --- a/tests/unit/h/services/email_test.py +++ b/tests/unit/h/services/email_test.py @@ -1,12 +1,17 @@ import smtplib -from dataclasses import asdict +from unittest import mock from unittest.mock import sentinel import pytest -from sqlalchemy import select -from h.models import TaskDone -from h.services.email import EmailData, EmailService, EmailTag, TaskData, factory +from h.services.email import ( + DAILY_SENDER_MENTION_LIMIT, + EmailData, + EmailService, + EmailTag, + TaskData, + factory, +) class TestEmailService: @@ -24,16 +29,17 @@ def test_send_creates_email_message( ) def test_send_creates_email_message_with_html_body( - self, task_data, email_service, pyramid_mailer + self, email_service, task_data, pyramid_mailer ): - email = EmailData( + email_data = EmailData( recipients=["foo@example.com"], subject="My email subject", body="Some text body", tag=EmailTag.TEST, html="

An HTML body

", ) - email_service.send(email, task_data) + + email_service.send(email_data, task_data) pyramid_mailer.message.Message.assert_called_once_with( recipients=["foo@example.com"], @@ -63,6 +69,48 @@ def test_send_creates_email_message_with_subaccount( extra_headers={"X-MC-Tags": EmailTag.TEST, "X-MC-Subaccount": "subaccount"}, ) + def test_send_creates_mention_email_when_sender_limit_not_reached( + self, + mention_email_data, + mention_task_data, + email_service, + pyramid_mailer, + task_done_service, + ): + task_done_service.sender_mention_count.return_value = ( + DAILY_SENDER_MENTION_LIMIT - 1 + ) + + email_service.send(mention_email_data, mention_task_data) + + task_done_service.sender_mention_count.assert_called_once_with( + mention_task_data.sender_id, mock.ANY + ) + pyramid_mailer.message.Message.assert_called_once_with( + recipients=["foo@example.com"], + subject="My email subject", + body="Some text body", + html=None, + extra_headers={"X-MC-Tags": EmailTag.MENTION_NOTIFICATION}, + ) + + def test_send_does_not_create_mention_email_when_sender_limit_reached( + self, + mention_email_data, + mention_task_data, + email_service, + pyramid_mailer, + task_done_service, + ): + task_done_service.sender_mention_count.return_value = DAILY_SENDER_MENTION_LIMIT + + email_service.send(mention_email_data, mention_task_data) + + task_done_service.sender_mention_count.assert_called_once_with( + mention_task_data.sender_id, mock.ANY + ) + pyramid_mailer.message.Message.assert_not_called() + def test_send_dispatches_email_using_request_mailer( self, email_data, task_data, email_service, pyramid_mailer ): @@ -90,34 +138,51 @@ def test_send_logging(self, email_data, task_data, email_service, info_caplog): ] def test_send_logging_with_extra(self, email_data, email_service, info_caplog): - user_id = 123 + sender_id = 123 + recipient_id = 124 annotation_id = "annotation_id" task_data = TaskData( tag=email_data.tag, - sender_id=user_id, - recipient_ids=[user_id], + sender_id=sender_id, + recipient_ids=[recipient_id], extra={"annotation_id": annotation_id}, ) + email_service.send(email_data, task_data) assert info_caplog.messages == [ - f"Sent email: tag={task_data.tag!r}, sender_id={user_id}, recipient_ids={[user_id]}, annotation_id={annotation_id!r}" + f"Sent email: tag={task_data.tag!r}, sender_id={sender_id}, recipient_ids={[recipient_id]}, annotation_id={annotation_id!r}" + ] + + def test_sender_limit_reached_logging( + self, + mention_email_data, + mention_task_data, + email_service, + task_done_service, + info_caplog, + ): + task_done_service.sender_mention_count.return_value = DAILY_SENDER_MENTION_LIMIT + + email_service.send(mention_email_data, mention_task_data) + + assert info_caplog.messages == [ + f"Not sending email: tag={mention_task_data.tag!r} sender_id={mention_task_data.sender_id} recipient_ids={mention_task_data.recipient_ids}. Sender limit reached." ] def test_send_creates_task_done( - self, email_data, task_data, email_service, db_session + self, email_data, task_data, email_service, task_done_service ): task_data = TaskData( tag=email_data.tag, sender_id=123, - recipient_ids=[123], + recipient_ids=[124], extra={"annotation_id": "annotation_id"}, ) + email_service.send(email_data, task_data) - task_dones = db_session.execute(select(TaskDone)).scalars().all() - assert len(task_dones) == 1 - assert task_dones[0].data == asdict(task_data) + task_done_service.create.assert_called_once_with(task_data) @pytest.fixture def email_data(self): @@ -133,13 +198,35 @@ def task_data(self): return TaskData( tag=EmailTag.TEST, sender_id=123, - recipient_ids=[123], + recipient_ids=[124], + ) + + @pytest.fixture + def mention_email_data(self): + return EmailData( + recipients=["foo@example.com"], + subject="My email subject", + body="Some text body", + tag=EmailTag.MENTION_NOTIFICATION, + ) + + @pytest.fixture + def mention_task_data(self): + return TaskData( + tag=EmailTag.MENTION_NOTIFICATION, + sender_id=123, + recipient_ids=[124], ) @pytest.fixture - def email_service(self, pyramid_request, pyramid_mailer): + def email_service(self, pyramid_request, pyramid_mailer, task_done_service): request_mailer = pyramid_mailer.get_mailer.return_value - return EmailService(pyramid_request.debug, pyramid_request.db, request_mailer) + return EmailService( + debug=pyramid_request.debug, + session=pyramid_request.db, + mailer=request_mailer, + task_done_service=task_done_service, + ) @pytest.fixture def info_caplog(self, caplog): @@ -148,13 +235,14 @@ def info_caplog(self, caplog): class TestFactory: - def test_it(self, pyramid_request, pyramid_mailer, EmailService): + def test_it(self, pyramid_request, pyramid_mailer, EmailService, task_done_service): service = factory(sentinel.context, pyramid_request) EmailService.assert_called_once_with( debug=pyramid_request.debug, session=pyramid_request.db, mailer=pyramid_mailer.get_mailer.return_value, + task_done_service=task_done_service, ) assert service == EmailService.return_value diff --git a/tests/unit/h/services/task_done_test.py b/tests/unit/h/services/task_done_test.py new file mode 100644 index 00000000000..c0424e80a49 --- /dev/null +++ b/tests/unit/h/services/task_done_test.py @@ -0,0 +1,54 @@ +from dataclasses import asdict +from datetime import datetime, timedelta +from unittest.mock import sentinel + +import pytest +from sqlalchemy import select + +from h.models import TaskDone +from h.models.notification import EmailTag +from h.services.task_done import TaskData, TaskDoneService, factory + + +class TestTaskDoneService: + def test_create(self, task_done_service, db_session): + task_data = TaskData( + tag=EmailTag.TEST, + sender_id=123, + recipient_ids=[124], + ) + + task_done_service.create(task_data) + + task_dones = db_session.execute(select(TaskDone)).scalars().all() + assert len(task_dones) == 1 + assert task_dones[0].data == asdict(task_data) + + def test_sender_mention_count(self, task_done_service, factories): + task_data = TaskData( + tag=EmailTag.MENTION_NOTIFICATION, + sender_id=123, + recipient_ids=[124], + ) + created = datetime.fromisoformat("2023-05-04 12:12:01+00:00") + _task_done = factories.TaskDone(created=created, data=asdict(task_data)) + + after = created - timedelta(seconds=1) + assert task_done_service.sender_mention_count(task_data.sender_id, after) == 1 + + @pytest.fixture + def task_done_service(self, pyramid_request): + return TaskDoneService(session=pyramid_request.db) + + +class TestFactory: + def test_it(self, pyramid_request, TaskDoneService): + service = factory(sentinel.context, pyramid_request) + + TaskDoneService.assert_called_once_with(session=pyramid_request.db) + + assert service == TaskDoneService.return_value + + @pytest.fixture(autouse=True) + def TaskDoneService(self, patch): + return patch("h.services.task_done.TaskDoneService")