Skip to content

Commit ed535de

Browse files
khvn26tushar5526zachaysan
authored
feat: strict typing (#70)
* feat: add mypy, update tooling * feat: add type hints to core lib (#71) --------- Co-authored-by: Tushar <[email protected]> Co-authored-by: Zach Aysan <[email protected]>
1 parent f8fd387 commit ed535de

18 files changed

+501
-339
lines changed

.github/workflows/pytest.yml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
name: Formatting and Tests
1+
name: Linting and Tests
22

33
on:
44
- pull_request
55

66
jobs:
77
test:
88
runs-on: ubuntu-latest
9-
name: Pytest and Black formatting
9+
name: Linting and Tests
1010

1111
strategy:
1212
max-parallel: 4
@@ -28,13 +28,16 @@ jobs:
2828
run: |
2929
python -m pip install --upgrade pip
3030
pip install poetry
31-
poetry install
31+
poetry install --with dev
3232
3333
- name: Check Formatting
3434
run: |
3535
poetry run black --check .
3636
poetry run flake8 .
3737
poetry run isort --check .
3838
39+
- name: Check Typing
40+
run: poetry run mypy --strict .
41+
3942
- name: Run Tests
4043
run: poetry run pytest

.pre-commit-config.yaml

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
repos:
2-
- repo: https://github.com/asottile/seed-isort-config
3-
rev: v1.9.3
2+
- repo: https://github.com/pre-commit/mirrors-mypy
3+
rev: v1.5.1
44
hooks:
5-
- id: seed-isort-config
6-
- repo: https://github.com/pre-commit/mirrors-isort
7-
rev: v4.3.21
5+
- id: mypy
6+
args: [--strict]
7+
additional_dependencies:
8+
[pydantic, pytest, pytest_mock, types-requests, flagsmith-flag-engine, responses, types-pytz, sseclient-py]
9+
- repo: https://github.com/PyCQA/isort
10+
rev: 5.12.0
811
hooks:
912
- id: isort
1013
- repo: https://github.com/psf/black
11-
rev: 23.3.0
14+
rev: 23.7.0
1215
hooks:
1316
- id: black
1417
language_version: python3
18+
- repo: https://github.com/pycqa/flake8
19+
rev: 6.1.0
20+
hooks:
21+
- id: flake8
22+
name: flake8
1523
- repo: https://github.com/pre-commit/pre-commit-hooks
1624
rev: v4.4.0
1725
hooks:

flagsmith/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
from .flagsmith import Flagsmith # noqa
1+
from .flagsmith import Flagsmith
2+
3+
__all__ = ("Flagsmith",)

flagsmith/analytics.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import json
2+
import typing
23
from datetime import datetime
34

4-
from requests_futures.sessions import FuturesSession
5+
from requests_futures.sessions import FuturesSession # type: ignore
56

6-
ANALYTICS_ENDPOINT = "analytics/flags/"
7+
ANALYTICS_ENDPOINT: typing.Final[str] = "analytics/flags/"
78

89
# Used to control how often we send data(in seconds)
9-
ANALYTICS_TIMER = 10
10+
ANALYTICS_TIMER: typing.Final[int] = 10
1011

1112
session = FuturesSession(max_workers=4)
1213

@@ -17,7 +18,9 @@ class AnalyticsProcessor:
1718
the Flagsmith SDK. Docs: https://docs.flagsmith.com/advanced-use/flag-analytics.
1819
"""
1920

20-
def __init__(self, environment_key: str, base_api_url: str, timeout: int = 3):
21+
def __init__(
22+
self, environment_key: str, base_api_url: str, timeout: typing.Optional[int] = 3
23+
):
2124
"""
2225
Initialise the AnalyticsProcessor to handle sending analytics on flag usage to
2326
the Flagsmith API.
@@ -30,10 +33,10 @@ def __init__(self, environment_key: str, base_api_url: str, timeout: int = 3):
3033
self.analytics_endpoint = base_api_url + ANALYTICS_ENDPOINT
3134
self.environment_key = environment_key
3235
self._last_flushed = datetime.now()
33-
self.analytics_data = {}
34-
self.timeout = timeout
36+
self.analytics_data: typing.MutableMapping[str, typing.Any] = {}
37+
self.timeout = timeout or 3
3538

36-
def flush(self):
39+
def flush(self) -> None:
3740
"""
3841
Sends all the collected data to the api asynchronously and resets the timer
3942
"""
@@ -53,7 +56,7 @@ def flush(self):
5356
self.analytics_data.clear()
5457
self._last_flushed = datetime.now()
5558

56-
def track_feature(self, feature_name: str):
59+
def track_feature(self, feature_name: str) -> None:
5760
self.analytics_data[feature_name] = self.analytics_data.get(feature_name, 0) + 1
5861
if (datetime.now() - self._last_flushed).seconds > ANALYTICS_TIMER:
5962
self.flush()

flagsmith/flagsmith.py

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,23 @@
2020
from flagsmith.offline_handlers import BaseOfflineHandler
2121
from flagsmith.polling_manager import EnvironmentDataPollingManager
2222
from flagsmith.streaming_manager import EventStreamManager, StreamEvent
23-
from flagsmith.utils.identities import generate_identities_data
23+
from flagsmith.utils.identities import Identity, generate_identities_data
2424

2525
logger = logging.getLogger(__name__)
2626

2727
DEFAULT_API_URL = "https://edge.api.flagsmith.com/api/v1/"
2828
DEFAULT_REALTIME_API_URL = "https://realtime.flagsmith.com/"
2929

30+
JsonType = typing.Union[
31+
None,
32+
int,
33+
str,
34+
bool,
35+
typing.List["JsonType"],
36+
typing.List[typing.Mapping[str, "JsonType"]],
37+
typing.Dict[str, "JsonType"],
38+
]
39+
3040

3141
class Flagsmith:
3242
"""A Flagsmith client.
@@ -45,19 +55,21 @@ class Flagsmith:
4555

4656
def __init__(
4757
self,
48-
environment_key: str = None,
49-
api_url: str = None,
58+
environment_key: typing.Optional[str] = None,
59+
api_url: typing.Optional[str] = None,
5060
realtime_api_url: typing.Optional[str] = None,
51-
custom_headers: typing.Dict[str, typing.Any] = None,
52-
request_timeout_seconds: int = None,
61+
custom_headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
62+
request_timeout_seconds: typing.Optional[int] = None,
5363
enable_local_evaluation: bool = False,
5464
environment_refresh_interval_seconds: typing.Union[int, float] = 60,
55-
retries: Retry = None,
65+
retries: typing.Optional[Retry] = None,
5666
enable_analytics: bool = False,
57-
default_flag_handler: typing.Callable[[str], DefaultFlag] = None,
58-
proxies: typing.Dict[str, str] = None,
67+
default_flag_handler: typing.Optional[
68+
typing.Callable[[str], DefaultFlag]
69+
] = None,
70+
proxies: typing.Optional[typing.Dict[str, str]] = None,
5971
offline_mode: bool = False,
60-
offline_handler: BaseOfflineHandler = None,
72+
offline_handler: typing.Optional[BaseOfflineHandler] = None,
6173
enable_realtime_updates: bool = False,
6274
):
6375
"""
@@ -94,8 +106,8 @@ def __init__(
94106
self.offline_handler = offline_handler
95107
self.default_flag_handler = default_flag_handler
96108
self.enable_realtime_updates = enable_realtime_updates
97-
self._analytics_processor = None
98-
self._environment = None
109+
self._analytics_processor: typing.Optional[AnalyticsProcessor] = None
110+
self._environment: typing.Optional[EnvironmentModel] = None
99111
self._identity_overrides_by_identifier: typing.Dict[str, IdentityModel] = {}
100112

101113
# argument validation
@@ -159,6 +171,9 @@ def __init__(
159171
def _initialise_local_evaluation(self) -> None:
160172
if self.enable_realtime_updates:
161173
self.update_environment()
174+
if not self._environment:
175+
raise ValueError("Unable to get environment from API key")
176+
162177
stream_url = f"{self.realtime_api_url}sse/environments/{self._environment.api_key}/stream"
163178

164179
self.event_stream_thread = EventStreamManager(
@@ -196,6 +211,10 @@ def handle_stream_event(self, event: StreamEvent) -> None:
196211
if stream_updated_at.tzinfo is None:
197212
stream_updated_at = pytz.utc.localize(stream_updated_at)
198213

214+
if not self._environment:
215+
raise ValueError(
216+
"Unable to access environment. Environment should not be null"
217+
)
199218
environment_updated_at = self._environment.updated_at
200219
if environment_updated_at.tzinfo is None:
201220
environment_updated_at = pytz.utc.localize(environment_updated_at)
@@ -214,7 +233,9 @@ def get_environment_flags(self) -> Flags:
214233
return self._get_environment_flags_from_api()
215234

216235
def get_identity_flags(
217-
self, identifier: str, traits: typing.Dict[str, typing.Any] = None
236+
self,
237+
identifier: str,
238+
traits: typing.Optional[typing.Mapping[str, TraitValue]] = None,
218239
) -> Flags:
219240
"""
220241
Get all the flags for the current environment for a given identity. Will also
@@ -233,7 +254,9 @@ def get_identity_flags(
233254
return self._get_identity_flags_from_api(identifier, traits)
234255

235256
def get_identity_segments(
236-
self, identifier: str, traits: typing.Dict[str, typing.Any] = None
257+
self,
258+
identifier: str,
259+
traits: typing.Optional[typing.Mapping[str, TraitValue]] = None,
237260
) -> typing.List[Segment]:
238261
"""
239262
Get a list of segments that the given identity is in.
@@ -255,7 +278,7 @@ def get_identity_segments(
255278
segment_models = get_identity_segments(self._environment, identity_model)
256279
return [Segment(id=sm.id, name=sm.name) for sm in segment_models]
257280

258-
def update_environment(self):
281+
def update_environment(self) -> None:
259282
self._environment = self._get_environment_from_api()
260283
self._update_overrides()
261284

@@ -272,16 +295,20 @@ def _get_environment_from_api(self) -> EnvironmentModel:
272295
return EnvironmentModel.model_validate(environment_data)
273296

274297
def _get_environment_flags_from_document(self) -> Flags:
298+
if self._environment is None:
299+
raise TypeError("No environment present")
275300
return Flags.from_feature_state_models(
276301
feature_states=engine.get_environment_feature_states(self._environment),
277302
analytics_processor=self._analytics_processor,
278303
default_flag_handler=self.default_flag_handler,
279304
)
280305

281306
def _get_identity_flags_from_document(
282-
self, identifier: str, traits: typing.Dict[str, typing.Any]
307+
self, identifier: str, traits: typing.Mapping[str, TraitValue]
283308
) -> Flags:
284309
identity_model = self._get_identity_model(identifier, **traits)
310+
if self._environment is None:
311+
raise TypeError("No environment present")
285312
feature_states = engine.get_identity_feature_states(
286313
self._environment, identity_model
287314
)
@@ -294,11 +321,11 @@ def _get_identity_flags_from_document(
294321

295322
def _get_environment_flags_from_api(self) -> Flags:
296323
try:
297-
api_flags = self._get_json_response(
298-
url=self.environment_flags_url, method="GET"
299-
)
324+
json_response: typing.List[
325+
typing.Mapping[str, JsonType]
326+
] = self._get_json_response(url=self.environment_flags_url, method="GET")
300327
return Flags.from_api_flags(
301-
api_flags=api_flags,
328+
api_flags=json_response,
302329
analytics_processor=self._analytics_processor,
303330
default_flag_handler=self.default_flag_handler,
304331
)
@@ -310,11 +337,13 @@ def _get_environment_flags_from_api(self) -> Flags:
310337
raise
311338

312339
def _get_identity_flags_from_api(
313-
self, identifier: str, traits: typing.Dict[str, typing.Any]
340+
self, identifier: str, traits: typing.Mapping[str, typing.Any]
314341
) -> Flags:
315342
try:
316343
data = generate_identities_data(identifier, traits)
317-
json_response = self._get_json_response(
344+
json_response: typing.Dict[
345+
str, typing.List[typing.Dict[str, JsonType]]
346+
] = self._get_json_response(
318347
url=self.identities_url, method="POST", body=data
319348
)
320349
return Flags.from_api_flags(
@@ -329,7 +358,14 @@ def _get_identity_flags_from_api(
329358
return Flags(default_flag_handler=self.default_flag_handler)
330359
raise
331360

332-
def _get_json_response(self, url: str, method: str, body: dict = None):
361+
def _get_json_response(
362+
self,
363+
url: str,
364+
method: str,
365+
body: typing.Optional[
366+
typing.Union[Identity, typing.Dict[str, JsonType]]
367+
] = None,
368+
) -> typing.Any:
333369
try:
334370
request_method = getattr(self.session, method.lower())
335371
response = request_method(
@@ -371,7 +407,7 @@ def _get_identity_model(
371407
identity_traits=trait_models,
372408
)
373409

374-
def __del__(self):
410+
def __del__(self) -> None:
375411
if hasattr(self, "environment_data_polling_manager_thread"):
376412
self.environment_data_polling_manager_thread.stop()
377413

flagsmith/models.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import typing
24
from dataclasses import dataclass, field
35

@@ -28,8 +30,8 @@ class Flag(BaseFlag):
2830
def from_feature_state_model(
2931
cls,
3032
feature_state_model: FeatureStateModel,
31-
identity_id: typing.Union[str, int] = None,
32-
) -> "Flag":
33+
identity_id: typing.Optional[typing.Union[str, int]] = None,
34+
) -> Flag:
3335
return Flag(
3436
enabled=feature_state_model.enabled,
3537
value=feature_state_model.get_value(identity_id=identity_id),
@@ -38,7 +40,7 @@ def from_feature_state_model(
3840
)
3941

4042
@classmethod
41-
def from_api_flag(cls, flag_data: dict) -> "Flag":
43+
def from_api_flag(cls, flag_data: typing.Mapping[str, typing.Any]) -> Flag:
4244
return Flag(
4345
enabled=flag_data["enabled"],
4446
value=flag_data["feature_state_value"],
@@ -50,17 +52,17 @@ def from_api_flag(cls, flag_data: dict) -> "Flag":
5052
@dataclass
5153
class Flags:
5254
flags: typing.Dict[str, Flag] = field(default_factory=dict)
53-
default_flag_handler: typing.Callable[[str], DefaultFlag] = None
54-
_analytics_processor: AnalyticsProcessor = None
55+
default_flag_handler: typing.Optional[typing.Callable[[str], DefaultFlag]] = None
56+
_analytics_processor: typing.Optional[AnalyticsProcessor] = None
5557

5658
@classmethod
5759
def from_feature_state_models(
5860
cls,
59-
feature_states: typing.List[FeatureStateModel],
60-
analytics_processor: AnalyticsProcessor,
61-
default_flag_handler: typing.Callable,
62-
identity_id: typing.Union[str, int] = None,
63-
) -> "Flags":
61+
feature_states: typing.Sequence[FeatureStateModel],
62+
analytics_processor: typing.Optional[AnalyticsProcessor],
63+
default_flag_handler: typing.Optional[typing.Callable[[str], DefaultFlag]],
64+
identity_id: typing.Optional[typing.Union[str, int]] = None,
65+
) -> Flags:
6466
flags = {
6567
feature_state.feature.name: Flag.from_feature_state_model(
6668
feature_state, identity_id=identity_id
@@ -77,10 +79,10 @@ def from_feature_state_models(
7779
@classmethod
7880
def from_api_flags(
7981
cls,
80-
api_flags: typing.List[dict],
81-
analytics_processor: AnalyticsProcessor,
82-
default_flag_handler: typing.Callable,
83-
) -> "Flags":
82+
api_flags: typing.Sequence[typing.Mapping[str, typing.Any]],
83+
analytics_processor: typing.Optional[AnalyticsProcessor],
84+
default_flag_handler: typing.Optional[typing.Callable[[str], DefaultFlag]],
85+
) -> Flags:
8486
flags = {
8587
flag_data["feature"]["name"]: Flag.from_api_flag(flag_data)
8688
for flag_data in api_flags
@@ -120,12 +122,12 @@ def get_feature_value(self, feature_name: str) -> typing.Any:
120122
"""
121123
return self.get_flag(feature_name).value
122124

123-
def get_flag(self, feature_name: str) -> BaseFlag:
125+
def get_flag(self, feature_name: str) -> typing.Union[DefaultFlag, Flag]:
124126
"""
125127
Get a specific flag given the feature name.
126128
127129
:param feature_name: the name of the feature to retrieve the flag for.
128-
:return: BaseFlag object.
130+
:return: DefaultFlag | Flag object.
129131
:raises FlagsmithClientError: if feature doesn't exist
130132
"""
131133
try:

0 commit comments

Comments
 (0)