Skip to content

Commit 612bcc1

Browse files
committed
Simplify SseManager a bit
1 parent 16c7b67 commit 612bcc1

File tree

5 files changed

+153
-72
lines changed

5 files changed

+153
-72
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ tar = "0.4"
6060
tempfile = "3.17.0"
6161
thiserror = "2"
6262
tokio = { version = "1.42.0", features = ["fs", "macros"] }
63+
tokio-stream = "0.1.17"
6364
zip = { version = "2", default-features = false }
6465

6566
[features]

src/handlers.rs

Lines changed: 141 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
1-
use std::time::Duration;
1+
use std::{path::PathBuf, sync::Arc, time::Duration};
22

33
use actix_web::{HttpRequest, HttpResponse, Responder, http::header::ContentType, web};
4-
use actix_web_lab::sse;
4+
use actix_web_lab::{
5+
sse::{self, Sse},
6+
util::InfallibleStream,
7+
};
58
use bytesize::ByteSize;
69
use dav_server::{
710
DavConfig, DavHandler,
811
actix::{DavRequest, DavResponse},
912
};
10-
use log::{error, info, warn};
13+
use futures::future::join_all;
14+
use log::{error, info};
1115
use percent_encoding::percent_decode_str;
1216
use serde::{Deserialize, Serialize};
13-
use tokio::sync::Mutex;
14-
use tokio::task::JoinSet;
17+
use tokio::{sync::Mutex, time::interval};
18+
use tokio::{sync::mpsc, task::JoinSet};
19+
use tokio_stream::wrappers::ReceiverStream;
1520

1621
use crate::{config::MiniserveConfig, errors::RuntimeError};
1722
use crate::{file_op::recursive_dir_size, file_utils};
@@ -39,7 +44,71 @@ pub enum ApiCommand {
3944
CalculateDirSizes(Vec<String>),
4045
}
4146

42-
pub type DirSizeJoinSet = JoinSet<Result<DirSize, RuntimeError>>;
47+
#[derive(Debug)]
48+
pub struct DirSizeTasks {
49+
tasks: Arc<Mutex<JoinSet<Result<DirSize, RuntimeError>>>>,
50+
}
51+
52+
impl DirSizeTasks {
53+
pub fn new(show_exact_bytes: bool, sse_manager: web::Data<SseManager>) -> Self {
54+
let tasks = Arc::new(Mutex::new(JoinSet::<Result<DirSize, RuntimeError>>::new()));
55+
56+
// Spawn a task that will periodically check for finished calculations.
57+
let tasks_ = tasks.clone();
58+
actix_web::rt::spawn(async move {
59+
let mut interval = interval(Duration::from_millis(50));
60+
loop {
61+
// See whether there are any calculations finished and if so dispatch a message to
62+
// the SSE channels.
63+
match tasks_.lock().await.try_join_next() {
64+
Some(Ok(Ok(finished_task))) => {
65+
let dir_size = if show_exact_bytes {
66+
format!("{} B", finished_task.size)
67+
} else {
68+
ByteSize::b(finished_task.size).to_string()
69+
};
70+
71+
let dir_size_reply = DirSizeReply {
72+
web_path: finished_task.web_path,
73+
size: dir_size,
74+
};
75+
76+
let msg = sse::Data::new_json(dir_size_reply)
77+
.expect("Couldn't serialize as JSON")
78+
.event("dir-size");
79+
sse_manager.broadcast(msg).await
80+
}
81+
Some(Ok(Err(e))) => {
82+
error!("Some error during dir size calculation: {e}");
83+
break;
84+
}
85+
Some(Err(e)) => {
86+
error!("Some error during dir size calculation joining: {e}");
87+
break;
88+
}
89+
None => {
90+
// If there's nothing we'll just chill a sec
91+
interval.tick().await;
92+
}
93+
};
94+
}
95+
});
96+
97+
Self { tasks }
98+
}
99+
100+
pub async fn calc_dir_size(&self, web_path: String, path: PathBuf) {
101+
self.tasks.lock().await.spawn(async move {
102+
recursive_dir_size(&path).await.map(|dir_size| {
103+
info!("Finished dir size calculation for {path:?}");
104+
DirSize {
105+
web_path,
106+
size: dir_size,
107+
}
108+
})
109+
});
110+
}
111+
}
43112

44113
// Holds the result of a calculated dir size
45114
#[derive(Debug, Clone)]
@@ -61,67 +130,81 @@ pub struct DirSizeReply {
61130
pub size: String,
62131
}
63132

