Skip to content

Commit c056feb

Browse files
authored
Fix shard runner info (#3345)
1 parent 5209741 commit c056feb

File tree

5 files changed

+40
-47
lines changed

5 files changed

+40
-47
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ collector = ["gateway"]
9898
# Enables the Framework trait which is an abstraction for old-style text commands.
9999
framework = ["gateway"]
100100
# Enables gateway support, which allows bots to listen for Discord events.
101-
gateway = ["model", "flate2"]
101+
gateway = ["model", "flate2", "dashmap"]
102102
# Enables HTTP, which enables bots to execute actions on Discord.
103103
http = ["dashmap", "mime_guess", "percent-encoding"]
104104
# Enables wrapper methods around HTTP requests on model types.

examples/e07_shard_manager/src/main.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
//!
88
//! This isn't particularly useful for small bots, but is useful for large bots that may need to
99
//! split load on separate VPSs or dedicated servers. Additionally, Discord requires that there be
10-
//! at least one shard for every
11-
//! 2500 guilds that a bot is on.
10+
//! at least one shard for every 2500 guilds that a bot is on.
1211
//!
1312
//! For the purposes of this example, we'll print the current statuses of the two shards to the
1413
//! terminal every 30 seconds. This includes the ID of the shard, the current connection stage,
@@ -60,22 +59,19 @@ async fn main() {
6059
let mut client =
6160
Client::builder(token, intents).event_handler(Handler).await.expect("Err creating client");
6261

63-
// Here we get a HashMap of of the shards' status that we move into a new thread. A separate
64-
// tokio task holds the ownership to each entry, so each one will require acquiring a lock
65-
// before reading.
66-
let runners = client.shard_manager.runner_info();
62+
// Here we get a DashMap of of the shards' status that we move into a new thread.
63+
let runners = client.shard_manager.runners.clone();
6764

6865
tokio::spawn(async move {
6966
loop {
7067
sleep(Duration::from_secs(30)).await;
7168

72-
for (id, runner) in &runners {
73-
if let Ok(runner) = runner.lock() {
74-
println!(
75-
"Shard ID {} is {} with a latency of {:?}",
76-
id, runner.stage, runner.latency,
77-
);
78-
}
69+
for entry in runners.iter() {
70+
let (id, (runner, _)) = entry.pair();
71+
println!(
72+
"Shard ID {} is {} with a latency of {:?}",
73+
id, runner.stage, runner.latency,
74+
);
7975
}
8076
}
8177
});

src/gateway/client/context.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
use std::sync::{Arc, Mutex};
1+
use std::sync::Arc;
22

3+
use dashmap::DashMap;
34
use futures::channel::mpsc::UnboundedSender as Sender;
45

56
#[cfg(feature = "cache")]
@@ -46,7 +47,8 @@ pub struct Context {
4647
pub http: Arc<Http>,
4748
#[cfg(feature = "cache")]
4849
pub cache: Arc<Cache>,
49-
pub runner_info: Arc<Mutex<ShardRunnerInfo>>,
50+
/// Metadata about the initialised shards, and their control channels.
51+
pub runners: Arc<DashMap<ShardId, (ShardRunnerInfo, Sender<ShardRunnerMessage>)>>,
5052
#[cfg(feature = "collector")]
5153
pub(crate) collectors: Arc<parking_lot::RwLock<Vec<CollectorCallback>>>,
5254
}

src/gateway/sharding/shard_manager.rs

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
use std::collections::HashMap;
21
use std::num::NonZeroU16;
2+
use std::sync::Arc;
33
#[cfg(feature = "framework")]
44
use std::sync::OnceLock;
5-
use std::sync::{Arc, Mutex};
65
use std::time::{Duration, Instant};
76

7+
use dashmap::DashMap;
88
use futures::StreamExt;
99
use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
1010
use tokio::time::{sleep, timeout};
@@ -67,7 +67,7 @@ pub struct ShardManager {
6767
///
6868
/// **Note**: It is highly recommended to not mutate this yourself unless you need to. Instead
6969
/// prefer to use methods on this struct that are provided where possible.
70-
pub runners: HashMap<ShardId, (Arc<Mutex<ShardRunnerInfo>>, Sender<ShardRunnerMessage>)>,
70+
pub runners: Arc<DashMap<ShardId, (ShardRunnerInfo, Sender<ShardRunnerMessage>)>>,
7171
/// A copy of the client's voice manager.
7272
#[cfg(feature = "voice")]
7373
pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
@@ -103,7 +103,7 @@ impl ShardManager {
103103
framework: opt.framework,
104104
last_start: None,
105105
queue: ShardQueue::new(opt.max_concurrency),
106-
runners: HashMap::new(),
106+
runners: Arc::new(DashMap::new()),
107107
#[cfg(feature = "voice")]
108108
voice_manager: opt.voice_manager,
109109
ws_url: opt.ws_url,
@@ -187,7 +187,7 @@ impl ShardManager {
187187
pub fn restart(&mut self, shard_id: ShardId) {
188188
info!("Restarting shard {shard_id}");
189189

190-
if let Some((_, tx)) = self.runners.remove(&shard_id) {
190+
if let Some((_, (_, tx))) = self.runners.remove(&shard_id) {
191191
if let Err(why) = tx.unbounded_send(ShardRunnerMessage::Restart) {
192192
warn!("Failed to send restart signal to shard {shard_id}: {why:?}");
193193
}
@@ -203,7 +203,7 @@ impl ShardManager {
203203
pub fn shutdown(&mut self, shard_id: ShardId, code: u16) {
204204
info!("Shutting down shard {}", shard_id);
205205

206-
if let Some((_, tx)) = self.runners.remove(&shard_id) {
206+
if let Some((_, (_, tx))) = self.runners.remove(&shard_id) {
207207
if let Err(why) = tx.unbounded_send(ShardRunnerMessage::Shutdown(code)) {
208208
warn!("Failed to send shutdown signal to shard {shard_id}: {why:?}");
209209
}
@@ -263,18 +263,13 @@ impl ShardManager {
263263
let cloned_http = Arc::clone(&self.http);
264264
shard.set_application_id_callback(move |id| cloned_http.set_application_id(id));
265265

266-
let runner_info = Arc::new(Mutex::new(ShardRunnerInfo {
267-
latency: None,
268-
stage: ConnectionStage::Disconnected,
269-
}));
270-
271266
let mut runner = ShardRunner::new(ShardRunnerOptions {
272267
data: Arc::clone(&self.data),
273268
event_handler: self.event_handler.clone(),
274269
raw_event_handler: self.raw_event_handler.clone(),
275270
#[cfg(feature = "framework")]
276271
framework: self.framework.get().cloned(),
277-
runner_info: Arc::clone(&runner_info),
272+
runners: Arc::clone(&self.runners),
278273
manager_tx: self.manager_tx.clone(),
279274
#[cfg(feature = "voice")]
280275
voice_manager: self.voice_manager.clone(),
@@ -284,6 +279,11 @@ impl ShardManager {
284279
http: Arc::clone(&self.http),
285280
});
286281

282+
let runner_info = ShardRunnerInfo {
283+
latency: None,
284+
stage: ConnectionStage::Disconnected,
285+
};
286+
287287
self.runners.insert(shard_id, (runner_info, runner.runner_tx()));
288288

289289
spawn_named("shard_runner::run", async move { runner.run().await });
@@ -305,17 +305,7 @@ impl ShardManager {
305305
#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
306306
#[must_use]
307307
pub fn shards_instantiated(&self) -> Vec<ShardId> {
308-
self.runners.keys().copied().collect()
309-
}
310-
311-
/// Returns the [`ShardRunnerInfo`] corresponding to each running shard.
312-
///
313-
/// Note that the shard runner also holds a copy of its info, which is why each entry is
314-
/// wrapped in `Arc<Mutex<T>>`.
315-
#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
316-
#[must_use]
317-
pub fn runner_info(&self) -> HashMap<ShardId, Arc<Mutex<ShardRunnerInfo>>> {
318-
self.runners.iter().map(|(&id, (runner, _))| (id, Arc::clone(runner))).collect()
308+
self.runners.iter().map(|entries| *entries.key()).collect()
319309
}
320310

321311
/// Returns the gateway intents used for this gateway connection.
@@ -334,7 +324,8 @@ impl Drop for ShardManager {
334324
fn drop(&mut self) {
335325
info!("Shutting down all shards");
336326

337-
for (shard_id, (_, tx)) in self.runners.drain() {
327+
for entry in self.runners.iter() {
328+
let (shard_id, (_, tx)) = entry.pair();
338329
info!("Shutting down shard {}", shard_id);
339330
if let Err(why) = tx.unbounded_send(ShardRunnerMessage::Shutdown(1000)) {
340331
warn!("Failed to send shutdown signal to shard {shard_id}: {why:?}");

src/gateway/sharding/shard_runner.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
use std::sync::{Arc, Mutex};
1+
use std::sync::Arc;
22

3+
use dashmap::DashMap;
4+
use dashmap::try_result::TryResult;
35
use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender};
46
use tokio_tungstenite::tungstenite;
57
use tokio_tungstenite::tungstenite::error::Error as TungsteniteError;
@@ -28,7 +30,7 @@ use crate::model::event::Event;
2830
use crate::model::event::GatewayEvent;
2931
#[cfg(feature = "voice")]
3032
use crate::model::id::ChannelId;
31-
use crate::model::id::GuildId;
33+
use crate::model::id::{GuildId, ShardId};
3234
use crate::model::user::OnlineStatus;
3335

3436
/// A runner for managing a [`Shard`] and its respective WebSocket client.
@@ -38,7 +40,7 @@ pub struct ShardRunner {
3840
raw_event_handler: Option<Arc<dyn RawEventHandler>>,
3941
#[cfg(feature = "framework")]
4042
framework: Option<Arc<dyn Framework>>,
41-
runner_info: Arc<Mutex<ShardRunnerInfo>>,
43+
runners: Arc<DashMap<ShardId, (ShardRunnerInfo, Sender<ShardRunnerMessage>)>>,
4244
// channel to send messages back to the shard manager
4345
manager_tx: Sender<ShardManagerMessage>,
4446
// channel to receive messages from the shard manager and dispatches
@@ -66,7 +68,7 @@ impl ShardRunner {
6668
raw_event_handler: opt.raw_event_handler,
6769
#[cfg(feature = "framework")]
6870
framework: opt.framework,
69-
runner_info: opt.runner_info,
71+
runners: opt.runners,
7072
manager_tx: opt.manager_tx,
7173
runner_rx: rx,
7274
runner_tx: tx,
@@ -458,7 +460,9 @@ impl ShardRunner {
458460

459461
#[cfg_attr(feature = "tracing_instrument", instrument(skip(self)))]
460462
fn update_runner_info(&self) {
461-
if let Ok(mut runner_info) = self.runner_info.try_lock() {
463+
if let TryResult::Present(mut entry) = self.runners.try_get_mut(&self.shard.info.id) {
464+
let (runner_info, _) = entry.value_mut();
465+
462466
runner_info.latency = self.shard.latency();
463467
runner_info.stage = self.shard.stage();
464468
}
@@ -473,7 +477,7 @@ impl ShardRunner {
473477
http: Arc::clone(&self.http),
474478
#[cfg(feature = "cache")]
475479
cache: Arc::clone(&self.cache),
476-
runner_info: Arc::clone(&self.runner_info),
480+
runners: Arc::clone(&self.runners),
477481
#[cfg(feature = "collector")]
478482
collectors: Arc::clone(&self.collectors),
479483
}
@@ -491,7 +495,7 @@ pub struct ShardRunnerOptions {
491495
pub raw_event_handler: Option<Arc<dyn RawEventHandler>>,
492496
#[cfg(feature = "framework")]
493497
pub framework: Option<Arc<dyn Framework>>,
494-
pub runner_info: Arc<Mutex<ShardRunnerInfo>>,
498+
pub runners: Arc<DashMap<ShardId, (ShardRunnerInfo, Sender<ShardRunnerMessage>)>>,
495499
pub manager_tx: Sender<ShardManagerMessage>,
496500
pub shard: Shard,
497501
#[cfg(feature = "voice")]

0 commit comments

Comments
 (0)