@@ -8,7 +8,9 @@ extern crate alloc;
88use alloc:: borrow:: Cow ;
99use alloc:: collections:: BTreeSet ;
1010use alloc:: format;
11+ use alloc:: string:: ToString ;
1112use alloc:: vec;
13+ use alloc:: vec:: Vec ;
1214use proc_macro:: TokenStream ;
1315use proc_macro2:: { Ident , Span } ;
1416use quote:: { quote, quote_spanned, ToTokens as _} ;
@@ -192,55 +194,54 @@ fn derive_move_and_assign_via_copy_impl(
192194 } )
193195}
194196
195- /// `project_pin_type!(foo::T)` is the name of the type returned by
196- /// `foo::T::project_pin()`.
197- ///
198- /// If `foo::T` is not `#[recursively_pinned]`, then this returns the name it
199- /// would have used, but is essentially useless.
200- #[ proc_macro]
201- pub fn project_pin_type ( name : TokenStream ) -> TokenStream {
202- project_type_impl ( name, project_pin_ident)
203- }
204-
205197fn project_pin_ident ( ident : & Ident ) -> Ident {
206198 Ident :: new ( & format ! ( "__CrubitProjectPin{}" , ident) , Span :: call_site ( ) )
207199}
208200
209- /// `project_ref_type!(foo::T)` is the name of the type returned by
210- /// `foo::T::project_ref()`.
211- ///
212- /// If `foo::T` is not `#[recursively_pinned]`, then this returns the name it
213- /// would have used, but is essentially useless.
214- #[ proc_macro]
215- pub fn project_ref_type ( name : TokenStream ) -> TokenStream {
216- project_type_impl ( name, project_ref_ident)
217- }
218-
219201fn project_ref_ident ( ident : & Ident ) -> Ident {
220202 Ident :: new ( & format ! ( "__CrubitProjectRef{}" , ident) , Span :: call_site ( ) )
221203}
222204
223- fn project_type_impl ( name : TokenStream , project_ident : fn ( & Ident ) -> Ident ) -> TokenStream {
224- let mut name = syn:: parse_macro_input!( name as syn:: Path ) ;
225- match name. segments . last_mut ( ) {
226- None => {
227- return syn:: Error :: new ( name. span ( ) , "Path must have at least one element" )
228- . into_compile_error ( )
229- . into ( ) ;
205+ fn most_public_vis ( a : & syn:: Visibility , b : & syn:: Visibility ) -> syn:: Visibility {
206+ use syn:: Visibility :: { Inherited , Public , Restricted } ;
207+ match ( a, b) {
208+ ( Public ( _) , _) | ( _, Public ( _) ) => Public ( <syn:: Token ![ pub ] >:: default ( ) ) ,
209+ ( Restricted ( r) , _) | ( _, Restricted ( r) ) => Restricted ( r. clone ( ) ) ,
210+ ( Inherited , Inherited ) => Inherited ,
211+ }
212+ }
213+
214+ fn most_private_vis ( a : & syn:: Visibility , b : & syn:: Visibility ) -> syn:: Visibility {
215+ use syn:: Visibility :: { Inherited , Public , Restricted } ;
216+ match ( a, b) {
217+ ( Inherited , _) | ( _, Inherited ) => Inherited ,
218+ ( Restricted ( r) , _) | ( _, Restricted ( r) ) => Restricted ( r. clone ( ) ) ,
219+ ( Public ( _) , Public ( _) ) => Public ( <syn:: Token ![ pub ] >:: default ( ) ) ,
220+ }
221+ }
222+
223+ fn get_max_visibility ( input : & syn:: DeriveInput ) -> syn:: Visibility {
224+ let mut max_vis = syn:: Visibility :: Inherited ;
225+ match & input. data {
226+ syn:: Data :: Struct ( data) => {
227+ for field in & data. fields {
228+ max_vis = most_public_vis ( & max_vis, & field. vis ) ;
229+ }
230230 }
231- Some ( last) => {
232- if let syn:: PathArguments :: Parenthesized ( p) = & last. arguments {
233- return syn:: Error :: new (
234- p. span ( ) ,
235- "Parenthesized paths (e.g. fn, Fn) do not have projected equivalents." ,
236- )
237- . into_compile_error ( )
238- . into ( ) ;
231+ syn:: Data :: Enum ( e) => {
232+ for variant in & e. variants {
233+ for field in & variant. fields {
234+ max_vis = most_public_vis ( & max_vis, & field. vis ) ;
235+ }
236+ }
237+ }
238+ syn:: Data :: Union ( u) => {
239+ for field in & u. fields . named {
240+ max_vis = most_public_vis ( & max_vis, & field. vis ) ;
239241 }
240- last. ident = project_ident ( & last. ident ) ;
241242 }
242243 }
243- TokenStream :: from ( quote ! { #name } )
244+ max_vis
244245}
245246
246247/// Defines the `project_pin`/`project_ref` function, and its return type.
@@ -261,6 +262,8 @@ fn project_method_impl(
261262 mut_ : proc_macro2:: TokenStream ,
262263 project_ident : fn ( & Ident ) -> Ident ,
263264) -> syn:: Result < proc_macro2:: TokenStream > {
265+ let max_field_vis = get_max_visibility ( input) ;
266+
264267 let is_fieldless = match & input. data {
265268 syn:: Data :: Struct ( data) => data. fields . is_empty ( ) ,
266269 syn:: Data :: Enum ( e) => e. variants . iter ( ) . all ( |variant| variant. fields . is_empty ( ) ) ,
@@ -269,19 +272,38 @@ fn project_method_impl(
269272 }
270273 } ;
271274
275+ // The projection type should be at most as public as the most public field
276+ // and at most as public as the input type.
277+ let vis = most_private_vis ( & input. vis , & max_field_vis) ;
278+
272279 let mut projected = input. clone ( ) ;
273280 // TODO(jeanpierreda): check attributes for repr(packed)
274281 projected. attrs . clear ( ) ;
275282 projected. ident = project_ident ( & projected. ident ) ;
276283
277- let lifetime = if is_fieldless {
278- quote ! { }
279- } else {
280- add_lifetime ( & mut projected. generics , "'proj" )
281- } ;
284+ // Projection types must be public to be used in the public RecursivelyPinned trait.
285+ // This is okay because the fields are still private, as is the projection method.
286+ // We also hide these unusable projection types from the docs.
287+ projected. vis = syn:: Visibility :: Public ( <syn:: Token ![ pub ] >:: default ( ) ) ;
288+ if matches ! ( vis, syn:: Visibility :: Inherited ) {
289+ projected. attrs . push ( syn:: parse_quote!( #[ doc( hidden) ] ) ) ;
290+ }
291+
292+ let doc_msg = format ! ( "Pinned references to the fields of [`{}`]." , input. ident) ;
293+ projected. attrs . push ( syn:: parse_quote!( #[ doc = #doc_msg] ) ) ;
294+
295+ let lifetime = add_lifetime ( & mut projected. generics , "'crubit_proj" ) ;
282296
283297 let project_field = |field : & mut syn:: Field | {
284298 field. attrs . clear ( ) ;
299+ let field_name = field
300+ . ident
301+ . as_ref ( )
302+ . map ( |id| id. to_string ( ) )
303+ . unwrap_or_else ( || "tuple field" . to_string ( ) ) ;
304+ let doc_msg =
305+ format ! ( "Pinned reference to the `{}` field of [`{}`]." , field_name, input. ident) ;
306+ field. attrs . push ( syn:: parse_quote!( #[ doc = #doc_msg] ) ) ;
285307 let field_ty = & field. ty ;
286308 let pin_ty = syn:: parse_quote!( :: core:: pin:: Pin <& #lifetime #mut_ #field_ty>) ;
287309 field. ty = syn:: Type :: Path ( pin_ty) ;
@@ -317,26 +339,47 @@ fn project_method_impl(
317339 let projected_ident = & projected. ident ;
318340 match & mut projected. data {
319341 syn:: Data :: Struct ( data) => {
320- for field in & mut data. fields {
321- project_field ( field) ;
342+ if is_fieldless {
343+ data. fields = syn:: Fields :: Named ( syn:: parse_quote!( {
344+ #[ doc( hidden) ]
345+ pub _phantom: :: core:: marker:: PhantomData <& #lifetime ( ) >,
346+ } ) ) ;
347+ project_body = quote ! {
348+ let _ = from;
349+ #projected_ident { _phantom: :: core:: marker:: PhantomData }
350+ } ;
351+ } else {
352+ for field in & mut data. fields {
353+ project_field ( field) ;
354+ }
355+ let ( pat, project) = pat_project ( & mut data. fields ) ;
356+ project_body = quote ! {
357+ let #input_ident #pat = from;
358+ #projected_ident #project
359+ } ;
322360 }
323- let ( pat, project) = pat_project ( & mut data. fields ) ;
324- project_body = quote ! {
325- let #input_ident #pat = from;
326- #projected_ident #project
327- } ;
328361 }
329362 syn:: Data :: Enum ( e) => {
330363 let mut match_body = quote ! { } ;
331364 for variant in & mut e. variants {
332- for field in & mut variant. fields {
333- project_field ( field) ;
334- }
335- let ( pat, project) = pat_project ( & mut variant. fields ) ;
336365 let variant_ident = & variant. ident ;
337- match_body. extend ( quote ! {
338- #input_ident:: #variant_ident #pat => #projected_ident:: #variant_ident #project,
339- } ) ;
366+ if is_fieldless {
367+ variant. fields = syn:: Fields :: Named ( syn:: parse_quote!( {
368+ #[ doc( hidden) ]
369+ pub _phantom: :: core:: marker:: PhantomData <& #lifetime ( ) >,
370+ } ) ) ;
371+ match_body. extend ( quote ! {
372+ #input_ident:: #variant_ident { .. } => #projected_ident:: #variant_ident { _phantom: :: core:: marker:: PhantomData } ,
373+ } ) ;
374+ } else {
375+ for field in & mut variant. fields {
376+ project_field ( field) ;
377+ }
378+ let ( pat, project) = pat_project ( & mut variant. fields ) ;
379+ match_body. extend ( quote ! {
380+ #input_ident:: #variant_ident #pat => #projected_ident:: #variant_ident #project,
381+ } ) ;
382+ }
340383 }
341384 project_body = quote ! {
342385 match from {
@@ -353,12 +396,18 @@ fn project_method_impl(
353396 input. generics . split_for_impl ( ) ;
354397 let ( _, projected_generics, _) = projected. generics . split_for_impl ( ) ;
355398
399+ let method_doc = format ! (
400+ "Projects a pinned reference to [`{}`] into a struct of pinned references to its fields." ,
401+ input. ident
402+ ) ;
403+
356404 Ok ( quote ! {
357405 #projected
358406
359407 impl #input_impl_generics #input_ident #input_type_generics #input_where_clause {
408+ #[ doc = #method_doc]
360409 #[ must_use]
361- pub fn #method_name<#lifetime>( self : :: core:: pin:: Pin <& #lifetime #mut_ Self >) -> #projected_ident #projected_generics {
410+ #vis fn #method_name <#lifetime> ( self : :: core:: pin:: Pin <& #lifetime #mut_ Self >) -> #projected_ident #projected_generics {
362411 unsafe {
363412 let from = :: core:: pin:: Pin :: into_inner_unchecked( self ) ;
364413 #project_body
@@ -384,7 +433,12 @@ fn add_lifetime(generics: &mut syn::Generics, prefix: &str) -> proc_macro2::Toke
384433 name = Cow :: Owned ( format ! ( "{prefix}_{i}" ) ) ;
385434 } ;
386435 let quoted_lifetime = quote ! { #lifetime} ;
387- generics. params . push ( syn:: GenericParam :: Lifetime ( syn:: LifetimeParam :: new ( lifetime) ) ) ;
436+ let pos = generics
437+ . params
438+ . iter ( )
439+ . position ( |p| !matches ! ( p, syn:: GenericParam :: Lifetime ( _) ) )
440+ . unwrap_or ( generics. params . len ( ) ) ;
441+ generics. params . insert ( pos, syn:: GenericParam :: Lifetime ( syn:: LifetimeParam :: new ( lifetime) ) ) ;
388442 quoted_lifetime
389443}
390444
@@ -749,6 +803,41 @@ fn recursively_pinned_impl(
749803 }
750804 } ;
751805
806+ let mut input_lifetimes = Vec :: new ( ) ;
807+ let mut input_types_and_consts = Vec :: new ( ) ;
808+
809+ for param in input. generics . params . iter ( ) {
810+ match param {
811+ syn:: GenericParam :: Lifetime ( lt) => {
812+ let lt = & lt. lifetime ;
813+ input_lifetimes. push ( quote ! { #lt } ) ;
814+ }
815+ syn:: GenericParam :: Type ( t) => {
816+ let id = & t. ident ;
817+ input_types_and_consts. push ( quote ! { #id } ) ;
818+ }
819+ syn:: GenericParam :: Const ( c) => {
820+ let id = & c. ident ;
821+ input_types_and_consts. push ( quote ! { #id } ) ;
822+ }
823+ }
824+ }
825+
826+ let ( projected_pin_type, projected_ref_type) = if matches ! ( input. data, syn:: Data :: Union ( _) ) {
827+ ( quote ! { ( ) } , quote ! { ( ) } )
828+ } else {
829+ let projected_pin_ident = project_pin_ident ( & name) ;
830+ let projected_ref_ident = project_ref_ident ( & name) ;
831+
832+ let mut args = input_lifetimes. clone ( ) ;
833+ args. push ( quote ! { ' crubit_proj } ) ;
834+ args. extend ( input_types_and_consts. clone ( ) ) ;
835+
836+ let type_args = quote ! { < #( #args) , * > } ;
837+
838+ ( quote ! { #projected_pin_ident #type_args } , quote ! { #projected_ref_ident #type_args } )
839+ } ;
840+
752841 Ok ( quote ! {
753842 #input
754843 #project_pin_impl
@@ -757,16 +846,15 @@ fn recursively_pinned_impl(
757846 #drop_impl
758847 #unpin_impl
759848
760- // Introduce a new scope to limit the blast radius of the CtorInitializedFields type.
761- // This lets us use relatively readable names: while the impl is visible outside the scope,
762- // type is otherwise not visible.
763849 const _ : ( ) = {
764850 #ctor_initialized_input
765851
766852 unsafe impl #input_impl_generics :: #ctor:: RecursivelyPinned for
767853 #name #input_type_generics #input_where_clause
768854 {
769855 type CtorInitializedFields = #ctor_initialized_name #input_type_generics;
856+ type ProjectedPin <' crubit_proj> = #projected_pin_type where Self : ' crubit_proj;
857+ type ProjectedRef <' crubit_proj> = #projected_ref_type where Self : ' crubit_proj;
770858 }
771859 } ;
772860 } )
@@ -809,6 +897,7 @@ mod test {
809897 struct __CrubitCtorS { x: i32 }
810898 unsafe impl :: ctor:: RecursivelyPinned for S {
811899 type CtorInitializedFields = __CrubitCtorS;
900+ ...
812901 }
813902 } ;
814903 }
@@ -866,6 +955,7 @@ mod test {
866955 }
867956 unsafe impl :: ctor:: RecursivelyPinned for E {
868957 type CtorInitializedFields = __CrubitCtorE;
958+ ...
869959 }
870960 } ;
871961 }
0 commit comments