Skip to content

static type checking (but loose-like around dependencies) #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ repos:
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: '9db9854e3041219b1eb619872a2dfaf58adfb20b' # v1.9.0
hooks:
- id: mypy
2 changes: 1 addition & 1 deletion addon_imps/storage/box_dot_com.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _params_from_cursor(self, cursor: str = "") -> dict[str, str]:
# https://developer.box.com/guides/api-calls/pagination/offset-based/
try:
_cursor = OffsetCursor.from_str(cursor)
return {"offset": _cursor.offset, "limit": _cursor.limit}
return {"offset": str(_cursor.offset), "limit": str(_cursor.limit)}
except ValueError:
return {}

Expand Down
14 changes: 10 additions & 4 deletions addon_service/addon_operation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@
get_imp_name,
)
from addon_service.common.static_dataclass_model import StaticDataclassModel
from addon_toolkit import AddonOperationImp
from addon_toolkit.json_arguments import jsonschema_for_signature_params
from addon_toolkit import (
AddonCapabilities,
AddonOperationImp,
)
from addon_toolkit.json_arguments import (
JsonableDict,
jsonschema_for_signature_params,
)
from addon_toolkit.operation import AddonOperationType


Expand Down Expand Up @@ -42,11 +48,11 @@ def implementation_docstring(self) -> str:
return self.operation_imp.imp_function.__doc__ or ""

@cached_property
def capability(self) -> str:
def capability(self) -> AddonCapabilities:
return self.operation_imp.declaration.capability

@cached_property
def params_jsonschema(self) -> dict:
def params_jsonschema(self) -> JsonableDict:
return jsonschema_for_signature_params(
self.operation_imp.declaration.call_signature
)
Expand Down
2 changes: 1 addition & 1 deletion addon_service/addon_operation/serializers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from rest_framework_json_api import serializers
from rest_framework_json_api.utils import get_resource_type_from_model

from addon_service.addon_imp.models import AddonImp
from addon_service.common import view_names
from addon_service.common.enums.serializers import EnumNameChoiceField
from addon_service.common.serializer_fields import DataclassRelatedDataField
from addon_toolkit import AddonCapabilities
from addon_toolkit.imp import AddonImp

from .models import AddonOperationModel

Expand Down
32 changes: 19 additions & 13 deletions addon_service/common/hmac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import hashlib
import hmac
import re
import urllib
import urllib.parse
from datetime import (
UTC,
datetime,
Expand All @@ -16,7 +16,7 @@
)


