|
| 1 | +use std::borrow::Borrow; |
1 | 2 | use std::future::Future;
|
| 3 | +use std::sync::atomic::AtomicU64; |
| 4 | +use std::sync::atomic::Ordering::Relaxed; |
| 5 | +use std::sync::Arc; |
2 | 6 |
|
3 |
| -pub use oneshot::channel as cancelation; |
4 |
| -use tokio::sync::oneshot; |
| 7 | +use tokio::sync::Notify; |
5 | 8 |
|
6 |
| -pub type CancelTx = oneshot::Sender<()>; |
7 |
| -pub type CancelRx = oneshot::Receiver<()>; |
8 |
| - |
9 |
| -pub async fn cancelable_future<T>(future: impl Future<Output = T>, cancel: CancelRx) -> Option<T> { |
| 9 | +pub async fn cancelable_future<T>( |
| 10 | + future: impl Future<Output = T>, |
| 11 | + cancel: impl Borrow<TaskHandle>, |
| 12 | +) -> Option<T> { |
10 | 13 | tokio::select! {
|
11 | 14 | biased;
|
12 |
| - _ = cancel => { |
| 15 | + _ = cancel.borrow().canceled() => { |
13 | 16 | None
|
14 | 17 | }
|
15 | 18 | res = future => {
|
16 | 19 | Some(res)
|
17 | 20 | }
|
18 | 21 | }
|
19 | 22 | }
|
| 23 | + |
| 24 | +#[derive(Default, Debug)] |
| 25 | +struct Shared { |
| 26 | + state: AtomicU64, |
| 27 | + // `Notify` has some features that we don't really need here because it |
| 28 | + // supports waking single tasks (`notify_one`) and does its own (more |
| 29 | + // complicated) state tracking, we could reimplement the waiter linked list |
| 30 | + // with modest effort and reduce memory consumption by one word/8 bytes and |
| 31 | + // reduce code complexity/number of atomic operations. |
| 32 | + // |
| 33 | + // I don't think that's worth the complexity (unsafe code). |
| 34 | + // |
| 35 | + // if we only cared about async code then we could also only use a notify |
| 36 | + // (without the generation count), this would be equivalent (or maybe more |
| 37 | + // correct if we want to allow cloning the TX) but it would be extremly slow |
| 38 | + // to frequently check for cancelation from sync code |
| 39 | + notify: Notify, |
| 40 | +} |
| 41 | + |
| 42 | +impl Shared { |
| 43 | + fn generation(&self) -> u32 { |
| 44 | + self.state.load(Relaxed) as u32 |
| 45 | + } |
| 46 | + |
| 47 | + fn num_running(&self) -> u32 { |
| 48 | + (self.state.load(Relaxed) >> 32) as u32 |
| 49 | + } |
| 50 | + |
| 51 | + /// Increments the generation count and sets `num_running` |
| 52 | + /// to the provided value, this operation is not with |
| 53 | + /// regard to the generation counter (doesn't use `fetch_add`) |
| 54 | + /// so the calling code must ensure it cannot execute concurrently |
| 55 | + /// to maintain correctness (but not safety) |
| 56 | + fn inc_generation(&self, num_running: u32) -> (u32, u32) { |
| 57 | + let state = self.state.load(Relaxed); |
| 58 | + let generation = state as u32; |
| 59 | + let prev_running = (state >> 32) as u32; |
| 60 | + // no need to create a new generation if the refcount is zero (fastpath) |
| 61 | + if prev_running == 0 && num_running == 0 { |
| 62 | + return (generation, 0); |
| 63 | + } |
| 64 | + let new_generation = generation.saturating_add(1); |
| 65 | + self.state.store( |
| 66 | + new_generation as u64 | ((num_running as u64) << 32), |
| 67 | + Relaxed, |
| 68 | + ); |
| 69 | + self.notify.notify_waiters(); |
| 70 | + (new_generation, prev_running) |
| 71 | + } |
| 72 | + |
| 73 | + fn inc_running(&self, generation: u32) { |
| 74 | + let mut state = self.state.load(Relaxed); |
| 75 | + loop { |
| 76 | + let current_generation = state as u32; |
| 77 | + if current_generation != generation { |
| 78 | + break; |
| 79 | + } |
| 80 | + let off = 1 << 32; |
| 81 | + let res = self.state.compare_exchange_weak( |
| 82 | + state, |
| 83 | + state.saturating_add(off), |
| 84 | + Relaxed, |
| 85 | + Relaxed, |
| 86 | + ); |
| 87 | + match res { |
| 88 | + Ok(_) => break, |
| 89 | + Err(new_state) => state = new_state, |
| 90 | + } |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + fn dec_running(&self, generation: u32) { |
| 95 | + let mut state = self.state.load(Relaxed); |
| 96 | + loop { |
| 97 | + let current_generation = state as u32; |
| 98 | + if current_generation != generation { |
| 99 | + break; |
| 100 | + } |
| 101 | + let num_running = (state >> 32) as u32; |
| 102 | + // running can't be zero here, that would mean we miscounted somewhere |
| 103 | + assert_ne!(num_running, 0); |
| 104 | + let off = 1 << 32; |
| 105 | + let res = self |
| 106 | + .state |
| 107 | + .compare_exchange_weak(state, state - off, Relaxed, Relaxed); |
| 108 | + match res { |
| 109 | + Ok(_) => break, |
| 110 | + Err(new_state) => state = new_state, |
| 111 | + } |
| 112 | + } |
| 113 | + } |
| 114 | +} |
| 115 | + |
| 116 | +// This intentionally doesn't implement `Clone` and requires a mutable reference |
| 117 | +// for cancelation to avoid races (in inc_generation). |
| 118 | + |
| 119 | +/// A task controller allows managing a single subtask enabling the controller |
| 120 | +/// to cancel the subtask and to check whether it is still running. |
| 121 | +/// |
| 122 | +/// For efficiency reasons the controller can be reused/restarted, |
| 123 | +/// in that case the previous task is automatically canceled. |
| 124 | +/// |
| 125 | +/// If the controller is dropped, the subtasks are automatically canceled. |
| 126 | +#[derive(Default, Debug)] |
| 127 | +pub struct TaskController { |
| 128 | + shared: Arc<Shared>, |
| 129 | +} |
| 130 | + |
| 131 | +impl TaskController { |
| 132 | + pub fn new() -> Self { |
| 133 | + TaskController::default() |
| 134 | + } |
| 135 | + /// Cancels the active task (handle). |
| 136 | + /// |
| 137 | + /// Returns whether any tasks were still running before the cancelation. |
| 138 | + pub fn cancel(&mut self) -> bool { |
| 139 | + self.shared.inc_generation(0).1 != 0 |
| 140 | + } |
| 141 | + |
| 142 | + /// Checks whether there are any task handles |
| 143 | + /// that haven't been dropped (or canceled) yet. |
| 144 | + pub fn is_running(&self) -> bool { |
| 145 | + self.shared.num_running() != 0 |
| 146 | + } |
| 147 | + |
| 148 | + /// Starts a new task and cancels the previous task (handles). |
| 149 | + pub fn restart(&mut self) -> TaskHandle { |
| 150 | + TaskHandle { |
| 151 | + generation: self.shared.inc_generation(1).0, |
| 152 | + shared: self.shared.clone(), |
| 153 | + } |
| 154 | + } |
| 155 | +} |
| 156 | + |
| 157 | +impl Drop for TaskController { |
| 158 | + fn drop(&mut self) { |
| 159 | + self.cancel(); |
| 160 | + } |
| 161 | +} |
| 162 | + |
| 163 | +/// A handle that is used to link a task with a task controller. |
| 164 | +/// |
| 165 | +/// It can be used to cancel async futures very efficiently but can also be checked for |
| 166 | +/// cancelation very quickly (single atomic read) in blocking code. |
| 167 | +/// The handle can be cheaply cloned (reference counted). |
| 168 | +/// |
| 169 | +/// The TaskController can check whether a task is "running" by inspecting the |
| 170 | +/// refcount of the (current) tasks handles. Therefore, if that information |
| 171 | +/// is important, ensure that the handle is not dropped until the task fully |
| 172 | +/// completes. |
| 173 | +pub struct TaskHandle { |
| 174 | + shared: Arc<Shared>, |
| 175 | + generation: u32, |
| 176 | +} |
| 177 | + |
| 178 | +impl Clone for TaskHandle { |
| 179 | + fn clone(&self) -> Self { |
| 180 | + self.shared.inc_running(self.generation); |
| 181 | + TaskHandle { |
| 182 | + shared: self.shared.clone(), |
| 183 | + generation: self.generation, |
| 184 | + } |
| 185 | + } |
| 186 | +} |
| 187 | + |
| 188 | +impl Drop for TaskHandle { |
| 189 | + fn drop(&mut self) { |
| 190 | + self.shared.dec_running(self.generation); |
| 191 | + } |
| 192 | +} |
| 193 | + |
| 194 | +impl TaskHandle { |
| 195 | + /// Waits until [`TaskController::cancel`] is called for the corresponding |
| 196 | + /// [`TaskController`]. Immediately returns if `cancel` was already called since |
| 197 | + pub async fn canceled(&self) { |
| 198 | + let notified = self.shared.notify.notified(); |
| 199 | + if !self.is_canceled() { |
| 200 | + notified.await |
| 201 | + } |
| 202 | + } |
| 203 | + |
| 204 | + pub fn is_canceled(&self) -> bool { |
| 205 | + self.generation != self.shared.generation() |
| 206 | + } |
| 207 | +} |
| 208 | + |
| 209 | +#[cfg(test)] |
| 210 | +mod tests { |
| 211 | + use std::future::poll_fn; |
| 212 | + |
| 213 | + use futures_executor::block_on; |
| 214 | + use tokio::task::yield_now; |
| 215 | + |
| 216 | + use crate::{cancelable_future, TaskController}; |
| 217 | + |
| 218 | + #[test] |
| 219 | + fn immediate_cancel() { |
| 220 | + let mut controller = TaskController::new(); |
| 221 | + let handle = controller.restart(); |
| 222 | + controller.cancel(); |
| 223 | + assert!(handle.is_canceled()); |
| 224 | + controller.restart(); |
| 225 | + assert!(handle.is_canceled()); |
| 226 | + |
| 227 | + let res = block_on(cancelable_future( |
| 228 | + poll_fn(|_cx| std::task::Poll::Ready(())), |
| 229 | + handle, |
| 230 | + )); |
| 231 | + assert!(res.is_none()); |
| 232 | + } |
| 233 | + |
| 234 | + #[test] |
| 235 | + fn running_count() { |
| 236 | + let mut controller = TaskController::new(); |
| 237 | + let handle = controller.restart(); |
| 238 | + assert!(controller.is_running()); |
| 239 | + assert!(!handle.is_canceled()); |
| 240 | + drop(handle); |
| 241 | + assert!(!controller.is_running()); |
| 242 | + assert!(!controller.cancel()); |
| 243 | + let handle = controller.restart(); |
| 244 | + assert!(!handle.is_canceled()); |
| 245 | + assert!(controller.is_running()); |
| 246 | + let handle2 = handle.clone(); |
| 247 | + assert!(!handle.is_canceled()); |
| 248 | + assert!(controller.is_running()); |
| 249 | + drop(handle2); |
| 250 | + assert!(!handle.is_canceled()); |
| 251 | + assert!(controller.is_running()); |
| 252 | + assert!(controller.cancel()); |
| 253 | + assert!(handle.is_canceled()); |
| 254 | + assert!(!controller.is_running()); |
| 255 | + } |
| 256 | + |
| 257 | + #[test] |
| 258 | + fn no_cancel() { |
| 259 | + let mut controller = TaskController::new(); |
| 260 | + let handle = controller.restart(); |
| 261 | + assert!(!handle.is_canceled()); |
| 262 | + |
| 263 | + let res = block_on(cancelable_future( |
| 264 | + poll_fn(|_cx| std::task::Poll::Ready(())), |
| 265 | + handle, |
| 266 | + )); |
| 267 | + assert!(res.is_some()); |
| 268 | + } |
| 269 | + |
| 270 | + #[test] |
| 271 | + fn delayed_cancel() { |
| 272 | + let mut controller = TaskController::new(); |
| 273 | + let handle = controller.restart(); |
| 274 | + |
| 275 | + let mut hit = false; |
| 276 | + let res = block_on(cancelable_future( |
| 277 | + async { |
| 278 | + controller.cancel(); |
| 279 | + hit = true; |
| 280 | + yield_now().await; |
| 281 | + }, |
| 282 | + handle, |
| 283 | + )); |
| 284 | + assert!(res.is_none()); |
| 285 | + assert!(hit); |
| 286 | + } |
| 287 | +} |
0 commit comments