Skip to content

Commit a6a9308

Browse files
committed
Add task done service
1 parent 65d2d47 commit a6a9308

File tree

8 files changed

+154
-45
lines changed

8 files changed

+154
-45
lines changed

Diff for: h/models/notification.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import enum
1+
from enum import StrEnum
22

33
from sqlalchemy import ForeignKey, UniqueConstraint
44
from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -8,11 +8,20 @@
88
from h.models import helpers
99

1010

11-
class NotificationType(enum.StrEnum):
11+
class NotificationType(StrEnum):
1212
MENTION = "mention"
1313
REPLY = "reply"
1414

1515

16+
class EmailTag(StrEnum):
17+
ACTIVATION = "activation"
18+
FLAG_NOTIFICATION = "flag_notification"
19+
REPLY_NOTIFICATION = "reply_notification"
20+
RESET_PASSWORD = "reset_password" # noqa: S105
21+
MENTION_NOTIFICATION = "mention_notification"
22+
TEST = "test"
23+
24+
1625
class Notification(Base, Timestamps): # pragma: no cover
1726
__tablename__ = "notification"
1827

Diff for: h/services/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
from h.services.mention import MentionService
1717
from h.services.notification import NotificationService
1818
from h.services.subscription import SubscriptionService
19+
from h.services.task_done import TaskDoneService
1920

2021

21-
def includeme(config): # pragma: no cover
22+
def includeme(config): # pragma: no cover # noqa: PLR0915
2223
# Annotation related services
2324
config.register_service_factory(
2425
"h.services.annotation_delete.annotation_delete_service_factory",
@@ -174,3 +175,6 @@ def includeme(config): # pragma: no cover
174175
"h.services.analytics.analytics_service_factory", name="analytics"
175176
)
176177
config.register_service_factory("h.services.email.factory", iface=EmailService)
178+
config.register_service_factory(
179+
"h.services.task_done.factory", iface=TaskDoneService
180+
)

Diff for: h/services/email.py

+18-32
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,21 @@
11
# noqa: A005
22

33
import smtplib
4-
from dataclasses import asdict, dataclass, field
5-
from enum import StrEnum
6-
from typing import Any
4+
from dataclasses import dataclass
75

86
import pyramid_mailer
97
import pyramid_mailer.message
108
from pyramid.request import Request
119
from pyramid_mailer import IMailer
1210
from sqlalchemy.orm import Session
1311

14-
from h.models import TaskDone
12+
from h.models.notification import EmailTag
13+
from h.services.task_done import TaskData, TaskDoneService
1514
from h.tasks.celery import get_task_logger
1615

1716
logger = get_task_logger(__name__)
1817

1918

20-
class EmailTag(StrEnum):
21-
ACTIVATION = "activation"
22-
FLAG_NOTIFICATION = "flag_notification"
23-
REPLY_NOTIFICATION = "reply_notification"
24-
RESET_PASSWORD = "reset_password" # noqa: S105
25-
MENTION_NOTIFICATION = "mention_notification"
26-
TEST = "test"
27-
28-
2919
@dataclass(frozen=True)
3020
class EmailData:
3121
recipients: list[str]
@@ -49,25 +39,20 @@ def message(self) -> pyramid_mailer.message.Message:
4939
)
5040

5141

52-
@dataclass(frozen=True)
53-
class TaskData:
54-
tag: EmailTag
55-
sender_id: int
56-
recipient_ids: list[int] = field(default_factory=list)
57-
extra: dict[str, Any] = field(default_factory=dict)
58-
59-
@property
60-
def formatted_extra(self) -> str:
61-
return ", ".join(f"{k}={v!r}" for k, v in self.extra.items() if v is not None)
62-
63-
6442
class EmailService:
6543
"""A service for sending emails."""
6644

67-
def __init__(self, debug: bool, session: Session, mailer: IMailer) -> None: # noqa: FBT001
45+
def __init__(
46+
self,
47+
debug: bool, # noqa: FBT001
48+
session: Session,
49+
mailer: IMailer,
50+
task_done_service: TaskDoneService,
51+
) -> None:
6852
self._debug = debug
6953
self._session = session
7054
self._mailer = mailer
55+
self._task_done_service = task_done_service
7156