64-
// Reply to check whether the client is still connected
65-
//
66-
// If the client has disconnected, we can cancel all the tasks and save some compute.
67-
#[derive(Debug, Clone, Serialize)]
68-
pub struct HeartbeatReply;
133+
#[derive(Debug, Clone, Default)]
134+
pub struct SseManager {
135+
clients: Arc<Mutex<Vec<mpsc::Sender<sse::Event>>>>,
136+
}
69137

70-
/// SSE API route that yields an event stream that clients can subscribe to
71-
pub async fn api_sse(
72-
config: web::Data<MiniserveConfig>,
73-
task_joinset: web::Data<Mutex<DirSizeJoinSet>>,
74-
) -> impl Responder {
75-
let (sender, receiver) = tokio::sync::mpsc::channel(2);
76-
77-
actix_web::rt::spawn(async move {
78-
loop {
79-
let msg = match task_joinset.lock().await.try_join_next() {
80-
Some(Ok(Ok(finished_task))) => {
81-
let dir_size = if config.show_exact_bytes {
82-
format!("{} B", finished_task.size)
83-
} else {
84-
ByteSize::b(finished_task.size).to_string()
85-
};
138+
impl SseManager {
139+
/// Constructs new broadcaster and spawns ping loop.
140+
pub fn new() -> Self {
141+
let clients = Arc::new(Mutex::new(Vec::<mpsc::Sender<sse::Event>>::new()));
86142

87-
let dir_size_reply = DirSizeReply {
88-
web_path: finished_task.web_path,
89-
size: dir_size,
90-
};
143+
// Spawn a task that will periodically check for stale clients.
144+
let clients_ = clients.clone();
145+
actix_web::rt::spawn(async move {
146+
let mut interval = interval(Duration::from_secs(10));
91147

92-
sse::Data::new_json(dir_size_reply)
93-
.expect("Couldn't serialize as JSON")
94-
.event("dir-size")
95-
}
96-
Some(Ok(Err(e))) => {
97-
error!("Some error during dir size calculation: {e}");
98-
break;
99-
}
100-
Some(Err(e)) => {
101-
error!("Some error during dir size calculation joining: {e}");
102-
break;
148+
loop {
149+
interval.tick().await;
150+
151+
// Clean up stale clients
152+
let clients = clients_.lock().await.clone();
153+
let mut ok_clients = Vec::new();
154+
for client in clients {
155+
if client
156+
.send(sse::Event::Comment("ping".into()))
157+
.await
158+
.is_ok()
159+
{
160+
// Clients that are able to receive this are still connected and the rest
161+
// will be dropped.
162+
ok_clients.push(client.clone());
163+
} else {
164+
info!("Removing a stale client");
165+
}
103166
}
104-
None => sse::Data::new_json(HeartbeatReply)
105-
.expect("Couldn't serialize as JSON")
106-
.event("heartbeat"),
107-
};
108-
109-
if sender.send(msg.into()).await.is_err() {
110-
warn!("Client disconnected; could not send SSE message");
111-
break;
167+
*clients_.lock().await = ok_clients;
112168
}
169+
});
113170

114-
tokio::time::sleep(Duration::from_secs(1)).await;
115-
}
116-
});
171+
Self { clients }
172+
}
173+
174+
/// Registers client with broadcaster, returning an SSE response body.
175+
pub async fn new_client(&self) -> Sse<InfallibleStream<ReceiverStream<sse::Event>>> {
176+
let (tx, rx) = mpsc::channel(10);
177+
178+
tx.send(sse::Data::new("Connected to SSE event stream").into())
179+
.await
180+
.unwrap();
181+
182+
self.clients.lock().await.push(tx);
183+
184+
Sse::from_infallible_receiver(rx)
185+
}
117186

118-
sse::Sse::from_infallible_receiver(receiver).with_keep_alive(Duration::from_secs(3))
187+
/// Broadcasts `msg` to all clients.
188+
pub async fn broadcast(&self, msg: sse::Data) {
189+
let clients = self.clients.lock().await.clone();
190+
191+
let send_futures = clients.iter().map(|client| client.send(msg.clone().into()));
192+
193+
// Try to send to all clients, ignoring failures disconnected clients will get swept up by
194+
// `remove_stale_clients`.
195+
let _ = join_all(send_futures).await;
196+
}
197+
}
198+
199+
/// SSE API route that yields an event stream that clients can subscribe to
200+
pub async fn api_sse(sse_manager: web::Data<SseManager>) -> impl Responder {
201+
sse_manager.new_client().await
119202
}
120203

121204
async fn handle_dir_size_tasks(
122205
dirs: Vec<String>,
123206
config: &MiniserveConfig,
124-
task_joinset: web::Data<Mutex<DirSizeJoinSet>>,
207+
dir_size_tasks: web::Data<DirSizeTasks>,
125208
) -> Result<(), RuntimeError> {
126209
for dir in dirs {
127210
// The dir argument might be percent-encoded so let's decode it just in case.
@@ -140,16 +223,7 @@ async fn handle_dir_size_tasks(
140223
.join(sanitized_path);
141224
info!("Requested directory size for {full_path:?}");
142225

143-
let mut joinset = task_joinset.lock().await;
144-
joinset.spawn(async move {
145-
recursive_dir_size(&full_path).await.map(|dir_size| {
146-
info!("Finished dir size calculation for {full_path:?}");
147-
DirSize {
148-
web_path: dir,
149-
size: dir_size,
150-
}
151-
})
152-
});
226+
dir_size_tasks.calc_dir_size(dir, full_path).await;
153227
}
154228
Ok(())
155229
}
@@ -159,11 +233,11 @@ async fn handle_dir_size_tasks(
159233
pub async fn api_command(
160234
command: web::Json<ApiCommand>,
161235
config: web::Data<MiniserveConfig>,
162-
task_joinset: web::Data<Mutex<DirSizeJoinSet>>,
236+
dir_size_tasks: web::Data<DirSizeTasks>,
163237
) -> Result<impl Responder, RuntimeError> {
164238
match command.into_inner() {
165239
ApiCommand::CalculateDirSizes(dirs) => {
166-
handle_dir_size_tasks(dirs, &config, task_joinset).await?;
240+
handle_dir_size_tasks(dirs, &config, dir_size_tasks).await?;
167241
Ok("Directories are being calculated")
168242
}
169243
}

src/main.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ use colored::*;
1919
use dav_server::{DavHandler, DavMethodSet};
2020
use fast_qr::QRBuilder;
2121
use log::{error, warn};
22-
use tokio::sync::Mutex;
2322

2423
mod archive;
2524
mod args;
@@ -38,7 +37,8 @@ mod webdav_fs;
3837
use crate::config::MiniserveConfig;
3938
use crate::errors::StartupError;
4039
use crate::handlers::{
41-
DirSizeJoinSet, api_command, api_sse, css, dav_handler, error_404, favicon, healthcheck,
40+
DirSizeTasks, SseManager, api_command, api_sse, css, dav_handler, error_404, favicon,
41+
healthcheck,
4242
};
4343
use crate::webdav_fs::RestrictedFs;
4444

@@ -214,13 +214,18 @@ async fn run(miniserve_config: MiniserveConfig) -> Result<(), StartupError> {
214214
.join("\n"),
215215
);
216216

217-
let dir_size_join_set = web::Data::new(Mutex::new(DirSizeJoinSet::new()));
217+
let sse_manager = web::Data::new(SseManager::new());
218+
let dir_size_tasks = web::Data::new(DirSizeTasks::new(
219+
miniserve_config.show_exact_bytes,
220+
sse_manager.clone(),
221+
));
218222

219223
let srv = actix_web::HttpServer::new(move || {
220224
App::new()
221225
.wrap(configure_header(&inside_config.clone()))
222226
.app_data(web::Data::new(inside_config.clone()))
223-
.app_data(dir_size_join_set.clone())
227+
.app_data(dir_size_tasks.clone())
228+
.app_data(sse_manager.clone())
224229
.app_data(stylesheet.clone())
225230
.wrap(from_fn(errors::error_page_middleware))
226231
.wrap(middleware::Logger::default())

tests/api.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use crate::fixtures::{DIRECTORIES, Error, TestServer, server};
1818
#[case(utf8_percent_encode(DIRECTORIES[2], NON_ALPHANUMERIC).to_string())]
1919
fn api_dir_size(#[case] dir: String, server: TestServer) -> Result<(), Error> {
2020
let mut command = HashMap::new();
21-
command.insert("DirSize", dir);
21+
command.insert("CalculateDirSizes", vec![dir]);
2222

2323
let resp = Client::new()
2424
.post(server.url().join(&format!("__miniserve_internal/api"))?)

0 commit comments

Comments
 (0)