Skip to content

Commit f17c3eb

Browse files
committed
Rework validation of names and aliases for aggregate UDFs
Add better validation that `name = ...` and `alias = ...` line up when there is a mismatch between the `BasicUdf` and `AggregateUdf` implementation, and fix a problem that was disallowing use of aliases in aggregate UDFs. Error messages are significantly improved. Fixes <#59>
1 parent 8233525 commit f17c3eb

File tree

11 files changed

+527
-109
lines changed

11 files changed

+527
-109
lines changed

CHANGELOG.md

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,9 @@
44

55
## [Unreleased] - ReleaseDate
66

7-
### Added
8-
9-
### Changed
10-
11-
### Removed
12-
13-
14-
15-
## [0.5.4] - 2023-09-10
16-
17-
### Added
18-
19-
### Changed
20-
21-
### Removed
22-
7+
Rework the validation of names and aliases for aggregate UDFs. This fixes an
8+
issue where aliases could not be used for aggregate UDFs, and provides better
9+
error messages.
2310

2411

2512
## [0.5.4] - 2023-09-10

udf-macros/src/register.rs

Lines changed: 130 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#![allow(unused_imports)]
22

3+
use std::iter;
4+
35
use heck::AsSnakeCase;
46
use proc_macro::TokenStream;
57
use proc_macro2::{Span, TokenStream as TokenStream2};
@@ -8,24 +10,17 @@ use syn::parse::{Parse, ParseStream, Parser};
810
use syn::punctuated::Punctuated;
911
use syn::{
1012
parse_macro_input, parse_quote, DeriveInput, Error, Expr, ExprLit, Ident, ImplItem,
11-
ImplItemType, Item, ItemImpl, Lit, Meta, Path, PathSegment, Token, Type, TypePath,
13+
ImplItemType, Item, ItemImpl, Lit, LitStr, Meta, Path, PathSegment, Token, Type, TypePath,
1214
TypeReference,
1315
};
1416

1517
use crate::match_variant;
1618
use crate::types::{make_type_list, ImplType, RetType, TypeClass};
1719