7257
def send(self, email_data: EmailData, task_data: TaskData) -> None:
7358
if self._debug: # pragma: no cover
@@ -91,13 +76,14 @@ def send(self, email_data: EmailData, task_data: TaskData) -> None:
9176
separator,
9277
task_data.formatted_extra,
9378
)
94-
self._create_task_done(task_data)
95-
96-
def _create_task_done(self, task_data: TaskData) -> None:
97-
task_done = TaskDone(data=asdict(task_data))
98-
self._session.add(task_done)
79+
self._task_done_service.create(task_data)
9980

10081

10182
def factory(_context, request: Request) -> EmailService:
10283
mailer = pyramid_mailer.get_mailer(request)
103-
return EmailService(debug=request.debug, session=request.db, mailer=mailer)
84+
return EmailService(
85+
debug=request.debug,
86+
session=request.db,
87+
mailer=mailer,
88+
task_done_service=request.find_service(TaskDoneService),
89+
)

Diff for: h/services/task_done.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from dataclasses import asdict, dataclass, field
2+
from datetime import datetime
3+
from typing import Any
4+
5+
from pyramid.request import Request
6+
from sqlalchemy import select
7+
from sqlalchemy.orm import Session
8+
from sqlalchemy.sql.functions import count
9+
10+
from h.models import TaskDone
11+
from h.models.notification import EmailTag
12+
13+
14+
@dataclass(frozen=True)
15+
class TaskData:
16+
tag: EmailTag
17+
sender_id: int
18+
recipient_ids: list[int] = field(default_factory=list)
19+
extra: dict[str, Any] = field(default_factory=dict)
20+
21+
@property
22+
def formatted_extra(self) -> str:
23+
return ", ".join(f"{k}={v!r}" for k, v in self.extra.items() if v is not None)
24+
25+
26+
class TaskDoneService:
27+
def __init__(self, session: Session) -> None:
28+
self._session = session
29+
30+
def create(self, task_data: TaskData) -> None:
31+
task_done = TaskDone(data=asdict(task_data))
32+
self._session.add(task_done)
33+
34+
def sender_mention_count(self, sender_id: str, after: datetime) -> int:
35+
stmt = select(count(TaskDone.id)).where(
36+
TaskDone.data["sender_id"].astext == str(sender_id),
37+
TaskDone.created >= after,
38+
)
39+
return self._session.execute(stmt).scalar_one()
40+
41+
42+
def factory(_context, request: Request) -> TaskDoneService:
43+
return TaskDoneService(session=request.db)

Diff for: tests/common/factories/task_done.py

+5
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,8 @@ class Meta:
1515
def expires_at(self, _create, extracted, **_kwargs):
1616
if extracted:
1717
self.expires_at = extracted
18+
19+
@post_generation
20+
def created(self, _create, extracted, **_kwargs):
21+
if extracted:
22+
self.created = extracted

Diff for: tests/common/fixtures/services.py

+7
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from h.services.organization import OrganizationService
4040
from h.services.search_index import SearchIndexService
4141
from h.services.subscription import SubscriptionService
42+
from h.services.task_done import TaskDoneService
4243
from h.services.url_migration import URLMigrationService
4344
from h.services.user import UserService
4445
from h.services.user_delete import UserDeleteService
@@ -85,6 +86,7 @@
8586
"queue_service",
8687
"search_index",
8788
"subscription_service",
89+
"task_done_service",
8890
"url_migration_service",
8991
"user_delete_service",
9092
"user_password_service",
@@ -340,3 +342,8 @@ def feature_service(mock_service):
340342
@pytest.fixture
341343
def email_service(mock_service):
342344
return mock_service(EmailService)
345+
346+
347+
@pytest.fixture
348+
def task_done_service(mock_service):
349+
return mock_service(TaskDoneService)

Diff for: tests/unit/h/services/email_test.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
import smtplib
2-
from dataclasses import asdict
32
from unittest.mock import sentinel
43

54
import pytest
6-
from sqlalchemy import select
75

8-
from h.models import TaskDone
96
from h.services.email import EmailData, EmailService, EmailTag, TaskData, factory
107

118

