Skip to content

Rate-limit emails sent by the same user in a period of time #9462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions h/models/notification.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import enum
from enum import StrEnum

from sqlalchemy import ForeignKey, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
Expand All @@ -8,11 +8,20 @@
from h.models import helpers


class NotificationType(enum.StrEnum):
class NotificationType(StrEnum):
MENTION = "mention"
REPLY = "reply"


class EmailTag(StrEnum):
ACTIVATION = "activation"
FLAG_NOTIFICATION = "flag_notification"
REPLY_NOTIFICATION = "reply_notification"
RESET_PASSWORD = "reset_password" # noqa: S105
MENTION_NOTIFICATION = "mention_notification"
TEST = "test"


class Notification(Base, Timestamps): # pragma: no cover
__tablename__ = "notification"

Expand Down
6 changes: 5 additions & 1 deletion h/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from h.services.mention import MentionService
from h.services.notification import NotificationService
from h.services.subscription import SubscriptionService
from h.services.task_done import TaskDoneService


def includeme(config): # pragma: no cover
def includeme(config): # pragma: no cover # noqa: PLR0915
# Annotation related services
config.register_service_factory(
"h.services.annotation_delete.annotation_delete_service_factory",
Expand Down Expand Up @@ -174,3 +175,6 @@ def includeme(config): # pragma: no cover
"h.services.analytics.analytics_service_factory", name="analytics"
)
config.register_service_factory("h.services.email.factory", iface=EmailService)
config.register_service_factory(
"h.services.task_done.factory", iface=TaskDoneService
)
72 changes: 42 additions & 30 deletions h/services/email.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
# noqa: A005

import smtplib
from dataclasses import asdict, dataclass, field
from enum import StrEnum
from typing import Any
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta

import pyramid_mailer
import pyramid_mailer.message
from pyramid.request import Request
from pyramid_mailer import IMailer
from sqlalchemy.orm import Session

from h.models import TaskDone
from h.models.notification import EmailTag
from h.services.task_done import TaskData, TaskDoneService
from h.tasks.celery import get_task_logger

logger = get_task_logger(__name__)


class EmailTag(StrEnum):
ACTIVATION = "activation"
FLAG_NOTIFICATION = "flag_notification"
REPLY_NOTIFICATION = "reply_notification"
RESET_PASSWORD = "reset_password" # noqa: S105
MENTION_NOTIFICATION = "mention_notification"
TEST = "test"
# Limit for the number of mention emails sent by a single user in a day to prevent abuse
DAILY_SENDER_MENTION_LIMIT = 100


@dataclass(frozen=True)
Expand All @@ -49,27 +43,25 @@ def message(self) -> pyramid_mailer.message.Message:
)


@dataclass(frozen=True)
class TaskData:
tag: EmailTag
sender_id: int
recipient_ids: list[int] = field(default_factory=list)
extra: dict[str, Any] = field(default_factory=dict)

@property
def formatted_extra(self) -> str:
return ", ".join(f"{k}={v!r}" for k, v in self.extra.items() if v is not None)


class EmailService:
"""A service for sending emails."""

def __init__(self, debug: bool, session: Session, mailer: IMailer) -> None: # noqa: FBT001
def __init__(
self,
debug: bool, # noqa: FBT001
session: Session,
mailer: IMailer,
task_done_service: TaskDoneService,
) -> None:
self._debug = debug
self._session = session
self._mailer = mailer
self._task_done_service = task_done_service

def send(self, email_data: EmailData, task_data: TaskData) -> None:
if not self._allow_sending(task_data):
return

if self._debug: # pragma: no cover
logger.info("emailing in debug mode: check the `mail/` directory")
try:
Expand All @@ -91,13 +83,33 @@ def send(self, email_data: EmailData, task_data: TaskData) -> None:
separator,
task_data.formatted_extra,
)
self._create_task_done(task_data)
self._task_done_service.create(task_data)

def _allow_sending(self, task_data: TaskData) -> bool:
if (
task_data.tag == EmailTag.MENTION_NOTIFICATION
and self._sender_limit_reached(task_data)
):
logger.warning(
"Not sending email: tag=%r sender_id=%s recipient_ids=%s. Sender limit reached.",
task_data.tag,
task_data.sender_id,
task_data.recipient_ids,
)
return False
return True

def _create_task_done(self, task_data: TaskData) -> None:
task_done = TaskDone(data=asdict(task_data))
self._session.add(task_done)
def _sender_limit_reached(self, task_data: TaskData) -> bool:
after = datetime.now(UTC) - timedelta(days=1)
count = self._task_done_service.sender_mention_count(task_data.sender_id, after)
return count >= DAILY_SENDER_MENTION_LIMIT


def factory(_context, request: Request) -> EmailService:
mailer = pyramid_mailer.get_mailer(request)
return EmailService(debug=request.debug, session=request.db, mailer=mailer)
return EmailService(
debug=request.debug,
session=request.db,
mailer=mailer,
task_done_service=request.find_service(TaskDoneService),
)
43 changes: 43 additions & 0 deletions h/services/task_done.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from dataclasses import asdict, dataclass, field
from datetime import datetime
from typing import Any

from pyramid.request import Request
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql.functions import count

from h.models import TaskDone
from h.models.notification import EmailTag


@dataclass(frozen=True)
class TaskData:
tag: EmailTag
sender_id: int
recipient_ids: list[int] = field(default_factory=list)
extra: dict[str, Any] = field(default_factory=dict)

@property
def formatted_extra(self) -> str:
return ", ".join(f"{k}={v!r}" for k, v in self.extra.items() if v is not None)


class TaskDoneService:
def __init__(self, session: Session) -> None:
self._session = session

def create(self, task_data: TaskData) -> None:
task_done = TaskDone(data=asdict(task_data))
self._session.add(task_done)

def sender_mention_count(self, sender_id: str, after: datetime) -> int:
stmt = select(count(TaskDone.id)).where(
TaskDone.data["sender_id"].astext == str(sender_id),
TaskDone.created >= after,
)
return self._session.execute(stmt).scalar_one()


def factory(_context, request: Request) -> TaskDoneService:
return TaskDoneService(session=request.db)
5 changes: 5 additions & 0 deletions tests/common/factories/task_done.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@ class Meta:
def expires_at(self, _create, extracted, **_kwargs):
if extracted:
self.expires_at = extracted

@post_generation
def created(self, _create, extracted, **_kwargs):
if extracted:
self.created = extracted
7 changes: 7 additions & 0 deletions tests/common/fixtures/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from h.services.organization import OrganizationService
from h.services.search_index import SearchIndexService
from h.services.subscription import SubscriptionService
from h.services.task_done import TaskDoneService
from h.services.url_migration import URLMigrationService
from h.services.user import UserService
from h.services.user_delete import UserDeleteService
Expand Down Expand Up @@ -85,6 +86,7 @@
"queue_service",
"search_index",
"subscription_service",
"task_done_service",
"url_migration_service",
"user_delete_service",
"user_password_service",
Expand Down Expand Up @@ -340,3 +342,8 @@ def feature_service(mock_service):
@pytest.fixture
def email_service(mock_service):
return mock_service(EmailService)


@pytest.fixture
def task_done_service(mock_service):
return mock_service(TaskDoneService)
Loading
Loading