11#![ allow( unused_imports) ]
22
3+ use std:: iter;
4+
35use heck:: AsSnakeCase ;
46use proc_macro:: TokenStream ;
57use proc_macro2:: { Span , TokenStream as TokenStream2 } ;
@@ -8,24 +10,17 @@ use syn::parse::{Parse, ParseStream, Parser};
810use syn:: punctuated:: Punctuated ;
911use syn:: {
1012 parse_macro_input, parse_quote, DeriveInput , Error , Expr , ExprLit , Ident , ImplItem ,
11- ImplItemType , Item , ItemImpl , Lit , Meta , Path , PathSegment , Token , Type , TypePath ,
13+ ImplItemType , Item , ItemImpl , Lit , LitStr , Meta , Path , PathSegment , Token , Type , TypePath ,
1214 TypeReference ,
1315} ;
1416
1517use crate :: match_variant;
1618use crate :: types:: { make_type_list, ImplType , RetType , TypeClass } ;
1719
18- /// Create an identifier from another identifier, changing the name to snake case
19- macro_rules! format_ident_str {
20- ( $formatter: tt, $ident: ident) => {
21- Ident :: new( format!( $formatter, $ident) . as_str( ) , Span :: call_site( ) )
22- } ;
23- }
24-
2520/// Verify that an `ItemImpl` matches the end of any given path
2621///
2722/// implements `BasicUdf` (in any of its pathing options)
28- fn impls_path ( itemimpl : & ItemImpl , expected : ImplType ) -> bool {
23+ fn impl_type ( itemimpl : & ItemImpl ) -> Option < ImplType > {
2924 let implemented = & itemimpl. trait_ . as_ref ( ) . unwrap ( ) . 1 . segments ;
3025
3126 let basic_paths: [ Punctuated < PathSegment , Token ! [ :: ] > ; 3 ] = [
@@ -39,9 +34,12 @@ fn impls_path(itemimpl: &ItemImpl, expected: ImplType) -> bool {
3934 parse_quote ! { AggregateUdf } ,
4035 ] ;
4136
42- match expected {
43- ImplType :: Basic => basic_paths. contains ( implemented) ,
44- ImplType :: Aggregate => arg_paths. contains ( implemented) ,
37+ if basic_paths. contains ( implemented) {
38+ Some ( ImplType :: Basic )
39+ } else if arg_paths. contains ( implemented) {
40+ Some ( ImplType :: Aggregate )
41+ } else {
42+ None
4543 }
4644}
4745
@@ -57,14 +55,11 @@ fn impls_path(itemimpl: &ItemImpl, expected: ImplType) -> bool {
5755pub fn register ( args : & TokenStream , input : TokenStream ) -> TokenStream {
5856 let parsed = parse_macro_input ! ( input as ItemImpl ) ;
5957
60- let impls_basic = impls_path ( & parsed, ImplType :: Basic ) ;
61- let impls_agg = impls_path ( & parsed, ImplType :: Aggregate ) ;
62-
63- if !( impls_basic || impls_agg) {
58+ let Some ( impl_ty) = impl_type ( & parsed) else {
6459 return Error :: new_spanned ( & parsed, "Expected trait `BasicUdf` or `AggregateUdf`" )
6560 . into_compile_error ( )
6661 . into ( ) ;
67- }
62+ } ;
6863
6964 // Full type path of our data struct
7065 let Type :: Path ( dstruct_path) = parsed. self_ty . as_ref ( ) else {
@@ -73,7 +68,7 @@ pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream {
7368 . into ( ) ;
7469 } ;
7570
76- let base_fn_names = match parse_args ( args, dstruct_path) {
71+ let parsed_meta = match ParsedMeta :: parse ( args, dstruct_path) {
7772 Ok ( v) => v,
7873 Err ( e) => return e. into_compile_error ( ) . into ( ) ,
7974 } ;
@@ -89,91 +84,110 @@ pub fn register(args: &TokenStream, input: TokenStream) -> TokenStream {
8984 Span :: call_site ( ) ,
9085 ) ;
9186
92- let ( ret_ty, wrapper_def) = if impls_basic {
93- match get_rt_and_wrapper ( & parsed, dstruct_path, & wrapper_ident) {
87+ let ( ret_ty, wrapper_def) = match impl_ty {
88+ ImplType :: Basic => match get_ret_ty_and_wrapper ( & parsed, dstruct_path, & wrapper_ident) {
9489 Ok ( ( r, w) ) => ( Some ( r) , w) ,
9590 Err ( e) => return e. into_compile_error ( ) . into ( ) ,
96- }
97- } else {
98- ( None , TokenStream2 :: new ( ) )
91+ } ,
92+ ImplType :: Aggregate => ( None , TokenStream2 :: new ( ) ) ,
9993 } ;
10094
101- let content_iter = base_fn_names. iter ( ) . map ( |base_fn_name| {
102- if impls_basic {
103- make_basic_fns (
104- ret_ty. as_ref ( ) . unwrap ( ) ,
105- base_fn_name,
106- dstruct_path,
107- & wrapper_ident,
108- )
109- } else {
110- make_agg_fns ( & parsed, base_fn_name, dstruct_path, & wrapper_ident)
111- }
95+ let helper_traits = make_helper_trait_impls ( dstruct_path, & parsed_meta, impl_ty) ;
96+
97+ let fn_items_iter = parsed_meta. all_names ( ) . map ( |base_fn_name| match impl_ty {
98+ ImplType :: Basic => make_basic_fns (
99+ ret_ty. as_ref ( ) . unwrap ( ) ,
100+ base_fn_name,
101+ dstruct_path,
102+ & wrapper_ident,
103+ ) ,
104+ ImplType :: Aggregate => make_agg_fns ( & parsed, base_fn_name, dstruct_path, & wrapper_ident) ,
112105 } ) ;
113106
114107 quote ! {
115108 #parsed
116109
117110 #wrapper_def
118111
119- #( #content_iter ) *
112+ #helper_traits
113+
114+ #( #fn_items_iter ) *
120115 }
121116 . into ( )
122117}
123118
124- /// Parse attribute arguments. Returns an iterator of names
125- fn parse_args ( args : & TokenStream , dstruct_path : & TypePath ) -> syn:: Result < Vec < String > > {
126- let meta = Punctuated :: < Meta , Token ! [ , ] > :: parse_terminated. parse ( args. clone ( ) ) ?;
127- let mut base_fn_names: Vec < String > = vec ! [ ] ;
128- let mut primary_name_specified = false ;
129-
130- for m in meta {
131- let Meta :: NameValue ( mval) = m else {
132- return Err ( Error :: new_spanned ( m, "expected `a = b atributes`" ) ) ;
133- } ;
119+ /// Arguments we parse from metadata or default to
120+ struct ParsedMeta {
121+ name : String ,
122+ aliases : Vec < String > ,
123+ default_name_used : bool ,
124+ }
134125
135- if !mval. path . segments . iter ( ) . count ( ) == 1 {
136- return Err ( Error :: new_spanned ( mval. path , "unexpected path" ) ) ;
137- }
126+ impl ParsedMeta {
127+ /// Parse attribute arguments. Returns an iterator of names
128+ fn parse ( args : & TokenStream , dstruct_path : & TypePath ) -> syn:: Result < Self > {
129+ let meta = Punctuated :: < Meta , Token ! [ , ] > :: parse_terminated. parse ( args. clone ( ) ) ?;
130+ let mut name_from_attributes = None ;
131+ let mut aliases = Vec :: new ( ) ;
138132
139- let key = mval. path . segments . first ( ) . unwrap ( ) ;
133+ for m in meta {
134+ let Meta :: NameValue ( mval) = m else {
135+ return Err ( Error :: new_spanned ( m, "expected `a = b atributes`" ) ) ;
136+ } ;
140137
141- let Expr :: Lit ( ExprLit {
142- lit : Lit :: Str ( value) ,
143- ..
144- } ) = mval. value
145- else {
146- return Err ( Error :: new_spanned ( mval. value , "expected a literal string" ) ) ;
147- } ;
138+ if !mval. path . segments . iter ( ) . count ( ) == 1 {
139+ return Err ( Error :: new_spanned ( mval. path , "unexpected path" ) ) ;
140+ }
148141
149- if key. ident == "name" {
150- if primary_name_specified {
151- return Err ( Error :: new_spanned ( key, "`name` can only be specified once" ) ) ;
142+ let key = mval. path . segments . first ( ) . unwrap ( ) ;
143+
144+ let Expr :: Lit ( ExprLit {
145+ lit : Lit :: Str ( value) ,
146+ ..
147+ } ) = mval. value
148+ else {
149+ return Err ( Error :: new_spanned ( mval. value , "expected a literal string" ) ) ;
150+ } ;
151+
152+ if key. ident == "name" {
153+ if name_from_attributes. is_some ( ) {
154+ return Err ( Error :: new_spanned ( key, "`name` can only be specified once" ) ) ;
155+ }
156+ name_from_attributes = Some ( value. value ( ) ) ;
157+ } else if key. ident == "alias" {
158+ aliases. push ( value. value ( ) ) ;
159+ } else {
160+ return Err ( Error :: new_spanned (
161+ key,
162+ "unexpected key (only `name` and `alias` are accepted)" ,
163+ ) ) ;
152164 }
153- base_fn_names. push ( value. value ( ) ) ;
154- primary_name_specified = true ;
155- } else if key. ident == "alias" {
156- base_fn_names. push ( value. value ( ) ) ;
157- } else {
158- return Err ( Error :: new_spanned (
159- key,
160- "unexpected key (only `name` and `alias` are accepted)" ,
161- ) ) ;
162165 }
163- }
164166
165- if !primary_name_specified {
166- // If we don't have a name specified, use the type name as snake case
167- let ty_ident = & dstruct_path. path . segments . last ( ) . unwrap ( ) . ident ;
168- let fn_name = AsSnakeCase ( & ty_ident. to_string ( ) ) . to_string ( ) ;
169- base_fn_names. push ( fn_name) ;
167+ let mut default_name_used = false ;
168+ let name = name_from_attributes. unwrap_or_else ( || {
169+ // If we don't have a name specified, use the type name as snake case
170+ let ty_ident = & dstruct_path. path . segments . last ( ) . unwrap ( ) . ident ;
171+ let fn_name = AsSnakeCase ( & ty_ident. to_string ( ) ) . to_string ( ) ;
172+ default_name_used = true ;
173+ fn_name
174+ } ) ;
175+
176+ Ok ( Self {
177+ name,
178+ aliases,
179+ default_name_used,
180+ } )
170181 }
171182
172- Ok ( base_fn_names)
183+ /// Iterate the basic name and all aliases
184+ fn all_names ( & self ) -> impl Iterator < Item = & String > {
185+ iter:: once ( & self . name ) . chain ( self . aliases . iter ( ) )
186+ }
173187}
174188
175189/// Get the return type to use and a wrapper. Once per impl setup.
176- fn get_rt_and_wrapper (
190+ fn get_ret_ty_and_wrapper (
177191 parsed : & ItemImpl ,
178192 dstruct_path : & TypePath ,
179193 wrapper_ident : & Ident ,
@@ -209,16 +223,50 @@ fn get_rt_and_wrapper(
209223 Ok ( ( ret_ty, wrapper_struct) )
210224}
211225
226+ /// Make implementations for our helper/metadata traits
227+ fn make_helper_trait_impls (
228+ dstruct_path : & TypePath ,
229+ meta : & ParsedMeta ,
230+ impl_ty : ImplType ,
231+ ) -> TokenStream2 {
232+ let name = LitStr :: new ( & meta. name , Span :: call_site ( ) ) ;
233+ let aliases = meta
234+ . aliases
235+ . iter ( )
236+ . map ( |alias| LitStr :: new ( alias. as_ref ( ) , Span :: call_site ( ) ) ) ;
237+ let ( trait_name, check_expr) = match impl_ty {
238+ ImplType :: Basic => (
239+ quote ! { :: udf:: wrapper:: RegisteredBasicUdf } ,
240+ TokenStream2 :: new ( ) ,
241+ ) ,
242+ ImplType :: Aggregate => (
243+ quote ! { :: udf:: wrapper:: RegisteredAggregateUdf } ,
244+ quote ! { const _: ( ) = :: udf:: wrapper:: verify_aggregate_attributes:: <#dstruct_path>( ) ; } ,
245+ ) ,
246+ } ;
247+ let default_name_used = meta. default_name_used ;
248+
249+ quote ! {
250+ impl #trait_name for #dstruct_path {
251+ const NAME : & ' static str = #name;
252+ const ALIASES : & ' static [ & ' static str ] = & [ #( #aliases ) , * ] ;
253+ const DEFAULT_NAME_USED : bool = #default_name_used;
254+ }
255+
256+ #check_expr
257+ }
258+ }
259+
212260/// Create the basic function signatures (`xxx_init`, `xxx_deinit`, `xxx`)
213261fn make_basic_fns (
214262 rt : & RetType ,
215263 base_fn_name : & str ,
216264 dstruct_path : & TypePath ,
217265 wrapper_ident : & Ident ,
218266) -> TokenStream2 {
219- let init_fn_name = format_ident_str ! ( "{}_init" , base_fn_name) ;
220- let deinit_fn_name = format_ident_str ! ( "{}_deinit" , base_fn_name) ;
221- let process_fn_name = format_ident_str ! ( "{}" , base_fn_name) ;
267+ let init_fn_name = format_ident ! ( "{}_init" , base_fn_name) ;
268+ let deinit_fn_name = format_ident ! ( "{}_deinit" , base_fn_name) ;
269+ let process_fn_name = format_ident ! ( "{}" , base_fn_name) ;
222270
223271 let init_fn = make_init_fn ( dstruct_path, wrapper_ident, & init_fn_name) ;
224272 let deinit_fn = make_deinit_fn ( dstruct_path, wrapper_ident, & deinit_fn_name) ;
@@ -269,9 +317,9 @@ fn make_agg_fns(
269317 dstruct_path : & TypePath , // Name of the data structure
270318 wrapper_ident : & Ident ,
271319) -> TokenStream2 {
272- let clear_fn_name = format_ident_str ! ( "{}_clear" , base_fn_name) ;
273- let add_fn_name = format_ident_str ! ( "{}_add" , base_fn_name) ;
274- let remove_fn_name = format_ident_str ! ( "{}_remove" , base_fn_name) ;
320+ let clear_fn_name = format_ident ! ( "{}_clear" , base_fn_name) ;
321+ let add_fn_name = format_ident ! ( "{}_add" , base_fn_name) ;
322+ let remove_fn_name = format_ident ! ( "{}_remove" , base_fn_name) ;
275323
276324 // Determine whether this re-implements `remove`
277325 let impls_remove = & parsed
@@ -280,7 +328,6 @@ fn make_agg_fns(
280328 . filter_map ( match_variant ! ( ImplItem :: Fn ) )
281329 . map ( |m| & m. sig . ident )
282330 . any ( |id| * id == "remove" ) ;
283- let base_fn_ident = Ident :: new ( base_fn_name, Span :: call_site ( ) ) ;
284331
285332 let clear_fn = make_clear_fn ( dstruct_path, wrapper_ident, & clear_fn_name) ;
286333 let add_fn = make_add_fn ( dstruct_path, wrapper_ident, & add_fn_name) ;
@@ -295,10 +342,6 @@ fn make_agg_fns(
295342 } ;
296343
297344 quote ! {
298- // Sanity check that we implemented
299- #[ allow( dead_code, non_upper_case_globals) ]
300- const did_you_apply_the_same_aliases_to_the_BasicUdf_impl: * const ( ) = #base_fn_ident as _;
301-
302345 #clear_fn
303346
304347 #add_fn
0 commit comments