diff --git a/Cargo.toml b/Cargo.toml index c78fd2f93..7a4bb19bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,6 @@ crossbeam-channel = "0.5.14" name = "compare" harness = false - [[bench]] name = "incremental" harness = false @@ -56,5 +55,9 @@ harness = false name = "accumulator" harness = false +[[bench]] +name = "dataflow" +harness = false + [workspace] members = ["components/salsa-macro-rules", "components/salsa-macros"] diff --git a/benches/dataflow.rs b/benches/dataflow.rs new file mode 100644 index 000000000..d535046e9 --- /dev/null +++ b/benches/dataflow.rs @@ -0,0 +1,170 @@ +//! Benchmark for fixpoint iteration cycle resolution. +//! +//! This benchmark simulates a (very simplified) version of a real dataflow analysis using fixpoint +//! iteration. +use codspeed_criterion_compat::{criterion_group, criterion_main, BatchSize, Criterion}; +use salsa::{CycleRecoveryAction, Database as Db, Setter}; +use std::collections::BTreeSet; +use std::iter::IntoIterator; + +/// A Use of a symbol. +#[salsa::input] +struct Use { + reaching_definitions: Vec, +} + +/// A Definition of a symbol, either of the form `base + increment` or `0 + increment`. +#[salsa::input] +struct Definition { + base: Option, + increment: usize, +} + +#[derive(Eq, PartialEq, Clone, Debug, salsa::Update)] +enum Type { + Bottom, + Values(Box<[usize]>), + Top, +} + +impl Type { + fn join(tys: impl IntoIterator) -> Type { + let mut result = Type::Bottom; + for ty in tys.into_iter() { + result = match (result, ty) { + (result, Type::Bottom) => result, + (_, Type::Top) => Type::Top, + (Type::Top, _) => Type::Top, + (Type::Bottom, ty) => ty, + (Type::Values(a_ints), Type::Values(b_ints)) => { + let mut set = BTreeSet::new(); + set.extend(a_ints); + set.extend(b_ints); + Type::Values(set.into_iter().collect()) + } + } + } + result + } +} + +#[salsa::tracked(cycle_fn=use_cycle_recover, cycle_initial=use_cycle_initial)] +fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { + let defs = u.reaching_definitions(db); + match defs[..] { + [] => Type::Bottom, + [def] => infer_definition(db, def), + _ => Type::join(defs.iter().map(|&def| infer_definition(db, def))), + } +} + +#[salsa::tracked(cycle_fn=def_cycle_recover, cycle_initial=def_cycle_initial)] +fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { + let increment_ty = Type::Values(Box::from([def.increment(db)])); + if let Some(base) = def.base(db) { + let base_ty = infer_use(db, base); + add(&base_ty, &increment_ty) + } else { + increment_ty + } +} + +fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { + Type::Bottom +} + +fn def_cycle_recover( + _db: &dyn Db, + value: &Type, + count: u32, + _def: Definition, +) -> CycleRecoveryAction { + cycle_recover(value, count) +} + +fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { + Type::Bottom +} + +fn use_cycle_recover( + _db: &dyn Db, + value: &Type, + count: u32, + _use: Use, +) -> CycleRecoveryAction { + cycle_recover(value, count) +} + +fn cycle_recover(value: &Type, count: u32) -> CycleRecoveryAction { + match value { + Type::Bottom => CycleRecoveryAction::Iterate, + Type::Values(_) => { + if count > 4 { + CycleRecoveryAction::Fallback(Type::Top) + } else { + CycleRecoveryAction::Iterate + } + } + Type::Top => CycleRecoveryAction::Iterate, + } +} + +fn add(a: &Type, b: &Type) -> Type { + match (a, b) { + (Type::Bottom, _) | (_, Type::Bottom) => Type::Bottom, + (Type::Top, _) | (_, Type::Top) => Type::Top, + (Type::Values(a_ints), Type::Values(b_ints)) => { + let mut set = BTreeSet::new(); + set.extend( + a_ints + .into_iter() + .flat_map(|a| b_ints.into_iter().map(move |b| a + b)), + ); + Type::Values(set.into_iter().collect()) + } + } +} + +fn dataflow(criterion: &mut Criterion) { + criterion.bench_function("converge_diverge", |b| { + b.iter_batched_ref( + || { + let mut db = salsa::DatabaseImpl::new(); + + let defx0 = Definition::new(&db, None, 0); + let defy0 = Definition::new(&db, None, 0); + let defx1 = Definition::new(&db, None, 0); + let defy1 = Definition::new(&db, None, 0); + let use_x = Use::new(&db, vec![defx0, defx1]); + let use_y = Use::new(&db, vec![defy0, defy1]); + defx1.set_base(&mut db).to(Some(use_y)); + defy1.set_base(&mut db).to(Some(use_x)); + + // prewarm cache + let _ = infer_use(&db, use_x); + let _ = infer_use(&db, use_y); + + (db, defx1, use_x, use_y) + }, + |(db, defx1, use_x, use_y)| { + // Set the increment on x to 0. + defx1.set_increment(db).to(0); + + // Both symbols converge on 0. + assert_eq!(infer_use(db, *use_x), Type::Values(Box::from([0]))); + assert_eq!(infer_use(db, *use_y), Type::Values(Box::from([0]))); + + // Set the increment on x to 1. + defx1.set_increment(db).to(1); + + // Now the loop diverges and we fall back to Top. + assert_eq!(infer_use(db, *use_x), Type::Top); + assert_eq!(infer_use(db, *use_y), Type::Top); + }, + BatchSize::LargeInput, + ); + }); +} + +criterion_group!(benches, dataflow); +criterion_main!(benches); diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 68bc68e0a..ff443e814 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -37,6 +37,9 @@ macro_rules! setup_tracked_fn { // Path to the cycle recovery function to use. cycle_recovery_fn: ($($cycle_recovery_fn:tt)*), + // Path to function to get the initial value to use for cycle recovery. + cycle_recovery_initial: ($($cycle_recovery_initial:tt)*), + // Name of cycle recovery strategy variant to use. cycle_recovery_strategy: $cycle_recovery_strategy:ident, @@ -168,7 +171,7 @@ macro_rules! setup_tracked_fn { const CYCLE_STRATEGY: $zalsa::CycleRecoveryStrategy = $zalsa::CycleRecoveryStrategy::$cycle_recovery_strategy; - fn should_backdate_value( + fn values_equal( old_value: &Self::Output<'_>, new_value: &Self::Output<'_>, ) -> bool { @@ -176,7 +179,7 @@ macro_rules! setup_tracked_fn { if $no_eq { false } else { - $zalsa::should_backdate_value(old_value, new_value) + $zalsa::values_equal(old_value, new_value) } } } @@ -187,12 +190,17 @@ macro_rules! setup_tracked_fn { $inner($db, $($input_id),*) } + fn cycle_initial<$db_lt>(db: &$db_lt dyn $Db, ($($input_id),*): ($($input_ty),*)) -> Self::Output<$db_lt> { + $($cycle_recovery_initial)*(db, $($input_id),*) + } + fn recover_from_cycle<$db_lt>( db: &$db_lt dyn $Db, - cycle: &$zalsa::Cycle, + value: &Self::Output<$db_lt>, + count: u32, ($($input_id),*): ($($input_ty),*) - ) -> Self::Output<$db_lt> { - $($cycle_recovery_fn)*(db, cycle, $($input_id),*) + ) -> $zalsa::CycleRecoveryAction> { + $($cycle_recovery_fn)*(db, value, count, $($input_id),*) } fn id_to_input<$db_lt>(db: &$db_lt Self::DbView, key: salsa::Id) -> Self::Input<$db_lt> { diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index a8b8122b3..a1cd1e73f 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -3,11 +3,18 @@ // a macro because it can take a variadic number of arguments. #[macro_export] macro_rules! unexpected_cycle_recovery { - ($db:ident, $cycle:ident, $($other_inputs:ident),*) => { - { - std::mem::drop($db); - std::mem::drop(($($other_inputs),*)); - panic!("cannot recover from cycle `{:?}`", $cycle) - } - } + ($db:ident, $value:ident, $count:ident, $($other_inputs:ident),*) => {{ + std::mem::drop($db); + std::mem::drop(($($other_inputs),*)); + panic!("cannot recover from cycle") + }}; +} + +#[macro_export] +macro_rules! unexpected_cycle_initial { + ($db:ident, $($other_inputs:ident),*) => {{ + std::mem::drop($db); + std::mem::drop(($($other_inputs),*)); + panic!("no cycle initial value") + }}; } diff --git a/components/salsa-macros/src/accumulator.rs b/components/salsa-macros/src/accumulator.rs index e84bae121..2885e131b 100644 --- a/components/salsa-macros/src/accumulator.rs +++ b/components/salsa-macros/src/accumulator.rs @@ -40,7 +40,8 @@ impl AllowedOptions for Accumulator { const SINGLETON: bool = false; const DATA: bool = false; const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + const CYCLE_INITIAL: bool = false; const LRU: bool = false; const CONSTRUCTOR_NAME: bool = false; const ID: bool = false; diff --git a/components/salsa-macros/src/input.rs b/components/salsa-macros/src/input.rs index e3e560520..cc330a584 100644 --- a/components/salsa-macros/src/input.rs +++ b/components/salsa-macros/src/input.rs @@ -52,7 +52,9 @@ impl crate::options::AllowedOptions for InputStruct { const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + + const CYCLE_INITIAL: bool = false; const LRU: bool = false; diff --git a/components/salsa-macros/src/interned.rs b/components/salsa-macros/src/interned.rs index 30d89f8fb..dea7116ce 100644 --- a/components/salsa-macros/src/interned.rs +++ b/components/salsa-macros/src/interned.rs @@ -53,7 +53,9 @@ impl crate::options::AllowedOptions for InternedStruct { const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + + const CYCLE_INITIAL: bool = false; const LRU: bool = false; diff --git a/components/salsa-macros/src/options.rs b/components/salsa-macros/src/options.rs index a114f0c86..98089efe2 100644 --- a/components/salsa-macros/src/options.rs +++ b/components/salsa-macros/src/options.rs @@ -50,10 +50,15 @@ pub(crate) struct Options { /// If this is `Some`, the value is the ``. pub db_path: Option, - /// The `recovery_fn = ` option is used to indicate the recovery function. + /// The `cycle_fn = ` option is used to indicate the cycle recovery function. /// /// If this is `Some`, the value is the ``. - pub recovery_fn: Option, + pub cycle_fn: Option, + + /// The `cycle_initial = ` option is the initial value for cycle iteration. + /// + /// If this is `Some`, the value is the ``. + pub cycle_initial: Option, /// The `data = ` option is used to define the name of the data type for an interned /// struct. @@ -92,7 +97,8 @@ impl Default for Options { no_lifetime: Default::default(), no_clone: Default::default(), db_path: Default::default(), - recovery_fn: Default::default(), + cycle_fn: Default::default(), + cycle_initial: Default::default(), data: Default::default(), constructor_name: Default::default(), phantom: Default::default(), @@ -114,7 +120,8 @@ pub(crate) trait AllowedOptions { const SINGLETON: bool; const DATA: bool; const DB: bool; - const RECOVERY_FN: bool; + const CYCLE_FN: bool; + const CYCLE_INITIAL: bool; const LRU: bool; const CONSTRUCTOR_NAME: bool; const ID: bool; @@ -237,20 +244,39 @@ impl syn::parse::Parse for Options { "`db` option not allowed here", )); } - } else if ident == "recovery_fn" { - if A::RECOVERY_FN { + } else if ident == "cycle_fn" { + if A::CYCLE_FN { + let _eq = Equals::parse(input)?; + let path = syn::Path::parse(input)?; + if let Some(old) = std::mem::replace(&mut options.cycle_fn, Some(path)) { + return Err(syn::Error::new( + old.span(), + "option `cycle_fn` provided twice", + )); + } + } else { + return Err(syn::Error::new( + ident.span(), + "`cycle_fn` option not allowed here", + )); + } + } else if ident == "cycle_initial" { + if A::CYCLE_INITIAL { + // TODO(carljm) should it be an error to give cycle_initial without cycle_fn, + // or should we just allow this to fall into potentially infinite iteration, if + // iteration never converges? let _eq = Equals::parse(input)?; let path = syn::Path::parse(input)?; - if let Some(old) = std::mem::replace(&mut options.recovery_fn, Some(path)) { + if let Some(old) = std::mem::replace(&mut options.cycle_initial, Some(path)) { return Err(syn::Error::new( old.span(), - "option `recovery_fn` provided twice", + "option `cycle_initial` provided twice", )); } } else { return Err(syn::Error::new( ident.span(), - "`recovery_fn` option not allowed here", + "`cycle_initial` option not allowed here", )); } } else if ident == "data" { diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index c74a265ff..1a3151a73 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -41,7 +41,9 @@ impl crate::options::AllowedOptions for TrackedFn { const DB: bool = false; - const RECOVERY_FN: bool = true; + const CYCLE_FN: bool = true; + + const CYCLE_INITIAL: bool = true; const LRU: bool = true; @@ -72,9 +74,20 @@ impl Macro { let input_ids = self.input_ids(&item); let input_tys = self.input_tys(&item)?; let output_ty = self.output_ty(&db_lt, &item)?; - let (cycle_recovery_fn, cycle_recovery_strategy) = self.cycle_recovery(); + let (cycle_recovery_fn, cycle_recovery_initial, cycle_recovery_strategy) = + self.cycle_recovery()?; let is_specifiable = self.args.specify.is_some(); - let no_eq = self.args.no_eq.is_some(); + let no_eq = if let Some(token) = &self.args.no_eq { + if self.args.cycle_fn.is_some() { + return Err(syn::Error::new_spanned( + token, + "the `no_eq` option cannot be used with `cycle_fn`", + )); + } + true + } else { + false + }; let mut inner_fn = item.clone(); inner_fn.vis = syn::Visibility::Inherited; @@ -143,6 +156,7 @@ impl Macro { output_ty: #output_ty, inner_fn: { #inner_fn }, cycle_recovery_fn: #cycle_recovery_fn, + cycle_recovery_initial: #cycle_recovery_initial, cycle_recovery_strategy: #cycle_recovery_strategy, is_specifiable: #is_specifiable, no_eq: #no_eq, @@ -178,14 +192,28 @@ impl Macro { Ok(ValidFn { db_ident, db_path }) } - fn cycle_recovery(&self) -> (TokenStream, TokenStream) { - if let Some(recovery_fn) = &self.args.recovery_fn { - (quote!((#recovery_fn)), quote!(Fallback)) - } else { - ( + fn cycle_recovery(&self) -> syn::Result<(TokenStream, TokenStream, TokenStream)> { + // TODO should we ask the user to specify a struct that impls a trait with two methods, + // rather than asking for two methods separately? + match (&self.args.cycle_fn, &self.args.cycle_initial) { + (Some(cycle_fn), Some(cycle_initial)) => Ok(( + quote!((#cycle_fn)), + quote!((#cycle_initial)), + quote!(Fixpoint), + )), + (None, None) => Ok(( quote!((salsa::plumbing::unexpected_cycle_recovery!)), + quote!((salsa::plumbing::unexpected_cycle_initial!)), quote!(Panic), - ) + )), + (Some(_), None) => Err(syn::Error::new_spanned( + self.args.cycle_fn.as_ref().unwrap(), + "must provide `cycle_initial` along with `cycle_fn`", + )), + (None, Some(_)) => Err(syn::Error::new_spanned( + self.args.cycle_initial.as_ref().unwrap(), + "must provide `cycle_fn` along with `cycle_initial`", + )), } } diff --git a/components/salsa-macros/src/tracked_struct.rs b/components/salsa-macros/src/tracked_struct.rs index 996fbbcf1..92ee52de4 100644 --- a/components/salsa-macros/src/tracked_struct.rs +++ b/components/salsa-macros/src/tracked_struct.rs @@ -47,7 +47,9 @@ impl crate::options::AllowedOptions for TrackedStruct { const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + + const CYCLE_INITIAL: bool = false; const LRU: bool = false; diff --git a/src/accumulator.rs b/src/accumulator.rs index 9f67bdb39..3aedd89dc 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -11,7 +11,8 @@ use accumulated::AnyAccumulated; use crate::{ cycle::CycleRecoveryStrategy, - ingredient::{fmt_index, Ingredient, Jar, MaybeChangedAfter}, + function::VerifyResult, + ingredient::{fmt_index, Ingredient, Jar}, plumbing::JarAux, table::Table, zalsa::IngredientIndex, @@ -106,10 +107,18 @@ impl Ingredient for IngredientImpl { _db: &dyn Database, _input: Id, _revision: Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { panic!("nothing should ever depend on an accumulator directly") } + fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + + fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { + true + } + fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { CycleRecoveryStrategy::Panic } diff --git a/src/active_query.rs b/src/active_query.rs index fe4f7a351..9156d2e3f 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -7,11 +7,12 @@ use crate::tracked_struct::{DisambiguatorMap, IdentityHash, IdentityMap}; use crate::zalsa_local::QueryEdge; use crate::{ accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}, + cycle::CycleHeads, durability::Durability, hash::FxIndexSet, key::{DatabaseKeyIndex, InputDependencyIndex}, tracked_struct::Disambiguator, - Cycle, Revision, + Revision, }; #[derive(Debug)] @@ -37,9 +38,6 @@ pub(crate) struct ActiveQuery { /// True if there was an untracked read. untracked_read: bool, - /// Stores the entire cycle, if one is found and this query is part of it. - pub(crate) cycle: Option, - /// When new tracked structs are created, their data is hashed, and the resulting /// hash is added to this map. If it is not present, then the disambiguator is 0. /// Otherwise it is 1 more than the current value (which is incremented). @@ -60,6 +58,9 @@ pub(crate) struct ActiveQuery { /// [`InputAccumulatedValues::Empty`] if any input read during the query's execution /// has any accumulated values. pub(super) accumulated_inputs: InputAccumulatedValues, + + /// Provisional cycle results that this query depends on. + pub(crate) cycle_heads: CycleHeads, } impl ActiveQuery { @@ -70,11 +71,11 @@ impl ActiveQuery { changed_at: Revision::start(), input_outputs: FxIndexSet::default(), untracked_read: false, - cycle: None, disambiguator_map: Default::default(), tracked_struct_ids: Default::default(), accumulated: Default::default(), accumulated_inputs: Default::default(), + cycle_heads: Default::default(), } } @@ -84,11 +85,13 @@ impl ActiveQuery { durability: Durability, revision: Revision, accumulated: InputAccumulatedValues, + cycle_heads: &CycleHeads, ) { self.input_outputs.insert(QueryEdge::Input(input)); self.durability = self.durability.min(durability); self.changed_at = self.changed_at.max(revision); self.accumulated_inputs |= accumulated; + self.cycle_heads.extend(cycle_heads); } pub(super) fn add_untracked_read(&mut self, changed_at: Revision) { @@ -132,36 +135,10 @@ impl ActiveQuery { tracked_struct_ids: self.tracked_struct_ids, accumulated_inputs: AtomicInputAccumulatedValues::new(self.accumulated_inputs), accumulated, + cycle_heads: self.cycle_heads, } } - /// Adds any dependencies from `other` into `self`. - /// Used during cycle recovery, see [`Runtime::unblock_cycle_and_maybe_throw`]. - pub(super) fn add_from(&mut self, other: &ActiveQuery) { - self.changed_at = self.changed_at.max(other.changed_at); - self.durability = self.durability.min(other.durability); - self.untracked_read |= other.untracked_read; - self.input_outputs - .extend(other.input_outputs.iter().copied()); - } - - /// Removes the participants in `cycle` from my dependencies. - /// Used during cycle recovery, see [`Runtime::unblock_cycle_and_maybe_throw`]. - pub(super) fn remove_cycle_participants(&mut self, cycle: &Cycle) { - for p in cycle.participant_keys() { - let p: InputDependencyIndex = p.into(); - self.input_outputs.shift_remove(&QueryEdge::Input(p)); - } - } - - /// Copy the changed-at, durability, and dependencies from `cycle_query`. - /// Used during cycle recovery, see [`Runtime::unblock_cycle_and_maybe_throw`]. - pub(crate) fn take_inputs_from(&mut self, cycle_query: &ActiveQuery) { - self.changed_at = cycle_query.changed_at; - self.durability = cycle_query.durability; - self.input_outputs.clone_from(&cycle_query.input_outputs); - } - pub(super) fn disambiguate(&mut self, key: IdentityHash) -> Disambiguator { self.disambiguator_map.disambiguate(key) } diff --git a/src/cycle.rs b/src/cycle.rs index 8483a2857..9e0c47014 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -1,109 +1,111 @@ -use crate::{key::DatabaseKeyIndex, Database}; -use std::{panic::AssertUnwindSafe, sync::Arc}; +use crate::key::DatabaseKeyIndex; +use rustc_hash::FxHashSet; -/// Captures the participants of a cycle that occurred when executing a query. +/// The maximum number of times we'll fixpoint-iterate before panicking. /// -/// This type is meant to be used to help give meaningful error messages to the -/// user or to help salsa developers figure out why their program is resulting -/// in a computation cycle. -/// -/// It is used in a few ways: -/// -/// * During [cycle recovery](https://https://salsa-rs.github.io/salsa/cycles/fallback.html), -/// where it is given to the fallback function. -/// * As the panic value when an unexpected cycle (i.e., a cycle where one or more participants -/// lacks cycle recovery information) occurs. -/// -/// You can read more about cycle handling in -/// the [salsa book](https://https://salsa-rs.github.io/salsa/cycles.html). -#[derive(Clone, PartialEq, Eq, Hash)] -pub struct Cycle { - participants: CycleParticipants, +/// Should only be relevant in case of a badly configured cycle recovery. +pub const MAX_ITERATIONS: u32 = 200; + +/// Return value from a cycle recovery function. +#[derive(Debug)] +pub enum CycleRecoveryAction { + /// Iterate the cycle again to look for a fixpoint. + Iterate, + + /// Cut off iteration and use the given result value for this query. + Fallback(T), } -// We want `Cycle`` to be thin -pub(crate) type CycleParticipants = Arc>; +/// Cycle recovery strategy: Is this query capable of recovering from +/// a cycle that results from executing the function? If so, how? +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum CycleRecoveryStrategy { + /// Cannot recover from cycles: panic. + /// + /// This is the default. + Panic, + + /// Recovers from cycles by fixpoint iterating and/or falling + /// back to a sentinel value. + /// + /// This choice is computed by the query's `cycle_recovery` + /// function and initial value. + Fixpoint, +} -impl Cycle { - pub(crate) fn new(participants: CycleParticipants) -> Self { - Self { participants } - } +/// A "cycle head" is the query at which we encounter a cycle; that is, if A -> B -> C -> A, then A +/// would be the cycle head. It returns an "initial value" when the cycle is encountered (if +/// fixpoint iteration is enabled for that query), and then is responsible for re-iterating the +/// cycle until it converges. Any provisional value generated by any query in the cycle will track +/// the cycle head(s) (can be plural in case of nested cycles) representing the cycles it is part +/// of. This struct tracks these cycle heads. +#[derive(Clone, Debug, Default)] +pub(crate) struct CycleHeads(Option>>); - /// True if two `Cycle` values represent the same cycle. - pub(crate) fn is(&self, cycle: &Cycle) -> bool { - Arc::ptr_eq(&self.participants, &cycle.participants) +impl CycleHeads { + pub(crate) fn is_empty(&self) -> bool { + // We ensure in `remove` and `extend` that we never have an empty hashset, we always use + // None to signify empty. + self.0.is_none() } - pub(crate) fn throw(self) -> ! { - tracing::debug!("throwing cycle {:?}", self); - std::panic::resume_unwind(Box::new(self)) + pub(crate) fn contains(&self, value: &DatabaseKeyIndex) -> bool { + self.0.as_ref().is_some_and(|heads| heads.contains(value)) } - pub(crate) fn catch(execute: impl FnOnce() -> T) -> Result { - match std::panic::catch_unwind(AssertUnwindSafe(execute)) { - Ok(v) => Ok(v), - Err(err) => match err.downcast::() { - Ok(cycle) => Err(*cycle), - Err(other) => std::panic::resume_unwind(other), - }, + pub(crate) fn remove(&mut self, value: &DatabaseKeyIndex) -> bool { + if let Some(cycle_heads) = self.0.as_mut() { + let found = cycle_heads.remove(value); + if found && cycle_heads.is_empty() { + self.0.take(); + } + found + } else { + false } } +} - /// Iterate over the [`DatabaseKeyIndex`] for each query participating - /// in the cycle. The start point of this iteration within the cycle - /// is arbitrary but deterministic, but the ordering is otherwise determined - /// by the execution. - pub fn participant_keys(&self) -> impl Iterator + '_ { - self.participants.iter().copied() +impl std::iter::Extend for CycleHeads { + fn extend>(&mut self, iter: T) { + let mut iter = iter.into_iter(); + if let Some(first) = iter.next() { + let heads = self.0.get_or_insert(Box::new(FxHashSet::default())); + heads.insert(first); + heads.extend(iter) + } } +} + +impl std::iter::IntoIterator for CycleHeads { + type Item = DatabaseKeyIndex; + type IntoIter = std::collections::hash_set::IntoIter; - /// Returns a vector with the debug information for - /// all the participants in the cycle. - pub fn all_participants(&self, _db: &dyn Database) -> Vec { - self.participant_keys().collect() + fn into_iter(self) -> Self::IntoIter { + self.0.map(|heads| heads.into_iter()).unwrap_or_default() } +} + +impl<'a> std::iter::IntoIterator for &'a CycleHeads { + type Item = DatabaseKeyIndex; + type IntoIter = std::iter::Copied>; - /// Returns a vector with the debug information for - /// those participants in the cycle that lacked recovery - /// information. - pub fn unexpected_participants(&self, db: &dyn Database) -> Vec { - self.participant_keys() - .filter(|&d| d.cycle_recovery_strategy(db) == CycleRecoveryStrategy::Panic) - .collect() + fn into_iter(self) -> Self::IntoIter { + self.0 + .as_ref() + .map(|heads| heads.iter().copied()) + .unwrap_or_default() } } -impl std::fmt::Debug for Cycle { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - crate::attach::with_attached_database(|db| { - f.debug_struct("UnexpectedCycle") - .field("all_participants", &self.all_participants(db)) - .field("unexpected_participants", &self.unexpected_participants(db)) - .finish() - }) - .unwrap_or_else(|| { - f.debug_struct("Cycle") - .field("participants", &self.participants) - .finish() +impl From> for CycleHeads { + fn from(value: FxHashSet) -> Self { + Self(if value.is_empty() { + None + } else { + Some(Box::new(value)) }) } } -/// Cycle recovery strategy: Is this query capable of recovering from -/// a cycle that results from executing the function? If so, how? -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum CycleRecoveryStrategy { - /// Cannot recover from cycles: panic. - /// - /// This is the default. - /// - /// In the case of a failure due to a cycle, the panic - /// value will be the `Cycle`. - Panic, - - /// Recovers from cycles by storing a sentinel value. - /// - /// This value is computed by the query's `recovery_fn` - /// function. - Fallback, -} +pub(crate) static EMPTY_CYCLE_HEADS: CycleHeads = CycleHeads(None); diff --git a/src/function.rs b/src/function.rs index 1073496c6..dd1163de5 100644 --- a/src/function.rs +++ b/src/function.rs @@ -2,21 +2,24 @@ use std::{any::Any, fmt, mem::ManuallyDrop, sync::Arc}; use crate::{ accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}, - cycle::CycleRecoveryStrategy, - ingredient::{fmt_index, MaybeChangedAfter}, + cycle::{CycleRecoveryAction, CycleRecoveryStrategy}, + ingredient::fmt_index, key::DatabaseKeyIndex, plumbing::JarAux, salsa_struct::SalsaStructInDb, + table::sync::ClaimResult, table::Table, zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}, zalsa_local::QueryOrigin, - Cycle, Database, Id, Revision, + Database, Id, Revision, }; use self::delete::DeletedEntries; use super::ingredient::Ingredient; +pub(crate) use maybe_changed_after::VerifyResult; + mod accumulated; mod backdate; mod delete; @@ -50,13 +53,12 @@ pub trait Configuration: Any { /// (and, if so, how). const CYCLE_STRATEGY: CycleRecoveryStrategy; - /// Invokes after a new result `new_value`` has been computed for which an older memoized - /// value existed `old_value`. Returns true if the new value is equal to the older one - /// and hence should be "backdated" (i.e., marked as having last changed in an older revision, - /// even though it was recomputed). + /// Invokes after a new result `new_value`` has been computed for which an older memoized value + /// existed `old_value`, or in fixpoint iteration. Returns true if the new value is equal to + /// the older one. /// - /// This invokes user's code in form of the `Eq` impl. - fn should_backdate_value(old_value: &Self::Output<'_>, new_value: &Self::Output<'_>) -> bool; + /// This invokes user code in form of the `Eq` impl. + fn values_equal(old_value: &Self::Output<'_>, new_value: &Self::Output<'_>) -> bool; /// Convert from the id used internally to the value that execute is expecting. /// This is a no-op if the input to the function is a salsa struct. @@ -68,15 +70,18 @@ pub trait Configuration: Any { /// This invokes the function the user wrote. fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; - /// If the cycle strategy is `Fallback`, then invoked when `key` is a participant - /// in a cycle to find out what value it should have. - /// - /// This invokes the recovery function given by the user. + /// Get the cycle recovery initial value. + fn cycle_initial<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; + + /// Decide whether to iterate a cycle again or fallback. `value` is the provisional return + /// value from the latest iteration of this cycle. `count` is the number of cycle iterations + /// we've already completed. fn recover_from_cycle<'db>( db: &'db Self::DbView, - cycle: &Cycle, + value: &Self::Output<'db>, + count: u32, input: Self::Input<'db>, - ) -> Self::Output<'db>; + ) -> CycleRecoveryAction>; } /// Function ingredients are the "workhorse" of salsa. @@ -117,9 +122,9 @@ pub struct IngredientImpl { } /// True if `old_value == new_value`. Invoked by the generated -/// code for `should_backdate_value` so as to give a better +/// code for `values_equal` so as to give a better /// error message. -pub fn should_backdate_value(old_value: &V, new_value: &V) -> bool { +pub fn values_equal(old_value: &V, new_value: &V) -> bool { old_value == new_value } @@ -196,11 +201,29 @@ where db: &dyn Database, input: Id, revision: Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { let db = db.as_view::(); self.maybe_changed_after(db, input, revision) } + fn is_provisional_cycle_head<'db>(&'db self, db: &'db dyn Database, input: Id) -> bool { + self.get_memo_from_table_for(db.zalsa(), input) + .is_some_and(|memo| memo.cycle_heads().contains(&self.database_key_index(input))) + } + + fn wait_for(&self, db: &dyn Database, key_index: Id) -> bool { + let (zalsa, zalsa_local) = db.zalsas(); + match zalsa.sync_table_for(key_index).claim( + db, + zalsa_local, + self.database_key_index(key_index), + self.memo_ingredient_index, + ) { + ClaimResult::Retry | ClaimResult::Claimed(_) => true, + ClaimResult::Cycle => false, + } + } + fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { C::CYCLE_STRATEGY } diff --git a/src/function/backdate.rs b/src/function/backdate.rs index 322d2a904..25ef5c3e8 100644 --- a/src/function/backdate.rs +++ b/src/function/backdate.rs @@ -21,7 +21,7 @@ where // consumers must be aware of. Becoming *more* durable // is not. See the test `constant_to_non_constant`. if revisions.durability >= old_memo.revisions.durability - && C::should_backdate_value(old_value, value) + && C::values_equal(old_value, value) { tracing::debug!( "value is equal, back-dating to {:?}", diff --git a/src/function/execute.rs b/src/function/execute.rs index 3adbe4b08..caa08e6a2 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,7 +1,10 @@ use std::sync::Arc; use crate::{ - zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Cycle, Database, Event, EventKind, + cycle::{CycleRecoveryStrategy, MAX_ITERATIONS}, + zalsa::ZalsaDatabase, + zalsa_local::ActiveQueryGuard, + Database, Event, EventKind, }; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -22,12 +25,13 @@ where pub(super) fn execute<'db>( &'db self, db: &'db C::DbView, - active_query: ActiveQueryGuard<'_>, + mut active_query: ActiveQueryGuard<'db>, opt_old_memo: Option>>>, ) -> &'db Memo> { - let zalsa = db.zalsa(); + let (zalsa, zalsa_local) = db.zalsas(); let revision_now = zalsa.current_revision(); let database_key_index = active_query.database_key_index; + let id = database_key_index.key_index; tracing::info!("{:?}: executing query", database_key_index); @@ -37,53 +41,133 @@ where }) }); - // If we already executed this query once, then use the tracked-struct ids from the - // previous execution as the starting point for the new one. - if let Some(old_memo) = &opt_old_memo { - active_query.seed_tracked_struct_ids(&old_memo.revisions.tracked_struct_ids); - } + let mut iteration_count: u32 = 0; + let mut fell_back = false; - // Query was not previously executed, or value is potentially - // stale, or value is absent. Let's execute! - let database_key_index = active_query.database_key_index; - let id = database_key_index.key_index; - let value = match Cycle::catch(|| C::execute(db, C::id_to_input(db, id))) { - Ok(v) => v, - Err(cycle) => { + // Our provisional value from the previous iteration, when doing fixpoint iteration. + // Initially it's set to None, because the initial provisional value is created lazily, + // only when a cycle is actually encountered. + let mut opt_last_provisional: Option<&Memo<::Output<'db>>> = None; + + loop { + // If we already executed this query once, then use the tracked-struct ids from the + // previous execution as the starting point for the new one. + if let Some(old_memo) = &opt_old_memo { + active_query.seed_tracked_struct_ids(&old_memo.revisions.tracked_struct_ids); + } + + // Query was not previously executed, or value is potentially + // stale, or value is absent. Let's execute! + let mut new_value = C::execute(db, C::id_to_input(db, id)); + let mut revisions = active_query.pop(); + + // Did the new result we got depend on our own provisional value, in a cycle? + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::Fixpoint + && revisions.cycle_heads.contains(&database_key_index) + { + let opt_owned_last_provisional; + let last_provisional_value = if let Some(last_provisional) = opt_last_provisional { + // We have a last provisional value from our previous time around the loop. + last_provisional + .value + .as_ref() + .expect("provisional value should not be evicted by LRU") + } else { + // This is our first time around the loop; a provisional value must have been + // inserted into the memo table when the cycle was hit, so let's pull our + // initial provisional value from there. + opt_owned_last_provisional = self.get_memo_from_table_for(zalsa, id); + debug_assert!(opt_owned_last_provisional + .as_ref() + .unwrap() + .may_be_provisional()); + opt_owned_last_provisional + .as_deref() + .expect( + "{database_key_index:#?} is a cycle head, \ + but no provisional memo found", + ) + .value + .as_ref() + .expect("provisional value should not be evicted by LRU") + }; tracing::debug!( - "{database_key_index:?}: caught cycle {cycle:?}, have strategy {:?}", - C::CYCLE_STRATEGY + "{database_key_index:?}: execute: \ + I am a cycle head, comparing last provisional value with new value" ); - match C::CYCLE_STRATEGY { - crate::cycle::CycleRecoveryStrategy::Panic => cycle.throw(), - crate::cycle::CycleRecoveryStrategy::Fallback => { - if let Some(c) = active_query.take_cycle() { - assert!(c.is(&cycle)); - C::recover_from_cycle(db, &cycle, C::id_to_input(db, id)) - } else { - // we are not a participant in this cycle - debug_assert!(!cycle - .participant_keys() - .any(|k| k == database_key_index)); - cycle.throw() + // If the new result is equal to the last provisional result, the cycle has + // converged and we are done. + if !C::values_equal(&new_value, last_provisional_value) { + if fell_back { + // We fell back to a value last iteration, but the fallback didn't result + // in convergence. We only have bad options here: continue iterating + // (ignoring the request to fall back), or forcibly use the fallback and + // leave the cycle in an inconsistent state (we'll be using a value for + // this query that it doesn't evaluate to, given its inputs). Maybe we'll + // have to go with the latter, but for now let's panic and see if real use + // cases need non-converging fallbacks. + panic!("{database_key_index:?}: execute: fallback did not converge"); + } + // We are in a cycle that hasn't converged; ask the user's + // cycle-recovery function what to do: + match C::recover_from_cycle( + db, + &new_value, + iteration_count, + C::id_to_input(db, id), + ) { + crate::CycleRecoveryAction::Iterate => { + tracing::debug!("{database_key_index:?}: execute: iterate again"); } + crate::CycleRecoveryAction::Fallback(fallback_value) => { + tracing::debug!( + "{database_key_index:?}: execute: user cycle_fn says to fall back" + ); + new_value = fallback_value; + // We have to insert the fallback value for this query and then iterate + // one more time to fill in correct values for everything else in the + // cycle based on it; then we'll re-insert it as final value. + fell_back = true; + } + } + iteration_count = iteration_count + .checked_add(1) + .expect("fixpoint iteration should converge before u32::MAX iterations"); + if iteration_count > MAX_ITERATIONS { + panic!("{database_key_index:?}: execute: too many cycle iterations"); } + opt_last_provisional = Some(self.insert_memo( + zalsa, + id, + Memo::new(Some(new_value), revision_now, revisions), + )); + + active_query = zalsa_local.push_query(database_key_index); + + continue; } + tracing::debug!( + "{database_key_index:?}: execute: fixpoint iteration has a final value" + ); + revisions.cycle_heads.remove(&database_key_index); } - }; - let mut revisions = active_query.pop(); - // If the new value is equal to the old one, then it didn't - // really change, even if some of its inputs have. So we can - // "backdate" its `changed_at` revision to be the same as the - // old value. - if let Some(old_memo) = &opt_old_memo { - self.backdate_if_appropriate(old_memo, &mut revisions, &value); - self.diff_outputs(db, database_key_index, old_memo, &mut revisions); - } + tracing::debug!("{database_key_index:?}: execute: result.revisions = {revisions:#?}"); - tracing::debug!("{database_key_index:?}: read_upgrade: result.revisions = {revisions:#?}"); + // If the new value is equal to the old one, then it didn't + // really change, even if some of its inputs have. So we can + // "backdate" its `changed_at` revision to be the same as the + // old value. + if let Some(old_memo) = &opt_old_memo { + self.backdate_if_appropriate(old_memo, &mut revisions, &new_value); + self.diff_outputs(db, database_key_index, old_memo, &mut revisions); + } - self.insert_memo(zalsa, id, Memo::new(Some(value), revision_now, revisions)) + return self.insert_memo( + zalsa, + id, + Memo::new(Some(new_value), revision_now, revisions), + ); + } } } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 7828f33b9..3f553554b 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,7 +1,8 @@ -use super::{memo::Memo, Configuration, IngredientImpl}; +use super::{memo::Memo, Configuration, IngredientImpl, VerifyResult}; use crate::{ accumulator::accumulated_map::InputAccumulatedValues, runtime::StampedValue, - zalsa::ZalsaDatabase, AsDynDatabase as _, Id, + table::sync::ClaimResult, zalsa::ZalsaDatabase, zalsa_local::QueryRevisions, + AsDynDatabase as _, Id, }; impl IngredientImpl @@ -29,6 +30,7 @@ where Some(_) => InputAccumulatedValues::Any, None => memo.revisions.accumulated_inputs.load(), }, + memo.cycle_heads(), ); value @@ -42,7 +44,16 @@ where ) -> &'db Memo> { loop { if let Some(memo) = self.fetch_hot(db, id).or_else(|| self.fetch_cold(db, id)) { - return memo; + // If we get back a provisional cycle memo, and it's provisional on any cycle heads + // that are claimed by a different thread, we can't propagate the provisional memo + // any further (it could escape outside the cycle); we need to block on the other + // thread completing fixpoint iteration of the cycle, and then we can re-query for + // our no-longer-provisional memo. + if !(memo.may_be_provisional() + && memo.provisional_retry(db.as_dyn_database(), self.database_key_index(id))) + { + return memo; + } } } } @@ -53,7 +64,7 @@ where let memo_guard = self.get_memo_from_table_for(zalsa, id); if let Some(memo) = &memo_guard { if memo.value.is_some() - && self.shallow_verify_memo(db, zalsa, self.database_key_index(id), memo) + && self.shallow_verify_memo(db, zalsa, self.database_key_index(id), memo, false) { // Unsafety invariant: memo is present in memo_map and we have verified that it is // still valid for the current revision. @@ -68,27 +79,77 @@ where let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. - let _claim_guard = zalsa.sync_table_for(id).claim( + let _claim_guard = match zalsa.sync_table_for(id).claim( db.as_dyn_database(), zalsa_local, database_key_index, self.memo_ingredient_index, - )?; - - // Push the query on the stack. - let active_query = zalsa_local.push_query(database_key_index); + ) { + ClaimResult::Retry => return None, + ClaimResult::Cycle => { + // check if there's a provisional value for this query + let memo_guard = self.get_memo_from_table_for(zalsa, id); + if let Some(memo) = &memo_guard { + if memo.value.is_some() + && memo.revisions.cycle_heads.contains(&database_key_index) + && self.shallow_verify_memo(db, zalsa, database_key_index, memo, true) + { + // Unsafety invariant: memo is present in memo_map. + unsafe { + return Some(self.extend_memo_lifetime(memo)); + } + } + } + // no provisional value; create/insert/return initial provisional value + return self + .initial_value(db, database_key_index.key_index) + .map(|initial_value| { + tracing::debug!( + "hit cycle at {database_key_index:#?}, \ + inserting and returning fixpoint initial value" + ); + self.insert_memo( + zalsa, + id, + Memo::new( + Some(initial_value), + zalsa.current_revision(), + QueryRevisions::fixpoint_initial( + database_key_index, + zalsa.current_revision(), + ), + ), + ) + }) + .or_else(|| { + panic!( + "dependency graph cycle querying {database_key_index:#?}; \ + set cycle_fn/cycle_initial to fixpoint iterate" + ) + }); + } + ClaimResult::Claimed(guard) => guard, + }; // Now that we've claimed the item, check again to see if there's a "hot" value. - let zalsa = db.zalsa(); + let active_query = zalsa_local.push_query(database_key_index); let opt_old_memo = self.get_memo_from_table_for(zalsa, id); if let Some(old_memo) = &opt_old_memo { - if old_memo.value.is_some() && self.deep_verify_memo(db, old_memo, &active_query) { - // Unsafety invariant: memo is present in memo_map and we have verified that it is - // still valid for the current revision. - return unsafe { Some(self.extend_memo_lifetime(old_memo)) }; + if old_memo.value.is_some() { + if let VerifyResult::Unchanged(_, cycle_heads) = + self.deep_verify_memo(db, old_memo, &active_query) + { + if cycle_heads.is_empty() { + // Unsafety invariant: memo is present in memo_map and we have verified that it is + // still valid for the current revision. + return unsafe { Some(self.extend_memo_lifetime(old_memo)) }; + } + } } } - Some(self.execute(db, active_query, opt_old_memo)) + let memo = self.execute(db, active_query, opt_old_memo); + + Some(memo) } } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index c37b1de68..b30bb6e1c 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -1,14 +1,46 @@ use crate::{ accumulator::accumulated_map::InputAccumulatedValues, - ingredient::MaybeChangedAfter, + cycle::CycleRecoveryStrategy, key::DatabaseKeyIndex, + table::sync::ClaimResult, zalsa::{Zalsa, ZalsaDatabase}, zalsa_local::{ActiveQueryGuard, QueryEdge, QueryOrigin}, AsDynDatabase as _, Id, Revision, }; +use rustc_hash::FxHashSet; +use std::sync::atomic::Ordering; use super::{memo::Memo, Configuration, IngredientImpl}; +/// Result of memo validation. +pub enum VerifyResult { + /// Memo has changed and needs to be recomputed. + Changed, + + /// Memo remains valid. + /// + /// The first inner value tracks whether the memo or any of its dependencies have an + /// accumulated value. + /// + /// Database keys in the hashset represent cycle heads encountered in validation; don't mark + /// memos verified until we've iterated the full cycle to ensure no inputs changed. + Unchanged(InputAccumulatedValues, FxHashSet), +} + +impl VerifyResult { + pub(crate) fn changed_if(changed: bool) -> Self { + if changed { + Self::Changed + } else { + Self::unchanged() + } + } + + pub(crate) fn unchanged() -> Self { + Self::Unchanged(InputAccumulatedValues::Empty, FxHashSet::default()) + } +} + impl IngredientImpl where C: Configuration, @@ -18,7 +50,7 @@ where db: &'db C::DbView, id: Id, revision: Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { let (zalsa, zalsa_local) = db.zalsas(); zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database()); @@ -30,11 +62,14 @@ where // Check if we have a verified version: this is the hot path. let memo_guard = self.get_memo_from_table_for(zalsa, id); if let Some(memo) = &memo_guard { - if self.shallow_verify_memo(db, zalsa, database_key_index, memo) { + if self.shallow_verify_memo(db, zalsa, database_key_index, memo, false) { return if memo.revisions.changed_at > revision { - MaybeChangedAfter::Yes + VerifyResult::Changed } else { - MaybeChangedAfter::No(memo.revisions.accumulated_inputs.load()) + VerifyResult::Unchanged( + memo.revisions.accumulated_inputs.load(), + FxHashSet::default(), + ) }; } drop(memo_guard); // release the arc-swap guard before cold path @@ -45,7 +80,7 @@ where } } else { // No memo? Assume has changed. - return MaybeChangedAfter::Yes; + return VerifyResult::Changed; } } } @@ -55,21 +90,34 @@ where db: &'db C::DbView, key_index: Id, revision: Revision, - ) -> Option { + ) -> Option { let (zalsa, zalsa_local) = db.zalsas(); let database_key_index = self.database_key_index(key_index); - let _claim_guard = zalsa.sync_table_for(key_index).claim( + let _claim_guard = match zalsa.sync_table_for(key_index).claim( db.as_dyn_database(), zalsa_local, database_key_index, self.memo_ingredient_index, - )?; - let active_query = zalsa_local.push_query(database_key_index); - + ) { + ClaimResult::Retry => return None, + ClaimResult::Cycle => match C::CYCLE_STRATEGY { + CycleRecoveryStrategy::Panic => panic!( + "dependency graph cycle validating {database_key_index:#?}; \ + set cycle_fn/cycle_initial to fixpoint iterate" + ), + CycleRecoveryStrategy::Fixpoint => { + return Some(VerifyResult::Unchanged( + InputAccumulatedValues::Empty, + FxHashSet::from_iter([database_key_index]), + )); + } + }, + ClaimResult::Claimed(guard) => guard, + }; // Load the current memo, if any. let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index) else { - return Some(MaybeChangedAfter::Yes); + return Some(VerifyResult::Changed); }; tracing::debug!( @@ -79,11 +127,14 @@ where ); // Check if the inputs are still valid. We can just compare `changed_at`. - if self.deep_verify_memo(db, &old_memo, &active_query) { + let active_query = zalsa_local.push_query(database_key_index); + if let VerifyResult::Unchanged(_, cycle_heads) = + self.deep_verify_memo(db, &old_memo, &active_query) + { return Some(if old_memo.revisions.changed_at > revision { - MaybeChangedAfter::Yes + VerifyResult::Changed } else { - MaybeChangedAfter::No(old_memo.revisions.accumulated_inputs.load()) + VerifyResult::Unchanged(old_memo.revisions.accumulated_inputs.load(), cycle_heads) }); } @@ -96,21 +147,34 @@ where let changed_at = memo.revisions.changed_at; return Some(if changed_at > revision { - MaybeChangedAfter::Yes + VerifyResult::Changed } else { - MaybeChangedAfter::No(match &memo.revisions.accumulated { - Some(_) => InputAccumulatedValues::Any, - None => memo.revisions.accumulated_inputs.load(), - }) + VerifyResult::Unchanged( + match &memo.revisions.accumulated { + Some(_) => InputAccumulatedValues::Any, + None => memo.revisions.accumulated_inputs.load(), + }, + FxHashSet::default(), + ) }); } // Otherwise, nothing for it: have to consider the value to have changed. - Some(MaybeChangedAfter::Yes) + Some(VerifyResult::Changed) } /// True if the memo's value and `changed_at` time is still valid in this revision. /// Does only a shallow O(1) check, doesn't walk the dependencies. + /// + /// In general, a provisional memo (from cycle iteration) does not verify. Since we don't + /// eagerly finalize all provisional memos in cycle iteration, we have to lazily check here + /// (via `validate_provisional`) whether a may-be-provisional memo should actually be verified + /// final, because its cycle heads are all now final. + /// + /// If `allow_provisional` is `true`, don't check provisionality and return whatever memo we + /// find that can be verified in this revision, whether provisional or not. This only occurs at + /// one call-site, in `fetch_cold` when we actually encounter a cycle, and want to check if + /// there is an existing provisional memo we can reuse. #[inline] pub(super) fn shallow_verify_memo( &self, @@ -118,14 +182,23 @@ where zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, memo: &Memo>, + allow_provisional: bool, ) -> bool { - let verified_at = memo.verified_at.load(); - let revision_now = zalsa.current_revision(); - tracing::debug!( "{database_key_index:?}: shallow_verify_memo(memo = {memo:#?})", memo = memo.tracing_debug() ); + if !allow_provisional && memo.may_be_provisional() { + tracing::debug!( + "{database_key_index:?}: validate_provisional(memo = {memo:#?})", + memo = memo.tracing_debug() + ); + if !self.validate_provisional(db, zalsa, memo) { + return false; + } + } + let verified_at = memo.verified_at.load(); + let revision_now = zalsa.current_revision(); if verified_at == revision_now { // Already verified. @@ -148,20 +221,41 @@ where false } - /// True if the memo's value and `changed_at` time is up-to-date in the current - /// revision. When this returns true, it also updates the memo's `verified_at` - /// field if needed to make future calls cheaper. + /// Check if this memo's cycle heads have all been finalized. If so, mark it verified final and + /// return true, if not return false. + fn validate_provisional( + &self, + db: &C::DbView, + zalsa: &Zalsa, + memo: &Memo>, + ) -> bool { + for cycle_head in &memo.revisions.cycle_heads { + if zalsa + .lookup_ingredient(cycle_head.ingredient_index) + .is_provisional_cycle_head(db.as_dyn_database(), cycle_head.key_index) + { + return false; + } + } + // Relaxed is sufficient here because there are no other writes we need to ensure have + // happened before marking this memo as verified-final. + memo.verified_final.store(true, Ordering::Relaxed); + true + } + + /// VerifyResult::Unchanged if the memo's value and `changed_at` time is up-to-date in the + /// current revision. When this returns Unchanged with no cycle heads, it also updates the + /// memo's `verified_at` field if needed to make future calls cheaper. /// /// Takes an [`ActiveQueryGuard`] argument because this function recursively /// walks dependencies of `old_memo` and may even execute them to see if their - /// outputs have changed. As that could lead to cycles, it is important that the - /// query is on the stack. + /// outputs have changed. pub(super) fn deep_verify_memo( &self, db: &C::DbView, old_memo: &Memo>, active_query: &ActiveQueryGuard<'_>, - ) -> bool { + ) -> VerifyResult { let zalsa = db.zalsa(); let database_key_index = active_query.database_key_index; @@ -170,88 +264,136 @@ where old_memo = old_memo.tracing_debug() ); - if self.shallow_verify_memo(db, zalsa, database_key_index, old_memo) { - return true; + if self.shallow_verify_memo(db, zalsa, database_key_index, old_memo, false) { + return VerifyResult::Unchanged(InputAccumulatedValues::Empty, Default::default()); + } + if old_memo.may_be_provisional() { + return VerifyResult::Changed; } - let inputs = match &old_memo.revisions.origin { - QueryOrigin::Assigned(_) => { - // If the value was assigneed by another query, - // and that query were up-to-date, - // then we would have updated the `verified_at` field already. - // So the fact that we are here means that it was not specified - // during this revision or is otherwise stale. - // - // Example of how this can happen: - // - // Conditionally specified queries - // where the value is specified - // in rev 1 but not in rev 2. - return false; - } - QueryOrigin::BaseInput => { - // This value was `set` by the mutator thread -- ie, it's a base input and it cannot be out of date. - return true; - } - QueryOrigin::DerivedUntracked(_) => { - // Untracked inputs? Have to assume that it changed. - return false; - } - QueryOrigin::Derived(edges) => { - // Fully tracked inputs? Iterate over the inputs and check them, one by one. - // - // NB: It's important here that we are iterating the inputs in the order that - // they executed. It's possible that if the value of some input I0 is no longer - // valid, then some later input I1 might never have executed at all, so verifying - // it is still up to date is meaningless. - let last_verified_at = old_memo.verified_at.load(); - let mut inputs = InputAccumulatedValues::Empty; - for &edge in edges.input_outputs.iter() { - match edge { - QueryEdge::Input(dependency_index) => { - match dependency_index - .maybe_changed_after(db.as_dyn_database(), last_verified_at) - { - MaybeChangedAfter::Yes => { - return false; - } - MaybeChangedAfter::No(input_accumulated) => { - inputs |= input_accumulated; + let mut cycle_heads = FxHashSet::default(); + loop { + let inputs = match &old_memo.revisions.origin { + QueryOrigin::Assigned(_) => { + // If the value was assigneed by another query, + // and that query were up-to-date, + // then we would have updated the `verified_at` field already. + // So the fact that we are here means that it was not specified + // during this revision or is otherwise stale. + // + // Example of how this can happen: + // + // Conditionally specified queries + // where the value is specified + // in rev 1 but not in rev 2. + return VerifyResult::Changed; + } + QueryOrigin::BaseInput | QueryOrigin::FixpointInitial => { + // This value was `set` by the mutator thread -- ie, it's a base input and it cannot be out of date. + return VerifyResult::unchanged(); + } + QueryOrigin::DerivedUntracked(_) => { + // Untracked inputs? Have to assume that it changed. + return VerifyResult::Changed; + } + QueryOrigin::Derived(edges) => { + // Fully tracked inputs? Iterate over the inputs and check them, one by one. + // + // NB: It's important here that we are iterating the inputs in the order that + // they executed. It's possible that if the value of some input I0 is no longer + // valid, then some later input I1 might never have executed at all, so verifying + // it is still up to date is meaningless. + let last_verified_at = old_memo.verified_at.load(); + let mut inputs = InputAccumulatedValues::Empty; + for &edge in edges.input_outputs.iter() { + match edge { + QueryEdge::Input(dependency_index) => { + match dependency_index + .maybe_changed_after(db.as_dyn_database(), last_verified_at) + { + VerifyResult::Changed => return VerifyResult::Changed, + VerifyResult::Unchanged(input_accumulated, cycles) => { + cycle_heads.extend(cycles); + inputs |= input_accumulated; + } } } - } - QueryEdge::Output(dependency_index) => { - // Subtle: Mark outputs as validated now, even though we may - // later find an input that requires us to re-execute the function. - // Even if it re-execute, the function will wind up writing the same value, - // since all prior inputs were green. It's important to do this during - // this loop, because it's possible that one of our input queries will - // re-execute and may read one of our earlier outputs - // (e.g., in a scenario where we do something like - // `e = Entity::new(..); query(e);` and `query` reads a field of `e`). - // - // NB. Accumulators are also outputs, but the above logic doesn't - // quite apply to them. Since multiple values are pushed, the first value - // may be unchanged, but later values could be different. - // In that case, however, the data accumulated - // by this function cannot be read until this function is marked green, - // so even if we mark them as valid here, the function will re-execute - // and overwrite the contents. - dependency_index - .mark_validated_output(db.as_dyn_database(), database_key_index); + QueryEdge::Output(dependency_index) => { + // Subtle: Mark outputs as validated now, even though we may + // later find an input that requires us to re-execute the function. + // Even if it re-execute, the function will wind up writing the same value, + // since all prior inputs were green. It's important to do this during + // this loop, because it's possible that one of our input queries will + // re-execute and may read one of our earlier outputs + // (e.g., in a scenario where we do something like + // `e = Entity::new(..); query(e);` and `query` reads a field of `e`). + // + // NB. Accumulators are also outputs, but the above logic doesn't + // quite apply to them. Since multiple values are pushed, the first value + // may be unchanged, but later values could be different. + // In that case, however, the data accumulated + // by this function cannot be read until this function is marked green, + // so even if we mark them as valid here, the function will re-execute + // and overwrite the contents. + // + // TODO not if we found a cycle head other than ourself? + dependency_index.mark_validated_output( + db.as_dyn_database(), + database_key_index, + ); + } } } + inputs } - inputs - } - }; + }; - old_memo.mark_as_verified( - db.as_dyn_database(), - zalsa.current_revision(), - database_key_index, - inputs, - ); - true + // Possible scenarios here: + // + // 1. Cycle heads is empty. We traversed our full dependency graph and neither hit any + // cycles, nor found any changed dependencies. We can mark our memo verified and + // return Unchanged with empty cycle heads. + // + // 2. Cycle heads is non-empty, and does not contain our own key index. We are part of + // a cycle, and since we don't know if some other cycle participant that hasn't been + // traversed yet (that is, some other dependency of the cycle head, which is only a + // dependency of ours via the cycle) might still have changed, we can't yet mark our + // memo verified. We can return a provisional Unchanged, with cycle heads. + // + // 3. Cycle heads is non-empty, and contains only our own key index. We are the head of + // a cycle, and we've now traversed the entire cycle and found no changes, but no + // other cycle participants were verified (they would have all hit case 2 above). We + // can now safely mark our own memo as verified. Then we have to traverse the entire + // cycle again. This time, since our own memo is verified, there will be no cycle + // encountered, and the rest of the cycle will be able to verify itself. + // + // 4. Cycle heads is non-empty, and contains our own key index as well as other key + // indices. We are the head of a cycle nested within another cycle. We can't mark + // our own memo verified (for the same reason as in case 2: the full outer cycle + // hasn't been validated unchanged yet). We return Unchanged, with ourself removed + // from cycle heads. We will handle our own memo (and the rest of our cycle) on a + // future iteration; first the outer cycle head needs to verify itself. + + let in_heads = cycle_heads.remove(&database_key_index); + + if cycle_heads.is_empty() { + old_memo.mark_as_verified( + db.as_dyn_database(), + zalsa.current_revision(), + database_key_index, + inputs, + ); + + if in_heads { + // Iterate our dependency graph again, starting from the top. We clear the + // cycle heads here because we are starting a fresh traversal. (It might be + // logically clearer to create a new HashSet each time, but clearing the + // existing one is more efficient.) + cycle_heads.clear(); + continue; + } + } + return VerifyResult::Unchanged(InputAccumulatedValues::Empty, cycle_heads); + } } } diff --git a/src/function/memo.rs b/src/function/memo.rs index d24151756..0b7852196 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -2,6 +2,7 @@ use std::any::Any; use std::fmt::Debug; use std::fmt::Formatter; use std::mem::ManuallyDrop; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use crate::accumulator::accumulated_map::InputAccumulatedValues; @@ -9,8 +10,11 @@ use crate::revision::AtomicRevision; use crate::table::memo::MemoTable; use crate::zalsa_local::QueryOrigin; use crate::{ - key::DatabaseKeyIndex, zalsa::Zalsa, zalsa_local::QueryRevisions, Event, EventKind, Id, - Revision, + cycle::{CycleHeads, CycleRecoveryStrategy, EMPTY_CYCLE_HEADS}, + key::DatabaseKeyIndex, + zalsa::Zalsa, + zalsa_local::QueryRevisions, + Event, EventKind, Id, Revision, }; use super::{Configuration, IngredientImpl}; @@ -75,7 +79,8 @@ impl IngredientImpl { match &memo.revisions.origin { QueryOrigin::Assigned(_) | QueryOrigin::DerivedUntracked(_) - | QueryOrigin::BaseInput => { + | QueryOrigin::BaseInput + | QueryOrigin::FixpointInitial => { // Careful: Cannot evict memos whose values were // assigned as output of another query // or those with untracked inputs @@ -93,6 +98,7 @@ impl IngredientImpl { ref tracked_struct_ids, ref accumulated, ref accumulated_inputs, + ref cycle_heads, } = &memo.revisions; // Re-assemble the memo but with the value set to `None` Arc::new(Memo::new( @@ -105,6 +111,7 @@ impl IngredientImpl { tracked_struct_ids: tracked_struct_ids.clone(), accumulated: accumulated.clone(), accumulated_inputs: accumulated_inputs.clone(), + cycle_heads: cycle_heads.clone(), }, )) } @@ -118,6 +125,17 @@ impl IngredientImpl { self.deleted_entries.push(ManuallyDrop::into_inner(old)); } } + + pub(super) fn initial_value<'db>( + &'db self, + db: &'db C::DbView, + key: Id, + ) -> Option> { + match C::CYCLE_STRATEGY { + CycleRecoveryStrategy::Fixpoint => Some(C::cycle_initial(db, C::id_to_input(db, key))), + CycleRecoveryStrategy::Panic => None, + } + } } #[derive(Debug)] @@ -129,6 +147,9 @@ pub(super) struct Memo { /// as the current revision. pub(super) verified_at: AtomicRevision, + /// Is this memo verified to not be a provisional cycle result? + pub(super) verified_final: AtomicBool, + /// Revision information pub(super) revisions: QueryRevisions, } @@ -136,16 +157,82 @@ pub(super) struct Memo { // Memo's are stored a lot, make sure their size is doesn't randomly increase. // #[cfg(test)] const _: [(); std::mem::size_of::>()] = - [(); std::mem::size_of::<[usize; 12]>()]; + [(); std::mem::size_of::<[usize; 14]>()]; impl Memo { pub(super) fn new(value: Option, revision_now: Revision, revisions: QueryRevisions) -> Self { Memo { value, verified_at: AtomicRevision::from(revision_now), + verified_final: AtomicBool::new(revisions.cycle_heads.is_empty()), revisions, } } + + /// True if this may be a provisional cycle-iteration result. + #[inline] + pub(super) fn may_be_provisional(&self) -> bool { + // Relaxed is OK here, because `verified_final` is only ever mutated in one direction (from + // `false` to `true`), and changing it to `true` on memos with cycle heads where it was + // ever `false` is purely an optimization; if we read an out-of-date `false`, it just means + // we might go validate it again unnecessarily. + !self.verified_final.load(Ordering::Relaxed) + } + + /// Invoked when `refresh_memo` is about to return a memo to the caller; if that memo is + /// provisional, and its cycle head is claimed by another thread, we need to wait for that + /// other thread to complete the fixpoint iteration, and then retry fetching our own memo. + /// + /// Return `true` if the caller should retry, `false` if the caller should go ahead and return + /// this memo to the caller. + pub(super) fn provisional_retry( + &self, + db: &dyn crate::Database, + database_key_index: DatabaseKeyIndex, + ) -> bool { + let mut retry = false; + for head in self.cycle_heads() { + if head == database_key_index { + continue; + } + let ingredient = db.zalsa().lookup_ingredient(head.ingredient_index); + if !ingredient.is_provisional_cycle_head(db, head.key_index) { + // This cycle is already finalized, so we don't need to wait on it; + // keep looping through cycle heads. + retry = true; + continue; + } + if ingredient.wait_for(db.as_dyn_database(), head.key_index) { + // There's a new memo available for the cycle head; fetch our own + // updated memo and see if it's still provisional or if the cycle + // has resolved. + retry = true; + continue; + } else { + // We hit a cycle blocking on the cycle head; this means it's in + // our own active query stack and we are responsible to resolve the + // cycle, so go ahead and return the provisional memo. + return false; + } + } + // If `retry` is `true`, all our cycle heads (barring ourself) are complete; re-fetch + // and we should get a non-provisional memo. If we get here and `retry` is still + // `false`, we have no cycle heads other than ourself, so we are a provisional value of + // the cycle head (either initial value, or from a later iteration) and should be + // returned to caller to allow fixpoint iteration to proceed. (All cases in the loop + // above other than "cycle head is self" are either terminal or set `retry`.) + retry + } + + /// Cycle heads that should be propagated to dependent queries. + pub(super) fn cycle_heads(&self) -> &CycleHeads { + if self.may_be_provisional() { + &self.revisions.cycle_heads + } else { + &EMPTY_CYCLE_HEADS + } + } + /// True if this memo is known not to have changed based on its durability. pub(super) fn check_durability(&self, zalsa: &Zalsa) -> bool { let last_changed = zalsa.last_changed_revision(self.revisions.durability); @@ -205,6 +292,7 @@ impl Memo { }, ) .field("verified_at", &self.memo.verified_at) + .field("verified_final", &self.memo.verified_final) .field("revisions", &self.memo.revisions) .finish() } diff --git a/src/function/specify.rs b/src/function/specify.rs index 5a9187c70..2c3da510c 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -1,3 +1,5 @@ +use std::sync::atomic::AtomicBool; + use crate::{ accumulator::accumulated_map::InputAccumulatedValues, revision::AtomicRevision, @@ -71,6 +73,7 @@ where tracked_struct_ids: Default::default(), accumulated: Default::default(), accumulated_inputs: Default::default(), + cycle_heads: Default::default(), }; if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key) { @@ -81,6 +84,7 @@ where let memo = Memo { value: Some(value), verified_at: AtomicRevision::from(revision), + verified_final: AtomicBool::new(true), revisions, }; diff --git a/src/ingredient.rs b/src/ingredient.rs index a063d981c..e70e2fb22 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -6,6 +6,7 @@ use std::{ use crate::{ accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}, cycle::CycleRecoveryStrategy, + function::VerifyResult, table::Table, zalsa::{IngredientIndex, MemoIngredientIndex}, zalsa_local::QueryOrigin, @@ -62,7 +63,20 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { db: &'db dyn Database, input: Id, revision: Revision, - ) -> MaybeChangedAfter; + ) -> VerifyResult; + + /// Is the value for `input` in this ingredient a cycle head that is still provisional? + /// + /// In the case of nested cycles, we are not asking here whether the value is provisional due + /// to the outer cycle being unresolved, only whether its own cycle remains provisional. + fn is_provisional_cycle_head<'db>(&'db self, db: &'db dyn Database, input: Id) -> bool; + + /// Invoked when the current thread needs to wait for a result for the given `key_index`. + /// + /// A return value of `true` indicates that a result is now available. A return value of + /// `false` means that a cycle was encountered; the waited-on query is either already claimed + /// by the current thread, or by a thread waiting on the current thread. + fn wait_for(&self, db: &dyn Database, key_index: Id) -> bool; /// What were the inputs (if any) that were used to create the value at `key_index`. fn origin(&self, db: &dyn Database, key_index: Id) -> Option; @@ -176,23 +190,3 @@ pub(crate) fn fmt_index( write!(fmt, "{debug_name}()") } } - -#[derive(Copy, Clone, Debug)] -pub enum MaybeChangedAfter { - /// The query result hasn't changed. - /// - /// The inner value tracks whether the memo or any of its dependencies have an accumulated value. - No(InputAccumulatedValues), - - /// The query's result has changed since the last revision or the query isn't cached yet. - Yes, -} - -impl From for MaybeChangedAfter { - fn from(value: bool) -> Self { - match value { - true => MaybeChangedAfter::Yes, - false => MaybeChangedAfter::No(InputAccumulatedValues::Empty), - } - } -} diff --git a/src/input.rs b/src/input.rs index 434293302..6cf9604e6 100644 --- a/src/input.rs +++ b/src/input.rs @@ -12,9 +12,10 @@ use input_field::FieldIngredientImpl; use crate::{ accumulator::accumulated_map::InputAccumulatedValues, - cycle::CycleRecoveryStrategy, + cycle::{CycleRecoveryStrategy, EMPTY_CYCLE_HEADS}, + function::VerifyResult, id::{AsId, FromId}, - ingredient::{fmt_index, Ingredient, MaybeChangedAfter}, + ingredient::{fmt_index, Ingredient}, input::singleton::{Singleton, SingletonChoice}, key::{DatabaseKeyIndex, InputDependencyIndex}, plumbing::{Jar, JarAux, Stamp}, @@ -180,6 +181,7 @@ impl IngredientImpl { stamp.durability, stamp.changed_at, InputAccumulatedValues::Empty, + &EMPTY_CYCLE_HEADS, ); &value.fields } @@ -218,10 +220,18 @@ impl Ingredient for IngredientImpl { _db: &dyn Database, _input: Id, _revision: Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { // Input ingredients are just a counter, they store no data, they are immortal. // Their *fields* are stored in function ingredients elsewhere. - MaybeChangedAfter::No(InputAccumulatedValues::Empty) + VerifyResult::unchanged() + } + + fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + + fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { + true } fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { diff --git a/src/input/input_field.rs b/src/input/input_field.rs index b22d03b2d..a9a8c33cd 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -1,5 +1,6 @@ use crate::cycle::CycleRecoveryStrategy; -use crate::ingredient::{fmt_index, Ingredient, MaybeChangedAfter}; +use crate::function::VerifyResult; +use crate::ingredient::{fmt_index, Ingredient}; use crate::input::Configuration; use crate::table::Table; use crate::zalsa::IngredientIndex; @@ -55,11 +56,18 @@ where db: &dyn Database, input: Id, revision: Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { let zalsa = db.zalsa(); let value = >::data(zalsa, input); + VerifyResult::changed_if(value.stamps[self.field_index].changed_at > revision) + } + + fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } - MaybeChangedAfter::from(value.stamps[self.field_index].changed_at > revision) + fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { + true } fn origin(&self, _db: &dyn Database, _key_index: Id) -> Option { diff --git a/src/interned.rs b/src/interned.rs index c31adeac3..7bc59074b 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -1,8 +1,10 @@ use dashmap::SharedValue; use crate::accumulator::accumulated_map::InputAccumulatedValues; +use crate::cycle::EMPTY_CYCLE_HEADS; use crate::durability::Durability; -use crate::ingredient::{fmt_index, MaybeChangedAfter}; +use crate::function::VerifyResult; +use crate::ingredient::fmt_index; use crate::key::InputDependencyIndex; use crate::plumbing::{Jar, JarAux}; use crate::table::memo::MemoTable; @@ -183,6 +185,7 @@ where Durability::MAX, self.reset_at, InputAccumulatedValues::Empty, + &EMPTY_CYCLE_HEADS, ); // Optimization to only get read lock on the map if the data has already been interned. @@ -287,8 +290,16 @@ where _db: &dyn Database, _input: Id, revision: Revision, - ) -> MaybeChangedAfter { - MaybeChangedAfter::from(revision < self.reset_at) + ) -> VerifyResult { + VerifyResult::changed_if(revision < self.reset_at) + } + + fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + + fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { + true } fn cycle_recovery_strategy(&self) -> crate::cycle::CycleRecoveryStrategy { diff --git a/src/key.rs b/src/key.rs index 8fd159520..5509cd5c2 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,9 +1,6 @@ use core::fmt; -use crate::{ - accumulator::accumulated_map::InputAccumulatedValues, cycle::CycleRecoveryStrategy, - ingredient::MaybeChangedAfter, zalsa::IngredientIndex, Database, Id, -}; +use crate::{function::VerifyResult, zalsa::IngredientIndex, Database, Id}; /// An integer that uniquely identifies a particular query instance within the /// database. Used to track output dependencies between queries. Fully ordered and @@ -89,14 +86,14 @@ impl InputDependencyIndex { &self, db: &dyn Database, last_verified_at: crate::Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { match self.key_index { Some(key_index) => db .zalsa() .lookup_ingredient(self.ingredient_index) .maybe_changed_after(db, key_index, last_verified_at), // Data in tables themselves remain valid until the table as a whole is reset. - None => MaybeChangedAfter::No(InputAccumulatedValues::Empty), + None => VerifyResult::unchanged(), } } @@ -139,10 +136,6 @@ impl DatabaseKeyIndex { pub fn key_index(self) -> Id { self.key_index } - - pub(crate) fn cycle_recovery_strategy(self, db: &dyn Database) -> CycleRecoveryStrategy { - self.ingredient_index.cycle_recovery_strategy(db) - } } impl std::fmt::Debug for DatabaseKeyIndex { diff --git a/src/lib.rs b/src/lib.rs index 7b029f735..e900f5a75 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,7 +30,7 @@ mod zalsa_local; pub use self::accumulator::Accumulator; pub use self::cancelled::Cancelled; -pub use self::cycle::Cycle; +pub use self::cycle::CycleRecoveryAction; pub use self::database::AsDynDatabase; pub use self::database::Database; pub use self::database_impl::DatabaseImpl; @@ -70,11 +70,11 @@ pub mod plumbing { pub use crate::array::Array; pub use crate::attach::attach; pub use crate::attach::with_attached_database; - pub use crate::cycle::Cycle; + pub use crate::cycle::CycleRecoveryAction; pub use crate::cycle::CycleRecoveryStrategy; pub use crate::database::current_revision; pub use crate::database::Database; - pub use crate::function::should_backdate_value; + pub use crate::function::values_equal; pub use crate::id::AsId; pub use crate::id::FromId; pub use crate::id::Id; @@ -114,6 +114,7 @@ pub mod plumbing { pub use salsa_macro_rules::setup_method_body; pub use salsa_macro_rules::setup_tracked_fn; pub use salsa_macro_rules::setup_tracked_struct; + pub use salsa_macro_rules::unexpected_cycle_initial; pub use salsa_macro_rules::unexpected_cycle_recovery; pub mod accumulator { diff --git a/src/runtime.rs b/src/runtime.rs index f28b93d98..16da23673 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,19 +1,14 @@ use std::{ mem, - panic::panic_any, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::atomic::{AtomicBool, Ordering}, thread::ThreadId, }; use parking_lot::Mutex; use crate::{ - active_query::ActiveQuery, cycle::CycleRecoveryStrategy, durability::Durability, - key::DatabaseKeyIndex, table::Table, zalsa_local::ZalsaLocal, Cancelled, Cycle, Database, - Event, EventKind, Revision, + durability::Durability, key::DatabaseKeyIndex, table::Table, zalsa_local::ZalsaLocal, + Cancelled, Database, Event, EventKind, Revision, }; use self::dependency_graph::DependencyGraph; @@ -49,7 +44,12 @@ pub struct Runtime { pub(crate) enum WaitResult { Completed, Panicked, - Cycle(Cycle), +} + +#[derive(Clone, Debug)] +pub(crate) enum BlockResult { + Completed, + Cycle, } #[derive(Copy, Clone, Debug)] @@ -152,8 +152,8 @@ impl Runtime { r_new } - /// Block until `other_id` completes executing `database_key`; - /// panic or unwind in the case of a cycle. + /// Block until `other_id` completes executing `database_key`, or return `BlockResult::Cycle` + /// immediately in case of a cycle. /// /// `query_mutex_guard` is the guard for the current query's state; /// it will be dropped after we have successfully registered the @@ -163,34 +163,19 @@ impl Runtime { /// /// If the thread `other_id` panics, then our thread is considered /// cancelled, so this function will panic with a `Cancelled` value. - /// - /// # Cycle handling - /// - /// If the thread `other_id` already depends on the current thread, - /// and hence there is a cycle in the query graph, then this function - /// will unwind instead of returning normally. The method of unwinding - /// depends on the [`Self::mutual_cycle_recovery_strategy`] - /// of the cycle participants: - /// - /// * [`CycleRecoveryStrategy::Panic`]: panic with the [`Cycle`] as the value. - /// * [`CycleRecoveryStrategy::Fallback`]: initiate unwinding with [`CycleParticipant::unwind`]. - pub(crate) fn block_on_or_unwind( + pub(crate) fn block_on( &self, db: &dyn Database, local_state: &ZalsaLocal, database_key: DatabaseKeyIndex, other_id: ThreadId, query_mutex_guard: QueryMutexGuard, - ) { + ) -> BlockResult { let mut dg = self.dependency_graph.lock(); let thread_id = std::thread::current().id(); if dg.depends_on(other_id, thread_id) { - self.unblock_cycle_and_maybe_throw(db, local_state, &mut dg, database_key, other_id); - - // If the above fn returns, then (via cycle recovery) it has unblocked the - // cycle, so we can continue. - assert!(!dg.depends_on(other_id, thread_id)); + return BlockResult::Cycle; } db.salsa_event(&|| { @@ -214,126 +199,12 @@ impl Runtime { }); match result { - WaitResult::Completed => (), + WaitResult::Completed => BlockResult::Completed, // If the other thread panicked, then we consider this thread // cancelled. The assumption is that the panic will be detected // by the other thread and responded to appropriately. WaitResult::Panicked => Cancelled::PropagatedPanic.throw(), - - WaitResult::Cycle(c) => c.throw(), - } - } - - /// Handles a cycle in the dependency graph that was detected when the - /// current thread tried to block on `database_key_index` which is being - /// executed by `to_id`. If this function returns, then `to_id` no longer - /// depends on the current thread, and so we should continue executing - /// as normal. Otherwise, the function will throw a `Cycle` which is expected - /// to be caught by some frame on our stack. This occurs either if there is - /// a frame on our stack with cycle recovery (possibly the top one!) or if there - /// is no cycle recovery at all. - fn unblock_cycle_and_maybe_throw( - &self, - db: &dyn Database, - local_state: &ZalsaLocal, - dg: &mut DependencyGraph, - database_key_index: DatabaseKeyIndex, - to_id: ThreadId, - ) { - tracing::debug!( - "unblock_cycle_and_maybe_throw(database_key={:?})", - database_key_index - ); - - let (me_recovered, others_recovered, cycle) = local_state.with_query_stack(|from_stack| { - let from_id = std::thread::current().id(); - - // Make a "dummy stack frame". As we iterate through the cycle, we will collect the - // inputs from each participant. Then, if we are participating in cycle recovery, we - // will propagate those results to all participants. - let mut cycle_query = ActiveQuery::new(database_key_index); - - // Identify the cycle participants: - let cycle = { - let mut v = vec![]; - dg.for_each_cycle_participant( - from_id, - from_stack, - database_key_index, - to_id, - |aqs| { - aqs.iter_mut().for_each(|aq| { - cycle_query.add_from(aq); - v.push(aq.database_key_index); - }); - }, - ); - - // We want to give the participants in a deterministic order - // (at least for this execution, not necessarily across executions), - // no matter where it started on the stack. Find the minimum - // key and rotate it to the front. - - if let Some((_, index)) = v - .iter() - .enumerate() - .map(|(idx, key)| (key.ingredient_index.debug_name(db), idx)) - .min() - { - v.rotate_left(index); - } - - Cycle::new(Arc::new(v.into_boxed_slice())) - }; - tracing::debug!("cycle {cycle:?}, cycle_query {cycle_query:#?}"); - - // We can remove the cycle participants from the list of dependencies; - // they are a strongly connected component (SCC) and we only care about - // dependencies to things outside the SCC that control whether it will - // form again. - cycle_query.remove_cycle_participants(&cycle); - - // Mark each cycle participant that has recovery set, along with - // any frames that come after them on the same thread. Those frames - // are going to be unwound so that fallback can occur. - dg.for_each_cycle_participant(from_id, from_stack, database_key_index, to_id, |aqs| { - aqs.iter_mut() - .skip_while(|aq| { - match db - .zalsa() - .lookup_ingredient(aq.database_key_index.ingredient_index) - .cycle_recovery_strategy() - { - CycleRecoveryStrategy::Panic => true, - CycleRecoveryStrategy::Fallback => false, - } - }) - .for_each(|aq| { - tracing::debug!("marking {:?} for fallback", aq.database_key_index); - aq.take_inputs_from(&cycle_query); - assert!(aq.cycle.is_none()); - aq.cycle = Some(cycle.clone()); - }); - }); - - // Unblock every thread that has cycle recovery with a `WaitResult::Cycle`. - // They will throw the cycle, which will be caught by the frame that has - // cycle recovery so that it can execute that recovery. - let (me_recovered, others_recovered) = - dg.maybe_unblock_runtimes_in_cycle(from_id, from_stack, database_key_index, to_id); - (me_recovered, others_recovered, cycle) - }); - - if me_recovered { - // If the current thread has recovery, we want to throw - // so that it can begin. - cycle.throw() - } else if others_recovered { - // If other threads have recovery but we didn't: return and we will block on them. - } else { - // if nobody has recover, then we panic - panic_any(cycle); } } diff --git a/src/runtime/dependency_graph.rs b/src/runtime/dependency_graph.rs index 84c5327fc..c90e650de 100644 --- a/src/runtime/dependency_graph.rs +++ b/src/runtime/dependency_graph.rs @@ -31,7 +31,6 @@ pub(super) struct DependencyGraph { #[derive(Debug)] struct Edge { blocked_on_id: ThreadId, - blocked_on_key: DatabaseKeyIndex, stack: QueryStack, /// Signalled whenever a query with dependents completes. @@ -55,115 +54,6 @@ impl DependencyGraph { p == to_id } - /// Invokes `closure` with a `&mut ActiveQuery` for each query that participates in the cycle. - /// The cycle runs as follows: - /// - /// 1. The runtime `from_id`, which has the stack `from_stack`, would like to invoke `database_key`... - /// 2. ...but `database_key` is already being executed by `to_id`... - /// 3. ...and `to_id` is transitively dependent on something which is present on `from_stack`. - pub(super) fn for_each_cycle_participant( - &mut self, - from_id: ThreadId, - from_stack: &mut QueryStack, - database_key: DatabaseKeyIndex, - to_id: ThreadId, - mut closure: impl FnMut(&mut [ActiveQuery]), - ) { - debug_assert!(self.depends_on(to_id, from_id)); - - // To understand this algorithm, consider this [drawing](https://is.gd/TGLI9v): - // - // database_key = QB2 - // from_id = A - // to_id = B - // from_stack = [QA1, QA2, QA3] - // - // self.edges[B] = { C, QC2, [QB1..QB3] } - // self.edges[C] = { A, QA2, [QC1..QC3] } - // - // The cyclic - // edge we have - // failed to add. - // : - // A : B C - // : - // QA1 v QB1 QC1 - // ┌► QA2 ┌──► QB2 ┌─► QC2 - // │ QA3 ───┘ QB3 ──┘ QC3 ───┐ - // │ │ - // └───────────────────────────────┘ - // - // Final output: [QB2, QB3, QC2, QC3, QA2, QA3] - - let mut id = to_id; - let mut key = database_key; - while id != from_id { - // Looking at the diagram above, the idea is to - // take the edge from `to_id` starting at `key` - // (inclusive) and down to the end. We can then - // load up the next thread (i.e., we start at B/QB2, - // and then load up the dependency on C/QC2). - let edge = self.edges.get_mut(&id).unwrap(); - closure(strip_prefix_query_stack_mut(&mut edge.stack, key)); - id = edge.blocked_on_id; - key = edge.blocked_on_key; - } - - // Finally, we copy in the results from `from_stack`. - closure(strip_prefix_query_stack_mut(from_stack, key)); - } - - /// Unblock each blocked runtime (excluding the current one) if some - /// query executing in that runtime is participating in cycle fallback. - /// - /// Returns a boolean (Current, Others) where: - /// * Current is true if the current runtime has cycle participants - /// with fallback; - /// * Others is true if other runtimes were unblocked. - pub(super) fn maybe_unblock_runtimes_in_cycle( - &mut self, - from_id: ThreadId, - from_stack: &QueryStack, - database_key: DatabaseKeyIndex, - to_id: ThreadId, - ) -> (bool, bool) { - // See diagram in `for_each_cycle_participant`. - let mut id = to_id; - let mut key = database_key; - let mut others_unblocked = false; - while id != from_id { - let edge = self.edges.get(&id).unwrap(); - let next_id = edge.blocked_on_id; - let next_key = edge.blocked_on_key; - - if let Some(cycle) = strip_prefix_query_stack(&edge.stack, key) - .iter() - .rev() - .find_map(|aq| aq.cycle.clone()) - { - // Remove `id` from the list of runtimes blocked on `next_key`: - self.query_dependents - .get_mut(&next_key) - .unwrap() - .retain(|r| *r != id); - - // Unblock runtime so that it can resume execution once lock is released: - self.unblock_runtime(id, WaitResult::Cycle(cycle)); - - others_unblocked = true; - } - - id = next_id; - key = next_key; - } - - let this_unblocked = strip_prefix_query_stack(from_stack, key) - .iter() - .any(|aq| aq.cycle.is_some()); - - (this_unblocked, others_unblocked) - } - /// Modifies the graph so that `from_id` is blocked /// on `database_key`, which is being computed by /// `to_id`. @@ -219,7 +109,6 @@ impl DependencyGraph { from_id, Edge { blocked_on_id: to_id, - blocked_on_key: database_key, stack: from_stack, condvar: condvar.clone(), }, @@ -260,22 +149,3 @@ impl DependencyGraph { edge.condvar.notify_one(); } } - -fn strip_prefix_query_stack(stack_mut: &[ActiveQuery], key: DatabaseKeyIndex) -> &[ActiveQuery] { - let prefix = stack_mut - .iter() - .take_while(|p| p.database_key_index != key) - .count(); - &stack_mut[prefix..] -} - -fn strip_prefix_query_stack_mut( - stack_mut: &mut [ActiveQuery], - key: DatabaseKeyIndex, -) -> &mut [ActiveQuery] { - let prefix = stack_mut - .iter() - .take_while(|p| p.database_key_index != key) - .count(); - &mut stack_mut[prefix..] -} diff --git a/src/table/sync.rs b/src/table/sync.rs index dfe78a23a..685b39af5 100644 --- a/src/table/sync.rs +++ b/src/table/sync.rs @@ -7,7 +7,7 @@ use parking_lot::RwLock; use crate::{ key::DatabaseKeyIndex, - runtime::WaitResult, + runtime::{BlockResult, WaitResult}, zalsa::{MemoIngredientIndex, Zalsa}, zalsa_local::ZalsaLocal, Database, @@ -30,6 +30,12 @@ struct SyncState { anyone_waiting: AtomicBool, } +pub(crate) enum ClaimResult<'a> { + Retry, + Cycle, + Claimed(ClaimGuard<'a>), +} + impl SyncTable { pub(crate) fn claim<'me>( &'me self, @@ -37,7 +43,7 @@ impl SyncTable { zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, memo_ingredient_index: MemoIngredientIndex, - ) -> Option> { + ) -> ClaimResult<'me> { let mut syncs = self.syncs.write(); let zalsa = db.zalsa(); let thread_id = std::thread::current().id(); @@ -50,11 +56,12 @@ impl SyncTable { id: thread_id, anyone_waiting: AtomicBool::new(false), }); - Some(ClaimGuard { + ClaimResult::Claimed(ClaimGuard { database_key_index, memo_ingredient_index, zalsa, sync_table: self, + _padding: false, }) } Some(SyncState { @@ -68,8 +75,10 @@ impl SyncTable { // boolean is to decide *whether* to acquire the lock, // not to gate future atomic reads. anyone_waiting.store(true, Ordering::Relaxed); - zalsa.block_on_or_unwind(db, zalsa_local, database_key_index, *other_id, syncs); - None + match zalsa.block_on(db, zalsa_local, database_key_index, *other_id, syncs) { + BlockResult::Completed => ClaimResult::Retry, + BlockResult::Cycle => ClaimResult::Cycle, + } } } } @@ -83,6 +92,9 @@ pub(crate) struct ClaimGuard<'me> { memo_ingredient_index: MemoIngredientIndex, zalsa: &'me Zalsa, sync_table: &'me SyncTable, + // Reduce the size of ClaimResult by making more niches available in ClaimGuard; this fits into + // the padding of ClaimGuard so doesn't increase its size. + _padding: bool, } impl ClaimGuard<'_> { @@ -93,7 +105,7 @@ impl ClaimGuard<'_> { syncs[self.memo_ingredient_index.as_usize()].take().unwrap(); // NB: `Ordering::Relaxed` is sufficient here, - // see `store` above for explanation. + // see `claim` above for explanation. if anyone_waiting.load(Ordering::Relaxed) { self.zalsa .unblock_queries_blocked_on(self.database_key_index, wait_result) diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index cff5e2ee8..896b4482a 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -5,8 +5,9 @@ use tracked_field::FieldIngredientImpl; use crate::{ accumulator::accumulated_map::InputAccumulatedValues, - cycle::CycleRecoveryStrategy, - ingredient::{fmt_index, Ingredient, Jar, JarAux, MaybeChangedAfter}, + cycle::{CycleRecoveryStrategy, EMPTY_CYCLE_HEADS}, + function::VerifyResult, + ingredient::{fmt_index, Ingredient, Jar, JarAux}, key::{DatabaseKeyIndex, InputDependencyIndex}, plumbing::ZalsaLocal, revision::OptionalAtomicRevision, @@ -658,6 +659,7 @@ where data.durability, field_changed_at, InputAccumulatedValues::Empty, + &EMPTY_CYCLE_HEADS, ); unsafe { self.to_self_ref(&data.fields) } @@ -709,8 +711,16 @@ where _db: &dyn Database, _input: Id, _revision: Revision, - ) -> MaybeChangedAfter { - MaybeChangedAfter::No(InputAccumulatedValues::Empty) + ) -> VerifyResult { + VerifyResult::unchanged() + } + + fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + + fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { + true } fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index e3d1e07a0..40104104c 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -1,10 +1,8 @@ use std::marker::PhantomData; use crate::{ - ingredient::{Ingredient, MaybeChangedAfter}, - table::Table, - zalsa::IngredientIndex, - Database, Id, + function::VerifyResult, ingredient::Ingredient, table::Table, zalsa::IngredientIndex, Database, + Id, }; use super::{Configuration, Value}; @@ -61,11 +59,19 @@ where db: &'db dyn Database, input: Id, revision: crate::Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { let zalsa = db.zalsa(); let data = >::data(zalsa.table(), input); let field_changed_at = data.revisions[self.field_index]; - MaybeChangedAfter::from(field_changed_at > revision) + VerifyResult::changed_if(field_changed_at > revision) + } + + fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + + fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { + true } fn origin( diff --git a/src/zalsa.rs b/src/zalsa.rs index 6c9f7c20c..07dc863f6 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -5,10 +5,9 @@ use std::any::{Any, TypeId}; use std::marker::PhantomData; use std::thread::ThreadId; -use crate::cycle::CycleRecoveryStrategy; use crate::ingredient::{Ingredient, Jar, JarAux}; use crate::nonce::{Nonce, NonceGenerator}; -use crate::runtime::{Runtime, WaitResult}; +use crate::runtime::{BlockResult, Runtime, WaitResult}; use crate::table::memo::MemoTable; use crate::table::sync::SyncTable; use crate::table::Table; @@ -86,18 +85,9 @@ impl IngredientIndex { self.0 as usize } - pub(crate) fn cycle_recovery_strategy(self, db: &dyn Database) -> CycleRecoveryStrategy { - db.zalsa().lookup_ingredient(self).cycle_recovery_strategy() - } - pub fn successor(self, index: usize) -> Self { IngredientIndex(self.0 + 1 + index as u32) } - - /// Return the "debug name" of this ingredient (e.g., the name of the tracked struct it represents) - pub(crate) fn debug_name(self, db: &dyn Database) -> &'static str { - db.zalsa().lookup_ingredient(self).debug_name() - } } /// A special secondary index *just* for ingredients that attach @@ -279,16 +269,16 @@ impl Zalsa { } /// See [`Runtime::block_on_or_unwind`][] - pub(crate) fn block_on_or_unwind( + pub(crate) fn block_on( &self, db: &dyn Database, local_state: &ZalsaLocal, database_key: DatabaseKeyIndex, other_id: ThreadId, query_mutex_guard: QueryMutexGuard, - ) { + ) -> BlockResult { self.runtime - .block_on_or_unwind(db, local_state, database_key, other_id, query_mutex_guard) + .block_on(db, local_state, database_key, other_id, query_mutex_guard) } /// See [`Runtime::unblock_queries_blocked_on`][] diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 25693bd54..3e171850f 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -1,10 +1,11 @@ -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; use tracing::debug; use crate::accumulator::accumulated_map::{ AccumulatedMap, AtomicInputAccumulatedValues, InputAccumulatedValues, }; use crate::active_query::ActiveQuery; +use crate::cycle::CycleHeads; use crate::durability::Durability; use crate::key::{DatabaseKeyIndex, InputDependencyIndex, OutputDependencyIndex}; use crate::runtime::StampedValue; @@ -15,7 +16,6 @@ use crate::tracked_struct::{Disambiguator, Identity, IdentityHash, IdentityMap}; use crate::zalsa::IngredientIndex; use crate::Accumulator; use crate::Cancelled; -use crate::Cycle; use crate::Database; use crate::Event; use crate::EventKind; @@ -163,6 +163,7 @@ impl ZalsaLocal { durability: Durability, changed_at: Revision, accumulated: InputAccumulatedValues, + cycle_heads: &CycleHeads, ) { debug!( "report_tracked_read(input={:?}, durability={:?}, changed_at={:?})", @@ -170,32 +171,7 @@ impl ZalsaLocal { ); self.with_query_stack(|stack| { if let Some(top_query) = stack.last_mut() { - top_query.add_read(input, durability, changed_at, accumulated); - - // We are a cycle participant: - // - // C0 --> ... --> Ci --> Ci+1 -> ... -> Cn --> C0 - // ^ ^ - // : | - // This edge -----+ | - // | - // | - // N0 - // - // In this case, the value we have just read from `Ci+1` - // is actually the cycle fallback value and not especially - // interesting. We unwind now with `CycleParticipant` to avoid - // executing the rest of our query function. This unwinding - // will be caught and our own fallback value will be used. - // - // Note that `Ci+1` may` have *other* callers who are not - // participants in the cycle (e.g., N0 in the graph above). - // They will not have the `cycle` marker set in their - // stack frames, so they will just read the fallback value - // from `Ci+1` and continue on their merry way. - if let Some(cycle) = &top_query.cycle { - cycle.clone().throw() - } + top_query.add_read(input, durability, changed_at, accumulated, cycle_heads); } }) } @@ -340,12 +316,36 @@ pub(crate) struct QueryRevisions { pub(super) tracked_struct_ids: IdentityMap, pub(super) accumulated: Option>, + /// [`InputAccumulatedValues::Empty`] if any input read during the query's execution /// has any direct or indirect accumulated values. pub(super) accumulated_inputs: AtomicInputAccumulatedValues, + + /// This result was computed based on provisional values from + /// these cycle heads. The "cycle head" is the query responsible + /// for managing a fixpoint iteration. In a cycle like + /// `--> A --> B --> C --> A`, the cycle head is query `A`: it is + /// the query whose value is requested while it is executing, + /// which must provide the initial provisional value and decide, + /// after each iteration, whether the cycle has converged or must + /// iterate again. + pub(super) cycle_heads: CycleHeads, } impl QueryRevisions { + pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex, revision: Revision) -> Self { + let cycle_heads = FxHashSet::from_iter([query]).into(); + Self { + changed_at: revision, + durability: Durability::MAX, + origin: QueryOrigin::FixpointInitial, + tracked_struct_ids: Default::default(), + accumulated: Default::default(), + accumulated_inputs: Default::default(), + cycle_heads, + } + } + pub(crate) fn stamped_value(&self, value: V) -> StampedValue { self.stamp_template().stamp(value) } @@ -394,6 +394,9 @@ pub enum QueryOrigin { /// The [`QueryEdges`] argument contains a listing of all the inputs we saw /// (but we know there were more). DerivedUntracked(QueryEdges), + + /// The value is an initial provisional value for a query that supports fixpoint iteration. + FixpointInitial, } impl QueryOrigin { @@ -401,7 +404,9 @@ impl QueryOrigin { pub(crate) fn inputs(&self) -> impl DoubleEndedIterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), - QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None, + QueryOrigin::Assigned(_) | QueryOrigin::BaseInput | QueryOrigin::FixpointInitial => { + None + } }; opt_edges.into_iter().flat_map(|edges| edges.inputs()) } @@ -410,7 +415,9 @@ impl QueryOrigin { pub(crate) fn outputs(&self) -> impl DoubleEndedIterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), - QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None, + QueryOrigin::Assigned(_) | QueryOrigin::BaseInput | QueryOrigin::FixpointInitial => { + None + } }; opt_edges.into_iter().flat_map(|edges| edges.outputs()) } @@ -522,18 +529,8 @@ impl ActiveQueryGuard<'_> { // Extract accumulated inputs. let popped_query = self.complete(); - // If this frame were a cycle participant, it would have unwound. - assert!(popped_query.cycle.is_none()); - popped_query.into_revisions() } - - /// If the active query is registered as a cycle participant, remove and - /// return that cycle. - pub(crate) fn take_cycle(&self) -> Option { - self.local_state - .with_query_stack(|stack| stack.last_mut()?.cycle.take()) - } } impl Drop for ActiveQueryGuard<'_> { diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 19f818b65..75d22073e 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -141,3 +141,41 @@ impl HasLogger for ExecuteValidateLoggerDatabase { &self.logger } } + +/// Trait implemented by databases that lets them provide a fixed u32 value. +pub trait HasValue { + fn get_value(&self) -> u32; +} + +#[salsa::db] +pub trait ValueDatabase: HasValue + Database {} + +#[salsa::db] +impl ValueDatabase for Db {} + +#[salsa::db] +#[derive(Clone, Default)] +pub struct DatabaseWithValue { + storage: Storage, + value: u32, +} + +impl HasValue for DatabaseWithValue { + fn get_value(&self) -> u32 { + self.value + } +} + +#[salsa::db] +impl Database for DatabaseWithValue { + fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) {} +} + +impl DatabaseWithValue { + pub fn new(value: u32) -> Self { + Self { + storage: Default::default(), + value, + } + } +} diff --git a/tests/compile-fail/get-set-on-private-input-field.rs b/tests/compile-fail/get-set-on-private-input-field.rs index 5ecec5836..345590b75 100644 --- a/tests/compile-fail/get-set-on-private-input-field.rs +++ b/tests/compile-fail/get-set-on-private-input-field.rs @@ -1,5 +1,3 @@ -use salsa::prelude::*; - mod a { #[salsa::input] pub struct MyInput { diff --git a/tests/compile-fail/get-set-on-private-input-field.stderr b/tests/compile-fail/get-set-on-private-input-field.stderr index b8dcca66d..40acd8c2d 100644 --- a/tests/compile-fail/get-set-on-private-input-field.stderr +++ b/tests/compile-fail/get-set-on-private-input-field.stderr @@ -1,17 +1,17 @@ error[E0624]: method `field` is private - --> tests/compile-fail/get-set-on-private-input-field.rs:14:11 + --> tests/compile-fail/get-set-on-private-input-field.rs:12:11 | -4 | #[salsa::input] +2 | #[salsa::input] | --------------- private method defined here ... -14 | input.field(&db); +12 | input.field(&db); | ^^^^^ private method error[E0624]: method `set_field` is private - --> tests/compile-fail/get-set-on-private-input-field.rs:15:11 + --> tests/compile-fail/get-set-on-private-input-field.rs:13:11 | -4 | #[salsa::input] +2 | #[salsa::input] | --------------- private method defined here ... -15 | input.set_field(&mut db).to(23); +13 | input.set_field(&mut db).to(23); | ^^^^^^^^^ private method diff --git a/tests/cycle.rs b/tests/cycle.rs new file mode 100644 index 000000000..fe5875fc0 --- /dev/null +++ b/tests/cycle.rs @@ -0,0 +1,1004 @@ +//! Test cases for fixpoint iteration cycle resolution. +//! +//! These test cases use a generic query setup that allows constructing arbitrary dependency +//! graphs, and attempts to achieve good coverage of various cases. +mod common; +use common::{ExecuteValidateLoggerDatabase, LogDatabase}; +use expect_test::expect; +use salsa::{CycleRecoveryAction, Database as Db, DatabaseImpl as DbImpl, Durability, Setter}; +use test_log::test; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, salsa::Update)] +enum Value { + N(u8), + OutOfBounds, + TooManyIterations, +} + +impl Value { + fn to_value(self) -> Option { + if let Self::N(val) = self { + Some(val) + } else { + None + } + } +} + +/// A vector of inputs a query can evaluate to get an iterator of values to operate on. +/// +/// This allows creating arbitrary query graphs between the four queries below (`min_iterate`, +/// `max_iterate`, `min_panic`, `max_panic`) for testing cycle behaviors. +#[salsa::input] +struct Inputs { + inputs: Vec, +} + +impl Inputs { + fn values(self, db: &dyn Db) -> impl Iterator + '_ { + self.inputs(db).into_iter().map(|input| input.eval(db)) + } +} + +/// A single input, evaluating to a single [`Value`]. +#[derive(Clone, Debug)] +enum Input { + /// a simple value + Value(Value), + + /// a simple value, reported as an untracked read + UntrackedRead(Value), + + /// minimum of the given inputs, with fixpoint iteration on cycles + MinIterate(Inputs), + + /// maximum of the given inputs, with fixpoint iteration on cycles + MaxIterate(Inputs), + + /// minimum of the given inputs, panicking on cycles + MinPanic(Inputs), + + /// maximum of the given inputs, panicking on cycles + MaxPanic(Inputs), + + /// value of the given input, plus one; propagates error values + Successor(Box), + + /// successor, converts error values to zero + SuccessorOrZero(Box), +} + +impl Input { + fn eval(self, db: &dyn Db) -> Value { + match self { + Self::Value(value) => value, + Self::UntrackedRead(value) => { + db.report_untracked_read(); + value + } + Self::MinIterate(inputs) => min_iterate(db, inputs), + Self::MaxIterate(inputs) => max_iterate(db, inputs), + Self::MinPanic(inputs) => min_panic(db, inputs), + Self::MaxPanic(inputs) => max_panic(db, inputs), + Self::Successor(input) => match input.eval(db) { + Value::N(num) => Value::N(num + 1), + other => other, + }, + Self::SuccessorOrZero(input) => match input.eval(db) { + Value::N(num) => Value::N(num + 1), + _ => Value::N(0), + }, + } + } + + fn assert(self, db: &dyn Db, expected: Value) { + assert_eq!(self.eval(db), expected) + } + + fn assert_value(self, db: &dyn Db, expected: u8) { + self.assert(db, Value::N(expected)) + } + + fn assert_bounds(self, db: &dyn Db) { + self.assert(db, Value::OutOfBounds) + } + + fn assert_count(self, db: &dyn Db) { + self.assert(db, Value::TooManyIterations) + } +} + +const MIN_VALUE: u8 = 10; +const MAX_VALUE: u8 = 245; +const MAX_ITERATIONS: u32 = 3; + +/// Recover from a cycle by falling back to `Value::OutOfBounds` if the value is out of bounds, +/// `Value::TooManyIterations` if we've iterated more than `MAX_ITERATIONS` times, or else +/// iterating again. +fn cycle_recover( + _db: &dyn Db, + value: &Value, + count: u32, + _inputs: Inputs, +) -> CycleRecoveryAction { + if value + .to_value() + .is_some_and(|val| val <= MIN_VALUE || val >= MAX_VALUE) + { + CycleRecoveryAction::Fallback(Value::OutOfBounds) + } else if count > MAX_ITERATIONS { + CycleRecoveryAction::Fallback(Value::TooManyIterations) + } else { + CycleRecoveryAction::Iterate + } +} + +/// Fold an iterator of `Value` into a `Value`, given some binary operator to apply to two `u8`. +/// `Value::TooManyIterations` and `Value::OutOfBounds` will always propagate, with +/// `Value::TooManyIterations` taking precedence. +fn fold_values(values: impl IntoIterator, op: F) -> Value +where + F: Fn(u8, u8) -> u8, +{ + values + .into_iter() + .fold(None, |accum, elem| { + let Some(accum) = accum else { + return Some(elem); + }; + match (accum, elem) { + (Value::TooManyIterations, _) | (_, Value::TooManyIterations) => { + Some(Value::TooManyIterations) + } + (Value::OutOfBounds, _) | (_, Value::OutOfBounds) => Some(Value::OutOfBounds), + (Value::N(val1), Value::N(val2)) => Some(Value::N(op(val1, val2))), + } + }) + .expect("inputs should not be empty") +} + +/// Query minimum value of inputs, with cycle recovery. +#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=min_initial)] +fn min_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { + fold_values(inputs.values(db), u8::min) +} + +fn min_initial(_db: &dyn Db, _inputs: Inputs) -> Value { + Value::N(255) +} + +/// Query maximum value of inputs, with cycle recovery. +#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=max_initial)] +fn max_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { + fold_values(inputs.values(db), u8::max) +} + +fn max_initial(_db: &dyn Db, _inputs: Inputs) -> Value { + Value::N(0) +} + +/// Query minimum value of inputs, without cycle recovery. +#[salsa::tracked] +fn min_panic<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { + fold_values(inputs.values(db), u8::min) +} + +/// Query maximum value of inputs, without cycle recovery. +#[salsa::tracked] +fn max_panic<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { + fold_values(inputs.values(db), u8::max) +} + +fn untracked(num: u8) -> Input { + Input::UntrackedRead(Value::N(num)) +} + +fn value(num: u8) -> Input { + Input::Value(Value::N(num)) +} + +// Diagram nomenclature for nodes: Each node is represented as a:xx(ii), where `a` is a sequential +// identifier from `a`, `b`, `c`..., xx is one of the four query kinds: +// - `Ni` for `min_iterate` +// - `Xi` for `max_iterate` +// - `Np` for `min_panic` +// - `Xp` for `max_panic` +//\ +// and `ii` is the inputs for that query, represented as a comma-separated list, with each +// component representing an input: +// - `a`, `b`, `c`... where the input is another node, +// - `uXX` for `UntrackedRead(XX)` +// - `vXX` for `Value(XX)` +// - `sY` for `Successor(Y)` +// - `zY` for `SuccessorOrZero(Y)` +// +// We always enter from the top left node in the diagram. + +/// a:Np(a) -+ +/// ^ | +/// +--------+ +/// +/// Simple self-cycle, no iteration, should panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn self_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + a_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.eval(&db); +} + +/// a:Np(u10, a) -+ +/// ^ | +/// +-------------+ +/// +/// Simple self-cycle with untracked read, no iteration, should panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn self_untracked_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + a_in.set_inputs(&mut db).to(vec![untracked(10), a.clone()]); + + a.eval(&db); +} + +/// a:Ni(a) -+ +/// ^ | +/// +--------+ +/// +/// Simple self-cycle, iteration converges on initial value. +#[test] +fn self_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + a_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert_value(&db, 255); +} + +/// a:Ni(b) --> b:Np(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, one with iteration and one without. +/// If we enter from the one with iteration, we converge on its initial value. +#[test] +fn two_mixed_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert_value(&db, 255); +} + +/// a:Np(b) --> b:Ni(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, one with iteration and one without. +/// If we enter from the one with no iteration, we panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn two_mixed_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(b_in); + let b = Input::MinIterate(a_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.eval(&db); +} + +/// a:Ni(b) --> b:Xi(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, both with iteration. +/// We converge on the initial value of whichever we first enter from. +#[test] +fn two_iterate_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MaxIterate(b_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert_value(&db, 255); + b.assert_value(&db, 255); +} + +/// a:Xi(b) --> b:Ni(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, both with iteration. +/// We converge on the initial value of whichever we enter from. +/// (Same setup as above test, different query order.) +#[test] +fn two_iterate_converge_initial_value_2() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MinIterate(b_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert_value(&db, 0); + b.assert_value(&db, 0); +} + +/// a:Np(b) --> b:Ni(c) --> c:Xp(b) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, enter indirectly at node with iteration, converge on its initial value. +#[test] +fn two_indirect_iterate_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![b]); + + a.assert_value(&db, 255); +} + +/// a:Xp(b) --> b:Np(c) --> c:Xi(b) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, enter indirectly at node without iteration, panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn two_indirect_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinPanic(b_in); + let c = Input::MaxIterate(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![b]); + + a.eval(&db); +} + +/// a:Np(b) -> b:Ni(v200,c) -> c:Xp(b) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, converges to non-initial value. +#[test] +fn two_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(200), c]); + c_in.set_inputs(&mut db).to(vec![b]); + + a.assert_value(&db, 200); +} + +/// a:Xp(b) -> b:Xi(v20,c) -> c:Xp(sb) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, falls back due to >3 iterations. +#[test] +fn two_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxPanic(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(20), c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Successor(Box::new(b))]); + + a.assert_count(&db); +} + +/// a:Xp(b) -> b:Xi(v20,c) -> c:Xp(zb) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, falls back but fallback does not converge. +#[test] +#[should_panic(expected = "fallback did not converge")] +fn two_fallback_diverge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxPanic(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(20), c.clone()]); + c_in.set_inputs(&mut db) + .to(vec![Input::SuccessorOrZero(Box::new(b))]); + + a.assert_count(&db); +} + +/// a:Xp(b) -> b:Xi(v244,c) -> c:Xp(sb) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, falls back due to value reaching >MAX_VALUE (we start at 244 and each +/// iteration increments until we reach >245). +#[test] +fn two_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxPanic(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(244), c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Successor(Box::new(b))]); + + a.assert_bounds(&db); +} + +/// a:Ni(b) -> b:Np(a, c) -> c:Np(v25, a) +/// ^ | | +/// +----------+------------------------+ +/// +/// Three-query cycle, (b) and (c) both depend on (a). We converge on 25. +#[test] +fn three_fork_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db).to(vec![value(25), a.clone()]); + + a.assert_value(&db, 25); +} + +/// a:Ni(b) -> b:Ni(a, c) -> c:Np(v25, b) +/// ^ | ^ | +/// +----------+ +----------+ +/// +/// Layered cycles. We converge on 25. +#[test] +fn layered_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db).to(vec![value(25), b]); + + a.assert_value(&db, 25); +} + +/// a:Xi(b) -> b:Xi(a, c) -> c:Xp(v25, sb) +/// ^ | ^ | +/// +----------+ +----------+ +/// +/// Layered cycles. We hit max iterations and fall back. +#[test] +fn layered_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db) + .to(vec![value(25), Input::Successor(Box::new(b))]); + a.assert_count(&db); +} + +/// a:Xi(b) -> b:Xi(a, c) -> c:Xp(v243, sb) +/// ^ | ^ | +/// +----------+ +----------+ +/// +/// Layered cycles. We hit max value and fall back. +#[test] +fn layered_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db) + .to(vec![value(243), Input::Successor(Box::new(b))]); + + a.assert_bounds(&db); +} + +/// a:Ni(b) -> b:Ni(c) -> c:Np(v25, a, b) +/// ^ ^ | +/// +----------+------------------------+ +/// +/// Nested cycles. We converge on 25. +#[test] +fn nested_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![value(25), a.clone(), b]); + + a.assert_value(&db, 25); +} + +/// a:Ni(b) -> b:Ni(c) -> c:Np(v25, b, a) +/// ^ ^ | +/// +----------+------------------------+ +/// +/// Nested cycles, inner first. We converge on 25. +#[test] +fn nested_inner_first_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![value(25), b, a.clone()]); + + a.assert_value(&db, 25); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v25, a, sb) +/// ^ ^ | +/// +----------+-------------------------+ +/// +/// Nested cycles. We hit max iterations and fall back. +#[test] +fn nested_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db) + .to(vec![value(25), a.clone(), Input::Successor(Box::new(b))]); + + a.assert_count(&db); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v25, b, sa) +/// ^ ^ | +/// +----------+-------------------------+ +/// +/// Nested cycles, inner first. We hit max iterations and fall back. +#[test] +fn nested_inner_first_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db) + .to(vec![value(25), b, Input::Successor(Box::new(a.clone()))]); + + a.assert_count(&db); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v243, a, sb) +/// ^ ^ | +/// +----------+--------------------------+ +/// +/// Nested cycles. We hit max value and fall back. +#[test] +fn nested_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c.clone()]); + c_in.set_inputs(&mut db).to(vec![ + value(243), + a.clone(), + Input::Successor(Box::new(b.clone())), + ]); + a.assert_bounds(&db); + b.assert_bounds(&db); + c.assert_bounds(&db); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v243, b, sa) +/// ^ ^ | +/// +----------+--------------------------+ +/// +/// Nested cycles, inner first. We hit max value and fall back. +#[test] +fn nested_inner_first_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db) + .to(vec![value(243), b, Input::Successor(Box::new(a.clone()))]); + + a.assert_bounds(&db); +} + +/// a:Ni(b) -> b:Ni(c, a) -> c:Np(v25, a, b) +/// ^ ^ | | +/// +----------+--------|------------------+ +/// | | +/// +-------------------+ +/// +/// Nested cycles, double head. We converge on 25. +#[test] +fn nested_double_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c, a.clone()]); + c_in.set_inputs(&mut db).to(vec![value(25), a.clone(), b]); + + a.assert_value(&db, 25); +} + +// Multiple-revision cycles + +/// a:Ni(b) --> b:Np(a) +/// ^ | +/// +-----------------+ +/// +/// a:Ni(b) --> b:Np(v30) +/// +/// Cycle becomes not-a-cycle in next revision. +#[test] +fn cycle_becomes_non_cycle() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.clone().assert_value(&db, 255); + + b_in.set_inputs(&mut db).to(vec![value(30)]); + + a.assert_value(&db, 30); +} + +/// a:Ni(b) --> b:Np(v30) +/// +/// a:Ni(b) --> b:Np(a) +/// ^ | +/// +-----------------+ +/// +/// Non-cycle becomes a cycle in next revision. +#[test] +fn non_cycle_becomes_cycle() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![value(30)]); + + a.clone().assert_value(&db, 30); + + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert_value(&db, 255); +} + +/// a:Xi(b) -> b:Xi(c, a) -> c:Xp(v25, a, sb) +/// ^ ^ | | +/// +----------+--------|-------------------+ +/// | | +/// +-------------------+ +/// +/// Nested cycles, double head. We hit max iterations and fall back, then max value on the next +/// revision, then converge on the next. +#[test] +fn nested_double_multiple_revisions() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c, a.clone()]); + c_in.set_inputs(&mut db).to(vec![ + value(25), + a.clone(), + Input::Successor(Box::new(b.clone())), + ]); + + a.clone().assert_count(&db); + + // next revision, we hit max value instead + c_in.set_inputs(&mut db).to(vec![ + value(243), + a.clone(), + Input::Successor(Box::new(b.clone())), + ]); + + a.clone().assert_bounds(&db); + + // and next revision, we converge + c_in.set_inputs(&mut db) + .to(vec![value(240), a.clone(), b.clone()]); + + a.clone().assert_value(&db, 240); + + // one more revision, without relevant changes + a_in.set_inputs(&mut db).to(vec![b]); + + a.assert_value(&db, 240); +} + +/// a:Ni(b) -> b:Ni(c) -> c:Ni(a) +/// ^ | +/// +---------------------------+ +/// +/// In a cycle with some LOW durability and some HIGH durability inputs, changing a LOW durability +/// input still re-executes the full cycle in the next revision. +#[test] +fn cycle_durability() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinIterate(c_in); + a_in.set_inputs(&mut db) + .with_durability(Durability::LOW) + .to(vec![b.clone()]); + b_in.set_inputs(&mut db) + .with_durability(Durability::HIGH) + .to(vec![c]); + c_in.set_inputs(&mut db) + .with_durability(Durability::HIGH) + .to(vec![a.clone()]); + + a.clone().assert_value(&db, 255); + + // next revision, we converge instead + a_in.set_inputs(&mut db) + .with_durability(Durability::LOW) + .to(vec![value(45), b]); + + a.assert_value(&db, 45); +} + +/// a:Np(v59, b) -> b:Ni(v60, c) -> c:Np(b) +/// ^ | +/// +---------------------+ +/// +/// If nothing in a cycle changed in the new revision, no part of the cycle should re-execute. +#[test] +fn cycle_unchanged() { + let mut db = ExecuteValidateLoggerDatabase::default(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![value(59), b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(60), c]); + c_in.set_inputs(&mut db).to(vec![b.clone()]); + + a.clone().assert_value(&db, 59); + b.clone().assert_value(&db, 60); + + db.assert_logs_len(4); + + // next revision, we change only A, which is not part of the cycle and the cycle does not + // depend on. + a_in.set_inputs(&mut db).to(vec![value(45), b.clone()]); + b.assert_value(&db, 60); + + db.assert_logs(expect![[r#" + [ + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + ]"#]]); + + a.assert_value(&db, 45); +} + +/// a:Np(v59, b) -> b:Ni(v60, c) -> c:Np(d) -> d:Ni(v61, b, e) -> e:Np(d) +/// ^ | ^ | +/// +--------------------------+ +--------------+ +/// +/// If nothing in a nested cycle changed in the new revision, no part of the cycle should +/// re-execute. +#[test] +fn cycle_unchanged_nested() { + let mut db = ExecuteValidateLoggerDatabase::default(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let d_in = Inputs::new(&db, vec![]); + let e_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + let d = Input::MinIterate(d_in); + let e = Input::MinPanic(e_in); + a_in.set_inputs(&mut db).to(vec![value(59), b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(60), c.clone()]); + c_in.set_inputs(&mut db).to(vec![d.clone()]); + d_in.set_inputs(&mut db) + .to(vec![value(61), b.clone(), e.clone()]); + e_in.set_inputs(&mut db).to(vec![d.clone()]); + + a.clone().assert_value(&db, 59); + b.clone().assert_value(&db, 60); + + db.assert_logs_len(10); + + // next revision, we change only A, which is not part of the cycle and the cycle does not + // depend on. + a_in.set_inputs(&mut db).to(vec![value(45), b.clone()]); + b.assert_value(&db, 60); + + db.assert_logs(expect![[r#" + [ + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(4)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + ]"#]]); + + a.assert_value(&db, 45); +} + +/// +--------------------------------+ +/// | v +/// a:Np(v59, b) -> b:Ni(v60, c) -> c:Np(d, e) -> d:Ni(v61, b, e) -> e:Ni(d) +/// ^ | ^ | +/// +-----------------------------+ +--------------+ +/// +/// If nothing in a nested cycle changed in the new revision, no part of the cycle should +/// re-execute. +#[test_log::test] +fn cycle_unchanged_nested_intertwined() { + // We run this test twice in order to catch some subtly different cases; see below. + for i in 0..1 { + let mut db = ExecuteValidateLoggerDatabase::default(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let d_in = Inputs::new(&db, vec![]); + let e_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + let d = Input::MinIterate(d_in); + let e = Input::MinIterate(e_in); + a_in.set_inputs(&mut db).to(vec![value(59), b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(60), c.clone()]); + c_in.set_inputs(&mut db).to(vec![d.clone(), e.clone()]); + d_in.set_inputs(&mut db) + .to(vec![value(61), b.clone(), e.clone()]); + e_in.set_inputs(&mut db).to(vec![d.clone()]); + + a.clone().assert_value(&db, 59); + b.clone().assert_value(&db, 60); + + // First time we run this test, don't fetch c/d/e here; this means they won't get marked + // `verified_final` in R6 (this revision), which will leave us in the next revision (R7) + // with a chain of could-be-provisional memos from the previous revision which should be + // final but were never confirmed as such; this triggers the case in `deep_verify_memo` + // where we need to double-check `validate_provisional` after traversing dependencies. + // + // Second time we run this test, fetch everything in R6, to check the behavior of + // `maybe_changed_after` with all validated-final memos. + if i == 1 { + c.clone().assert_value(&db, 60); + d.clone().assert_value(&db, 60); + e.clone().assert_value(&db, 60); + } + + db.assert_logs_len(16 + i); + + // next revision, we change only A, which is not part of the cycle and the cycle does not + // depend on. + a_in.set_inputs(&mut db).to(vec![value(45), b.clone()]); + b.assert_value(&db, 60); + + db.assert_logs(expect![[r#" + [ + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(4)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + ]"#]]); + + a.assert_value(&db, 45); + } +} diff --git a/tests/cycle_initial_call_back_into_cycle.rs b/tests/cycle_initial_call_back_into_cycle.rs new file mode 100644 index 000000000..9dfe39a92 --- /dev/null +++ b/tests/cycle_initial_call_back_into_cycle.rs @@ -0,0 +1,36 @@ +//! Calling back into the same cycle from your cycle initial function will trigger another cycle. + +#[salsa::tracked] +fn initial_value(db: &dyn salsa::Database) -> u32 { + query(db) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query(db: &dyn salsa::Database) -> u32 { + let val = query(db); + if val < 5 { + val + 1 + } else { + val + } +} + +fn cycle_initial(db: &dyn salsa::Database) -> u32 { + initial_value(db) +} + +fn cycle_fn( + _db: &dyn salsa::Database, + _value: &u32, + _count: u32, +) -> salsa::CycleRecoveryAction { + salsa::CycleRecoveryAction::Iterate +} + +#[test_log::test] +#[should_panic(expected = "dependency graph cycle")] +fn the_test() { + let db = salsa::DatabaseImpl::default(); + + query(&db); +} diff --git a/tests/cycle_initial_call_query.rs b/tests/cycle_initial_call_query.rs new file mode 100644 index 000000000..4c52fff27 --- /dev/null +++ b/tests/cycle_initial_call_query.rs @@ -0,0 +1,35 @@ +//! It's possible to call a Salsa query from within a cycle initial fn. + +#[salsa::tracked] +fn initial_value(_db: &dyn salsa::Database) -> u32 { + 0 +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query(db: &dyn salsa::Database) -> u32 { + let val = query(db); + if val < 5 { + val + 1 + } else { + val + } +} + +fn cycle_initial(db: &dyn salsa::Database) -> u32 { + initial_value(db) +} + +fn cycle_fn( + _db: &dyn salsa::Database, + _value: &u32, + _count: u32, +) -> salsa::CycleRecoveryAction { + salsa::CycleRecoveryAction::Iterate +} + +#[test_log::test] +fn the_test() { + let db = salsa::DatabaseImpl::default(); + + assert_eq!(query(&db), 5); +} diff --git a/tests/cycle_recovery_call_back_into_cycle.rs b/tests/cycle_recovery_call_back_into_cycle.rs new file mode 100644 index 000000000..a4dc5e250 --- /dev/null +++ b/tests/cycle_recovery_call_back_into_cycle.rs @@ -0,0 +1,43 @@ +//! Calling back into the same cycle from your cycle recovery function _can_ work out, as long as +//! the overall cycle still converges. + +mod common; +use common::{DatabaseWithValue, ValueDatabase}; + +#[salsa::tracked] +fn fallback_value(db: &dyn ValueDatabase) -> u32 { + query(db) + db.get_value() +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query(db: &dyn ValueDatabase) -> u32 { + let val = query(db); + if val < 5 { + val + 1 + } else { + val + } +} + +fn cycle_initial(_db: &dyn ValueDatabase) -> u32 { + 0 +} + +fn cycle_fn(db: &dyn ValueDatabase, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction { + salsa::CycleRecoveryAction::Fallback(fallback_value(db)) +} + +#[test] +fn converges() { + let db = DatabaseWithValue::new(10); + + assert_eq!(query(&db), 10); +} + +#[test] +#[should_panic(expected = "fallback did not converge")] +fn diverges() { + let db = DatabaseWithValue::new(3); + + query(&db); +} diff --git a/tests/cycle_recovery_call_query.rs b/tests/cycle_recovery_call_query.rs new file mode 100644 index 000000000..a768017c8 --- /dev/null +++ b/tests/cycle_recovery_call_query.rs @@ -0,0 +1,35 @@ +//! It's possible to call a Salsa query from within a cycle recovery fn. + +#[salsa::tracked] +fn fallback_value(_db: &dyn salsa::Database) -> u32 { + 10 +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query(db: &dyn salsa::Database) -> u32 { + let val = query(db); + if val < 5 { + val + 1 + } else { + val + } +} + +fn cycle_initial(_db: &dyn salsa::Database) -> u32 { + 0 +} + +fn cycle_fn( + db: &dyn salsa::Database, + _value: &u32, + _count: u32, +) -> salsa::CycleRecoveryAction { + salsa::CycleRecoveryAction::Fallback(fallback_value(db)) +} + +#[test_log::test] +fn the_test() { + let db = salsa::DatabaseImpl::default(); + + assert_eq!(query(&db), 10); +} diff --git a/tests/cycles.rs b/tests/cycles.rs deleted file mode 100644 index be32beb8e..000000000 --- a/tests/cycles.rs +++ /dev/null @@ -1,437 +0,0 @@ -#![allow(warnings)] - -use std::panic::{RefUnwindSafe, UnwindSafe}; - -use expect_test::expect; -use salsa::DatabaseImpl; -use salsa::Durability; - -// Axes: -// -// Threading -// * Intra-thread -// * Cross-thread -- part of cycle is on one thread, part on another -// -// Recovery strategies: -// * Panic -// * Fallback -// * Mixed -- multiple strategies within cycle participants -// -// Across revisions: -// * N/A -- only one revision -// * Present in new revision, not old -// * Present in old revision, not new -// * Present in both revisions -// -// Dependencies -// * Tracked -// * Untracked -- cycle participant(s) contain untracked reads -// -// Layers -// * Direct -- cycle participant is directly invoked from test -// * Indirect -- invoked a query that invokes the cycle -// -// -// | Thread | Recovery | Old, New | Dep style | Layers | Test Name | -// | ------ | -------- | -------- | --------- | ------ | --------- | -// | Intra | Panic | N/A | Tracked | direct | cycle_memoized | -// | Intra | Panic | N/A | Untracked | direct | cycle_volatile | -// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle | -// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle | -// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate | -// | Intra | Fallback | New | Tracked | direct | cycle_appears | -// | Intra | Fallback | Old | Tracked | direct | cycle_disappears | -// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | -// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 | -// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 | -// | Cross | Panic | N/A | Tracked | both | parallel/parallel_cycle_none_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_one_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_mid_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_all_recover.rs | - -#[derive(PartialEq, Eq, Hash, Clone, Debug, Update)] -struct Error { - cycle: Vec, -} - -use salsa::Database as Db; -use salsa::Setter; -use salsa::Update; - -#[salsa::input] -struct MyInput {} - -#[salsa::tracked] -fn memoized_a(db: &dyn Db, input: MyInput) { - memoized_b(db, input) -} - -#[salsa::tracked] -fn memoized_b(db: &dyn Db, input: MyInput) { - memoized_a(db, input) -} - -#[salsa::tracked] -fn volatile_a(db: &dyn Db, input: MyInput) { - db.report_untracked_read(); - volatile_b(db, input) -} - -#[salsa::tracked] -fn volatile_b(db: &dyn Db, input: MyInput) { - db.report_untracked_read(); - volatile_a(db, input) -} - -/// The queries A, B, and C in `Database` can be configured -/// to invoke one another in arbitrary ways using this -/// enum. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -enum CycleQuery { - None, - A, - B, - C, - AthenC, -} - -#[salsa::input] -struct ABC { - a: CycleQuery, - b: CycleQuery, - c: CycleQuery, -} - -impl CycleQuery { - fn invoke(self, db: &dyn Db, abc: ABC) -> Result<(), Error> { - match self { - CycleQuery::A => cycle_a(db, abc), - CycleQuery::B => cycle_b(db, abc), - CycleQuery::C => cycle_c(db, abc), - CycleQuery::AthenC => { - let _ = cycle_a(db, abc); - cycle_c(db, abc) - } - CycleQuery::None => Ok(()), - } - } -} - -#[salsa::tracked(recovery_fn=recover_a)] -fn cycle_a(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.a(db).invoke(db, abc) -} - -fn recover_a(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { - Err(Error { - cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), - }) -} - -#[salsa::tracked(recovery_fn=recover_b)] -fn cycle_b(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.b(db).invoke(db, abc) -} - -fn recover_b(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { - Err(Error { - cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), - }) -} - -#[salsa::tracked] -fn cycle_c(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.c(db).invoke(db, abc) -} - -#[track_caller] -fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle { - let v = std::panic::catch_unwind(f); - if let Err(d) = &v { - if let Some(cycle) = d.downcast_ref::() { - return cycle.clone(); - } - } - panic!("unexpected value: {:?}", v) -} - -#[test] -fn cycle_memoized() { - salsa::DatabaseImpl::new().attach(|db| { - let input = MyInput::new(db); - let cycle = extract_cycle(|| memoized_a(db, input)); - let expected = expect![[r#" - [ - memoized_a(Id(0)), - memoized_b(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&cycle.all_participants(db)); - }) -} - -#[test] -fn cycle_volatile() { - salsa::DatabaseImpl::new().attach(|db| { - let input = MyInput::new(db); - let cycle = extract_cycle(|| volatile_a(db, input)); - let expected = expect![[r#" - [ - volatile_a(Id(0)), - volatile_b(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&cycle.all_participants(db)); - }); -} - -#[test] -fn expect_cycle() { - // A --> B - // ^ | - // +-----+ - - salsa::DatabaseImpl::new().attach(|db| { - let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(db, abc).is_err()); - }) -} - -#[test] -fn inner_cycle() { - // A --> B <-- C - // ^ | - // +-----+ - salsa::DatabaseImpl::new().attach(|db| { - let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::B); - let err = cycle_c(db, abc); - assert!(err.is_err()); - let expected = expect![[r#" - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&err.unwrap_err().cycle); - }) -} - -#[test] -fn cycle_revalidate() { - // A --> B - // ^ | - // +-----+ - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - abc.set_b(&mut db).to(CycleQuery::A); // same value as default - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_recovery_unchanged_twice() { - // A --> B - // ^ | - // +-----+ - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - - abc.set_c(&mut db).to(CycleQuery::A); // force new revision - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_appears() { - let mut db = salsa::DatabaseImpl::new(); - // A --> B - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::None, CycleQuery::None); - assert!(cycle_a(&db, abc).is_ok()); - - // A --> B - // ^ | - // +-----+ - abc.set_b(&mut db).to(CycleQuery::A); - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_disappears() { - let mut db = salsa::DatabaseImpl::new(); - - // A --> B - // ^ | - // +-----+ - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - - // A --> B - abc.set_b(&mut db).to(CycleQuery::None); - assert!(cycle_a(&db, abc).is_ok()); -} - -/// A variant on `cycle_disappears` in which the values of -/// `a` and `b` are set with durability values. -/// If we are not careful, this could cause us to overlook -/// the fact that the cycle will no longer occur. -#[test] -fn cycle_disappears_durability() { - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new( - &mut db, - CycleQuery::None, - CycleQuery::None, - CycleQuery::None, - ); - abc.set_a(&mut db) - .with_durability(Durability::LOW) - .to(CycleQuery::B); - abc.set_b(&mut db) - .with_durability(Durability::HIGH) - .to(CycleQuery::A); - - assert!(cycle_a(&db, abc).is_err()); - - // At this point, `a` read `LOW` input, and `b` read `HIGH` input. However, - // because `b` participates in the same cycle as `a`, its final durability - // should be `LOW`. - // - // Check that setting a `LOW` input causes us to re-execute `b` query, and - // observe that the cycle goes away. - abc.set_a(&mut db) - .with_durability(Durability::LOW) - .to(CycleQuery::None); - - assert!(cycle_b(&mut db, abc).is_ok()); -} - -#[test] -fn cycle_mixed_1() { - salsa::DatabaseImpl::new().attach(|db| { - // A --> B <-- C - // | ^ - // +-----+ - let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::B); - - let expected = expect![[r#" - [ - "cycle_b(Id(0))", - "cycle_c(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&cycle_c(db, abc).unwrap_err().cycle); - }) -} - -#[test] -fn cycle_mixed_2() { - salsa::DatabaseImpl::new().attach(|db| { - // Configuration: - // - // A --> B --> C - // ^ | - // +-----------+ - let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::A); - let expected = expect![[r#" - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - "cycle_c(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&cycle_a(db, abc).unwrap_err().cycle); - }) -} - -#[test] -fn cycle_deterministic_order() { - // No matter whether we start from A or B, we get the same set of participants: - let f = || { - let mut db = salsa::DatabaseImpl::new(); - - // A --> B - // ^ | - // +-----+ - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - (db, abc) - }; - let (db, abc) = f(); - let a = cycle_a(&db, abc); - let (db, abc) = f(); - let b = cycle_b(&db, abc); - let expected = expect![[r#" - ( - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - ) - "#]]; - expected.assert_debug_eq(&(a.unwrap_err().cycle, b.unwrap_err().cycle)); -} - -#[test] -fn cycle_multiple() { - // No matter whether we start from A or B, we get the same set of participants: - let mut db = salsa::DatabaseImpl::new(); - - // Configuration: - // - // A --> B <-- C - // ^ | ^ - // +-----+ | - // | | - // +-----+ - // - // Here, conceptually, B encounters a cycle with A and then - // recovers. - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::AthenC, CycleQuery::A); - - let c = cycle_c(&db, abc); - let b = cycle_b(&db, abc); - let a = cycle_a(&db, abc); - let expected = expect![[r#" - ( - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - ) - "#]]; - expected.assert_debug_eq(&( - c.unwrap_err().cycle, - b.unwrap_err().cycle, - a.unwrap_err().cycle, - )); -} - -#[test] -fn cycle_recovery_set_but_not_participating() { - salsa::DatabaseImpl::new().attach(|db| { - // A --> C -+ - // ^ | - // +--+ - let abc = ABC::new(db, CycleQuery::C, CycleQuery::None, CycleQuery::C); - - // Here we expect C to panic and A not to recover: - let r = extract_cycle(|| drop(cycle_a(db, abc))); - let expected = expect![[r#" - [ - cycle_c(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&r.all_participants(db)); - }) -} diff --git a/tests/dataflow.rs b/tests/dataflow.rs new file mode 100644 index 000000000..b5784d8e2 --- /dev/null +++ b/tests/dataflow.rs @@ -0,0 +1,246 @@ +//! Test case for fixpoint iteration cycle resolution. +//! +//! This test case is intended to simulate a (very simplified) version of a real dataflow analysis +//! using fixpoint iteration. +use salsa::{CycleRecoveryAction, Database as Db, Setter}; +use std::collections::BTreeSet; +use std::iter::IntoIterator; + +/// A Use of a symbol. +#[salsa::input] +struct Use { + reaching_definitions: Vec, +} + +/// A Definition of a symbol, either of the form `base + increment` or `0 + increment`. +#[salsa::input] +struct Definition { + base: Option, + increment: usize, +} + +#[derive(Eq, PartialEq, Clone, Debug, salsa::Update)] +enum Type { + Bottom, + Values(Box<[usize]>), + Top, +} + +impl Type { + fn join(tys: impl IntoIterator) -> Type { + let mut result = Type::Bottom; + for ty in tys.into_iter() { + result = match (result, ty) { + (result, Type::Bottom) => result, + (_, Type::Top) => Type::Top, + (Type::Top, _) => Type::Top, + (Type::Bottom, ty) => ty, + (Type::Values(a_ints), Type::Values(b_ints)) => { + let mut set = BTreeSet::new(); + set.extend(a_ints); + set.extend(b_ints); + Type::Values(set.into_iter().collect()) + } + } + } + result + } +} + +#[salsa::tracked(cycle_fn=use_cycle_recover, cycle_initial=use_cycle_initial)] +fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { + let defs = u.reaching_definitions(db); + match defs[..] { + [] => Type::Bottom, + [def] => infer_definition(db, def), + _ => Type::join(defs.iter().map(|&def| infer_definition(db, def))), + } +} + +#[salsa::tracked(cycle_fn=def_cycle_recover, cycle_initial=def_cycle_initial)] +fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { + let increment_ty = Type::Values(Box::from([def.increment(db)])); + if let Some(base) = def.base(db) { + let base_ty = infer_use(db, base); + add(&base_ty, &increment_ty) + } else { + increment_ty + } +} + +fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { + Type::Bottom +} + +fn def_cycle_recover( + _db: &dyn Db, + value: &Type, + count: u32, + _def: Definition, +) -> CycleRecoveryAction { + cycle_recover(value, count) +} + +fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { + Type::Bottom +} + +fn use_cycle_recover( + _db: &dyn Db, + value: &Type, + count: u32, + _use: Use, +) -> CycleRecoveryAction { + cycle_recover(value, count) +} + +fn cycle_recover(value: &Type, count: u32) -> CycleRecoveryAction { + match value { + Type::Bottom => CycleRecoveryAction::Iterate, + Type::Values(_) => { + if count > 4 { + CycleRecoveryAction::Fallback(Type::Top) + } else { + CycleRecoveryAction::Iterate + } + } + Type::Top => CycleRecoveryAction::Iterate, + } +} + +fn add(a: &Type, b: &Type) -> Type { + match (a, b) { + (Type::Bottom, _) | (_, Type::Bottom) => Type::Bottom, + (Type::Top, _) | (_, Type::Top) => Type::Top, + (Type::Values(a_ints), Type::Values(b_ints)) => { + let mut set = BTreeSet::new(); + set.extend( + a_ints + .into_iter() + .flat_map(|a| b_ints.into_iter().map(move |b| a + b)), + ); + Type::Values(set.into_iter().collect()) + } + } +} + +/// x = 1 +#[test] +fn simple() { + let db = salsa::DatabaseImpl::new(); + + let def = Definition::new(&db, None, 1); + let u = Use::new(&db, vec![def]); + + let ty = infer_use(&db, u); + + assert_eq!(ty, Type::Values(Box::from([1]))); +} + +/// x = 1 if flag else 2 +#[test] +fn union() { + let db = salsa::DatabaseImpl::new(); + + let def1 = Definition::new(&db, None, 1); + let def2 = Definition::new(&db, None, 2); + let u = Use::new(&db, vec![def1, def2]); + + let ty = infer_use(&db, u); + + assert_eq!(ty, Type::Values(Box::from([1, 2]))); +} + +/// x = 1 if flag else 2; y = x + 1 +#[test] +fn union_add() { + let db = salsa::DatabaseImpl::new(); + + let x1 = Definition::new(&db, None, 1); + let x2 = Definition::new(&db, None, 2); + let x_use = Use::new(&db, vec![x1, x2]); + let y_def = Definition::new(&db, Some(x_use), 1); + let y_use = Use::new(&db, vec![y_def]); + + let ty = infer_use(&db, y_use); + + assert_eq!(ty, Type::Values(Box::from([2, 3]))); +} + +/// x = 1; loop { x = x + 0 } +#[test] +fn cycle_converges_then_diverges() { + let mut db = salsa::DatabaseImpl::new(); + + let def1 = Definition::new(&db, None, 1); + let def2 = Definition::new(&db, None, 0); + let u = Use::new(&db, vec![def1, def2]); + def2.set_base(&mut db).to(Some(u)); + + let ty = infer_use(&db, u); + + // Loop converges on 1 + assert_eq!(ty, Type::Values(Box::from([1]))); + + // Set the increment on x from 0 to 1 + let new_increment = 1; + def2.set_increment(&mut db).to(new_increment); + + // Now the loop diverges and we fall back to Top + assert_eq!(infer_use(&db, u), Type::Top); +} + +/// x = 1; loop { x = x + 1 } +#[test] +fn cycle_diverges_then_converges() { + let mut db = salsa::DatabaseImpl::new(); + + let def1 = Definition::new(&db, None, 1); + let def2 = Definition::new(&db, None, 1); + let u = Use::new(&db, vec![def1, def2]); + def2.set_base(&mut db).to(Some(u)); + + let ty = infer_use(&db, u); + + // Loop diverges. Cut it off and fallback to Type::Top + assert_eq!(ty, Type::Top); + + // Set the increment from 1 to 0. + def2.set_increment(&mut db).to(0); + + // Now the loop converges on 1. + assert_eq!(infer_use(&db, u), Type::Values(Box::from([1]))); +} + +/// x = 0; y = 0; loop { x = y + 0; y = x + 0 } +#[test_log::test] +fn multi_symbol_cycle_converges_then_diverges() { + let mut db = salsa::DatabaseImpl::new(); + + let defx0 = Definition::new(&db, None, 0); + let defy0 = Definition::new(&db, None, 0); + let defx1 = Definition::new(&db, None, 0); + let defy1 = Definition::new(&db, None, 0); + let use_x = Use::new(&db, vec![defx0, defx1]); + let use_y = Use::new(&db, vec![defy0, defy1]); + defx1.set_base(&mut db).to(Some(use_y)); + defy1.set_base(&mut db).to(Some(use_x)); + + // Both symbols converge on 0 + assert_eq!(infer_use(&db, use_x), Type::Values(Box::from([0]))); + assert_eq!(infer_use(&db, use_y), Type::Values(Box::from([0]))); + + // Set the increment on x to 0. + defx1.set_increment(&mut db).to(0); + + // Both symbols still converge on 0. + assert_eq!(infer_use(&db, use_x), Type::Values(Box::from([0]))); + assert_eq!(infer_use(&db, use_y), Type::Values(Box::from([0]))); + + // Set the increment on x from 0 to 1. + defx1.set_increment(&mut db).to(1); + + // Now the loop diverges and we fall back to Top. + assert_eq!(infer_use(&db, use_x), Type::Top); + assert_eq!(infer_use(&db, use_y), Type::Top); +} diff --git a/tests/parallel/cycle_a_t1_b_t2.rs b/tests/parallel/cycle_a_t1_b_t2.rs new file mode 100644 index 000000000..aa0b84845 --- /dev/null +++ b/tests/parallel/cycle_a_t1_b_t2.rs @@ -0,0 +1,74 @@ +//! Test a specific cycle scenario: +//! +//! ```text +//! Thread T1 Thread T2 +//! --------- --------- +//! | | +//! v | +//! query_a() | +//! ^ | v +//! | +------------> query_b() +//! | | +//! +--------------------+ +//! ``` + +use salsa::CycleRecoveryAction; + +use crate::setup::{Knobs, KnobsDatabase}; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MAX: CycleValue = CycleValue(3); + +// Signal 1: T1 has entered `query_a` +// Signal 2: T2 has entered `query_b` + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + db.signal(1); + + // Wait for Thread T2 to enter `query_b` before we continue. + db.wait_for(2); + + query_b(db) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + // Wait for Thread T1 to enter `query_a` before we continue. + db.wait_for(1); + + db.signal(2); + + let a_value = query_a(db); + CycleValue(a_value.0 + 1).min(MAX) +} + +fn cycle_fn( + _db: &dyn KnobsDatabase, + _value: &CycleValue, + _count: u32, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn initial(_db: &dyn KnobsDatabase) -> CycleValue { + MIN +} + +#[test_log::test] +fn the_test() { + std::thread::scope(|scope| { + let db_t1 = Knobs::default(); + let db_t2 = db_t1.clone(); + + let t1 = scope.spawn(move || query_a(&db_t1)); + let t2 = scope.spawn(move || query_b(&db_t2)); + + let (r_t1, r_t2) = (t1.join().unwrap(), t2.join().unwrap()); + + assert_eq!((r_t1, r_t2), (MAX, MAX)); + }); +} diff --git a/tests/parallel/cycle_ab_peeping_c.rs b/tests/parallel/cycle_ab_peeping_c.rs new file mode 100644 index 000000000..1c8233fab --- /dev/null +++ b/tests/parallel/cycle_ab_peeping_c.rs @@ -0,0 +1,78 @@ +//! Test a specific cycle scenario: +//! +//! Thread T1 calls A which calls B which calls A. +//! +//! Thread T2 calls C which calls B. +//! +//! The trick is that the call from Thread T2 comes before B has reached a fixed point. +//! We want to be sure that C sees the final value (and blocks until it is complete). + +use salsa::CycleRecoveryAction; + +use crate::setup::{Knobs, KnobsDatabase}; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MID: CycleValue = CycleValue(11); +const MAX: CycleValue = CycleValue(22); + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + let b_value = query_b(db); + + // When we reach the mid point, signal stage 1 (unblocking T2) + // and then wait for T2 to signal stage 2. + if b_value == MID { + db.signal(1); + db.wait_for(2); + } + + b_value +} + +fn cycle_fn( + _db: &dyn KnobsDatabase, + _value: &CycleValue, + _count: u32, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn cycle_initial(_db: &dyn KnobsDatabase) -> CycleValue { + MIN +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + let a_value = query_a(db); + + CycleValue(a_value.0 + 1).min(MAX) +} + +#[salsa::tracked] +fn query_c(db: &dyn KnobsDatabase) -> CycleValue { + // Wait until T1 has reached MID then execute `query_b`. + // This should block and (due to the configuration on our database) signal stage 2. + db.wait_for(1); + + query_b(db) +} + +#[test] +fn the_test() { + std::thread::scope(|scope| { + let db_t1 = Knobs::default(); + + let db_t2 = db_t1.clone(); + db_t2.signal_on_will_block(2); + + let t1 = scope.spawn(move || query_a(&db_t1)); + let t2 = scope.spawn(move || query_c(&db_t2)); + + let (r_t1, r_t2) = (t1.join().unwrap(), t2.join().unwrap()); + + assert_eq!((r_t1, r_t2), (MAX, MAX)); + }); +} diff --git a/tests/parallel/cycle_nested_three_threads.rs b/tests/parallel/cycle_nested_three_threads.rs new file mode 100644 index 000000000..f0ff0e128 --- /dev/null +++ b/tests/parallel/cycle_nested_three_threads.rs @@ -0,0 +1,89 @@ +//! Test a nested-cycle scenario across three threads: +//! +//! ```text +//! Thread T1 Thread T2 Thread T3 +//! --------- --------- --------- +//! | | | +//! v | | +//! query_a() | | +//! ^ | v | +//! | +------------> query_b() | +//! | ^ | v +//! | | +------------> query_c() +//! | | | +//! +------------------+--------------------+ +//! +//! ``` + +use salsa::CycleRecoveryAction; + +use crate::setup::{Knobs, KnobsDatabase}; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MAX: CycleValue = CycleValue(3); + +// Signal 1: T1 has entered `query_a` +// Signal 2: T2 has entered `query_b` +// Signal 3: T3 has entered `query_c` + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + db.signal(1); + db.wait_for(3); + + query_b(db) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + db.wait_for(1); + db.signal(2); + db.wait_for(3); + + let c_value = query_c(db); + CycleValue(c_value.0 + 1).min(MAX) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_c(db: &dyn KnobsDatabase) -> CycleValue { + db.wait_for(2); + db.signal(3); + + let a_value = query_a(db); + let b_value = query_b(db); + CycleValue(a_value.0.max(b_value.0)) +} + +fn cycle_fn( + _db: &dyn KnobsDatabase, + _value: &CycleValue, + _count: u32, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn initial(_db: &dyn KnobsDatabase) -> CycleValue { + MIN +} + +#[test_log::test] +fn the_test() { + std::thread::scope(|scope| { + let db_t1 = Knobs::default(); + let db_t2 = db_t1.clone(); + let db_t3 = db_t1.clone(); + + let t1 = scope.spawn(move || query_a(&db_t1)); + let t2 = scope.spawn(move || query_b(&db_t2)); + let t3 = scope.spawn(move || query_c(&db_t3)); + + let r_t1 = t1.join().unwrap(); + let r_t2 = t2.join().unwrap(); + let r_t3 = t3.join().unwrap(); + + assert_eq!((r_t1, r_t2, r_t3), (MAX, MAX, MAX)); + }); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index e01e46546..3f1d886c1 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -1,9 +1,8 @@ mod setup; +mod cycle_a_t1_b_t2; +mod cycle_ab_peeping_c; +mod cycle_nested_three_threads; mod parallel_cancellation; -mod parallel_cycle_all_recover; -mod parallel_cycle_mid_recover; -mod parallel_cycle_none_recover; -mod parallel_cycle_one_recover; mod parallel_map; mod signal; diff --git a/tests/parallel/parallel_cycle_all_recover.rs b/tests/parallel/parallel_cycle_all_recover.rs deleted file mode 100644 index 08858ef5d..000000000 --- a/tests/parallel/parallel_cycle_all_recover.rs +++ /dev/null @@ -1,104 +0,0 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::Knobs; -use crate::setup::KnobsDatabase; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked(recovery_fn = recover_a1)] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.signal(1); - db.wait_for(2); - - a2(db, input) -} - -fn recover_a1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_a1"); - key.field(db) * 10 + 1 -} - -#[salsa::tracked(recovery_fn=recover_a2)] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - b1(db, input) -} - -fn recover_a2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_a2"); - key.field(db) * 10 + 2 -} - -#[salsa::tracked(recovery_fn=recover_b1)] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.wait_for(1); - db.signal(2); - - // Wait for thread A to block on this thread - db.wait_for(3); - b2(db, input) -} - -fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b1"); - key.field(db) * 20 + 1 -} - -#[salsa::tracked(recovery_fn=recover_b2)] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - a1(db, input) -} - -fn recover_b2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b2"); - key.field(db) * 20 + 2 -} - -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | signal stage 2 -// (unblocked) wait for stage 3 (blocks) -// a2 | -// b1 (blocks -> stage 3) | -// | (unblocked) -// | b2 -// | a1 (cycle detected, recovers) -// | b2 completes, recovers -// | b1 completes, recovers -// a2 sees cycle, recovers -// a1 completes, recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block(3); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b1(&db, input) - }); - - assert_eq!(thread_a.join().unwrap(), 11); - assert_eq!(thread_b.join().unwrap(), 21); -} diff --git a/tests/parallel/parallel_cycle_mid_recover.rs b/tests/parallel/parallel_cycle_mid_recover.rs deleted file mode 100644 index c41ed32d1..000000000 --- a/tests/parallel/parallel_cycle_mid_recover.rs +++ /dev/null @@ -1,102 +0,0 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::{Knobs, KnobsDatabase}; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // tell thread b we have started - db.signal(1); - - // wait for thread b to block on a1 - db.wait_for(2); - - a2(db, input) -} - -#[salsa::tracked] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // create the cycle - b1(db, input) -} - -#[salsa::tracked(recovery_fn=recover_b1)] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // wait for thread a to have started - db.wait_for(1); - b2(db, input) -} - -fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b1"); - key.field(db) * 20 + 2 -} - -#[salsa::tracked] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // will encounter a cycle but recover - b3(db, input); - b1(db, input); // hasn't recovered yet - 0 -} - -#[salsa::tracked(recovery_fn=recover_b3)] -pub(crate) fn b3(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // will block on thread a, signaling stage 2 - a1(db, input) -} - -fn recover_b3(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b3"); - key.field(db) * 200 + 2 -} - -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | | -// | b2 -// | b3 -// | a1 (blocks -> stage 2) -// (unblocked) | -// a2 (cycle detected) | -// b3 recovers -// b2 resumes -// b1 recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block(3); - move || b1(&db, input) - }); - - // We expect that the recovery function yields - // `1 * 20 + 2`, which is returned (and forwarded) - // to b1, and from there to a2 and a1. - assert_eq!(thread_a.join().unwrap(), 22); - assert_eq!(thread_b.join().unwrap(), 22); -} diff --git a/tests/parallel/parallel_cycle_none_recover.rs b/tests/parallel/parallel_cycle_none_recover.rs deleted file mode 100644 index f1f0ee91e..000000000 --- a/tests/parallel/parallel_cycle_none_recover.rs +++ /dev/null @@ -1,78 +0,0 @@ -//! Test a cycle where no queries recover that occurs across threads. -//! See the `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::Knobs; -use crate::setup::KnobsDatabase; -use expect_test::expect; -use salsa::Database; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked] -pub(crate) fn a(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.signal(1); - db.wait_for(2); - - b(db, input) -} - -#[salsa::tracked] -pub(crate) fn b(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.wait_for(1); - db.signal(2); - - // Wait for thread A to block on this thread - db.wait_for(3); - - // Now try to execute A - a(db, input) -} - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, -1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block(3); - move || a(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b(&db, input) - }); - - // We expect B to panic because it detects a cycle (it is the one that calls A, ultimately). - // Right now, it panics with a string. - let err_b = thread_b.join().unwrap_err(); - db.attach(|_| { - if let Some(c) = err_b.downcast_ref::() { - let expected = expect![[r#" - [ - a(Id(0)), - b(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&c.all_participants(&db)); - } else { - panic!("b failed in an unexpected way: {:?}", err_b); - } - }); - - // We expect A to propagate a panic, which causes us to use the sentinel - // type `Canceled`. - assert!(thread_a - .join() - .unwrap_err() - .downcast_ref::() - .is_some()); -} diff --git a/tests/parallel/parallel_cycle_one_recover.rs b/tests/parallel/parallel_cycle_one_recover.rs deleted file mode 100644 index 65737797b..000000000 --- a/tests/parallel/parallel_cycle_one_recover.rs +++ /dev/null @@ -1,91 +0,0 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::{Knobs, KnobsDatabase}; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.signal(1); - db.wait_for(2); - - a2(db, input) -} - -#[salsa::tracked(recovery_fn=recover)] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - b1(db, input) -} - -fn recover(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover"); - key.field(db) * 20 + 2 -} - -#[salsa::tracked] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.wait_for(1); - db.signal(2); - - // Wait for thread A to block on this thread - db.wait_for(3); - b2(db, input) -} - -#[salsa::tracked] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - a1(db, input) -} - -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | signal stage 2 -// (unblocked) wait for stage 3 (blocks) -// a2 | -// b1 (blocks -> stage 3) | -// | (unblocked) -// | b2 -// | a1 (cycle detected) -// a2 recovery fn executes | -// a1 completes normally | -// b2 completes, recovers -// b1 completes, recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block(3); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b1(&db, input) - }); - - // We expect that the recovery function yields - // `1 * 20 + 2`, which is returned (and forwarded) - // to b1, and from there to a2 and a1. - assert_eq!(thread_a.join().unwrap(), 22); - assert_eq!(thread_b.join().unwrap(), 22); -} diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index b29d1b7be..52c0ce227 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -11,10 +11,10 @@ use crate::signal::Signal; /// a certain behavior. #[salsa::db] pub(crate) trait KnobsDatabase: Database { - fn knobs(&self) -> &Knobs; - + /// Signal that we are entering stage 1. fn signal(&self, stage: usize); + /// Wait until we reach stage `stage` (no-op if we have already reached that stage). fn wait_for(&self, stage: usize); } @@ -80,10 +80,6 @@ impl salsa::Database for Knobs { #[salsa::db] impl KnobsDatabase for Knobs { - fn knobs(&self) -> &Knobs { - self - } - fn signal(&self, stage: usize) { self.signal.signal(stage); }