Skip to content

Commit 7ffb878

Browse files
authored
Merge pull request #6 from AnswerDotAI/reconnect
Add reconnection support to GatewayClient
2 parents c74b093 + d3383e5 commit 7ffb878

4 files changed

Lines changed: 193 additions & 103 deletions

File tree

.github/workflows/test.yaml

Lines changed: 0 additions & 7 deletions
This file was deleted.

cordslite/_modidx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
'cordslite.core.GatewayClient._connect': ('core.html#gatewayclient._connect', 'cordslite/core.py'),
5353
'cordslite.core.GatewayClient._hb': ('core.html#gatewayclient._hb', 'cordslite/core.py'),
5454
'cordslite.core.GatewayClient._listen': ('core.html#gatewayclient._listen', 'cordslite/core.py'),
55+
'cordslite.core.GatewayClient._reconnect': ('core.html#gatewayclient._reconnect', 'cordslite/core.py'),
5556
'cordslite.core.GatewayClient.on': ('core.html#gatewayclient.on', 'cordslite/core.py'),
5657
'cordslite.core.GatewayClient.recv_evt': ('core.html#gatewayclient.recv_evt', 'cordslite/core.py'),
5758
'cordslite.core.GatewayClient.start': ('core.html#gatewayclient.start', 'cordslite/core.py'),
@@ -77,6 +78,7 @@
7778
'cordslite.core.Op.__repr__': ('core.html#op.__repr__', 'cordslite/core.py'),
7879
'cordslite.core.Op.heartbeat': ('core.html#op.heartbeat', 'cordslite/core.py'),
7980
'cordslite.core.Op.identify': ('core.html#op.identify', 'cordslite/core.py'),
81+
'cordslite.core.Op.resume': ('core.html#op.resume', 'cordslite/core.py'),
8082
'cordslite.core.Op.select_protocol': ('core.html#op.select_protocol', 'cordslite/core.py'),
8183
'cordslite.core.Op.speaking': ('core.html#op.speaking', 'cordslite/core.py'),
8284
'cordslite.core.Op.voice_heartbeat': ('core.html#op.voice_heartbeat', 'cordslite/core.py'),

cordslite/core.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -213,19 +213,15 @@ async def members(self:Guild, limit=100):
213213
data = (await self.client._req('GET', f'/guilds/{self.id}/members', params={'limit': limit})).json()
214214
return Members(Member(d, self.client) for d in data)
215215

216-
# %% ../nbs/00_core.ipynb #7f1f8399
216+
# %% ../nbs/00_core.ipynb #a301c00c
217217
class GatewayClient:
218218
def __init__(self, intents, client, token=None):
219-
self.intents = intents
220-
self.dc = client
219+
self.intents,self.dc = intents,client
221220
self.token = token or os.environ['DISCORD_BOT_TOKEN']
222-
self.ws = None
223-
self.hb_int = None
224-
self.session_id = None
225-
self.seq = None
221+
self.ws = self.hb_int = self.session_id = self.seq = None
226222
self.running = False
227-
gw_info = httpx.get('https://discord.com/api/v10/gateway/bot',
228-
headers={'Authorization': f'Bot {self.token}'}).json()
223+
gw_info = httpx.get(f'{client.base_url}/gateway/bot', headers={'Authorization': f'Bot {self.token}'}).json()
224+
if 'url' not in gw_info: raise ConnectionError(f"Gateway auth failed: {gw_info.get('message', gw_info)}")
229225
self.url = f"{gw_info['url']}?v=10&encoding=json"
230226
def __repr__(self): return f"GatewayClient({self.intents=}, {self.url=})"
231227

@@ -236,6 +232,8 @@ def __repr__(self): return f"Op(op={self.op}, d={self.d})"
236232
def identify(cls, token, intents): return cls(op=2, d=AttrDict(token=token, intents=intents, properties=dict(os='linux', browser='discord_wrapper', device='discord_wrapper')))
237233
@classmethod
238234
def heartbeat(cls, seq): return cls(op=1, d=seq)
235+
@classmethod
236+
def resume(cls, token, session_id, seq): return cls(op=6, d=AttrDict(token=token, session_id=session_id, seq=seq))
239237

240238
# %% ../nbs/00_core.ipynb #6932c93e
241239
evt_typs = {'MESSAGE_CREATE': Message,
@@ -271,6 +269,7 @@ async def _connect(self:GatewayClient):
271269
await self.ws.send(Op.identify(self.token, self.intents))
272270
rdy = Event(json.loads(await self.ws.recv()), self.dc)
273271
self.session_id, self.user_id, self.seq = rdy.d['session_id'], rdy.d['user']['id'], rdy.seq
272+
self.resume_url = rdy.d.get('resume_gateway_url', self.url)
274273
print(f"Connected! Session: {self.session_id}, heartbeat: {self.hb_int}ms")
275274
return rdy
276275

@@ -285,8 +284,15 @@ async def recv_evt(self:GatewayClient):
285284
@patch
286285
async def _listen(self:GatewayClient):
287286
while self.running:
288-
evt = await self.recv_evt()
289-
if evt.op == 0 and evt.type in getattr(self, 'handlers', {}):
287+
try: evt = await self.recv_evt()
288+
except Exception as e:
289+
print(f"Listen error: {e}")
290+
if self.running: await self._reconnect()
291+
return
292+
if evt.op == 11: self._got_ack = True
293+
elif evt.op == 1: await self.ws.send(Op.heartbeat(self.seq))
294+
elif evt.op == 7: await self._reconnect()
295+
elif evt.op == 0 and evt.type in getattr(self, 'handlers', {}):
290296
asyncio.create_task(self.handlers[evt.type](evt.d))
291297

292298
@patch
@@ -295,11 +301,29 @@ def on(self:GatewayClient, event_type, handler):
295301
self.handlers[event_type] = handler
296302

297303
# %% ../nbs/00_core.ipynb #3a138dab
304+
@patch
305+
async def _reconnect(self:GatewayClient):
306+
print("Reconnecting...")
307+
if self.ws: await self.ws.close()
308+
self.ws = await websockets.connect(self.resume_url)
309+
hello = json.loads(await self.ws.recv())
310+
self.hb_int = hello['d']['heartbeat_interval']
311+
await self.ws.send(Op.resume(self.token, self.session_id, self.seq))
312+
self._got_ack = True
313+
if hasattr(self, '_hb_task') and self._hb_task: self._hb_task.cancel()
314+
self._hb_task = asyncio.create_task(self._hb())
315+
print(f"Resumed session {self.session_id}")
316+
298317
@patch
299318
async def _hb(self:GatewayClient):
300319
await asyncio.sleep(self.hb_int / 1_000 * random.random())
301320
while self.running:
302-
await self.ws.send(Op.heartbeat(self.seq))
321+
if hasattr(self, '_got_ack') and not self._got_ack:
322+
print("Missed heartbeat ACK — reconnecting...")
323+
return await self._reconnect()
324+
self._got_ack = False
325+
try: await self.ws.send(Op.heartbeat(self.seq))
326+
except Exception as e: return print(f"Heartbeat send error: {e}")
303327
await asyncio.sleep(self.hb_int / 1_000)
304328

305329
# %% ../nbs/00_core.ipynb #4afce159

0 commit comments

Comments
 (0)