diff --git a/src/storage.rs b/src/storage.rs index 10341e0b8..27f1f7461 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,3 +1,4 @@ +//! Public API facades for the implementation details of [`Zalsa`] and [`ZalsaLocal`]. use std::{marker::PhantomData, panic::RefUnwindSafe, sync::Arc}; use parking_lot::{Condvar, Mutex}; @@ -8,6 +9,55 @@ use crate::{ Database, Event, EventKind, }; +/// A handle to non-local database state. +pub struct StorageHandle { + // Note: Drop order is important, zalsa_impl needs to drop before coordinate + /// Reference to the database. + zalsa_impl: Arc, + + // Note: Drop order is important, coordinate needs to drop after zalsa_impl + /// Coordination data for cancellation of other handles when `zalsa_mut` is called. + /// This could be stored in Zalsa but it makes things marginally cleaner to keep it separate. + coordinate: CoordinateDrop, + + /// We store references to `Db` + phantom: PhantomData Db>, +} + +impl Clone for StorageHandle { + fn clone(&self) -> Self { + *self.coordinate.clones.lock() += 1; + + Self { + zalsa_impl: self.zalsa_impl.clone(), + coordinate: CoordinateDrop(Arc::clone(&self.coordinate)), + phantom: PhantomData, + } + } +} + +impl Default for StorageHandle { + fn default() -> Self { + Self { + zalsa_impl: Arc::new(Zalsa::new::()), + coordinate: CoordinateDrop(Arc::new(Coordinate { + clones: Mutex::new(1), + cvar: Default::default(), + })), + phantom: PhantomData, + } + } +} + +impl StorageHandle { + pub fn into_storage(self) -> Storage { + Storage { + handle: self, + zalsa_local: ZalsaLocal::new(), + } + } +} + /// Access the "storage" of a Salsa database: this is an internal plumbing trait /// automatically implemented by `#[salsa::db]` applied to a struct. /// @@ -20,24 +70,14 @@ pub unsafe trait HasStorage: Database + Clone + Sized { fn storage_mut(&mut self) -> &mut Storage; } -/// Concrete implementation of the [`Database`][] trait. -/// Takes an optional type parameter `U` that allows you to thread your own data. -pub struct Storage { - // Note: Drop order is important, zalsa_impl needs to drop before coordinate - /// Reference to the database. - zalsa_impl: Arc, - - // Note: Drop order is important, coordinate needs to drop after zalsa_impl - /// Coordination data for cancellation of other handles when `zalsa_mut` is called. - /// This could be stored in Zalsa but it makes things marginally cleaner to keep it separate. - coordinate: CoordinateDrop, +/// Concrete implementation of the [`Database`] trait with local state that can be used to drive computations. +pub struct Storage { + handle: StorageHandle, /// Per-thread state zalsa_local: zalsa_local::ZalsaLocal, - - /// We store references to `Db` - phantom: PhantomData Db>, } + struct Coordinate { /// Counter of the number of clones of actor. Begins at 1. /// Incremented when cloned, decremented when dropped. @@ -45,21 +85,30 @@ struct Coordinate { cvar: Condvar, } +// We cannot panic while holding a lock to `clones: Mutex` and therefore we cannot enter an +// inconsistent state. +impl RefUnwindSafe for Coordinate {} + impl Default for Storage { fn default() -> Self { Self { - zalsa_impl: Arc::new(Zalsa::new::()), - coordinate: CoordinateDrop(Arc::new(Coordinate { - clones: Mutex::new(1), - cvar: Default::default(), - })), + handle: StorageHandle::default(), zalsa_local: ZalsaLocal::new(), - phantom: PhantomData, } } } impl Storage { + /// Discard the local state of this handle, turning it into a [`StorageHandle`] that is [`Sync`] + /// and [`std::panic::UnwindSafe`]. + pub fn into_zalsa_handle(self) -> StorageHandle { + let Storage { + handle, + zalsa_local: _, + } = self; + handle + } + // ANCHOR: cancel_other_workers /// Sets cancellation flag and blocks until all other workers with access /// to this storage have completed. @@ -67,13 +116,13 @@ impl Storage { /// This could deadlock if there is a single worker with two handles to the /// same database! fn cancel_others(&self, db: &Db) { - self.zalsa_impl.set_cancellation_flag(); + self.handle.zalsa_impl.set_cancellation_flag(); db.salsa_event(&|| Event::new(EventKind::DidSetCancellationFlag)); - let mut clones = self.coordinate.clones.lock(); + let mut clones = self.handle.coordinate.clones.lock(); while *clones != 1 { - self.coordinate.cvar.wait(&mut clones); + self.handle.coordinate.cvar.wait(&mut clones); } } // ANCHOR_END: cancel_other_workers @@ -81,7 +130,7 @@ impl Storage { unsafe impl ZalsaDatabase for T { fn zalsa(&self) -> &Zalsa { - &self.storage().zalsa_impl + &self.storage().handle.zalsa_impl } fn zalsa_mut(&mut self) -> &mut Zalsa { @@ -89,7 +138,7 @@ unsafe impl ZalsaDatabase for T { let storage = self.storage_mut(); // The ref count on the `Arc` should now be 1 - let zalsa_mut = Arc::get_mut(&mut storage.zalsa_impl).unwrap(); + let zalsa_mut = Arc::get_mut(&mut storage.handle.zalsa_impl).unwrap(); zalsa_mut.new_revision(); zalsa_mut } @@ -103,17 +152,11 @@ unsafe impl ZalsaDatabase for T { } } -impl RefUnwindSafe for Storage {} - impl Clone for Storage { fn clone(&self) -> Self { - *self.coordinate.clones.lock() += 1; - Self { - zalsa_impl: self.zalsa_impl.clone(), - coordinate: CoordinateDrop(Arc::clone(&self.coordinate)), + handle: self.handle.clone(), zalsa_local: ZalsaLocal::new(), - phantom: PhantomData, } } } diff --git a/src/zalsa.rs b/src/zalsa.rs index cfd775d0c..31134861c 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -3,6 +3,7 @@ use parking_lot::{Mutex, RwLock}; use rustc_hash::FxHashMap; use std::any::{Any, TypeId}; use std::marker::PhantomData; +use std::panic::RefUnwindSafe; use std::thread::ThreadId; use crate::cycle::CycleRecoveryStrategy; @@ -149,6 +150,13 @@ pub struct Zalsa { runtime: Runtime, } +// Our fields locked behind Mutices and RwLocks cannot enter an inconsistent state due to panics +// as they are all merely ID mappings with the exception of the `Runtime::dependency_graph`. +// `Runtime::dependency_graph` does not invoke user queries though and as such will not arbitrarily +// panic. The only way it may panic is by failing one of its asserts in which case we are already +// in a broken state anyways. +impl RefUnwindSafe for Zalsa {} + impl Zalsa { pub(crate) fn new() -> Self { Self { diff --git a/tests/check_auto_traits.rs b/tests/check_auto_traits.rs new file mode 100644 index 000000000..6e9c62c62 --- /dev/null +++ b/tests/check_auto_traits.rs @@ -0,0 +1,50 @@ +//! Test that auto trait impls exist as expected. + +use std::panic::UnwindSafe; + +use salsa::Database; +use test_log::test; + +#[salsa::input] +struct MyInput { + field: String, +} + +#[salsa::tracked] +struct MyTracked<'db> { + field: MyInterned<'db>, +} + +#[salsa::interned] +struct MyInterned<'db> { + field: String, +} + +#[salsa::tracked] +fn test(db: &dyn Database, input: MyInput) { + let input = is_send(is_sync(input)); + let interned = is_send(is_sync(MyInterned::new(db, input.field(db).clone()))); + let _tracked_struct = is_send(is_sync(MyTracked::new(db, interned))); +} + +fn is_send(t: T) -> T { + t +} + +fn is_sync(t: T) -> T { + t +} + +fn is_unwind_safe(t: T) -> T { + t +} + +#[test] +fn execute() { + let db = is_send(salsa::DatabaseImpl::new()); + let _handle = is_send(is_sync(is_unwind_safe( + db.storage().clone().into_zalsa_handle(), + ))); + let input = MyInput::new(&db, "Hello".to_string()); + test(&db, input); +} diff --git a/tests/is_send_sync.rs b/tests/is_send_sync.rs deleted file mode 100644 index 6ada1bacc..000000000 --- a/tests/is_send_sync.rs +++ /dev/null @@ -1,38 +0,0 @@ -//! Test that a setting a field on a `#[salsa::input]` -//! overwrites and returns the old value. - -use salsa::Database; -use test_log::test; - -#[salsa::input] -struct MyInput { - field: String, -} - -#[salsa::tracked] -struct MyTracked<'db> { - field: MyInterned<'db>, -} - -#[salsa::interned] -struct MyInterned<'db> { - field: String, -} - -#[salsa::tracked] -fn test(db: &dyn Database, input: MyInput) { - let input = is_send_sync(input); - let interned = is_send_sync(MyInterned::new(db, input.field(db).clone())); - let _tracked_struct = is_send_sync(MyTracked::new(db, interned)); -} - -fn is_send_sync(t: T) -> T { - t -} - -#[test] -fn execute() { - let db = salsa::DatabaseImpl::new(); - let input = MyInput::new(&db, "Hello".to_string()); - test(&db, input); -}