Skip to content

Commit 351d9cf

Browse files
authored
Merge pull request #680 from MichaReiser/micha/tracked-fn-update-fallback
Use *fallback* trick for tracked-fn `Update` constraint
2 parents 538eaad + c48a9df commit 351d9cf

File tree

12 files changed

+174
-56
lines changed

12 files changed

+174
-56
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ rust-version = "1.76"
1010

1111
[dependencies]
1212
arc-swap = "1"
13+
compact_str = { version = "0.8", optional = true }
1314
crossbeam = "0.8"
1415
dashmap = { version = "6", features = ["raw-api"] }
1516
hashlink = "0.9"

components/salsa-macro-rules/src/setup_tracked_fn.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ macro_rules! setup_tracked_fn {
5555
// True if we `return_ref` flag was given to the function
5656
return_ref: $return_ref:tt,
5757

58+
maybe_update_fn: {$($maybe_update_fn:tt)*},
59+
5860
// Annoyingly macro-rules hygiene does not extend to items defined in the macro.
5961
// We have the procedural macro generate names for those items that are
6062
// not used elsewhere in the user's code.
@@ -145,6 +147,14 @@ macro_rules! setup_tracked_fn {
145147
}
146148
}
147149

150+
/// This method isn't used anywhere. It only exitst to enforce the `Self::Output: Update` constraint
151+
/// for types that aren't `'static`.
152+
///
153+
/// # Safety
154+
/// The same safety rules as for `Update` apply.
155+
$($maybe_update_fn)*
156+
157+
148158
impl $zalsa::function::Configuration for $Configuration {
149159
const DEBUG_NAME: &'static str = stringify!($fn_name);
150160

components/salsa-macros/src/tracked_fn.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,18 @@ impl Macro {
117117

118118
let return_ref: bool = self.args.return_ref.is_some();
119119

120+
let maybe_update_fn = quote_spanned! {output_ty.span()=> {
121+
#[allow(clippy::all, unsafe_code)]
122+
unsafe fn _maybe_update_fn<'db>(old_pointer: *mut #output_ty, new_value: #output_ty) -> bool {
123+
unsafe {
124+
use #zalsa::UpdateFallback;
125+
#zalsa::UpdateDispatch::<#output_ty>::maybe_update(
126+
old_pointer, new_value
127+
)
128+
}
129+
}
130+
}};
131+
120132
Ok(crate::debug::dump_tokens(
121133
fn_name,
122134
quote![salsa::plumbing::setup_tracked_fn! {
@@ -137,6 +149,7 @@ impl Macro {
137149
needs_interner: #needs_interner,
138150
lru: #lru,
139151
return_ref: #return_ref,
152+
maybe_update_fn: { #maybe_update_fn },
140153
unused_names: [
141154
#zalsa,
142155
#Configuration,

components/salsa-macros/src/update.rs

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use proc_macro2::{Literal, TokenStream};
2+
use syn::spanned::Spanned;
23
use synstructure::BindStyle;
34

45
use crate::hygiene::Hygiene;
@@ -34,7 +35,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
3435
.bindings()
3536
.iter()
3637
.fold(quote!(), |tokens, binding| quote!(#tokens #binding,));
37-
let make_new_value = quote! {
38+
let make_new_value = quote_spanned! {variant.ast().ident.span()=>
3839
let #new_value = if let #variant_pat = #new_value {
3940
(#make_tuple)
4041
} else {
@@ -46,20 +47,28 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
4647
// For each field, invoke `maybe_update` recursively to update its value.
4748
// Or the results together (using `|`, not `||`, to avoid shortcircuiting)
4849
// to get the final return value.
49-
let update_fields = variant.bindings().iter().zip(0..).fold(
50+
let update_fields = variant.bindings().iter().enumerate().fold(
5051
quote!(false),
51-
|tokens, (binding, index)| {
52+
|tokens, (index, binding)| {
5253
let field_ty = &binding.ast().ty;
5354
let field_index = Literal::usize_unsuffixed(index);
5455

56+
let field_span = binding
57+
.ast()
58+
.ident
59+
.as_ref()
60+
.map(Spanned::span)
61+
.unwrap_or(binding.ast().span());
62+
63+
let update_field = quote_spanned! {field_span=>
64+
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update(
65+
#binding,
66+
#new_value.#field_index,
67+
)
68+
};
69+
5570
quote! {
56-
#tokens |
57-
unsafe {
58-
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update(
59-
#binding,
60-
#new_value.#field_index,
61-
)
62-
}
71+
#tokens | unsafe { #update_field }
6372
}
6473
},
6574
);
@@ -77,6 +86,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
7786
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
7887
let tokens = quote! {
7988
#[allow(clippy::all)]
89+
#[automatically_derived]
8090
unsafe impl #impl_generics salsa::Update for #ident #ty_generics #where_clause {
8191
unsafe fn maybe_update(#old_pointer: *mut Self, #new_value: Self) -> bool {
8292
use ::salsa::plumbing::UpdateFallback as _;

examples/calc/ir.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ pub enum ExpressionData<'db> {
6565
Call(FunctionId<'db>, Vec<Expression<'db>>),
6666
}
6767

68-
#[derive(Eq, PartialEq, Copy, Clone, Hash, Debug, salsa::Update)]
68+
#[derive(Eq, PartialEq, Copy, Clone, Hash, Debug)]
6969
pub enum Op {
7070
Add,
7171
Subtract,

src/function.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::{
99
salsa_struct::SalsaStructInDb,
1010
zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa},
1111
zalsa_local::QueryOrigin,
12-
Cycle, Database, Id, Revision, Update,
12+
Cycle, Database, Id, Revision,
1313
};
1414

1515
use self::delete::DeletedEntries;
@@ -43,7 +43,7 @@ pub trait Configuration: Any {
4343
type Input<'db>: Send + Sync;
4444

4545
/// The value computed by the function.
46-
type Output<'db>: fmt::Debug + Send + Sync + Update;
46+
type Output<'db>: fmt::Debug + Send + Sync;
4747

4848
/// Determines whether this function can recover from being a participant in a cycle
4949
/// (and, if so, how).

src/update.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::{
22
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
33
hash::{BuildHasher, Hash},
4+
marker::PhantomData,
45
path::PathBuf,
56
sync::Arc,
67
};
@@ -188,6 +189,29 @@ where
188189
}
189190
}
190191

192+
unsafe impl<A> Update for smallvec::SmallVec<A>
193+
where
194+
A: smallvec::Array,
195+
A::Item: Update,
196+
{
197+
unsafe fn maybe_update(old_pointer: *mut Self, new_vec: Self) -> bool {
198+
let old_vec: &mut smallvec::SmallVec<A> = unsafe { &mut *old_pointer };
199+
200+
if old_vec.len() != new_vec.len() {
201+
old_vec.clear();
202+
old_vec.extend(new_vec);
203+
return true;
204+
}
205+
206+
let mut changed = false;
207+
for (old_element, new_element) in old_vec.iter_mut().zip(new_vec) {
208+
changed |= A::Item::maybe_update(old_element, new_element);
209+
}
210+
211+
changed
212+
}
213+
}
214+
191215
macro_rules! maybe_update_set {
192216
($old_pointer: expr, $new_set: expr) => {{
193217
let old_pointer = $old_pointer;
@@ -291,6 +315,26 @@ where
291315
}
292316
}
293317

318+
unsafe impl<T> Update for Box<[T]>
319+
where
320+
T: Update,
321+
{
322+
unsafe fn maybe_update(old_pointer: *mut Self, new_box: Self) -> bool {
323+
let old_box: &mut Box<[T]> = unsafe { &mut *old_pointer };
324+
325+
if old_box.len() == new_box.len() {
326+
let mut changed = false;
327+
for (old_element, new_element) in old_box.iter_mut().zip(new_box) {
328+
changed |= T::maybe_update(old_element, new_element);
329+
}
330+
changed
331+
} else {
332+
*old_box = new_box;
333+
true
334+
}
335+
}
336+
}
337+
294338
unsafe impl<T> Update for Arc<T>
295339
where
296340
T: Update,
@@ -398,6 +442,9 @@ fallback_impl! {
398442
PathBuf,
399443
}
400444

445+
#[cfg(feature = "compact_str")]
446+
fallback_impl! { compact_str::CompactString, }
447+
401448
macro_rules! tuple_impl {
402449
($($t:ident),*; $($u:ident),*) => {
403450
unsafe impl<$($t),*> Update for ($($t,)*)
@@ -451,3 +498,9 @@ where
451498
}
452499
}
453500
}
501+
502+
unsafe impl<T> Update for PhantomData<T> {
503+
unsafe fn maybe_update(_old_pointer: *mut Self, _new_value: Self) -> bool {
504+
false
505+
}
506+
}

tests/compile-fail/tracked_fn_return_ref.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use salsa::Database as Db;
2-
use salsa::Update;
32

43
#[salsa::input]
54
struct MyInput {
Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,20 @@
1-
warning: unused import: `salsa::Update`
2-
--> tests/compile-fail/tracked_fn_return_ref.rs:2:5
3-
|
4-
2 | use salsa::Update;
5-
| ^^^^^^^^^^^^^
6-
|
7-
= note: `#[warn(unused_imports)]` on by default
8-
9-
error[E0277]: the trait bound `&'db str: Update` is not satisfied
10-
--> tests/compile-fail/tracked_fn_return_ref.rs:16:67
11-
|
12-
16 | fn tracked_fn_return_ref<'db>(db: &'db dyn Db, input: MyInput) -> &'db str {
13-
| ^^^^^^^^ the trait `Update` is not implemented for `&'db str`
1+
error: lifetime may not live long enough
2+
--> tests/compile-fail/tracked_fn_return_ref.rs:14:1
143
|
15-
= help: the trait `Update` is implemented for `String`
16-
note: required by a bound in `salsa::plumbing::function::Configuration::Output`
17-
--> src/function.rs
4+
14 | #[salsa::tracked]
5+
| ^^^^^^^^^^^^^^^^^ requires that `'db` must outlive `'static`
6+
15 | fn tracked_fn_return_ref<'db>(db: &'db dyn Db, input: MyInput) -> &'db str {
7+
| - lifetime `'db` defined here
188
|
19-
| type Output<'db>: fmt::Debug + Send + Sync + Update;
20-
| ^^^^^^ required by this bound in `Configuration::Output`
9+
= note: this error originates in the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info)
2110

22-
error[E0277]: the trait bound `ContainsRef<'db>: Update` is not satisfied
23-
--> tests/compile-fail/tracked_fn_return_ref.rs:24:6
24-
|
25-
24 | ) -> ContainsRef<'db> {
26-
| ^^^^^^^^^^^^^^^^ the trait `Update` is not implemented for `ContainsRef<'db>`
11+
error: lifetime may not live long enough
12+
--> tests/compile-fail/tracked_fn_return_ref.rs:19:1
2713
|
28-
= help: the following other types implement trait `Update`:
29-
()
30-
(A, B)
31-
(A, B, C)
32-
(A, B, C, D)
33-
(A, B, C, D, E)
34-
(A, B, C, D, E, F)
35-
(A, B, C, D, E, F, G)
36-
(A, B, C, D, E, F, G, H)
37-
and $N others
38-
note: required by a bound in `salsa::plumbing::function::Configuration::Output`
39-
--> src/function.rs
14+
19 | #[salsa::tracked]
15+
| ^^^^^^^^^^^^^^^^^ requires that `'db` must outlive `'static`
16+
...
17+
23 | ) -> ContainsRef<'db> {
18+
| ----------- lifetime `'db` defined here
4019
|
41-
| type Output<'db>: fmt::Debug + Send + Sync + Update;
42-
| ^^^^^^ required by this bound in `Configuration::Output`
20+
= note: this error originates in the attribute macro `salsa::tracked` (in Nightly builds, run with -Z macro-backtrace for more info)

tests/lru.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ use std::sync::{
88

99
mod common;
1010
use common::LogDatabase;
11-
use salsa::{Database as _, Update};
11+
use salsa::Database as _;
1212
use test_log::test;
1313

14-
#[derive(Debug, PartialEq, Eq, Update)]
14+
#[derive(Debug, PartialEq, Eq)]
1515
struct HotPotato(u32);
1616

1717
thread_local! {

tests/tracked_struct.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
mod common;
2+
3+
use salsa::{Database, Setter};
4+
5+
#[salsa::tracked]
6+
struct Tracked<'db> {
7+
untracked_1: usize,
8+
9+
untracked_2: usize,
10+
}
11+
12+
#[salsa::input]
13+
struct MyInput {
14+
field1: usize,
15+
field2: usize,
16+
}
17+
18+
#[salsa::tracked]
19+
fn intermediate(db: &dyn salsa::Database, input: MyInput) -> Tracked<'_> {
20+
Tracked::new(db, input.field1(db), input.field2(db))
21+
}
22+
23+
#[salsa::tracked]
24+
fn accumulate(db: &dyn salsa::Database, input: MyInput) -> (usize, usize) {
25+
let tracked = intermediate(db, input);
26+
let one = read_tracked_1(db, tracked);
27+
let two = read_tracked_2(db, tracked);
28+
29+
(one, two)
30+
}
31+
32+
#[salsa::tracked]
33+
fn read_tracked_1<'db>(db: &'db dyn Database, tracked: Tracked<'db>) -> usize {
34+
tracked.untracked_1(db)
35+
}
36+
37+
#[salsa::tracked]
38+
fn read_tracked_2<'db>(db: &'db dyn Database, tracked: Tracked<'db>) -> usize {
39+
tracked.untracked_2(db)
40+
}
41+
42+
#[test]
43+
fn execute() {
44+
let mut db = salsa::DatabaseImpl::default();
45+
let input = MyInput::new(&db, 1, 1);
46+
47+
assert_eq!(accumulate(&db, input), (1, 1));
48+
49+
// Should only re-execute `read_tracked_1`.
50+
input.set_field1(&mut db).to(2);
51+
assert_eq!(accumulate(&db, input), (2, 1));
52+
53+
// Should only re-execute `read_tracked_2`.
54+
input.set_field2(&mut db).to(2);
55+
assert_eq!(accumulate(&db, input), (2, 2));
56+
}

tests/warnings/needless_lifetimes.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
use salsa::Update;
2-
31
#[salsa::db]
42
pub trait Db: salsa::Database {}
53

6-
#[derive(Debug, PartialEq, Eq, Hash, Update)]
4+
#[derive(Debug, PartialEq, Eq, Hash)]
75
pub struct Item {}
86

97
#[salsa::tracked]

0 commit comments

Comments
 (0)