def _sign_message(message: str, hmac_key: str = None) -> str:
def _sign_message(message: str, hmac_key: str | None = None) -> str:
key = hmac_key or settings.DEFAULT_HMAC_KEY
encoded_message = base64.b64encode(message.encode())
return hmac.new(
Expand All @@ -30,25 +30,31 @@ def _get_signed_components(
parsed_url = urllib.parse.urlparse(request_url)
if isinstance(body, str):
body = body.encode()
content_hash = hashlib.sha256(body).hexdigest if body else None
auth_timestamp = datetime.now(UTC)
content_hash = hashlib.sha256(body).hexdigest() if body else None
auth_timestamp = str(datetime.now(UTC))
# Filter out query string and content_hash if none present
signed_segments = [
request_method,
parsed_url.path,
parsed_url.query,
str(auth_timestamp),
content_hash,
segment
for segment in [
request_method,
parsed_url.path,
parsed_url.query,
auth_timestamp,
content_hash,
]
if segment
]
# Filter out query string and content_hash if none present
signed_segments = [segment for segment in signed_segments if segment]
signed_headers = {"X-Authorization-Timestamp": auth_timestamp}
signed_headers: dict[str, str] = {"X-Authorization-Timestamp": auth_timestamp}
if content_hash:
signed_headers["X-Content-SHA256"] = content_hash
return signed_segments, signed_headers


def make_signed_headers(
request_url: str, request_method: str, body: str | bytes = "", hmac_key: str = None
request_url: str,
request_method: str,
body: str | bytes = "",
hmac_key: str | None = None,
) -> dict:
signed_string_segments, signed_headers = _get_signed_components(
request_url, request_method, body
Expand Down
12 changes: 7 additions & 5 deletions addon_service/common/jsonapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class JSONAPIQueryParam:
"""Dataclass for describing the contents of a JSON:API-compliant Query Parameter."""

family: str
args: tuple[str] = ()
args: tuple[str, ...] = ()
value: str = ""

# Matches any alphanumeric string followed by an open bracket or end of input
Expand All @@ -29,7 +29,7 @@ def from_key_value_pair(cls, query_param_name: str, query_param_value: str) -> S
return cls(family, args, query_param_value)

@classmethod
def parse_param_name(cls, query_param_name: str) -> tuple[str, tuple[str]]:
def parse_param_name(cls, query_param_name: str) -> tuple[str, tuple[str, ...]]:
"""Parses a query parameter name into its family and bracketed args.

>>> JSONAPIQueryParam.parse_param_name('filter')
Expand All @@ -43,7 +43,9 @@ def parse_param_name(cls, query_param_name: str) -> tuple[str, tuple[str]]:
"""
if not cls._param_name_is_valid(query_param_name):
raise ValueError(f"Invalid query param name: {query_param_name}")
family = cls.FAMILY_REGEX.match(query_param_name).group()
family_match = cls.FAMILY_REGEX.match(query_param_name)
assert family_match is not None
family = family_match.group()
args = cls.ARG_REGEX.findall(query_param_name)
return (family, tuple(args))

Expand Down Expand Up @@ -82,7 +84,7 @@ def __str__(self):
return f"{self.family}{args}={self.value}"


QueryParamFamilies = dict[str, Iterable[JSONAPIQueryParam]]
QueryParamFamilies = dict[str, list[JSONAPIQueryParam]]


def group_query_params_by_family(
Expand All @@ -93,7 +95,7 @@ def group_query_params_by_family(
Data should be pre-normalized before calling, such as by using the results of
`urllib.parse.parse_qs(...).items()` or `django.utils.QueryDict.lists()`
"""
grouped_query_params = QueryParamFamilies()
grouped_query_params: QueryParamFamilies = {}
for _unparsed_name, _param_values in query_items:
# Handle wsgiref.headers.Headers-style multi-dicts
if isinstance(_param_values, str):
Expand Down
2 changes: 1 addition & 1 deletion addon_service/common/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(

# abstract method from HttpRequestor:
@contextlib.asynccontextmanager
async def do_send(self, request: HttpRequestInfo):
async def send_request(self, request: HttpRequestInfo):
try:
async with self._try_send(request) as _response:
yield _response
Expand Down
16 changes: 13 additions & 3 deletions addon_service/common/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,20 @@ def has_permission(self, request, view):
expiration_time = datetime.now(UTC) - timedelta(
seconds=self.REQUEST_EXPIRATION_SECONDS
)
request_timestamp = request.headers.get("X-Authorization-Timestamp")
if not request_timestamp or request_timestamp < expiration_time:
request_timestamp_str = request.headers.get("X-Authorization-Timestamp")
if not request_timestamp_str:
raise exceptions.PermissionDenied(
"Missing required 'X-Authorization-Timestamp' value"
)
try:
_request_timestamp = datetime.fromisoformat(request_timestamp_str)
except ValueError:
raise exceptions.PermissionDenied(
"Invalid 'X-Authorization-Timestamp' value"
)
if _request_timestamp < expiration_time:
raise exceptions.PermissionDenied("HMAC Signed Request is expired")
elif request_timestamp > datetime.now(UTC):
if _request_timestamp > datetime.now(UTC):
raise exceptions.PermissionDenied(
"HMAC Signed Request provided a timestamp from the future"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _make_post_payload(
}
credentials = credentials or MOCK_CREDENTIALS[external_service.credentials_format]
if credentials:
payload["data"]["attributes"]["credentials"] = credentials.asdict()
payload["data"]["attributes"]["credentials"] = credentials.asdict() # type: ignore
return payload


Expand Down Expand Up @@ -143,10 +143,12 @@ def test_post__sets_credentials(self):
self.assertEqual(_resp.status_code, HTTPStatus.CREATED)

account = db.AuthorizedStorageAccount.objects.get(id=_resp.data["id"])
mock_credentials = MOCK_CREDENTIALS[creds_format]
assert mock_credentials is not None
with self.subTest(creds_format=creds_format):
self.assertEqual(
account._credentials.credentials_blob,
MOCK_CREDENTIALS[creds_format].asdict(),
mock_credentials.asdict(),
)

def test_post__sets_auth_url(self):
Expand Down Expand Up @@ -390,6 +392,7 @@ def test_set_credentials__create(self):
)
self.assertIsNone(account._credentials)
mock_credentials = MOCK_CREDENTIALS[creds_format]
assert mock_credentials is not None
account.credentials = mock_credentials
account.save()
with self.subTest(creds_format=creds_format):
Expand Down
22 changes: 11 additions & 11 deletions addon_toolkit/constrained_network/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,25 @@ class HttpRequestor(typing.Protocol):
def response_info_cls(self) -> type[HttpResponseInfo]: ...

# abstract method for subclasses
def do_send(
def send_request(
self, request: HttpRequestInfo
) -> contextlib.AbstractAsyncContextManager[HttpResponseInfo]: ...

@contextlib.asynccontextmanager
async def request(
async def _request(
self,
http_method: HTTPMethod,
uri_path: str,
query: Multidict | KeyValuePairs | None = None,
headers: Multidict | KeyValuePairs | None = None,
):
) -> typing.Any: # loose type; method-specific methods below are more accurate
_request_info = HttpRequestInfo(
http_method=http_method,
uri_path=uri_path,
query=(query if isinstance(query, Multidict) else Multidict(query)),
headers=(headers if isinstance(headers, Multidict) else Multidict(headers)),
)
async with self.do_send(_request_info) as _response:
async with self.send_request(_request_info) as _response:
yield _response

# TODO: streaming send/receive (only if/when needed)
Expand All @@ -88,10 +88,10 @@ async def request(
# convenience methods for http methods
# (same call signature as self.request, minus `http_method`)

OPTIONS: _MethodRequestMethod = partialmethod(request, HTTPMethod.OPTIONS)
HEAD: _MethodRequestMethod = partialmethod(request, HTTPMethod.HEAD)
GET: _MethodRequestMethod = partialmethod(request, HTTPMethod.GET)
PATCH: _MethodRequestMethod = partialmethod(request, HTTPMethod.PATCH)
POST: _MethodRequestMethod = partialmethod(request, HTTPMethod.POST)
PUT: _MethodRequestMethod = partialmethod(request, HTTPMethod.PUT)
DELETE: _MethodRequestMethod = partialmethod(request, HTTPMethod.DELETE)
OPTIONS: _MethodRequestMethod = partialmethod(_request, HTTPMethod.OPTIONS)
HEAD: _MethodRequestMethod = partialmethod(_request, HTTPMethod.HEAD)
GET: _MethodRequestMethod = partialmethod(_request, HTTPMethod.GET)
PATCH: _MethodRequestMethod = partialmethod(_request, HTTPMethod.PATCH)
POST: _MethodRequestMethod = partialmethod(_request, HTTPMethod.POST)
PUT: _MethodRequestMethod = partialmethod(_request, HTTPMethod.PUT)
DELETE: _MethodRequestMethod = partialmethod(_request, HTTPMethod.DELETE)
8 changes: 5 additions & 3 deletions addon_toolkit/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
import typing


@dataclasses.dataclass(frozen=True)
class Credentials(typing.Protocol):
def asdict(self):
def asdict(self) -> dict[str, typing.Any]:
return dataclasses.asdict(self)

def iter_headers(self) -> typing.Iterator[tuple[str, str]]: ...
def iter_headers(self) -> typing.Iterator[tuple[str, str]]:
yield from () # no headers unless implemented by subclass


@dataclasses.dataclass(frozen=True, kw_only=True)
class AccessTokenCredentials(Credentials):
access_token: str

def iter_headers(self):
def iter_headers(self) -> typing.Iterator[tuple[str, str]]:
yield ("Authorization", f"Bearer {self.access_token}")


Expand Down
24 changes: 15 additions & 9 deletions addon_toolkit/cursor.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
import base64
import dataclasses
import json
from typing import (
ClassVar,
Protocol,
)
import typing


def encode_cursor_dataclass(dataclass_instance) -> str:
class DataclassInstance(typing.Protocol):
__dataclass_fields__: typing.ClassVar[dict[str, typing.Any]]


SomeDataclassInstance = typing.TypeVar("SomeDataclassInstance", bound=DataclassInstance)


def encode_cursor_dataclass(dataclass_instance: DataclassInstance) -> str:
_as_json = json.dumps(dataclasses.astuple(dataclass_instance))
_cursor_bytes = base64.b64encode(_as_json.encode())
return _cursor_bytes.decode()


def decode_cursor_dataclass(cursor: str, dataclass_class):
def decode_cursor_dataclass(
cursor: str, dataclass_class: type[SomeDataclassInstance]
) -> SomeDataclassInstance:
_as_list = json.loads(base64.b64decode(cursor))
return dataclass_class(*_as_list)


class Cursor(Protocol):
class Cursor(DataclassInstance, typing.Protocol):
@classmethod
def from_str(cls, cursor: str):
def from_str(cls, cursor: str) -> typing.Self:
return decode_cursor_dataclass(cursor, cls)

@property
Expand Down Expand Up @@ -52,7 +58,7 @@ class OffsetCursor(Cursor):
limit: int
total_count: int # use -1 to mean "many more"

MAX_INDEX: ClassVar[int] = 9999
MAX_INDEX: typing.ClassVar[int] = 9999

@property
def next_cursor_str(self) -> str | None:
Expand Down
Loading