Skip to content

Commit

Permalink
Merge pull request #680 from MichaReiser/micha/tracked-fn-update-fall…
Browse files Browse the repository at this point in the history
…back

Use *fallback* trick for tracked-fn `Update` constraint
  • Loading branch information
MichaReiser authored Feb 10, 2025
2 parents 538eaad + c48a9df commit 351d9cf
Show file tree
Hide file tree
Showing 12 changed files with 174 additions and 56 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ rust-version = "1.76"

[dependencies]
arc-swap = "1"
compact_str = { version = "0.8", optional = true }
crossbeam = "0.8"
dashmap = { version = "6", features = ["raw-api"] }
hashlink = "0.9"
Expand Down
10 changes: 10 additions & 0 deletions components/salsa-macro-rules/src/setup_tracked_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ macro_rules! setup_tracked_fn {
// True if we `return_ref` flag was given to the function
return_ref: $return_ref:tt,

maybe_update_fn: {$($maybe_update_fn:tt)*},

// Annoyingly macro-rules hygiene does not extend to items defined in the macro.
// We have the procedural macro generate names for those items that are
// not used elsewhere in the user's code.
Expand Down Expand Up @@ -145,6 +147,14 @@ macro_rules! setup_tracked_fn {
}
}

/// This method isn't used anywhere. It only exitst to enforce the `Self::Output: Update` constraint
/// for types that aren't `'static`.
///
/// # Safety
/// The same safety rules as for `Update` apply.
$($maybe_update_fn)*


impl $zalsa::function::Configuration for $Configuration {
const DEBUG_NAME: &'static str = stringify!($fn_name);

Expand Down
13 changes: 13 additions & 0 deletions components/salsa-macros/src/tracked_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ impl Macro {

let return_ref: bool = self.args.return_ref.is_some();

let maybe_update_fn = quote_spanned! {output_ty.span()=> {
#[allow(clippy::all, unsafe_code)]
unsafe fn _maybe_update_fn<'db>(old_pointer: *mut #output_ty, new_value: #output_ty) -> bool {
unsafe {
use #zalsa::UpdateFallback;
#zalsa::UpdateDispatch::<#output_ty>::maybe_update(
old_pointer, new_value
)
}
}
}};

