Skip to content

Commit

Permalink
#78 - Custom FastAPI and ENP Class (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
MackHalliday authored Dec 23, 2024
1 parent 1a7ba79 commit 957d367
Show file tree
Hide file tree
Showing 13 changed files with 148 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .talismanrc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fileignoreconfig:
- filename: ci/docker-compose-local.yml
checksum: c73beda98c39232441d86f6d6f6f2858274f9703f8a462eac1effacfbb9aa39d
- filename: poetry.lock
checksum: 5387ec3e4cc64686ab820f306c61241342c29317f6b13596f15bc4eef309a798
checksum: 6ffe240c15e0287b1b05aacda82fef9f5c8b677772e854a9f7aa5f05546d1669
- filename: tests/app/providers/__init__.py
checksum: 0e3ae2fd3a50245a8c143d31c4316b164b51161d2db6e660aa956b78eda1b4d8
- filename: tests/app/providers/test_provider_aws.py
Expand Down
2 changes: 1 addition & 1 deletion app/legacy/v2/notifications/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def create_push_notification(
notification = await dao_create_notification(Notification(personalization=json.dumps(personalization)))

background_tasks.add_task(
send_push_notification_helper, personalization, icn, template, request.app.state.providers['aws']
send_push_notification_helper, personalization, icn, template, request.app.enp_state.providers['aws']
)

logger.info(
Expand Down
38 changes: 26 additions & 12 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,43 @@
import os
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Annotated, Never
from typing import Annotated, Any, AsyncContextManager, Callable, Mapping, Never

from fastapi import Depends, FastAPI, status
from fastapi.staticfiles import StaticFiles
from loguru import logger
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session

from app.db.db_init import close_db, get_read_session_with_depends, get_write_session_with_depends, init_db
from app.db.db_init import (
close_db,
get_read_session_with_depends,
get_write_session_with_depends,
init_db,
)
from app.legacy.v2.notifications.rest import v2_notification_router
from app.logging.logging_config import CustomizeLogger
from app.providers.provider_aws import ProviderAWS
from app.state import ENPState
from app.v3 import api_router as v3_router

MKDOCS_DIRECTORY = 'site'


class CustomFastAPI(FastAPI):
"""Custom FastAPI class to include ENPState."""

def __init__(self, lifespan: Callable[['CustomFastAPI'], AsyncContextManager[Mapping[str, Any]]]) -> None:
"""Initialize the CustomFastAPI instance with ENPState.
Args:
lifespan: The lifespan context manager for the application.
"""
super().__init__(lifespan=lifespan)
self.enp_state = ENPState()


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[Never]:
async def lifespan(app: CustomFastAPI) -> AsyncIterator[Never]:
"""Initialize the database, and populate the providers dictionary.
https://fastapi.tiangolo.com/advanced/events/?h=life#lifespan
Expand All @@ -34,25 +52,21 @@ async def lifespan(app: FastAPI) -> AsyncIterator[Never]:
"""
await init_db()
# Route handlers should access this dictionary to send notifications using
# various third-party services, such as AWS, Twilio, etc.
app.state.providers = {'aws': ProviderAWS()}

yield # type: ignore

app.state.providers.clear()
app.enp_state.clear_providers()
await close_db()


def create_app() -> FastAPI:
def create_app() -> CustomFastAPI:
"""Create and configure the FastAPI app.
Returns:
CustomFastAPI: The FastAPI application instance with custom logging.
"""
CustomizeLogger.make_logger()
app = FastAPI(lifespan=lifespan)
app = CustomFastAPI(lifespan=lifespan)
app.include_router(v3_router)
app.include_router(v2_notification_router)

Expand All @@ -64,7 +78,7 @@ def create_app() -> FastAPI:
return app


app: FastAPI = create_app()
app: CustomFastAPI = create_app()


@app.get('/')
Expand Down
20 changes: 20 additions & 0 deletions app/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""This module manages state for the application."""

from typing import Dict

from app.providers.provider_aws import ProviderAWS
from app.providers.provider_base import ProviderBase


class ENPState:
"""Custom application state class."""

def __init__(self) -> None:
"""Initialize ENPState with a default set of providers."""
# Route handlers should access this dictionary to send notifications using
# various third-party services, such as AWS, Twilio, etc.
self.providers: Dict[str, ProviderBase] = {'aws': ProviderAWS()}

def clear_providers(self) -> None:
"""Clear the providers dictionary."""
self.providers.clear()
2 changes: 1 addition & 1 deletion app/v3/device_registrations/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def create_device_registration(
"""
logger.debug('Received device registration request: {}', request)

provider = fastapi_request.app.state.providers['aws']
provider = fastapi_request.app.enp_state.providers['aws']
logger.debug('Loaded provider: {}', provider)

device_registration_model = DeviceRegistrationModel(
Expand Down
21 changes: 19 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ optional = true
pytest-cov = "*"
pytest = "*"
pytest-asyncio = "*"
pytest-mock = "*"

[tool.poetry.group.static_tools]
optional = true
Expand Down
18 changes: 9 additions & 9 deletions tests/app/legacy/v2/notifications/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import pytest
from fastapi import BackgroundTasks, status
from fastapi.testclient import TestClient

from app.constants import IdentifierType, MobileAppType
from app.db.models import Template
from app.legacy.v2.notifications.route_schema import (
V2PostPushRequestModel,
V2PostPushResponseModel,
)
from tests.conftest import ENPTestClient


@pytest.mark.asyncio
Expand All @@ -26,15 +26,15 @@ async def test_router_returns_400_with_invalid_request_data(
mock_validate_template: AsyncMock,
mock_dao_create_notification: AsyncMock,
mock_background_task: AsyncMock,
client: TestClient,
client: ENPTestClient,
) -> None:
"""Test route can return 400.
Args:
mock_validate_template (AsyncMock): Mock call to validate_template
mock_dao_create_notification (AsyncMock): Mock call to create notification in the database
mock_background_task (AsyncMock): Mock call to add a background task
client (TestClient): FastAPI client fixture
client (ENPTestClient): Custom FastAPI client fixture
"""
invalid_request = {
Expand All @@ -56,15 +56,15 @@ async def test_router_returns_500_when_other_exception_thrown(
mock_validate_template: AsyncMock,
mock_dao_create_notification: AsyncMock,
mock_background_task: AsyncMock,
client: TestClient,
client: ENPTestClient,
) -> None:
"""Test route can return 500.
Args:
mock_validate_template (AsyncMock): Mock call to validate_template
mock_dao_create_notification (AsyncMock): Mock call to create notification in the database
mock_background_task (AsyncMock): Mock call to add a background task
client (TestClient): FastAPI client fixture
client (ENPTestClient): Custom FastAPI client fixture
"""
mock_validate_template.return_value = Template(name='test_template')
Expand Down Expand Up @@ -97,15 +97,15 @@ async def test_post_push_returns_201(
mock_validate_template: AsyncMock,
mock_dao_create_notification: AsyncMock,
mock_background_task: AsyncMock,
client: TestClient,
client: ENPTestClient,
) -> None:
"""Test route can return 201.
Args:
mock_validate_template (AsyncMock): Mock call to validate_template
mock_dao_create_notification (AsyncMock): Mock call to create notification in the database
mock_background_task (AsyncMock): Mock call to add a background task
client (TestClient): FastAPI client fixture
client (ENPTestClient): Custom FastAPI client fixture
"""
mock_validate_template.return_value = Template(name='test_template')
Expand All @@ -130,15 +130,15 @@ async def test_post_push_returns_400_when_unable_to_validate_template(
mock_validate_template: AsyncMock,
mock_dao_create_notification: AsyncMock,
mock_background_task: AsyncMock,
client: TestClient,
client: ENPTestClient,
) -> None:
"""Test route returns 400 when there is an exception thrown trying to validate the template.
Args:
mock_validate_template (AsyncMock): Mock call to validate_template
mock_dao_create_notification (AsyncMock): Mock call to create notification in the database
mock_background_task (AsyncMock): Mock call to add a background task
client (TestClient): FastAPI client fixture
client (ENPTestClient): Custom FastAPI client fixture
"""
mock_validate_template.side_effect = Exception()
Expand Down
11 changes: 6 additions & 5 deletions tests/app/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@

from unittest.mock import Mock, patch

from fastapi.testclient import TestClient
from starlette import status

from tests.conftest import ENPTestClient

def test_simple_route(client: TestClient) -> None:

def test_simple_route(client: ENPTestClient) -> None:
"""Test GET / to return Hello World.
Args:
client (TestClient): FastAPI client fixture
client (ENPTestClient): Custom FastAPI client fixture
"""
resp = client.get('/')
Expand All @@ -19,12 +20,12 @@ def test_simple_route(client: TestClient) -> None:


@patch('app.main.logger.info')
def test_simple_route_logs_hello_world(mock_logger: Mock, client: TestClient) -> None:
def test_simple_route_logs_hello_world(mock_logger: Mock, client: ENPTestClient) -> None:
"""Test that GET / logs 'Hello World' as an info log.
Args:
mock_logger (Mock): Mocked logger for capturing log calls.
client (TestClient): FastAPI client fixture
client (ENPTestClient): Custom FastAPI client fixture
"""
client.get('/')
Expand Down
21 changes: 21 additions & 0 deletions tests/app/test_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Test for ENPState Module."""

from app.providers.provider_aws import ProviderAWS
from app.state import ENPState


def test_enp_state_initialization() -> None:
"""Test to make sure ENPState can have provider attribute."""
state = ENPState()
assert isinstance(state.providers, dict)
assert 'aws' in state.providers
assert isinstance(state.providers['aws'], ProviderAWS)


def test_clear_providers() -> None:
"""Test the clear_providers method to ensure it clears the providers dictionary."""
state = ENPState()

assert len(state.providers) == 1
state.clear_providers()
assert len(state.providers) == 0
18 changes: 11 additions & 7 deletions tests/app/v3/device_registrations/test_rest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Test cases for the device-registrations REST API."""

from unittest.mock import AsyncMock

import pytest
from fastapi import status
from fastapi.testclient import TestClient

from app.constants import MobileAppType, OSPlatformType
from app.v3.device_registrations.route_schema import DeviceRegistrationRequest
from tests.conftest import ENPTestClient


# Valid applications are VA_FLAGSHIP_APP, VETEXT. Valid platforms are IOS, ANDROID.
Expand All @@ -25,27 +27,29 @@
],
)
def test_post(
client: TestClient,
client: ENPTestClient,
application: MobileAppType,
platform: OSPlatformType,
payload: dict[str, str],
mocker: AsyncMock,
) -> None:
"""Test POST /v3/device-registration.
The endpoint should return a 201 status code, and the response should include
the endpoint sid.
Args:
client(TestClient): FastAPI client fixture
client(ENPTestClient): Custom FastAPI client fixture
application(str): The application name, either VA_FLAGSHIP_APP or VETEXT
platform(str): The platform name, either IOS or ANDROID
payload(dict): The request payload
mocker(AsyncMock): Mock fixture for async dependencies
"""
client.app.state.providers[ # type: ignore
'aws'
].register_device.return_value = (
'arn:aws:sns:us-east-1:000000000000:endpoint/APNS/notify/00000000-0000-0000-0000-000000000000'
mocker.patch.object(
client.app.enp_state.providers['aws'],
'register_device',
return_value='arn:aws:sns:us-east-1:000000000000:endpoint/APNS/notify/00000000-0000-0000-0000-000000000000',
)

request = DeviceRegistrationRequest(**payload, app_name=application, os_name=platform)
Expand Down
Loading

0 comments on commit 957d367

Please sign in to comment.