-
-
Notifications
You must be signed in to change notification settings - Fork 77
Expand file tree
/
Copy pathsession.py
More file actions
238 lines (188 loc) Β· 8.68 KB
/
Copy pathsession.py
File metadata and controls
238 lines (188 loc) Β· 8.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import typing as t
from datetime import timedelta
from authlib.common.security import generate_token
from authlib.oauth2.rfc7636 import create_s256_code_challenge
from authlib.jose import jwt
import json
from atproto_client.client.methods_mixin.time import TimeMethodsMixin
from atproto_client.client.session import (
AsyncSessionChangeCallback,
Session,
SessionChangeCallback,
SessionDispatcher,
SessionEvent,
SessionResponse,
get_session_pds_endpoint,
)
from atproto_client.exceptions import LoginRequiredError
class SessionDispatchMixin:
def on_session_change(self, callback: SessionChangeCallback) -> None:
"""Register a callback for session change event.
Args:
callback: A callback to be called when the session changes.
The callback must accept two arguments: event and session.
Note:
Possible events: `SessionEvent.IMPORT`, `SessionEvent.CREATE`, `SessionEvent.REFRESH`.
Tip:
You should save the session string to persistent storage
on `SessionEvent.CREATE` and `SessionEvent.REFRESH` event.
Example:
>>> from atproto import Client, SessionEvent, Session
>>>
>>> client = Client()
>>>
>>> @client.on_session_change
>>> def on_session_change(event: SessionEvent, session: Session):
>>> print(event, session)
>>>
>>> # or you can use this syntax:
>>> # client.on_session_change(on_session_change)
Returns:
:obj:`None`
"""
self._session_dispatcher.on_session_change(callback)
def _call_on_session_change_callbacks(self, event: SessionEvent) -> None:
self._session_dispatcher.dispatch_session_change(event)
class AsyncSessionDispatchMixin:
def on_session_change(self, callback: t.Union['AsyncSessionChangeCallback', 'SessionChangeCallback']) -> None:
"""Register a callback for session change event.
Args:
callback: A callback to be called when the session changes.
The callback must accept two arguments: event and session.
Note:
Possible events: `SessionEvent.IMPORT`, `SessionEvent.CREATE`, `SessionEvent.REFRESH`.
Note:
You can register both synchronous and asynchronous callbacks.
Tip:
You should save the session string to persistent storage
on `SessionEvent.CREATE` and `SessionEvent.REFRESH` event.
Example:
>>> from atproto import AsyncClient, SessionEvent, Session
>>>
>>> client = AsyncClient()
>>>
>>> @client.on_session_change
>>> async def on_session_change(event: SessionEvent, session: Session):
>>> print(event, session)
>>>
>>> # or you can use this syntax:
>>> # client.on_session_change(on_session_change)
Returns:
:obj:`None`
"""
self._session_dispatcher.on_session_change(callback)
async def _call_on_session_change_callbacks(self, event: SessionEvent) -> None:
await self._session_dispatcher.dispatch_session_change_async(event)
class SessionMethodsMixin(TimeMethodsMixin):
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
self._session: t.Optional[Session] = None
self._session_dispatcher = SessionDispatcher()
def _register_auth_headers_source(self) -> None:
self.request.add_additional_headers_source(
self._get_access_auth_headers)
def _should_refresh_session(self) -> bool:
if not self._session:
raise LoginRequiredError
if self._session.static_access_token or self._session.static_dpop_token:
return False
if self._session.access_jwt is None or self._session.access_jwt_payload is None or self._session.access_jwt_payload.exp is None:
raise LoginRequiredError
expired_at = self.get_time_from_timestamp(
self._session.access_jwt_payload.exp)
# let's update the token a bit earlier than required
expired_at = expired_at - timedelta(minutes=15)
return self.get_current_time() > expired_at
def _set_or_update_session(self, session: SessionResponse, pds_endpoint: str) -> 'Session':
if not self._session:
if isinstance(session, Session):
self._session = session
else:
self._session = Session(
access_jwt=session.access_jwt,
refresh_jwt=session.refresh_jwt,
did=session.did,
handle=session.handle,
pds_endpoint=pds_endpoint,
)
self._session_dispatcher.set_session(self._session)
self._register_auth_headers_source()
else:
self._session.access_jwt = session.access_jwt
self._session.refresh_jwt = session.refresh_jwt
self._session.did = session.did
self._session.handle = session.handle
self._session.pds_endpoint = pds_endpoint
return self._session
def _set_session_common(self, session: SessionResponse, current_pds: str) -> Session:
pds_endpoint = get_session_pds_endpoint(session)
if not pds_endpoint:
# current_pds ends with xrpc endpoint, but this is not a problem
# overhead is only 4-5 symbols in the exported session string
pds_endpoint = current_pds
self._update_pds_endpoint(pds_endpoint)
return self._set_or_update_session(session, pds_endpoint)
def _get_access_auth_headers(self, *args: t.Any, **kwargs: t.Any) -> t.Dict[str, str]:
if not self._session:
return {}
if self._session.static_access_token is not None:
return {'Authorization': f'Bearer {self._session.static_access_token}'}
if self._session.static_dpop_token is not None and self._session.static_dpop_jwk is not None:
htm = kwargs.get("method", "")
htu = kwargs.get("url", "")
dpop_pub_jwk = json.loads(
self._session.static_dpop_jwk.as_json(is_private=False))
now = self.get_current_time().timestamp()
body = {
"iss": self._session.static_dpop_issuer,
"iat": int(now),
"exp": int(now) + 10,
"jti": generate_token(),
"htm": htm,
"htu": htu,
"ath": create_s256_code_challenge(self._session.static_dpop_token),
}
if self._session.static_dpop_nonce is not None:
body["nonce"] = self._session.static_dpop_nonce
dpop_jwt_encoded = dpop_proof = jwt.encode(
{"typ": "dpop+jwt", "alg": "ES256", "jwk": dpop_pub_jwk}, body, self._session.static_dpop_jwk).decode("utf-8")
return {
"Authorization": f"DPoP {self._session.static_dpop_token}",
"DPoP": dpop_jwt_encoded,
}
return {'Authorization': f'Bearer {self._session.access_jwt}'}
def _get_refresh_auth_headers(self) -> t.Dict[str, str]:
if not self._session:
return {}
return {'Authorization': f'Bearer {self._session.refresh_jwt}'}
def _update_pds_endpoint(self, pds_endpoint: str) -> None:
self.update_base_url(pds_endpoint)
def export_session_string(self) -> str:
"""Export session string.
Note:
This method is useful for storing the session and reusing it later.
Warning:
You should use it if you create the client instance often.
Because of server rate limits for `createSession`.
Rate limited by handle.
30/5 min, 300/day.
Attention:
You must export session at the end of the Client`s life cycle!
Alternatively, you can subscribe to the session change event.
Use :py:attr:`~on_session_change` to register handler.
Example:
>>> from atproto import Client
>>> # the first time login with login and password
>>> client = Client()
>>> client.login('login', 'password')
>>> session_string = client.export_session_string()
>>> # store session_string somewhere.
>>> # for example, in env and next time use it for login
>>> client2 = Client()
>>> client2.login(session_string=session_string)
Returns:
:obj:`str`: Session string.
"""
if not self._session:
raise LoginRequiredError
return self._session.export()