@@ -11,10 +11,12 @@ def __init__(self):
11
11
try :
12
12
self .db_location = os .environ ["FASTAPI_SIMPLE_SECURITY_DB_LOCATION" ]
13
13
except KeyError :
14
- self .db_location = "app/ sqlite.db"
14
+ self .db_location = "sqlite.db"
15
15
16
16
try :
17
- self .expiration_limit = int (os .environ ["FAST_API_SIMPLE_SECURITY_AUTOMATIC_EXPIRATION" ])
17
+ self .expiration_limit = int (
18
+ os .environ ["FAST_API_SIMPLE_SECURITY_AUTOMATIC_EXPIRATION" ]
19
+ )
18
20
except KeyError :
19
21
self .expiration_limit = 15
20
22
@@ -51,7 +53,9 @@ def create_key(self, never_expire) -> str:
51
53
api_key ,
52
54
1 ,
53
55
1 if never_expire else 0 ,
54
- (datetime .utcnow () + timedelta (days = self .expiration_limit )).isoformat (timespec = "seconds" ),
56
+ (
57
+ datetime .utcnow () + timedelta (days = self .expiration_limit )
58
+ ).isoformat (timespec = "seconds" ),
55
59
None ,
56
60
0 ,
57
61
),
@@ -83,34 +87,47 @@ def renew_key(self, api_key: str, new_expiration_date: str) -> Optional[str]:
83
87
84
88
# Previously revoked key. Issue a text warning and reactivate it.
85
89
if response [0 ] == 0 :
86
- response_lines .append ("This API key was revoked and has been reactivated." )
90
+ response_lines .append (
91
+ "This API key was revoked and has been reactivated."
92
+ )
87
93
# Expired key. Issue a text warning and reactivate it.
88
- if (not response [3 ]) and (datetime .fromisoformat (response [2 ]) < datetime .utcnow ()):
94
+ if (not response [3 ]) and (
95
+ datetime .fromisoformat (response [2 ]) < datetime .utcnow ()
96
+ ):
89
97
response_lines .append ("This API key was expired and is now renewed." )
90
98
91
99
if not new_expiration_date :
92
- parsed_expiration_date = (datetime . utcnow () + timedelta ( days = self . expiration_limit )). isoformat (
93
- timespec = "seconds"
94
- )
100
+ parsed_expiration_date = (
101
+ datetime . utcnow () + timedelta ( days = self . expiration_limit )
102
+ ). isoformat ( timespec = "seconds" )
95
103
else :
96
104
try :
97
105
# We parse and re-write to the right timespec
98
- parsed_expiration_date = datetime .fromisoformat (new_expiration_date ).isoformat (timespec = "seconds" )
106
+ parsed_expiration_date = datetime .fromisoformat (
107
+ new_expiration_date
108
+ ).isoformat (timespec = "seconds" )
99
109
except ValueError :
100
- return "The expiration date could not be parsed. Please use ISO 8601."
110
+ return (
111
+ "The expiration date could not be parsed. Please use ISO 8601."
112
+ )
101
113
102
114
c .execute (
103
115
"""
104
116
UPDATE fastapi_simple_security
105
117
SET expiration_date = ?, is_active = 1
106
118
WHERE api_key = ?
107
119
""" ,
108
- (parsed_expiration_date , api_key ,),
120
+ (
121
+ parsed_expiration_date ,
122
+ api_key ,
123
+ ),
109
124
)
110
125
111
126
connection .commit ()
112
127
113
- response_lines .append (f"The new expiration date for the API key is { parsed_expiration_date } " )
128
+ response_lines .append (
129
+ f"The new expiration date for the API key is { parsed_expiration_date } "
130
+ )
114
131
115
132
return " " .join (response_lines )
116
133
@@ -162,15 +179,24 @@ def check_key(self, api_key: str) -> bool:
162
179
# Inactive
163
180
or response [0 ] != 1
164
181
# Expired key
165
- or ((not response [3 ]) and (datetime .fromisoformat (response [2 ]) < datetime .utcnow ()))
182
+ or (
183
+ (not response [3 ])
184
+ and (datetime .fromisoformat (response [2 ]) < datetime .utcnow ())
185
+ )
166
186
):
167
187
# The key is not valid
168
188
return False
169
189
else :
170
190
# The key is valid
171
191
172
192
# We run the logging in a separate thread as writing takes some time
173
- threading .Thread (target = self ._update_usage , args = (api_key , response [1 ],)).start ()
193
+ threading .Thread (
194
+ target = self ._update_usage ,
195
+ args = (
196
+ api_key ,
197
+ response [1 ],
198
+ ),
199
+ ).start ()
174
200
175
201
# We return directly
176
202
return True
@@ -186,7 +212,11 @@ def _update_usage(self, api_key: str, usage_count: int):
186
212
SET total_queries = ?, latest_query_date = ?
187
213
WHERE api_key = ?
188
214
""" ,
189
- (usage_count + 1 , datetime .utcnow ().isoformat (timespec = "seconds" ), api_key ),
215
+ (
216
+ usage_count + 1 ,
217
+ datetime .utcnow ().isoformat (timespec = "seconds" ),
218
+ api_key ,
219
+ ),
190
220
)
191
221
192
222
connection .commit ()
@@ -201,7 +231,6 @@ def get_usage_stats(self) -> List[Tuple[str, int, str, str, int]]:
201
231
with sqlite3 .connect (self .db_location ) as connection :
202
232
c = connection .cursor ()
203
233
204
- # TODO Add filtering somehow
205
234
c .execute (
206
235
"""
207
236
SELECT api_key, is_active, never_expire, expiration_date, latest_query_date, total_queries
0 commit comments