Skip to content

Commit 957d367

Browse files
authored
#78 - Custom FastAPI and ENP Class (#99)
1 parent 1a7ba79 commit 957d367

File tree

13 files changed

+148
-56
lines changed

13 files changed

+148
-56
lines changed

.talismanrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ fileignoreconfig:
1414
- filename: ci/docker-compose-local.yml
1515
checksum: c73beda98c39232441d86f6d6f6f2858274f9703f8a462eac1effacfbb9aa39d
1616
- filename: poetry.lock
17-
checksum: 5387ec3e4cc64686ab820f306c61241342c29317f6b13596f15bc4eef309a798
17+
checksum: 6ffe240c15e0287b1b05aacda82fef9f5c8b677772e854a9f7aa5f05546d1669
1818
- filename: tests/app/providers/__init__.py
1919
checksum: 0e3ae2fd3a50245a8c143d31c4316b164b51161d2db6e660aa956b78eda1b4d8
2020
- filename: tests/app/providers/test_provider_aws.py

app/legacy/v2/notifications/rest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async def create_push_notification(
6060
notification = await dao_create_notification(Notification(personalization=json.dumps(personalization)))
6161

6262
background_tasks.add_task(
63-
send_push_notification_helper, personalization, icn, template, request.app.state.providers['aws']
63+
send_push_notification_helper, personalization, icn, template, request.app.enp_state.providers['aws']
6464
)
6565

6666
logger.info(

app/main.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,43 @@
33
import os
44
from collections.abc import AsyncIterator
55
from contextlib import asynccontextmanager
6-
from typing import Annotated, Never
6+
from typing import Annotated, Any, AsyncContextManager, Callable, Mapping, Never
77

88
from fastapi import Depends, FastAPI, status
99
from fastapi.staticfiles import StaticFiles
1010
from loguru import logger
1111
from sqlalchemy import select
1212
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session
1313

14-
from app.db.db_init import close_db, get_read_session_with_depends, get_write_session_with_depends, init_db
14+
from app.db.db_init import (
15+
close_db,
16+
get_read_session_with_depends,
17+
get_write_session_with_depends,
18+
init_db,
19+
)
1520
from app.legacy.v2.notifications.rest import v2_notification_router
1621
from app.logging.logging_config import CustomizeLogger
17-
from app.providers.provider_aws import ProviderAWS
22+
from app.state import ENPState
1823
from app.v3 import api_router as v3_router
1924

2025
MKDOCS_DIRECTORY = 'site'
2126

2227

28+
class CustomFastAPI(FastAPI):
29+
"""Custom FastAPI class to include ENPState."""
30+
31+
def __init__(self, lifespan: Callable[['CustomFastAPI'], AsyncContextManager[Mapping[str, Any]]]) -> None:
32+
"""Initialize the CustomFastAPI instance with ENPState.
33+
34+
Args:
35+
lifespan: The lifespan context manager for the application.
36+
"""
37+
super().__init__(lifespan=lifespan)
38+
self.enp_state = ENPState()
39+
40+
2341
@asynccontextmanager
24-
async def lifespan(app: FastAPI) -> AsyncIterator[Never]:
42+
async def lifespan(app: CustomFastAPI) -> AsyncIterator[Never]:
2543
"""Initialize the database, and populate the providers dictionary.
2644
2745
https://fastapi.tiangolo.com/advanced/events/?h=life#lifespan
@@ -34,25 +52,21 @@ async def lifespan(app: FastAPI) -> AsyncIterator[Never]:
3452
3553
"""
3654
await init_db()
37-
# Route handlers should access this dictionary to send notifications using
38-
# various third-party services, such as AWS, Twilio, etc.
39-
app.state.providers = {'aws': ProviderAWS()}
4055

4156
yield # type: ignore
4257

43-
app.state.providers.clear()
58+
app.enp_state.clear_providers()
4459
await close_db()
4560

4661

47-
def create_app() -> FastAPI:
62+
def create_app() -> CustomFastAPI:
4863
"""Create and configure the FastAPI app.
4964
5065
Returns:
5166
CustomFastAPI: The FastAPI application instance with custom logging.
52-
5367
"""
5468
CustomizeLogger.make_logger()
55-
app = FastAPI(lifespan=lifespan)
69+
app = CustomFastAPI(lifespan=lifespan)
5670
app.include_router(v3_router)
5771
app.include_router(v2_notification_router)
5872

@@ -64,7 +78,7 @@ def create_app() -> FastAPI:
6478
return app
6579

6680

67-
app: FastAPI = create_app()
81+
app: CustomFastAPI = create_app()
6882

6983

7084
@app.get('/')

app/state.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""This module manages state for the application."""
2+
3+
from typing import Dict
4+
5+
from app.providers.provider_aws import ProviderAWS
6+
from app.providers.provider_base import ProviderBase
7+
8+
9+
class ENPState:
10+
"""Custom application state class."""
11+
12+
def __init__(self) -> None:
13+
"""Initialize ENPState with a default set of providers."""
14+
# Route handlers should access this dictionary to send notifications using
15+
# various third-party services, such as AWS, Twilio, etc.
16+
self.providers: Dict[str, ProviderBase] = {'aws': ProviderAWS()}
17+
18+
def clear_providers(self) -> None:
19+
"""Clear the providers dictionary."""
20+
self.providers.clear()

app/v3/device_registrations/rest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ async def create_device_registration(
3333
"""
3434
logger.debug('Received device registration request: {}', request)
3535

36-
provider = fastapi_request.app.state.providers['aws']
36+
provider = fastapi_request.app.enp_state.providers['aws']
3737
logger.debug('Loaded provider: {}', provider)
3838

3939
device_registration_model = DeviceRegistrationModel(

poetry.lock

Lines changed: 19 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ optional = true
2929
pytest-cov = "*"
3030
pytest = "*"
3131
pytest-asyncio = "*"
32+
pytest-mock = "*"
3233

3334
[tool.poetry.group.static_tools]
3435
optional = true

tests/app/legacy/v2/notifications/test_rest.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
import pytest
66
from fastapi import BackgroundTasks, status
7-
from fastapi.testclient import TestClient
87

98
from app.constants import IdentifierType, MobileAppType
109
from app.db.models import Template
1110
from app.legacy.v2.notifications.route_schema import (
1211
V2PostPushRequestModel,
1312
V2PostPushResponseModel,
1413
)
14+
from tests.conftest import ENPTestClient
1515

1616

1717
@pytest.mark.asyncio
@@ -26,15 +26,15 @@ async def test_router_returns_400_with_invalid_request_data(
2626
mock_validate_template: AsyncMock,
2727
mock_dao_create_notification: AsyncMock,
2828
mock_background_task: AsyncMock,
29-
client: TestClient,
29+
client: ENPTestClient,
3030
) -> None:
3131
"""Test route can return 400.
3232
3333
Args:
3434
mock_validate_template (AsyncMock): Mock call to validate_template
3535
mock_dao_create_notification (AsyncMock): Mock call to create notification in the database
3636
mock_background_task (AsyncMock): Mock call to add a background task
37-
client (TestClient): FastAPI client fixture
37+
client (ENPTestClient): Custom FastAPI client fixture
3838
3939
"""
4040
invalid_request = {
@@ -56,15 +56,15 @@ async def test_router_returns_500_when_other_exception_thrown(
5656
mock_validate_template: AsyncMock,
5757
mock_dao_create_notification: AsyncMock,
5858
mock_background_task: AsyncMock,
59-
client: TestClient,
59+
client: ENPTestClient,
6060
) -> None:
6161
"""Test route can return 500.
6262
6363
Args:
6464
mock_validate_template (AsyncMock): Mock call to validate_template
6565
mock_dao_create_notification (AsyncMock): Mock call to create notification in the database
6666
mock_background_task (AsyncMock): Mock call to add a background task
67-
client (TestClient): FastAPI client fixture
67+
client (ENPTestClient): Custom FastAPI client fixture
6868
6969
"""
7070
mock_validate_template.return_value = Template(name='test_template')
@@ -97,15 +97,15 @@ async def test_post_push_returns_201(
9797
mock_validate_template: AsyncMock,
9898
mock_dao_create_notification: AsyncMock,
9999
mock_background_task: AsyncMock,
100-
client: TestClient,
100+
client: ENPTestClient,
101101
) -> None:
102102
"""Test route can return 201.
103103
104104
Args:
105105
mock_validate_template (AsyncMock): Mock call to validate_template
106106
mock_dao_create_notification (AsyncMock): Mock call to create notification in the database
107107
mock_background_task (AsyncMock): Mock call to add a background task
108-
client (TestClient): FastAPI client fixture
108+
client (ENPTestClient): Custom FastAPI client fixture
109109
110110
"""
111111
mock_validate_template.return_value = Template(name='test_template')
@@ -130,15 +130,15 @@ async def test_post_push_returns_400_when_unable_to_validate_template(
130130
mock_validate_template: AsyncMock,
131131
mock_dao_create_notification: AsyncMock,
132132
mock_background_task: AsyncMock,
133-
client: TestClient,
133+
client: ENPTestClient,
134134
) -> None:
135135
"""Test route returns 400 when there is an exception thrown trying to validate the template.
136136
137137
Args:
138138
mock_validate_template (AsyncMock): Mock call to validate_template
139139
mock_dao_create_notification (AsyncMock): Mock call to create notification in the database
140140
mock_background_task (AsyncMock): Mock call to add a background task
141-
client (TestClient): FastAPI client fixture
141+
client (ENPTestClient): Custom FastAPI client fixture
142142
143143
"""
144144
mock_validate_template.side_effect = Exception()

tests/app/test_main.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22

33
from unittest.mock import Mock, patch
44

5-
from fastapi.testclient import TestClient
65
from starlette import status
76

7+
from tests.conftest import ENPTestClient
88

9-
def test_simple_route(client: TestClient) -> None:
9+
10+
def test_simple_route(client: ENPTestClient) -> None:
1011
"""Test GET / to return Hello World.
1112
1213
Args:
13-
client (TestClient): FastAPI client fixture
14+
client (ENPTestClient): Custom FastAPI client fixture
1415
1516
"""
1617
resp = client.get('/')
@@ -19,12 +20,12 @@ def test_simple_route(client: TestClient) -> None:
1920

2021

2122
@patch('app.main.logger.info')
22-
def test_simple_route_logs_hello_world(mock_logger: Mock, client: TestClient) -> None:
23+
def test_simple_route_logs_hello_world(mock_logger: Mock, client: ENPTestClient) -> None:
2324
"""Test that GET / logs 'Hello World' as an info log.
2425
2526
Args:
2627
mock_logger (Mock): Mocked logger for capturing log calls.
27-
client (TestClient): FastAPI client fixture
28+
client (ENPTestClient): Custom FastAPI client fixture
2829
2930
"""
3031
client.get('/')

tests/app/test_state.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Test for ENPState Module."""
2+
3+
from app.providers.provider_aws import ProviderAWS
4+
from app.state import ENPState
5+
6+
7+
def test_enp_state_initialization() -> None:
8+
"""Test to make sure ENPState can have provider attribute."""
9+
state = ENPState()
10+
assert isinstance(state.providers, dict)
11+
assert 'aws' in state.providers
12+
assert isinstance(state.providers['aws'], ProviderAWS)
13+
14+
15+
def test_clear_providers() -> None:
16+
"""Test the clear_providers method to ensure it clears the providers dictionary."""
17+
state = ENPState()
18+
19+
assert len(state.providers) == 1
20+
state.clear_providers()
21+
assert len(state.providers) == 0

0 commit comments

Comments
 (0)