20
20
from flagsmith .offline_handlers import BaseOfflineHandler
21
21
from flagsmith .polling_manager import EnvironmentDataPollingManager
22
22
from flagsmith .streaming_manager import EventStreamManager , StreamEvent
23
- from flagsmith .utils .identities import generate_identities_data
23
+ from flagsmith .utils .identities import Identity , generate_identities_data
24
24
25
25
logger = logging .getLogger (__name__ )
26
26
27
27
DEFAULT_API_URL = "https://edge.api.flagsmith.com/api/v1/"
28
28
DEFAULT_REALTIME_API_URL = "https://realtime.flagsmith.com/"
29
29
30
+ JsonType = typing .Union [
31
+ None ,
32
+ int ,
33
+ str ,
34
+ bool ,
35
+ typing .List ["JsonType" ],
36
+ typing .List [typing .Mapping [str , "JsonType" ]],
37
+ typing .Dict [str , "JsonType" ],
38
+ ]
39
+
30
40
31
41
class Flagsmith :
32
42
"""A Flagsmith client.
@@ -45,19 +55,21 @@ class Flagsmith:
45
55
46
56
def __init__ (
47
57
self ,
48
- environment_key : str = None ,
49
- api_url : str = None ,
58
+ environment_key : typing . Optional [ str ] = None ,
59
+ api_url : typing . Optional [ str ] = None ,
50
60
realtime_api_url : typing .Optional [str ] = None ,
51
- custom_headers : typing .Dict [str , typing .Any ] = None ,
52
- request_timeout_seconds : int = None ,
61
+ custom_headers : typing .Optional [ typing . Dict [str , typing .Any ] ] = None ,
62
+ request_timeout_seconds : typing . Optional [ int ] = None ,
53
63
enable_local_evaluation : bool = False ,
54
64
environment_refresh_interval_seconds : typing .Union [int , float ] = 60 ,
55
- retries : Retry = None ,
65
+ retries : typing . Optional [ Retry ] = None ,
56
66
enable_analytics : bool = False ,
57
- default_flag_handler : typing .Callable [[str ], DefaultFlag ] = None ,
58
- proxies : typing .Dict [str , str ] = None ,
67
+ default_flag_handler : typing .Optional [
68
+ typing .Callable [[str ], DefaultFlag ]
69
+ ] = None ,
70
+ proxies : typing .Optional [typing .Dict [str , str ]] = None ,
59
71
offline_mode : bool = False ,
60
- offline_handler : BaseOfflineHandler = None ,
72
+ offline_handler : typing . Optional [ BaseOfflineHandler ] = None ,
61
73
enable_realtime_updates : bool = False ,
62
74
):
63
75
"""
@@ -94,8 +106,8 @@ def __init__(
94
106
self .offline_handler = offline_handler
95
107
self .default_flag_handler = default_flag_handler
96
108
self .enable_realtime_updates = enable_realtime_updates
97
- self ._analytics_processor = None
98
- self ._environment = None
109
+ self ._analytics_processor : typing . Optional [ AnalyticsProcessor ] = None
110
+ self ._environment : typing . Optional [ EnvironmentModel ] = None
99
111
self ._identity_overrides_by_identifier : typing .Dict [str , IdentityModel ] = {}
100
112
101
113
# argument validation
@@ -159,6 +171,9 @@ def __init__(
159
171
def _initialise_local_evaluation (self ) -> None :
160
172
if self .enable_realtime_updates :
161
173
self .update_environment ()
174
+ if not self ._environment :
175
+ raise ValueError ("Unable to get environment from API key" )
176
+
162
177
stream_url = f"{ self .realtime_api_url } sse/environments/{ self ._environment .api_key } /stream"
163
178
164
179
self .event_stream_thread = EventStreamManager (
@@ -196,6 +211,10 @@ def handle_stream_event(self, event: StreamEvent) -> None:
196
211
if stream_updated_at .tzinfo is None :
197
212
stream_updated_at = pytz .utc .localize (stream_updated_at )
198
213
214
+ if not self ._environment :
215
+ raise ValueError (
216
+ "Unable to access environment. Environment should not be null"
217
+ )
199
218
environment_updated_at = self ._environment .updated_at
200
219
if environment_updated_at .tzinfo is None :
201
220
environment_updated_at = pytz .utc .localize (environment_updated_at )
@@ -214,7 +233,9 @@ def get_environment_flags(self) -> Flags:
214
233
return self ._get_environment_flags_from_api ()
215
234
216
235
def get_identity_flags (
217
- self , identifier : str , traits : typing .Dict [str , typing .Any ] = None
236
+ self ,
237
+ identifier : str ,
238
+ traits : typing .Optional [typing .Mapping [str , TraitValue ]] = None ,
218
239
) -> Flags :
219
240
"""
220
241
Get all the flags for the current environment for a given identity. Will also
@@ -233,7 +254,9 @@ def get_identity_flags(
233
254
return self ._get_identity_flags_from_api (identifier , traits )
234
255
235
256
def get_identity_segments (
236
- self , identifier : str , traits : typing .Dict [str , typing .Any ] = None
257
+ self ,
258
+ identifier : str ,
259
+ traits : typing .Optional [typing .Mapping [str , TraitValue ]] = None ,
237
260
) -> typing .List [Segment ]:
238
261
"""
239
262
Get a list of segments that the given identity is in.
@@ -255,7 +278,7 @@ def get_identity_segments(
255
278
segment_models = get_identity_segments (self ._environment , identity_model )
256
279
return [Segment (id = sm .id , name = sm .name ) for sm in segment_models ]
257
280
258
- def update_environment (self ):
281
+ def update_environment (self ) -> None :
259
282
self ._environment = self ._get_environment_from_api ()
260
283
self ._update_overrides ()
261
284
@@ -272,16 +295,20 @@ def _get_environment_from_api(self) -> EnvironmentModel:
272
295
return EnvironmentModel .model_validate (environment_data )
273
296
274
297
def _get_environment_flags_from_document (self ) -> Flags :
298
+ if self ._environment is None :
299
+ raise TypeError ("No environment present" )
275
300
return Flags .from_feature_state_models (
276
301
feature_states = engine .get_environment_feature_states (self ._environment ),
277
302
analytics_processor = self ._analytics_processor ,
278
303
default_flag_handler = self .default_flag_handler ,
279
304
)
280
305
281
306
def _get_identity_flags_from_document (
282
- self , identifier : str , traits : typing .Dict [str , typing . Any ]
307
+ self , identifier : str , traits : typing .Mapping [str , TraitValue ]
283
308
) -> Flags :
284
309
identity_model = self ._get_identity_model (identifier , ** traits )
310
+ if self ._environment is None :
311
+ raise TypeError ("No environment present" )
285
312
feature_states = engine .get_identity_feature_states (
286
313
self ._environment , identity_model
287
314
)
@@ -294,11 +321,11 @@ def _get_identity_flags_from_document(
294
321
295
322
def _get_environment_flags_from_api (self ) -> Flags :
296
323
try :
297
- api_flags = self . _get_json_response (
298
- url = self . environment_flags_url , method = "GET"
299
- )
324
+ json_response : typing . List [
325
+ typing . Mapping [ str , JsonType ]
326
+ ] = self . _get_json_response ( url = self . environment_flags_url , method = "GET" )
300
327
return Flags .from_api_flags (
301
- api_flags = api_flags ,
328
+ api_flags = json_response ,
302
329
analytics_processor = self ._analytics_processor ,
303
330
default_flag_handler = self .default_flag_handler ,
304
331
)
@@ -310,11 +337,13 @@ def _get_environment_flags_from_api(self) -> Flags:
310
337
raise
311
338
312
339
def _get_identity_flags_from_api (
313
- self , identifier : str , traits : typing .Dict [str , typing .Any ]
340
+ self , identifier : str , traits : typing .Mapping [str , typing .Any ]
314
341
) -> Flags :
315
342
try :
316
343
data = generate_identities_data (identifier , traits )
317
- json_response = self ._get_json_response (
344
+ json_response : typing .Dict [
345
+ str , typing .List [typing .Dict [str , JsonType ]]
346
+ ] = self ._get_json_response (
318
347
url = self .identities_url , method = "POST" , body = data
319
348
)
320
349
return Flags .from_api_flags (
@@ -329,7 +358,14 @@ def _get_identity_flags_from_api(
329
358
return Flags (default_flag_handler = self .default_flag_handler )
330
359
raise
331
360
332
- def _get_json_response (self , url : str , method : str , body : dict = None ):
361
+ def _get_json_response (
362
+ self ,
363
+ url : str ,
364
+ method : str ,
365
+ body : typing .Optional [
366
+ typing .Union [Identity , typing .Dict [str , JsonType ]]
367
+ ] = None ,
368
+ ) -> typing .Any :
333
369
try :
334
370
request_method = getattr (self .session , method .lower ())
335
371
response = request_method (
@@ -371,7 +407,7 @@ def _get_identity_model(
371
407
identity_traits = trait_models ,
372
408
)
373
409
374
- def __del__ (self ):
410
+ def __del__ (self ) -> None :
375
411
if hasattr (self , "environment_data_polling_manager_thread" ):
376
412
self .environment_data_polling_manager_thread .stop ()
377
413
0 commit comments