1- use std:: collections:: HashMap ;
21use std:: num:: NonZeroU16 ;
2+ use std:: sync:: Arc ;
33#[ cfg( feature = "framework" ) ]
44use std:: sync:: OnceLock ;
5- use std:: sync:: { Arc , Mutex } ;
65use std:: time:: { Duration , Instant } ;
76
7+ use dashmap:: DashMap ;
88use futures:: StreamExt ;
99use futures:: channel:: mpsc:: { self , UnboundedReceiver as Receiver , UnboundedSender as Sender } ;
1010use tokio:: time:: { sleep, timeout} ;
@@ -67,7 +67,7 @@ pub struct ShardManager {
6767 ///
6868 /// **Note**: It is highly recommended to not mutate this yourself unless you need to. Instead
6969 /// prefer to use methods on this struct that are provided where possible.
70- pub runners : HashMap < ShardId , ( Arc < Mutex < ShardRunnerInfo > > , Sender < ShardRunnerMessage > ) > ,
70+ pub runners : Arc < DashMap < ShardId , ( ShardRunnerInfo , Sender < ShardRunnerMessage > ) > > ,
7171 /// A copy of the client's voice manager.
7272 #[ cfg( feature = "voice" ) ]
7373 pub voice_manager : Option < Arc < dyn VoiceGatewayManager + ' static > > ,
@@ -103,7 +103,7 @@ impl ShardManager {
103103 framework : opt. framework ,
104104 last_start : None ,
105105 queue : ShardQueue :: new ( opt. max_concurrency ) ,
106- runners : HashMap :: new ( ) ,
106+ runners : Arc :: new ( DashMap :: new ( ) ) ,
107107 #[ cfg( feature = "voice" ) ]
108108 voice_manager : opt. voice_manager ,
109109 ws_url : opt. ws_url ,
@@ -187,7 +187,7 @@ impl ShardManager {
187187 pub fn restart ( & mut self , shard_id : ShardId ) {
188188 info ! ( "Restarting shard {shard_id}" ) ;
189189
190- if let Some ( ( _, tx ) ) = self . runners . remove ( & shard_id) {
190+ if let Some ( ( _, ( _ , tx ) ) ) = self . runners . remove ( & shard_id) {
191191 if let Err ( why) = tx. unbounded_send ( ShardRunnerMessage :: Restart ) {
192192 warn ! ( "Failed to send restart signal to shard {shard_id}: {why:?}" ) ;
193193 }
@@ -203,7 +203,7 @@ impl ShardManager {
203203 pub fn shutdown ( & mut self , shard_id : ShardId , code : u16 ) {
204204 info ! ( "Shutting down shard {}" , shard_id) ;
205205
206- if let Some ( ( _, tx ) ) = self . runners . remove ( & shard_id) {
206+ if let Some ( ( _, ( _ , tx ) ) ) = self . runners . remove ( & shard_id) {
207207 if let Err ( why) = tx. unbounded_send ( ShardRunnerMessage :: Shutdown ( code) ) {
208208 warn ! ( "Failed to send shutdown signal to shard {shard_id}: {why:?}" ) ;
209209 }
@@ -263,18 +263,13 @@ impl ShardManager {
263263 let cloned_http = Arc :: clone ( & self . http ) ;
264264 shard. set_application_id_callback ( move |id| cloned_http. set_application_id ( id) ) ;
265265
266- let runner_info = Arc :: new ( Mutex :: new ( ShardRunnerInfo {
267- latency : None ,
268- stage : ConnectionStage :: Disconnected ,
269- } ) ) ;
270-
271266 let mut runner = ShardRunner :: new ( ShardRunnerOptions {
272267 data : Arc :: clone ( & self . data ) ,
273268 event_handler : self . event_handler . clone ( ) ,
274269 raw_event_handler : self . raw_event_handler . clone ( ) ,
275270 #[ cfg( feature = "framework" ) ]
276271 framework : self . framework . get ( ) . cloned ( ) ,
277- runner_info : Arc :: clone ( & runner_info ) ,
272+ runners : Arc :: clone ( & self . runners ) ,
278273 manager_tx : self . manager_tx . clone ( ) ,
279274 #[ cfg( feature = "voice" ) ]
280275 voice_manager : self . voice_manager . clone ( ) ,
@@ -284,6 +279,11 @@ impl ShardManager {
284279 http : Arc :: clone ( & self . http ) ,
285280 } ) ;
286281
282+ let runner_info = ShardRunnerInfo {
283+ latency : None ,
284+ stage : ConnectionStage :: Disconnected ,
285+ } ;
286+
287287 self . runners . insert ( shard_id, ( runner_info, runner. runner_tx ( ) ) ) ;
288288
289289 spawn_named ( "shard_runner::run" , async move { runner. run ( ) . await } ) ;
@@ -305,17 +305,7 @@ impl ShardManager {
305305 #[ cfg_attr( feature = "tracing_instrument" , instrument( skip( self ) ) ) ]
306306 #[ must_use]
307307 pub fn shards_instantiated ( & self ) -> Vec < ShardId > {
308- self . runners . keys ( ) . copied ( ) . collect ( )
309- }
310-
311- /// Returns the [`ShardRunnerInfo`] corresponding to each running shard.
312- ///
313- /// Note that the shard runner also holds a copy of its info, which is why each entry is
314- /// wrapped in `Arc<Mutex<T>>`.
315- #[ cfg_attr( feature = "tracing_instrument" , instrument( skip( self ) ) ) ]
316- #[ must_use]
317- pub fn runner_info ( & self ) -> HashMap < ShardId , Arc < Mutex < ShardRunnerInfo > > > {
318- self . runners . iter ( ) . map ( |( & id, ( runner, _) ) | ( id, Arc :: clone ( runner) ) ) . collect ( )
308+ self . runners . iter ( ) . map ( |entries| * entries. key ( ) ) . collect ( )
319309 }
320310
321311 /// Returns the gateway intents used for this gateway connection.
@@ -334,7 +324,8 @@ impl Drop for ShardManager {
334324 fn drop ( & mut self ) {
335325 info ! ( "Shutting down all shards" ) ;
336326
337- for ( shard_id, ( _, tx) ) in self . runners . drain ( ) {
327+ for entry in self . runners . iter ( ) {
328+ let ( shard_id, ( _, tx) ) = entry. pair ( ) ;
338329 info ! ( "Shutting down shard {}" , shard_id) ;
339330 if let Err ( why) = tx. unbounded_send ( ShardRunnerMessage :: Shutdown ( 1000 ) ) {
340331 warn ! ( "Failed to send shutdown signal to shard {shard_id}: {why:?}" ) ;
0 commit comments