@@ -105,7 +102,7 @@ def test_send_logging_with_extra(self, email_data, email_service, info_caplog):
105102
]
106103

107104
def test_send_creates_task_done(
108-
self, email_data, task_data, email_service, db_session
105+
self, email_data, task_data, email_service, task_done_service
109106
):
110107
task_data = TaskData(
111108
tag=email_data.tag,
@@ -115,9 +112,7 @@ def test_send_creates_task_done(
115112
)
116113
email_service.send(email_data, task_data)
117114

118-
task_dones = db_session.execute(select(TaskDone)).scalars().all()
119-
assert len(task_dones) == 1
120-
assert task_dones[0].data == asdict(task_data)
115+
task_done_service.create.assert_called_once_with(task_data)
121116

122117
@pytest.fixture
123118
def email_data(self):
@@ -137,9 +132,14 @@ def task_data(self):
137132
)
138133

139134
@pytest.fixture
140-
def email_service(self, pyramid_request, pyramid_mailer):
135+
def email_service(self, pyramid_request, pyramid_mailer, task_done_service):
141136
request_mailer = pyramid_mailer.get_mailer.return_value
142-
return EmailService(pyramid_request.debug, pyramid_request.db, request_mailer)
137+
return EmailService(
138+
debug=pyramid_request.debug,
139+
session=pyramid_request.db,
140+
mailer=request_mailer,
141+
task_done_service=task_done_service,
142+
)
143143

144144
@pytest.fixture
145145
def info_caplog(self, caplog):
@@ -148,13 +148,14 @@ def info_caplog(self, caplog):
148148

149149

150150
class TestFactory:
151-
def test_it(self, pyramid_request, pyramid_mailer, EmailService):
151+
def test_it(self, pyramid_request, pyramid_mailer, EmailService, task_done_service):
152152
service = factory(sentinel.context, pyramid_request)
153153

154154
EmailService.assert_called_once_with(
155155
debug=pyramid_request.debug,
156156
session=pyramid_request.db,
157157
mailer=pyramid_mailer.get_mailer.return_value,
158+
task_done_service=task_done_service,
158159
)
159160

160161
assert service == EmailService.return_value

Diff for: tests/unit/h/services/task_done_test.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from dataclasses import asdict
2+
from datetime import datetime, timedelta
3+
from unittest.mock import sentinel
4+
5+
import pytest
6+
from sqlalchemy import select
7+
8+
from h.models import TaskDone
9+
from h.models.notification import EmailTag
10+
from h.services.task_done import TaskData, TaskDoneService, factory
11+
12+
13+
class TestTaskDoneService:
14+
def test_create(self, task_done_service, db_session):
15+
task_data = TaskData(
16+
tag=EmailTag.TEST,
17+
sender_id=123,
18+
recipient_ids=[124],
19+
)
20+
21+
task_done_service.create(task_data)
22+
23+
task_dones = db_session.execute(select(TaskDone)).scalars().all()
24+
assert len(task_dones) == 1
25+
assert task_dones[0].data == asdict(task_data)
26+
27+
def test_sender_mention_count(self, task_done_service, factories):
28+
task_data = TaskData(
29+
tag=EmailTag.MENTION_NOTIFICATION,
30+
sender_id=123,
31+
recipient_ids=[124],
32+
)
33+
created = datetime.fromisoformat("2023-05-04 12:12:01+00:00")
34+
_task_done = factories.TaskDone(created=created, data=asdict(task_data))
35+
36+
after = created - timedelta(seconds=1)
37+
assert task_done_service.sender_mention_count(task_data.sender_id, after) == 1
38+
39+
@pytest.fixture
40+
def task_done_service(self, pyramid_request):
41+
return TaskDoneService(session=pyramid_request.db)
42+
43+
44+
class TestFactory:
45+
def test_it(self, pyramid_request, TaskDoneService):
46+
service = factory(sentinel.context, pyramid_request)
47+
48+
TaskDoneService.assert_called_once_with(session=pyramid_request.db)
49+
50+
assert service == TaskDoneService.return_value
51+
52+
@pytest.fixture(autouse=True)
53+
def TaskDoneService(self, patch):
54+
return patch("h.services.task_done.TaskDoneService")

0 commit comments

Comments
 (0)