Skip to content

Commit f698118

Browse files
authored
Merge pull request #1138 from rust-lang/careful-dropping
Cancel tokens and abort tasks on drop
2 parents 87a7451 + 036f2b8 commit f698118

File tree

8 files changed

+79
-37
lines changed

8 files changed

+79
-37
lines changed

compiler/base/orchestrator/Cargo.lock

Lines changed: 21 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

compiler/base/orchestrator/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ snafu = { version = "0.8.0", default-features = false, features = ["futures", "s
2020
strum_macros = { version = "0.26.1", default-features = false }
2121
tokio = { version = "1.28", default-features = false, features = ["fs", "io-std", "io-util", "macros", "process", "rt", "time", "sync"] }
2222
tokio-stream = { version = "0.1.14", default-features = false }
23-
tokio-util = { version = "0.7.8", default-features = false, features = ["io", "io-util"] }
23+
tokio-util = { version = "0.7.8", default-features = false, features = ["io", "io-util", "rt"] }
2424
toml = { version = "0.8.2", default-features = false, features = ["parse", "display"] }
2525
tracing = { version = "0.1.37", default-features = false, features = ["attributes"] }
2626

compiler/base/orchestrator/src/coordinator.rs

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ use tokio::{
1515
process::{Child, ChildStdin, ChildStdout, Command},
1616
select,
1717
sync::{mpsc, oneshot, OnceCell},
18-
task::{JoinHandle, JoinSet},
18+
task::JoinSet,
1919
time::{self, MissedTickBehavior},
2020
try_join,
2121
};
2222
use tokio_stream::wrappers::ReceiverStream;
23-
use tokio_util::{io::SyncIoBridge, sync::CancellationToken};
23+
use tokio_util::{io::SyncIoBridge, sync::CancellationToken, task::AbortOnDropHandle};
2424
use tracing::{error, info, info_span, instrument, trace, trace_span, warn, Instrument};
2525

2626
use crate::{
@@ -30,7 +30,7 @@ use crate::{
3030
ExecuteCommandResponse, JobId, Multiplexed, OneToOneResponse, ReadFileRequest,
3131
ReadFileResponse, SerializedError2, WorkerMessage, WriteFileRequest,
3232
},
33-
DropErrorDetailsExt,
33+
DropErrorDetailsExt, TaskAbortExt as _,
3434
};
3535

3636
pub mod limits;
@@ -1161,7 +1161,7 @@ impl Drop for CancelOnDrop {
11611161
#[derive(Debug)]
11621162
struct Container {
11631163
permit: Box<dyn ContainerPermit>,
1164-
task: JoinHandle<Result<()>>,
1164+
task: AbortOnDropHandle<Result<()>>,
11651165
kill_child: TerminateContainer,
11661166
modify_cargo_toml: ModifyCargoToml,
11671167
commander: Commander,
@@ -1186,7 +1186,8 @@ impl Container {
11861186

11871187
let (command_tx, command_rx) = mpsc::channel(8);
11881188
let demultiplex_task =
1189-
tokio::spawn(Commander::demultiplex(command_rx, from_worker_rx).in_current_span());
1189+
tokio::spawn(Commander::demultiplex(command_rx, from_worker_rx).in_current_span())
1190+
.abort_on_drop();
11901191

11911192
let task = tokio::spawn(
11921193
async move {
@@ -1216,7 +1217,8 @@ impl Container {
12161217
Ok(())
12171218
}
12181219
.in_current_span(),
1219-
);
1220+
)
1221+
.abort_on_drop();
12201222

12211223
let commander = Commander {
12221224
to_worker_tx,
@@ -1865,7 +1867,8 @@ impl Container {
18651867
}
18661868
}
18671869
.instrument(trace_span!("cargo task").or_current())
1868-
});
1870+
})
1871+
.abort_on_drop();
18691872

18701873
Ok(SpawnCargo {
18711874
permit,
@@ -2128,7 +2131,7 @@ pub enum DoRequestError {
21282131

21292132
struct SpawnCargo {
21302133
permit: Box<dyn ProcessPermit>,
2131-
task: JoinHandle<Result<ExecuteCommandResponse, SpawnCargoError>>,
2134+
task: AbortOnDropHandle<Result<ExecuteCommandResponse, SpawnCargoError>>,
21322135
stdin_tx: mpsc::Sender<String>,
21332136
stdout_rx: mpsc::Receiver<String>,
21342137
stderr_rx: mpsc::Receiver<String>,
@@ -2842,14 +2845,9 @@ fn spawn_io_queue(stdin: ChildStdin, stdout: ChildStdout, token: CancellationTok
28422845
let handle = tokio::runtime::Handle::current();
28432846

28442847
loop {
2845-
let coordinator_msg = handle.block_on(async {
2846-
select! {
2847-
() = token.cancelled() => None,
2848-
msg = rx.recv() => msg,
2849-
}
2850-
});
2848+
let coordinator_msg = handle.block_on(token.run_until_cancelled(rx.recv()));
28512849

2852-
let Some(coordinator_msg) = coordinator_msg else {
2850+
let Some(Some(coordinator_msg)) = coordinator_msg else {
28532851
break;
28542852
};
28552853

compiler/base/orchestrator/src/lib.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,16 @@ pub mod coordinator;
44
mod message;
55
pub mod worker;
66

7+
pub trait TaskAbortExt<T>: Sized {
8+
fn abort_on_drop(self) -> tokio_util::task::AbortOnDropHandle<T>;
9+
}
10+
11+
impl<T> TaskAbortExt<T> for tokio::task::JoinHandle<T> {
12+
fn abort_on_drop(self) -> tokio_util::task::AbortOnDropHandle<T> {
13+
tokio_util::task::AbortOnDropHandle::new(self)
14+
}
15+
}
16+
717
pub trait DropErrorDetailsExt<T> {
818
fn drop_error_details(self) -> Result<T, tokio::sync::mpsc::error::SendError<()>>;
919
}

compiler/base/orchestrator/src/worker.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ use tokio::{
4646
sync::mpsc,
4747
task::JoinSet,
4848
};
49-
use tokio_util::sync::CancellationToken;
49+
use tokio_util::sync::{CancellationToken, DropGuard};
5050

5151
use crate::{
5252
bincode_input_closed,
@@ -55,7 +55,7 @@ use crate::{
5555
ExecuteCommandResponse, JobId, Multiplexed, ReadFileRequest, ReadFileResponse,
5656
SerializedError2, WorkerMessage, WriteFileRequest, WriteFileResponse,
5757
},
58-
DropErrorDetailsExt,
58+
DropErrorDetailsExt as _, TaskAbortExt as _,
5959
};
6060

6161
pub async fn listen(project_dir: impl Into<PathBuf>) -> Result<(), Error> {
@@ -66,14 +66,16 @@ pub async fn listen(project_dir: impl Into<PathBuf>) -> Result<(), Error> {
6666
let mut io_tasks = spawn_io_queue(coordinator_msg_tx, worker_msg_rx);
6767

6868
let (process_tx, process_rx) = mpsc::channel(8);
69-
let process_task = tokio::spawn(manage_processes(process_rx, project_dir.clone()));
69+
let process_task =
70+
tokio::spawn(manage_processes(process_rx, project_dir.clone())).abort_on_drop();
7071

7172
let handler_task = tokio::spawn(handle_coordinator_message(
7273
coordinator_msg_rx,
7374
worker_msg_tx,
7475
project_dir,
7576
process_tx,
76-
));
77+
))
78+
.abort_on_drop();
7779

7880
select! {
7981
Some(io_task) = io_tasks.join_next() => {
@@ -403,7 +405,7 @@ struct ProcessState {
403405
processes: JoinSet<Result<(), ProcessError>>,
404406
stdin_senders: HashMap<JobId, mpsc::Sender<String>>,
405407
stdin_shutdown_tx: mpsc::Sender<JobId>,
406-
kill_tokens: HashMap<JobId, CancellationToken>,
408+
kill_tokens: HashMap<JobId, DropGuard>,
407409
}
408410

409411
impl ProcessState {
@@ -456,7 +458,7 @@ impl ProcessState {
456458

457459
let task_set = stream_stdio(worker_msg_tx.clone(), stdin_rx, stdin, stdout, stderr);
458460

459-
self.kill_tokens.insert(job_id, token.clone());
461+
self.kill_tokens.insert(job_id, token.clone().drop_guard());
460462

461463
self.processes.spawn({
462464
let stdin_shutdown_tx = self.stdin_shutdown_tx.clone();
@@ -508,8 +510,8 @@ impl ProcessState {
508510
}
509511

510512
fn kill(&mut self, job_id: JobId) {
511-
if let Some(token) = self.kill_tokens.get(&job_id) {
512-
token.cancel();
513+
if let Some(token) = self.kill_tokens.remove(&job_id) {
514+
drop(token);
513515
}
514516
}
515517
}

ui/Cargo.lock

Lines changed: 10 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ui/src/server_axum/cache.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use futures::{
22
future::{Fuse, FusedFuture as _},
33
FutureExt as _,
44
};
5-
use orchestrator::DropErrorDetailsExt as _;
5+
use orchestrator::{DropErrorDetailsExt as _, TaskAbortExt as _};
66
use snafu::prelude::*;
77
use std::{
88
future::Future,
@@ -13,9 +13,9 @@ use std::{
1313
use tokio::{
1414
select,
1515
sync::{mpsc, oneshot},
16-
task::JoinHandle,
1716
time,
1817
};
18+
use tokio_util::task::AbortOnDropHandle;
1919
use tracing::warn;
2020

2121
const ONE_HUNDRED_MILLISECONDS: Duration = Duration::from_millis(100);
@@ -48,12 +48,12 @@ where
4848
{
4949
pub fn spawn<Fut>(
5050
f: impl FnOnce(mpsc::Receiver<CacheTaskItem<T, E>>) -> Fut,
51-
) -> (JoinHandle<()>, Self)
51+
) -> (AbortOnDropHandle<()>, Self)
5252
where
5353
Fut: Future<Output = ()> + Send + 'static,
5454
{
5555
let (tx, rx) = mpsc::channel(8);
56-
let task = tokio::spawn(f(rx));
56+
let task = tokio::spawn(f(rx)).abort_on_drop();
5757
let cache_tx = CacheTx(tx);
5858
(task, cache_tx)
5959
}
@@ -148,7 +148,8 @@ where
148148
let new_value = generator().await.map_err(CacheError::from);
149149
CacheInfo::build(new_value)
150150
}
151-
});
151+
})
152+
.abort_on_drop();
152153

153154
new_value.set(new_value_task.fuse());
154155
}

ui/src/server_axum/websocket.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use tokio::{
2929
task::{AbortHandle, JoinSet},
3030
time,
3131
};
32-
use tokio_util::sync::CancellationToken;
32+
use tokio_util::sync::{CancellationToken, DropGuard};
3333
use tracing::{error, info, instrument, warn, Instrument};
3434

3535
#[derive(Debug, serde::Deserialize, serde::Serialize)]
@@ -525,7 +525,7 @@ async fn handle_idle(manager: &mut CoordinatorManager, tx: &ResponseTx) -> Contr
525525
ControlFlow::Continue(())
526526
}
527527

528-
type ActiveExecutionInfo = (CancellationToken, Option<mpsc::Sender<String>>);
528+
type ActiveExecutionInfo = (DropGuard, Option<mpsc::Sender<String>>);
529529

530530
async fn handle_msg(
531531
txt: &str,
@@ -545,7 +545,10 @@ async fn handle_msg(
545545

546546
let guard = db.clone().start_with_guard("ws.Execute", txt).await;
547547

548-
active_executions.insert(meta.sequence_number, (token.clone(), Some(execution_tx)));
548+
active_executions.insert(
549+
meta.sequence_number,
550+
(token.clone().drop_guard(), Some(execution_tx)),
551+
);
549552

550553
// TODO: Should a single execute / build / etc. session have a timeout of some kind?
551554
let spawned = manager
@@ -602,11 +605,11 @@ async fn handle_msg(
602605
}
603606

604607
Ok(ExecuteKill { meta }) => {
605-
let Some((token, _)) = active_executions.get(&meta.sequence_number) else {
608+
let Some((token, _)) = active_executions.remove(&meta.sequence_number) else {
606609
warn!("Received kill for an execution that is no longer active");
607610
return;
608611
};
609-
token.cancel();
612+
drop(token);
610613
}
611614

612615
Err(e) => {

0 commit comments

Comments
 (0)