18-
/// Create an identifier from another identifier, changing the name to snake case
19-
macro_rules! format_ident_str {
20-
($formatter:tt, $ident:ident) => {
21-
Ident::new(format!($formatter, $ident).as_str(), Span::call_site())
22-
};
23-
}
24-
2520
/// Verify that an `ItemImpl` matches the end of any given path
2621
///
2722
/// implements `BasicUdf` (in any of its pathing options)
28-
fn impls_path(itemimpl: &ItemImpl, expected: ImplType) -> bool {
23+
fn impl_type(itemimpl: &ItemImpl) -> Option<ImplType> {
2924
let implemented = &itemimpl.trait_.as_ref().unwrap().1.segments;
3025

3126
let basic_paths: [Punctuated<PathSegment, Token![::]>; 3] = [
@@ -39,9 +34,12 @@ fn impls_path(itemimpl: &ItemImpl, expected: ImplType) -> bool {
3934
parse_quote! {AggregateUdf},
4035
];
4136

42-
match expected {
43-
ImplType::Basic => basic_paths.contains(implemented),
44-
ImplType::Aggregate => arg_paths.contains(implemented),
37+
if basic_paths.contains(implemented) {
38+
Some(ImplType::Basic)
39+
} else if arg_paths.contains(implemented) {
40+
Some(ImplType::Aggregate)
41+
} else {
42+
None
4543
}
4644
}
4745

@@ -57,14 +55,11 @@ fn impls_path(itemimpl: &ItemImpl, expected: ImplType) -> bool {
5755
pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream {
5856
let parsed = parse_macro_input!(input as ItemImpl);
5957

60-
let impls_basic = impls_path(&parsed, ImplType::Basic);
61-
let impls_agg = impls_path(&parsed, ImplType::Aggregate);
62-
63-
if !(impls_basic || impls_agg) {
58+
let Some(impl_ty) = impl_type(&parsed) else {
6459
return Error::new_spanned(&parsed, "Expected trait `BasicUdf` or `AggregateUdf`")
6560
.into_compile_error()
6661
.into();
67-
}
62+
};
6863

6964
// Full type path of our data struct
7065
let Type::Path(dstruct_path) = parsed.self_ty.as_ref() else {
@@ -73,7 +68,7 @@ pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream {
7368
.into();
7469
};
7570

76-
let base_fn_names = match parse_args(args, dstruct_path) {
71+
let parsed_meta = match ParsedMeta::parse(args, dstruct_path) {
7772
Ok(v) => v,
7873
Err(e) => return e.into_compile_error().into(),
7974
};
@@ -89,91 +84,110 @@ pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream {
8984
Span::call_site(),
9085
);
9186

92-
let (ret_ty, wrapper_def) = if impls_basic {
93-
match get_rt_and_wrapper(&parsed, dstruct_path, &wrapper_ident) {
87+
let (ret_ty, wrapper_def) = match impl_ty {
88+
ImplType::Basic => match get_ret_ty_and_wrapper(&parsed, dstruct_path, &wrapper_ident) {
9489
Ok((r, w)) => (Some(r), w),
9590
Err(e) => return e.into_compile_error().into(),
96-
}
97-
} else {
98-
(None, TokenStream2::new())
91+
},
92+
ImplType::Aggregate => (None, TokenStream2::new()),
9993
};
10094

101-
let content_iter = base_fn_names.iter().map(|base_fn_name| {
102-
if impls_basic {
103-
make_basic_fns(
104-
ret_ty.as_ref().unwrap(),
105-
base_fn_name,
106-
dstruct_path,
107-
&wrapper_ident,
108-
)
109-
} else {
110-
make_agg_fns(&parsed, base_fn_name, dstruct_path, &wrapper_ident)
111-
}
95+
let helper_traits = make_helper_trait_impls(dstruct_path, &parsed_meta, impl_ty);
96+
97+
let fn_items_iter = parsed_meta.all_names().map(|base_fn_name| match impl_ty {
98+
ImplType::Basic => make_basic_fns(
99+
ret_ty.as_ref().unwrap(),
100+
base_fn_name,
101+
dstruct_path,
102+
&wrapper_ident,
103+
),
104+
ImplType::Aggregate => make_agg_fns(&parsed, base_fn_name, dstruct_path, &wrapper_ident),
112105
});
113106

114107
quote! {
115108
#parsed
116109

117110
#wrapper_def
118111

119-
#( #content_iter )*
112+
#helper_traits
113+
114+
#( #fn_items_iter )*
120115
}
121116
.into()
122117
}
123118

124-
/// Parse attribute arguments. Returns an iterator of names
125-
fn parse_args(args: &TokenStream, dstruct_path: &TypePath) -> syn::Result<Vec<String>> {
126-
let meta = Punctuated::<Meta, Token![,]>::parse_terminated.parse(args.clone())?;
127-
let mut base_fn_names: Vec<String> = vec![];
128-
let mut primary_name_specified = false;
129-
130-
for m in meta {
131-
let Meta::NameValue(mval) = m else {
132-
return Err(Error::new_spanned(m, "expected `a = b atributes`"));
133-
};
119+
/// Arguments we parse from metadata or default to
120+
struct ParsedMeta {
121+
name: String,
122+
aliases: Vec<String>,
123+
default_name_used: bool,
124+
}
134125

135-
if !mval.path.segments.iter().count() == 1 {
136-
return Err(Error::new_spanned(mval.path, "unexpected path"));
137-
}
126+
impl ParsedMeta {
127+
/// Parse attribute arguments. Returns an iterator of names
128+
fn parse(args: &TokenStream, dstruct_path: &TypePath) -> syn::Result<Self> {
129+
let meta = Punctuated::<Meta, Token![,]>::parse_terminated.parse(args.clone())?;
130+
let mut name_from_attributes = None;
131+
let mut aliases = Vec::new();
138132

139-
let key = mval.path.segments.first().unwrap();
133+
for m in meta {
134+
let Meta::NameValue(mval) = m else {
135+
return Err(Error::new_spanned(m, "expected `a = b atributes`"));
136+
};
140137

141-
let Expr::Lit(ExprLit {
142-
lit: Lit::Str(value),
143-
..
144-
}) = mval.value
145-
else {
146-
return Err(Error::new_spanned(mval.value, "expected a literal string"));
147-
};
138+
if !mval.path.segments.iter().count() == 1 {
139+
return Err(Error::new_spanned(mval.path, "unexpected path"));
140+
}
148141

149-
if key.ident == "name" {
150-
if primary_name_specified {
151-
return Err(Error::new_spanned(key, "`name` can only be specified once"));
142+
let key = mval.path.segments.first().unwrap();
143+
144+
let Expr::Lit(ExprLit {
145+
lit: Lit::Str(value),
146+
..
147+
}) = mval.value
148+
else {
149+
return Err(Error::new_spanned(mval.value, "expected a literal string"));
150+
};
151+
152+
if key.ident == "name" {
153+
if name_from_attributes.is_some() {
154+
return Err(Error::new_spanned(key, "`name` can only be specified once"));
155+
}
156+
name_from_attributes = Some(value.value());
157+
} else if key.ident == "alias" {
158+
aliases.push(value.value());
159+
} else {
160+
return Err(Error::new_spanned(
161+
key,
162+
"unexpected key (only `name` and `alias` are accepted)",
163+
));
152164
}
153-
base_fn_names.push(value.value());
154-
primary_name_specified = true;
155-
} else if key.ident == "alias" {
156-
base_fn_names.push(value.value());
157-
} else {
158-
return Err(Error::new_spanned(
159-
key,
160-
"unexpected key (only `name` and `alias` are accepted)",
161-
));
162165
}
163-
}
164166

165-
if !primary_name_specified {
166-
// If we don't have a name specified, use the type name as snake case
167-
let ty_ident = &dstruct_path.path.segments.last().unwrap().ident;
168-
let fn_name = AsSnakeCase(&ty_ident.to_string()).to_string();
169-
base_fn_names.push(fn_name);
167+
let mut default_name_used = false;
168+
let name = name_from_attributes.unwrap_or_else(|| {
169+
// If we don't have a name specified, use the type name as snake case
170+
let ty_ident = &dstruct_path.path.segments.last().unwrap().ident;
171+
let fn_name = AsSnakeCase(&ty_ident.to_string()).to_string();
172+
default_name_used = true;
173+
fn_name
174+
});
175+
176+
Ok(Self {
177+
name,
178+
aliases,
179+
default_name_used,
180+
})
170181
}
171182

172-
Ok(base_fn_names)
183+
/// Iterate the basic name and all aliases
184+
fn all_names(&self) -> impl Iterator<Item = &String> {
185+
iter::once(&self.name).chain(self.aliases.iter())
186+
}
173187
}
174188

175189
/// Get the return type to use and a wrapper. Once per impl setup.
176-
fn get_rt_and_wrapper(
190+
fn get_ret_ty_and_wrapper(
177191
parsed: &ItemImpl,
178192
dstruct_path: &TypePath,
179193
wrapper_ident: &Ident,
@@ -209,16 +223,50 @@ fn get_rt_and_wrapper(
209223
Ok((ret_ty, wrapper_struct))
210224
}
211225

226+
/// Make implementations for our helper/metadata traits
227+
fn make_helper_trait_impls(
228+
dstruct_path: &TypePath,
229+
meta: &ParsedMeta,
230+
impl_ty: ImplType,
231+
) -> TokenStream2 {
232+
let name = LitStr::new(&meta.name, Span::call_site());
233+
let aliases = meta
234+
.aliases
235+
.iter()
236+
.map(|alias| LitStr::new(alias.as_ref(), Span::call_site()));
237+
let (trait_name, check_expr) = match impl_ty {
238+
ImplType::Basic => (
239+
quote! { ::udf::wrapper::RegisteredBasicUdf },
240+
TokenStream2::new(),
241+
),
242+
ImplType::Aggregate => (
243+
quote! { ::udf::wrapper::RegisteredAggregateUdf },
244+
quote! { const _: () = ::udf::wrapper::verify_aggregate_attributes::<#dstruct_path>(); },
245+
),
246+
};
247+
let default_name_used = meta.default_name_used;
248+
249+
quote! {
250+
impl #trait_name for #dstruct_path {
251+
const NAME: &'static str = #name;
252+
const ALIASES: &'static [&'static str] = &[#( #aliases ),*];
253+
const DEFAULT_NAME_USED: bool = #default_name_used;
254+
}
255+
256+
#check_expr
257+
}
258+
}
259+
212260
/// Create the basic function signatures (`xxx_init`, `xxx_deinit`, `xxx`)
213261
fn make_basic_fns(
214262
rt: &RetType,
215263
base_fn_name: &str,
216264
dstruct_path: &TypePath,
217265
wrapper_ident: &Ident,
218266
) -> TokenStream2 {
219-
let init_fn_name = format_ident_str!("{}_init", base_fn_name);
220-
let deinit_fn_name = format_ident_str!("{}_deinit", base_fn_name);
221-
let process_fn_name = format_ident_str!("{}", base_fn_name);
267+
let init_fn_name = format_ident!("{}_init", base_fn_name);
268+
let deinit_fn_name = format_ident!("{}_deinit", base_fn_name);
269+
let process_fn_name = format_ident!("{}", base_fn_name);
222270

223271
let init_fn = make_init_fn(dstruct_path, wrapper_ident, &init_fn_name);
224272
let deinit_fn = make_deinit_fn(dstruct_path, wrapper_ident, &deinit_fn_name);
@@ -269,9 +317,9 @@ fn make_agg_fns(
269317
dstruct_path: &TypePath, // Name of the data structure
270318
wrapper_ident: &Ident,
271319
) -> TokenStream2 {
272-
let clear_fn_name = format_ident_str!("{}_clear", base_fn_name);
273-
let add_fn_name = format_ident_str!("{}_add", base_fn_name);
274-
let remove_fn_name = format_ident_str!("{}_remove", base_fn_name);
320+
let clear_fn_name = format_ident!("{}_clear", base_fn_name);
321+
let add_fn_name = format_ident!("{}_add", base_fn_name);
322+
let remove_fn_name = format_ident!("{}_remove", base_fn_name);
275323

276324
// Determine whether this re-implements `remove`
277325
let impls_remove = &parsed
@@ -280,7 +328,6 @@ fn make_agg_fns(
280328
.filter_map(match_variant!(ImplItem::Fn))
281329
.map(|m| &m.sig.ident)
282330
.any(|id| *id == "remove");
283-
let base_fn_ident = Ident::new(base_fn_name, Span::call_site());
284331

285332
let clear_fn = make_clear_fn(dstruct_path, wrapper_ident, &clear_fn_name);
286333
let add_fn = make_add_fn(dstruct_path, wrapper_ident, &add_fn_name);
@@ -295,10 +342,6 @@ fn make_agg_fns(
295342
};
296343

297344
quote! {
298-
// Sanity check that we implemented
299-
#[allow(dead_code, non_upper_case_globals)]
300-
const did_you_apply_the_same_aliases_to_the_BasicUdf_impl: *const () = #base_fn_ident as _;
301-
302345
#clear_fn
303346

304347
#add_fn
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#![allow(unused)]
2+
3+
use udf::prelude::*;
4+
5+
struct MyUdf;
6+
7+
impl AggregateUdf for MyUdf {
8+
// Required methods
9+
fn clear(&mut self, cfg: &UdfCfg<Process>, error: Option<NonZeroU8>) -> Result<(), NonZeroU8> {
10+
todo!()
11+
}
12+
fn add(
13+
&mut self,
14+
cfg: &UdfCfg<Process>,
15+
args: &ArgList<'_, Process>,
16+
error: Option<NonZeroU8>,
17+
) -> Result<(), NonZeroU8> {
18+
todo!()
19+
}
20+
}
21+
22+
fn main() {}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
error[E0277]: the trait bound `MyUdf: BasicUdf` is not satisfied
2+
--> tests/fail/agg_missing_basic.rs:7:23
3+
|
4+
7 | impl AggregateUdf for MyUdf {
5+
| ^^^^^ the trait `BasicUdf` is not implemented for `MyUdf`
6+
|
7+
note: required by a bound in `udf::AggregateUdf`
8+
--> $WORKSPACE/udf/src/traits.rs
9+
|
10+
| pub trait AggregateUdf: BasicUdf {
11+
| ^^^^^^^^ required by this bound in `AggregateUdf`

0 commit comments

Comments
 (0)