@@ -213,19 +213,15 @@ async def members(self:Guild, limit=100):
213213 data = (await self .client ._req ('GET' , f'/guilds/{ self .id } /members' , params = {'limit' : limit })).json ()
214214 return Members (Member (d , self .client ) for d in data )
215215
216- # %% ../nbs/00_core.ipynb #7f1f8399
216+ # %% ../nbs/00_core.ipynb #a301c00c
217217class GatewayClient :
218218 def __init__ (self , intents , client , token = None ):
219- self .intents = intents
220- self .dc = client
219+ self .intents ,self .dc = intents ,client
221220 self .token = token or os .environ ['DISCORD_BOT_TOKEN' ]
222- self .ws = None
223- self .hb_int = None
224- self .session_id = None
225- self .seq = None
221+ self .ws = self .hb_int = self .session_id = self .seq = None
226222 self .running = False
227- gw_info = httpx .get ('https://discord.com/api/v10/ gateway/bot' ,
228- headers = { 'Authorization' : f'Bot { self . token } ' }). json ( )
223+ gw_info = httpx .get (f' { client . base_url } / gateway/bot' , headers = { 'Authorization' : f'Bot { self . token } ' }). json ()
224+ if 'url' not in gw_info : raise ConnectionError ( f"Gateway auth failed: { gw_info . get ( 'message' , gw_info ) } " )
229225 self .url = f"{ gw_info ['url' ]} ?v=10&encoding=json"
230226 def __repr__ (self ): return f"GatewayClient({ self .intents = } , { self .url = } )"
231227
@@ -236,6 +232,8 @@ def __repr__(self): return f"Op(op={self.op}, d={self.d})"
236232 def identify (cls , token , intents ): return cls (op = 2 , d = AttrDict (token = token , intents = intents , properties = dict (os = 'linux' , browser = 'discord_wrapper' , device = 'discord_wrapper' )))
237233 @classmethod
238234 def heartbeat (cls , seq ): return cls (op = 1 , d = seq )
235+ @classmethod
236+ def resume (cls , token , session_id , seq ): return cls (op = 6 , d = AttrDict (token = token , session_id = session_id , seq = seq ))
239237
240238# %% ../nbs/00_core.ipynb #6932c93e
241239evt_typs = {'MESSAGE_CREATE' : Message ,
@@ -271,6 +269,7 @@ async def _connect(self:GatewayClient):
271269 await self .ws .send (Op .identify (self .token , self .intents ))
272270 rdy = Event (json .loads (await self .ws .recv ()), self .dc )
273271 self .session_id , self .user_id , self .seq = rdy .d ['session_id' ], rdy .d ['user' ]['id' ], rdy .seq
272+ self .resume_url = rdy .d .get ('resume_gateway_url' , self .url )
274273 print (f"Connected! Session: { self .session_id } , heartbeat: { self .hb_int } ms" )
275274 return rdy
276275
@@ -285,8 +284,15 @@ async def recv_evt(self:GatewayClient):
285284@patch
286285async def _listen (self :GatewayClient ):
287286 while self .running :
288- evt = await self .recv_evt ()
289- if evt .op == 0 and evt .type in getattr (self , 'handlers' , {}):
287+ try : evt = await self .recv_evt ()
288+ except Exception as e :
289+ print (f"Listen error: { e } " )
290+ if self .running : await self ._reconnect ()
291+ return
292+ if evt .op == 11 : self ._got_ack = True
293+ elif evt .op == 1 : await self .ws .send (Op .heartbeat (self .seq ))
294+ elif evt .op == 7 : await self ._reconnect ()
295+ elif evt .op == 0 and evt .type in getattr (self , 'handlers' , {}):
290296 asyncio .create_task (self .handlers [evt .type ](evt .d ))
291297
292298@patch
@@ -295,11 +301,29 @@ def on(self:GatewayClient, event_type, handler):
295301 self .handlers [event_type ] = handler
296302
297303# %% ../nbs/00_core.ipynb #3a138dab
304+ @patch
305+ async def _reconnect (self :GatewayClient ):
306+ print ("Reconnecting..." )
307+ if self .ws : await self .ws .close ()
308+ self .ws = await websockets .connect (self .resume_url )
309+ hello = json .loads (await self .ws .recv ())
310+ self .hb_int = hello ['d' ]['heartbeat_interval' ]
311+ await self .ws .send (Op .resume (self .token , self .session_id , self .seq ))
312+ self ._got_ack = True
313+ if hasattr (self , '_hb_task' ) and self ._hb_task : self ._hb_task .cancel ()
314+ self ._hb_task = asyncio .create_task (self ._hb ())
315+ print (f"Resumed session { self .session_id } " )
316+
298317@patch
299318async def _hb (self :GatewayClient ):
300319 await asyncio .sleep (self .hb_int / 1_000 * random .random ())
301320 while self .running :
302- await self .ws .send (Op .heartbeat (self .seq ))
321+ if hasattr (self , '_got_ack' ) and not self ._got_ack :
322+ print ("Missed heartbeat ACK — reconnecting..." )
323+ return await self ._reconnect ()
324+ self ._got_ack = False
325+ try : await self .ws .send (Op .heartbeat (self .seq ))
326+ except Exception as e : return print (f"Heartbeat send error: { e } " )
303327 await asyncio .sleep (self .hb_int / 1_000 )
304328
305329# %% ../nbs/00_core.ipynb #4afce159
0 commit comments