Skip to content

Commit 0aebe9f

Browse files
committed
Allow overwriting tracked struct field update functions
1 parent c4cf0d9 commit 0aebe9f

File tree

9 files changed

+189
-25
lines changed

9 files changed

+189
-25
lines changed

Diff for: components/salsa-macro-rules/src/maybe_backdate.rs

+3-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
macro_rules! maybe_backdate {
44
(
55
($maybe_clone:ident, no_backdate, $maybe_default:ident),
6-
$field_ty:ty,
6+
$maybe_update:tt,
77
$old_field_place:expr,
88
$new_field_place:expr,
99
$revision_place:expr,
@@ -21,17 +21,14 @@ macro_rules! maybe_backdate {
2121

2222
(
2323
($maybe_clone:ident, backdate, $maybe_default:ident),
24-
$field_ty:ty,
24+
$maybe_update:tt,
2525
$old_field_place:expr,
2626
$new_field_place:expr,
2727
$revision_place:expr,
2828
$current_revision:expr,
2929
$zalsa:ident,
3030
) => {
31-
if $zalsa::UpdateDispatch::<$field_ty>::maybe_update(
32-
std::ptr::addr_of_mut!($old_field_place),
33-
$new_field_place,
34-
) {
31+
if $maybe_update(std::ptr::addr_of_mut!($old_field_place), $new_field_place) {
3532
$revision_place = $current_revision;
3633
}
3734
};

Diff for: components/salsa-macro-rules/src/setup_tracked_struct.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ macro_rules! setup_tracked_struct {
5050
// Absolute indices of any untracked fields.
5151
absolute_untracked_indices: [$($absolute_untracked_index:tt),*],
5252

53+
// Tracked field types.
54+
tracked_maybe_updates: [$($tracked_maybe_update:tt),*],
55+
56+
// Untracked field types.
57+
untracked_maybe_updates: [$($untracked_maybe_update:tt),*],
58+
5359
// A set of "field options" for each tracked field.
5460
//
5561
// Each field option is a tuple `(maybe_clone, maybe_backdate)` where:
@@ -152,7 +158,7 @@ macro_rules! setup_tracked_struct {
152158
$(
153159
$crate::maybe_backdate!(
154160
$tracked_option,
155-
$tracked_ty,
161+
$tracked_maybe_update,
156162
(*old_fields).$absolute_tracked_index,
157163
new_fields.$absolute_tracked_index,
158164
revisions[$relative_tracked_index],
@@ -164,7 +170,7 @@ macro_rules! setup_tracked_struct {
164170
// If any untracked field has changed, return `true`, indicating that the tracked struct
165171
// itself should be considered changed.
166172
$(
167-
$zalsa::UpdateDispatch::<$untracked_ty>::maybe_update(
173+
$untracked_maybe_update(
168174
&mut (*old_fields).$absolute_untracked_index,
169175
new_fields.$absolute_untracked_index,
170176
)

Diff for: components/salsa-macros/src/input.rs

+2
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ impl crate::options::AllowedOptions for InputStruct {
6464
impl SalsaStructAllowedOptions for InputStruct {
6565
const KIND: &'static str = "input";
6666

67+
const ALLOW_MAYBE_UPDATE: bool = false;
68+
6769
const ALLOW_TRACKED: bool = false;
6870

6971
const HAS_LIFETIME: bool = false;

Diff for: components/salsa-macros/src/interned.rs

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ impl crate::options::AllowedOptions for InternedStruct {
6565
impl SalsaStructAllowedOptions for InternedStruct {
6666
const KIND: &'static str = "interned";
6767

68+
const ALLOW_MAYBE_UPDATE: bool = false;
69+
6870
const ALLOW_TRACKED: bool = false;
6971

7072
const HAS_LIFETIME: bool = true;

Diff for: components/salsa-macros/src/lib.rs

+5
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,8 @@ pub(crate) fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error
9393
tokens.extend(TokenStream::from(error.into_compile_error()));
9494
tokens
9595
}
96+
97+
mod kw {
98+
syn::custom_keyword!(with);
99+
syn::custom_keyword!(maybe_update);
100+
}

Diff for: components/salsa-macros/src/salsa_struct.rs

+70-11
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use crate::{
3030
options::{AllowedOptions, Options},
3131
};
3232
use proc_macro2::{Ident, Literal, Span, TokenStream};
33+
use syn::parse::ParseStream;
3334

3435
pub(crate) struct SalsaStruct<'s, A: SalsaStructAllowedOptions> {
3536
struct_item: &'s syn::ItemStruct,
@@ -41,6 +42,9 @@ pub(crate) trait SalsaStructAllowedOptions: AllowedOptions {
4142
/// The kind of struct (e.g., interned, input, tracked).
4243
const KIND: &'static str;
4344

45+
/// Are `#[maybe_update]` fields allowed?
46+
const ALLOW_MAYBE_UPDATE: bool;
47+
4448
/// Are `#[tracked]` fields allowed?
4549
const ALLOW_TRACKED: bool;
4650

@@ -55,29 +59,54 @@ pub(crate) trait SalsaStructAllowedOptions: AllowedOptions {
5559
}
5660

5761
pub(crate) struct SalsaField<'s> {
58-
field: &'s syn::Field,
62+
pub(crate) field: &'s syn::Field,
5963

6064
pub(crate) has_tracked_attr: bool,
6165
pub(crate) has_default_attr: bool,
6266
pub(crate) has_ref_attr: bool,
6367
pub(crate) has_no_eq_attr: bool,
68+
pub(crate) maybe_update_attr: Option<(syn::Path, syn::Expr)>,
6469
get_name: syn::Ident,
6570
set_name: syn::Ident,
6671
}
6772

6873
const BANNED_FIELD_NAMES: &[&str] = &["from", "new"];
6974

7075
#[allow(clippy::type_complexity)]
71-
pub(crate) const FIELD_OPTION_ATTRIBUTES: &[(&str, fn(&syn::Attribute, &mut SalsaField))] = &[
72-
("tracked", |_, ef| ef.has_tracked_attr = true),
73-
("default", |_, ef| ef.has_default_attr = true),
74-
("return_ref", |_, ef| ef.has_ref_attr = true),
75-
("no_eq", |_, ef| ef.has_no_eq_attr = true),
76+
pub(crate) const FIELD_OPTION_ATTRIBUTES: &[(
77+
&str,
78+
fn(&syn::Attribute, &mut SalsaField) -> syn::Result<()>,
79+
)] = &[
80+
("tracked", |_, ef| {
81+
ef.has_tracked_attr = true;
82+
Ok(())
83+
}),
84+
("default", |_, ef| {
85+
ef.has_default_attr = true;
86+
Ok(())
87+
}),
88+
("return_ref", |_, ef| {
89+
ef.has_ref_attr = true;
90+
Ok(())
91+
}),
92+
("no_eq", |_, ef| {
93+
ef.has_no_eq_attr = true;
94+
Ok(())
95+
}),
7696
("get", |attr, ef| {
77-
ef.get_name = attr.parse_args().unwrap();
97+
ef.get_name = attr.parse_args()?;
98+
Ok(())
7899
}),
79100
("set", |attr, ef| {
80-
ef.set_name = attr.parse_args().unwrap();
101+
ef.set_name = attr.parse_args()?;
102+
Ok(())
103+
}),
104+
("maybe_update", |attr, ef| {
105+
ef.maybe_update_attr = Some(attr.parse_args_with(|parser: ParseStream| {
106+
let expr = parser.parse::<syn::Expr>()?;
107+
Ok((attr.path().clone(), expr))
108+
})?);
109+
Ok(())
81110
}),
82111
];
83112

@@ -105,6 +134,7 @@ where
105134
fields,
106135
};
107136

137+
this.maybe_disallow_maybe_update_fields()?;
108138
this.maybe_disallow_tracked_fields()?;
109139
this.maybe_disallow_default_fields()?;
110140

@@ -129,6 +159,34 @@ where
129159
}
130160
}
131161

162+
/// Disallow `#[tracked]` attributes on the fields of this struct.
163+
///
164+
/// If an `#[tracked]` field is found, return an error.
165+
///
166+
/// # Parameters
167+
///
168+
/// * `kind`, the attribute name (e.g., `input` or `interned`)
169+
fn maybe_disallow_maybe_update_fields(&self) -> syn::Result<()> {
170+
if A::ALLOW_MAYBE_UPDATE {
171+
return Ok(());
172+
}
173+
174+
// Check if any field has the `#[maybe_update]` attribute.
175+
for ef in &self.fields {
176+
if ef.maybe_update_attr.is_some() {
177+
return Err(syn::Error::new_spanned(
178+
ef.field,
179+
format!(
180+
"`#[maybe_update]` cannot be used with `#[salsa::{}]`",
181+
A::KIND
182+
),
183+
));
184+
}
185+
}
186+
187+
Ok(())
188+
}
189+
132190
/// Disallow `#[tracked]` attributes on the fields of this struct.
133191
///
134192
/// If an `#[tracked]` field is found, return an error.
@@ -337,14 +395,14 @@ where
337395
self.args.no_lifetime.is_none()
338396
}
339397

340-
fn tracked_fields_iter(&self) -> impl Iterator<Item = (usize, &SalsaField<'s>)> {
398+
pub fn tracked_fields_iter(&self) -> impl Iterator<Item = (usize, &SalsaField<'s>)> {
341399
self.fields
342400
.iter()
343401
.enumerate()
344402
.filter(|(_, f)| f.has_tracked_attr)
345403
}
346404

347-
fn untracked_fields_iter(&self) -> impl Iterator<Item = (usize, &SalsaField<'s>)> {
405+
pub fn untracked_fields_iter(&self) -> impl Iterator<Item = (usize, &SalsaField<'s>)> {
348406
self.fields
349407
.iter()
350408
.enumerate()
@@ -374,6 +432,7 @@ impl<'s> SalsaField<'s> {
374432
has_ref_attr: false,
375433
has_default_attr: false,
376434
has_no_eq_attr: false,
435+
maybe_update_attr: None,
377436
get_name,
378437
set_name,
379438
};
@@ -382,7 +441,7 @@ impl<'s> SalsaField<'s> {
382441
for attr in &field.attrs {
383442
for (fa, func) in FIELD_OPTION_ATTRIBUTES {
384443
if attr.path().is_ident(fa) {
385-
func(attr, &mut result);
444+
func(attr, &mut result)?;
386445
}
387446
}
388447
}

Diff for: components/salsa-macros/src/tracked_struct.rs

+24-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::{
55
salsa_struct::{SalsaStruct, SalsaStructAllowedOptions},
66
};
77
use proc_macro2::TokenStream;
8+
use syn::spanned::Spanned;
89

910
/// For an entity struct `Foo` with fields `f1: T1, ..., fN: TN`, we generate...
1011
///
@@ -59,6 +60,8 @@ impl crate::options::AllowedOptions for TrackedStruct {
5960
impl SalsaStructAllowedOptions for TrackedStruct {
6061
const KIND: &'static str = "tracked";
6162

63+
const ALLOW_MAYBE_UPDATE: bool = true;
64+
6265
const ALLOW_TRACKED: bool = true;
6366

6467
const HAS_LIFETIME: bool = true;
@@ -78,6 +81,7 @@ impl Macro {
7881
#[allow(non_snake_case)]
7982
fn try_macro(&self) -> syn::Result<TokenStream> {
8083
let salsa_struct = SalsaStruct::new(&self.struct_item, &self.args)?;
84+
let zalsa = self.hygiene.ident("zalsa");
8185

8286
let attrs = &self.struct_item.attrs;
8387
let vis = &self.struct_item.vis;
@@ -108,10 +112,26 @@ impl Macro {
108112
let tracked_tys = salsa_struct.tracked_tys();
109113
let untracked_tys = salsa_struct.untracked_tys();
110114

115+
let tracked_maybe_update = salsa_struct.tracked_fields_iter().map(|(_, field)| {
116+
let field_ty = &field.field.ty;
117+
if let Some((with_token, maybe_update)) = &field.maybe_update_attr {
118+
quote_spanned! { with_token.span() => ({ let maybe_update: unsafe fn(*mut #field_ty, #field_ty) -> bool = #maybe_update; maybe_update }) }
119+
} else {
120+
quote! {(#zalsa::UpdateDispatch::<#field_ty>::maybe_update)}
121+
}
122+
});
123+
let untracked_maybe_update = salsa_struct.untracked_fields_iter().map(|(_, field)| {
124+
let field_ty = &field.field.ty;
125+
if let Some((with_token, maybe_update)) = &field.maybe_update_attr {
126+
quote_spanned! { with_token.span() => ({ let maybe_update: unsafe fn(*mut #field_ty, #field_ty) -> bool = #maybe_update; maybe_update }) }
127+
} else {
128+
quote! {(#zalsa::UpdateDispatch::<#field_ty>::maybe_update)}
129+
}
130+
});
131+
111132
let num_tracked_fields = salsa_struct.num_tracked_fields();
112133
let generate_debug_impl = salsa_struct.generate_debug_impl();
113134

114-
let zalsa = self.hygiene.ident("zalsa");
115135
let zalsa_struct = self.hygiene.ident("zalsa_struct");
116136
let Configuration = self.hygiene.ident("Configuration");
117137
let CACHE = self.hygiene.ident("CACHE");
@@ -146,6 +166,9 @@ impl Macro {
146166

147167
absolute_untracked_indices: [#(#absolute_untracked_indices),*],
148168

169+
tracked_maybe_updates: [#(#tracked_maybe_update),*],
170+
untracked_maybe_updates: [#(#untracked_maybe_update),*],
171+
149172
tracked_options: [#(#tracked_options),*],
150173
untracked_options: [#(#untracked_options),*],
151174

Diff for: components/salsa-macros/src/update.rs

+1-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use proc_macro2::{Literal, Span, TokenStream};
22
use syn::{parenthesized, parse::ParseStream, spanned::Spanned, Token};
33
use synstructure::BindStyle;
44

5-
use crate::hygiene::Hygiene;
5+
use crate::{hygiene::Hygiene, kw};
66

77
pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream> {
88
let hygiene = Hygiene::from2(&input);
@@ -85,9 +85,6 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
8585

8686
let (maybe_update, unsafe_token) = match attr {
8787
Some(attr) => {
88-
mod kw {
89-
syn::custom_keyword!(with);
90-
}
9188
attr.parse_args_with(|parser: ParseStream| {
9289
let mut content;
9390

@@ -98,7 +95,6 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
9895
let expr = content.parse::<syn::Expr>()?;
9996
Ok((
10097
quote_spanned! { with_token.span() => ({ let maybe_update: unsafe fn(*mut #field_ty, #field_ty) -> bool = #expr; maybe_update }) },
101-
// quote_spanned! { with_token.span() => ((#expr) as unsafe fn(*mut #field_ty, #field_ty) -> bool) },
10298
unsafe_token,
10399
))
104100
})?

0 commit comments

Comments
 (0)