Ok(crate::debug::dump_tokens(
fn_name,
quote![salsa::plumbing::setup_tracked_fn! {
Expand All @@ -137,6 +149,7 @@ impl Macro {
needs_interner: #needs_interner,
lru: #lru,
return_ref: #return_ref,
maybe_update_fn: { #maybe_update_fn },
unused_names: [
#zalsa,
#Configuration,
Expand Down
30 changes: 20 additions & 10 deletions components/salsa-macros/src/update.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use proc_macro2::{Literal, TokenStream};
use syn::spanned::Spanned;
use synstructure::BindStyle;

use crate::hygiene::Hygiene;
Expand Down Expand Up @@ -34,7 +35,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
.bindings()
.iter()
.fold(quote!(), |tokens, binding| quote!(#tokens #binding,));
let make_new_value = quote! {
let make_new_value = quote_spanned! {variant.ast().ident.span()=>
let #new_value = if let #variant_pat = #new_value {
(#make_tuple)
} else {
Expand All @@ -46,20 +47,28 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
// For each field, invoke `maybe_update` recursively to update its value.
// Or the results together (using `|`, not `||`, to avoid shortcircuiting)
// to get the final return value.
let update_fields = variant.bindings().iter().zip(0..).fold(
let update_fields = variant.bindings().iter().enumerate().fold(
quote!(false),
|tokens, (binding, index)| {
|tokens, (index, binding)| {
let field_ty = &binding.ast().ty;
let field_index = Literal::usize_unsuffixed(index);

let field_span = binding
.ast()
.ident
.as_ref()
.map(Spanned::span)
.unwrap_or(binding.ast().span());

let update_field = quote_spanned! {field_span=>
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update(
#binding,
#new_value.#field_index,
)
};

quote! {
#tokens |
unsafe {
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update(
#binding,
#new_value.#field_index,
)
}
#tokens | unsafe { #update_field }
}
},
);
Expand All @@ -77,6 +86,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let tokens = quote! {
#[allow(clippy::all)]
#[automatically_derived]
unsafe impl #impl_generics salsa::Update for #ident #ty_generics #where_clause {
unsafe fn maybe_update(#old_pointer: *mut Self, #new_value: Self) -> bool {
use ::salsa::plumbing::UpdateFallback as _;
Expand Down
2 changes: 1 addition & 1 deletion examples/calc/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub enum ExpressionData<'db> {
Call(FunctionId<'db>, Vec<Expression<'db>>),
}

#[derive(Eq, PartialEq, Copy, Clone, Hash, Debug, salsa::Update)]
#[derive(Eq, PartialEq, Copy, Clone, Hash, Debug)]
pub enum Op {
Add,
Subtract,
Expand Down
4 changes: 2 additions & 2 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
salsa_struct::SalsaStructInDb,
zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa},
zalsa_local::QueryOrigin,
Cycle, Database, Id, Revision, Update,
Cycle, Database, Id, Revision,
};

use self::delete::DeletedEntries;
Expand Down Expand Up @@ -43,7 +43,7 @@ pub trait Configuration: Any {
type Input<'db>: Send + Sync;

/// The value computed by the function.
type Output<'db>: fmt::Debug + Send + Sync + Update;
type Output<'db>: fmt::Debug + Send + Sync;

/// Determines whether this function can recover from being a participant in a cycle
/// (and, if so, how).
Expand Down
53 changes: 53 additions & 0 deletions src/update.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
hash::{BuildHasher, Hash},
marker::PhantomData,
path::PathBuf,
sync::Arc,
};
Expand Down Expand Up @@ -188,6 +189,29 @@ where
}
}

unsafe impl<A> Update for smallvec::SmallVec<A>
where
A: smallvec::Array,
A::Item: Update,
{
unsafe fn maybe_update(old_pointer: *mut Self, new_vec: Self) -> bool {
let old_vec: &mut smallvec::SmallVec<A> = unsafe { &mut *old_pointer };

if old_vec.len() != new_vec.len() {
old_vec.clear();
old_vec.extend(new_vec);
return true;
}

let mut changed = false;
for (old_element, new_element) in old_vec.iter_mut().zip(new_vec) {
changed |= A::Item::maybe_update(old_element, new_element);
}

changed
}
}

macro_rules! maybe_update_set {
($old_pointer: expr, $new_set: expr) => {{
let old_pointer = $old_pointer;
Expand Down Expand Up @@ -291,6 +315,26 @@ where
}
}

unsafe impl<T> Update for Box<[T]>
where
T: Update,
{
unsafe fn maybe_update(old_pointer: *mut Self, new_box: Self) -> bool {
let old_box: &mut Box<[T]> = unsafe { &mut *old_pointer };

if old_box.len() == new_box.len() {
let mut changed = false;
for (old_element, new_element) in old_box.iter_mut().zip(new_box) {
changed |= T::maybe_update(old_element, new_element);
}
changed
} else {
*old_box = new_box;
true
}
}
}

unsafe impl<T> Update for Arc<T>
where
T: Update,
Expand Down Expand Up @@ -398,6 +442,9 @@ fallback_impl! {
PathBuf,
}

#[cfg(feature = "compact_str")]
fallback_impl! { compact_str::CompactString, }

macro_rules! tuple_impl {
($($t:ident),*; $($u:ident),*) => {
unsafe impl<$($t),*> Update for ($($t,)*)
Expand Down Expand Up @@ -451,3 +498,9 @@ where
}
}
}

unsafe impl<T> Update for PhantomData<T> {
unsafe fn maybe_update(_old_pointer: *mut Self, _new_value: Self) -> bool {
false
}
}
1 change: 0 additions & 1 deletion tests/compile-fail/tracked_fn_return_ref.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use salsa::Database as Db;
use salsa::Update;

#[salsa::input]
struct MyInput {
Expand Down
52 changes: 15 additions & 37 deletions tests/compile-fail/tracked_fn_return_ref.stderr
Original file line number Diff line number Diff line change
@@ -1,42 +1,20 @@
warning: unused import: `salsa::Update`
--> tests/compile-fail/tracked_fn_return_ref.rs:2:5
|
2 | use salsa::Update;
| ^^^^^^^^^^^^^
|
= note: `#[warn(unused_imports)]` on by default

error[E0277]: the trait bound `&'db str: Update` is not satisfied
--> tests/compile-fail/tracked_fn_return_ref.rs:16:67
|
16 | fn tracked_fn_return_ref<'db>(db: &'db dyn Db, input: MyInput) -> &'db str {
| ^^^^^^^^ the trait `Update` is not implemented for `&'db str`
error: lifetime may not live long enough
--> tests/compile-fail/tracked_fn_return_ref.rs:14:1
|
= help: the trait `Update` is implemented for `String`
note: required by a bound in `salsa::plumbing::function::Configuration::Output`
--> src/function.rs
14 | #[salsa::tracked]
| ^^^^^^^^^^^^^^^^^ requires that `'db` must outlive `'static`
15 | fn tracked_fn_return_ref<'db>(db: &'db dyn Db, input: MyInput) -> &'db str {
| - lifetime `'db` defined here
|
| type Output<'db>: fmt::Debug + Send + Sync + Update;
| ^^^^^^ required by this bound in `Configuration::Output`
= note: this error originates in the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0277]: the trait bound `ContainsRef<'db>: Update` is not satisfied
--> tests/compile-fail/tracked_fn_return_ref.rs:24:6
|
24 | ) -> ContainsRef<'db> {
| ^^^^^^^^^^^^^^^^ the trait `Update` is not implemented for `ContainsRef<'db>`
error: lifetime may not live long enough
--> tests/compile-fail/tracked_fn_return_ref.rs:19:1
|
= help: the following other types implement trait `Update`:
()
(A, B)
(A, B, C)
(A, B, C, D)
(A, B, C, D, E)
(A, B, C, D, E, F)
(A, B, C, D, E, F, G)
(A, B, C, D, E, F, G, H)
and $N others
note: required by a bound in `salsa::plumbing::function::Configuration::Output`
--> src/function.rs
19 | #[salsa::tracked]
| ^^^^^^^^^^^^^^^^^ requires that `'db` must outlive `'static`
...
23 | ) -> ContainsRef<'db> {
| ----------- lifetime `'db` defined here
|
| type Output<'db>: fmt::Debug + Send + Sync + Update;
| ^^^^^^ required by this bound in `Configuration::Output`
= note: this error originates in the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info)
4 changes: 2 additions & 2 deletions tests/lru.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ use std::sync::{

mod common;
use common::LogDatabase;
use salsa::{Database as _, Update};
use salsa::Database as _;
use test_log::test;

#[derive(Debug, PartialEq, Eq, Update)]
#[derive(Debug, PartialEq, Eq)]
struct HotPotato(u32);

thread_local! {
Expand Down
56 changes: 56 additions & 0 deletions tests/tracked_struct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
mod common;

use salsa::{Database, Setter};

#[salsa::tracked]
struct Tracked<'db> {
untracked_1: usize,

untracked_2: usize,
}

#[salsa::input]
struct MyInput {
field1: usize,
field2: usize,
}

#[salsa::tracked]
fn intermediate(db: &dyn salsa::Database, input: MyInput) -> Tracked<'_> {
Tracked::new(db, input.field1(db), input.field2(db))
}

#[salsa::tracked]
fn accumulate(db: &dyn salsa::Database, input: MyInput) -> (usize, usize) {
let tracked = intermediate(db, input);
let one = read_tracked_1(db, tracked);
let two = read_tracked_2(db, tracked);

(one, two)
}

#[salsa::tracked]
fn read_tracked_1<'db>(db: &'db dyn Database, tracked: Tracked<'db>) -> usize {
tracked.untracked_1(db)
}

#[salsa::tracked]
fn read_tracked_2<'db>(db: &'db dyn Database, tracked: Tracked<'db>) -> usize {
tracked.untracked_2(db)
}

#[test]
fn execute() {
let mut db = salsa::DatabaseImpl::default();
let input = MyInput::new(&db, 1, 1);

assert_eq!(accumulate(&db, input), (1, 1));

// Should only re-execute `read_tracked_1`.
input.set_field1(&mut db).to(2);
assert_eq!(accumulate(&db, input), (2, 1));

// Should only re-execute `read_tracked_2`.
input.set_field2(&mut db).to(2);
assert_eq!(accumulate(&db, input), (2, 2));
}
4 changes: 1 addition & 3 deletions tests/warnings/needless_lifetimes.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use salsa::Update;

#[salsa::db]
pub trait Db: salsa::Database {}

#[derive(Debug, PartialEq, Eq, Hash, Update)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct Item {}

#[salsa::tracked]
Expand Down

0 comments on commit 351d9cf

Please sign in to comment.