Skip to content

Commit ed5ee4e

Browse files
cramertjcopybara-github
authored andcommitted
ctor::recursively_pinned: Document projection types and adjust visibility
After this CL, projection types and methods are only as public as the type or most-public field. This prevents the projection methods being used as an escape hatch to override the type's usual visibility. PiperOrigin-RevId: 910936727
1 parent ff391f8 commit ed5ee4e

6 files changed

Lines changed: 359 additions & 109 deletions

File tree

support/ctor.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,7 @@ use core::mem::{ManuallyDrop, MaybeUninit};
167167
use core::ops::{Deref, DerefMut};
168168
use core::pin::Pin;
169169

170-
pub use ctor_proc_macros::{
171-
project_pin_type, project_ref_type, recursively_pinned, CtorFrom_Default, MoveAndAssignViaCopy,
172-
};
170+
pub use ctor_proc_macros::{recursively_pinned, CtorFrom_Default, MoveAndAssignViaCopy};
173171

174172
/// The error type for an infallible `Ctor`.
175173
///
@@ -1396,6 +1394,18 @@ pub unsafe trait RecursivelyPinned {
13961394
/// will not be initialized, so they must permit uninitialized memory.
13971395
/// (For example, ZST or MaybeUninit.)
13981396
type CtorInitializedFields: ?Sized;
1397+
1398+
/// The type returned by `project_pin` containing pinned references to the
1399+
/// fields of `Self`.
1400+
type ProjectedPin<'a>
1401+
where
1402+
Self: 'a;
1403+
1404+
/// The type returned by `project_ref` containing pinned references to the
1405+
/// fields of `Self`.
1406+
type ProjectedRef<'a>
1407+
where
1408+
Self: 'a;
13991409
}
14001410

14011411
/// The drop trait for `#[recursively_pinned(PinnedDrop)]` types.

support/ctor_macro_test.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ fn test_ctor_struct_hygiene() {
3131
}
3232
unsafe impl ::ctor::RecursivelyPinned for Struct {
3333
type CtorInitializedFields = Self;
34+
type ProjectedPin<'a>
35+
= Self
36+
where
37+
Self: 'a;
38+
type ProjectedRef<'a>
39+
= Self
40+
where
41+
Self: 'a;
3442
}
3543
let _ = ::ctor::ctor! {Struct { x: 0 }};
3644
}
@@ -41,6 +49,14 @@ fn test_ctor_tuple_struct_hygiene() {
4149
struct TupleStruct(i32);
4250
unsafe impl ::ctor::RecursivelyPinned for TupleStruct {
4351
type CtorInitializedFields = Self;
52+
type ProjectedPin<'a>
53+
= Self
54+
where
55+
Self: 'a;
56+
type ProjectedRef<'a>
57+
= Self
58+
where
59+
Self: 'a;
4460
}
4561
let _ = ::ctor::ctor! {TupleStruct(0)};
4662
}

support/ctor_proc_macros.rs

Lines changed: 151 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ extern crate alloc;
88
use alloc::borrow::Cow;
99
use alloc::collections::BTreeSet;
1010
use alloc::format;
11+
use alloc::string::ToString;
1112
use alloc::vec;
13+
use alloc::vec::Vec;
1214
use proc_macro::TokenStream;
1315
use proc_macro2::{Ident, Span};
1416
use quote::{quote, quote_spanned, ToTokens as _};
@@ -192,55 +194,54 @@ fn derive_move_and_assign_via_copy_impl(
192194
})
193195
}
194196

195-
/// `project_pin_type!(foo::T)` is the name of the type returned by
196-
/// `foo::T::project_pin()`.
197-
///
198-
/// If `foo::T` is not `#[recursively_pinned]`, then this returns the name it
199-
/// would have used, but is essentially useless.
200-
#[proc_macro]
201-
pub fn project_pin_type(name: TokenStream) -> TokenStream {
202-
project_type_impl(name, project_pin_ident)
203-
}
204-
205197
fn project_pin_ident(ident: &Ident) -> Ident {
206198
Ident::new(&format!("__CrubitProjectPin{}", ident), Span::call_site())
207199
}
208200

209-
/// `project_ref_type!(foo::T)` is the name of the type returned by
210-
/// `foo::T::project_ref()`.
211-
///
212-
/// If `foo::T` is not `#[recursively_pinned]`, then this returns the name it
213-
/// would have used, but is essentially useless.
214-
#[proc_macro]
215-
pub fn project_ref_type(name: TokenStream) -> TokenStream {
216-
project_type_impl(name, project_ref_ident)
217-
}
218-
219201
fn project_ref_ident(ident: &Ident) -> Ident {
220202
Ident::new(&format!("__CrubitProjectRef{}", ident), Span::call_site())
221203
}
222204

223-
fn project_type_impl(name: TokenStream, project_ident: fn(&Ident) -> Ident) -> TokenStream {
224-
let mut name = syn::parse_macro_input!(name as syn::Path);
225-
match name.segments.last_mut() {
226-
None => {
227-
return syn::Error::new(name.span(), "Path must have at least one element")
228-
.into_compile_error()
229-
.into();
205+
fn most_public_vis(a: &syn::Visibility, b: &syn::Visibility) -> syn::Visibility {
206+
use syn::Visibility::{Inherited, Public, Restricted};
207+
match (a, b) {
208+
(Public(_), _) | (_, Public(_)) => Public(<syn::Token![pub]>::default()),
209+
(Restricted(r), _) | (_, Restricted(r)) => Restricted(r.clone()),
210+
(Inherited, Inherited) => Inherited,
211+
}
212+
}
213+
214+
fn most_private_vis(a: &syn::Visibility, b: &syn::Visibility) -> syn::Visibility {
215+
use syn::Visibility::{Inherited, Public, Restricted};
216+
match (a, b) {
217+
(Inherited, _) | (_, Inherited) => Inherited,
218+
(Restricted(r), _) | (_, Restricted(r)) => Restricted(r.clone()),
219+
(Public(_), Public(_)) => Public(<syn::Token![pub]>::default()),
220+
}
221+
}
222+
223+
fn get_max_visibility(input: &syn::DeriveInput) -> syn::Visibility {
224+
let mut max_vis = syn::Visibility::Inherited;
225+
match &input.data {
226+
syn::Data::Struct(data) => {
227+
for field in &data.fields {
228+
max_vis = most_public_vis(&max_vis, &field.vis);
229+
}
230230
}
231-
Some(last) => {
232-
if let syn::PathArguments::Parenthesized(p) = &last.arguments {
233-
return syn::Error::new(
234-
p.span(),
235-
"Parenthesized paths (e.g. fn, Fn) do not have projected equivalents.",
236-
)
237-
.into_compile_error()
238-
.into();
231+
syn::Data::Enum(e) => {
232+
for variant in &e.variants {
233+
for field in &variant.fields {
234+
max_vis = most_public_vis(&max_vis, &field.vis);
235+
}
236+
}
237+
}
238+
syn::Data::Union(u) => {
239+
for field in &u.fields.named {
240+
max_vis = most_public_vis(&max_vis, &field.vis);
239241
}
240-
last.ident = project_ident(&last.ident);
241242
}
242243
}
243-
TokenStream::from(quote! { #name })
244+
max_vis
244245
}
245246

246247
/// Defines the `project_pin`/`project_ref` function, and its return type.
@@ -261,6 +262,8 @@ fn project_method_impl(
261262
mut_: proc_macro2::TokenStream,
262263
project_ident: fn(&Ident) -> Ident,
263264
) -> syn::Result<proc_macro2::TokenStream> {
265+
let max_field_vis = get_max_visibility(input);
266+
264267
let is_fieldless = match &input.data {
265268
syn::Data::Struct(data) => data.fields.is_empty(),
266269
syn::Data::Enum(e) => e.variants.iter().all(|variant| variant.fields.is_empty()),
@@ -269,19 +272,38 @@ fn project_method_impl(
269272
}
270273
};
271274

275+
// The projection type should be at most as public as the most public field
276+
// and at most as public as the input type.
277+
let vis = most_private_vis(&input.vis, &max_field_vis);
278+
272279
let mut projected = input.clone();
273280
// TODO(jeanpierreda): check attributes for repr(packed)
274281
projected.attrs.clear();
275282
projected.ident = project_ident(&projected.ident);
276283

277-
let lifetime = if is_fieldless {
278-
quote! {}
279-
} else {
280-
add_lifetime(&mut projected.generics, "'proj")
281-
};
284+
// Projection types must be public to be used in the public RecursivelyPinned trait.
285+
// This is okay because the fields are still private, as is the projection method.
286+
// We also hide these unusable projection types from the docs.
287+
projected.vis = syn::Visibility::Public(<syn::Token![pub]>::default());
288+
if matches!(vis, syn::Visibility::Inherited) {
289+
projected.attrs.push(syn::parse_quote!(#[doc(hidden)]));
290+
}
291+
292+
let doc_msg = format!("Pinned references to the fields of [`{}`].", input.ident);
293+
projected.attrs.push(syn::parse_quote!(#[doc = #doc_msg]));
294+
295+
let lifetime = add_lifetime(&mut projected.generics, "'crubit_proj");
282296

283297
let project_field = |field: &mut syn::Field| {
284298
field.attrs.clear();
299+
let field_name = field
300+
.ident
301+
.as_ref()
302+
.map(|id| id.to_string())
303+
.unwrap_or_else(|| "tuple field".to_string());
304+
let doc_msg =
305+
format!("Pinned reference to the `{}` field of [`{}`].", field_name, input.ident);
306+
field.attrs.push(syn::parse_quote!(#[doc = #doc_msg]));
285307
let field_ty = &field.ty;
286308
let pin_ty = syn::parse_quote!(::core::pin::Pin<& #lifetime #mut_ #field_ty>);
287309
field.ty = syn::Type::Path(pin_ty);
@@ -317,26 +339,47 @@ fn project_method_impl(
317339
let projected_ident = &projected.ident;
318340
match &mut projected.data {
319341
syn::Data::Struct(data) => {
320-
for field in &mut data.fields {
321-
project_field(field);
342+
if is_fieldless {
343+
data.fields = syn::Fields::Named(syn::parse_quote!({
344+
#[doc(hidden)]
345+
pub _phantom: ::core::marker::PhantomData<& #lifetime ()>,
346+
}));
347+
project_body = quote! {
348+
let _ = from;
349+
#projected_ident { _phantom: ::core::marker::PhantomData }
350+
};
351+
} else {
352+
for field in &mut data.fields {
353+
project_field(field);
354+
}
355+
let (pat, project) = pat_project(&mut data.fields);
356+
project_body = quote! {
357+
let #input_ident #pat = from;
358+
#projected_ident #project
359+
};
322360
}
323-
let (pat, project) = pat_project(&mut data.fields);
324-
project_body = quote! {
325-
let #input_ident #pat = from;
326-
#projected_ident #project
327-
};
328361
}
329362
syn::Data::Enum(e) => {
330363
let mut match_body = quote! {};
331364
for variant in &mut e.variants {
332-
for field in &mut variant.fields {
333-
project_field(field);
334-
}
335-
let (pat, project) = pat_project(&mut variant.fields);
336365
let variant_ident = &variant.ident;
337-
match_body.extend(quote! {
338-
#input_ident::#variant_ident #pat => #projected_ident::#variant_ident #project,
339-
});
366+
if is_fieldless {
367+
variant.fields = syn::Fields::Named(syn::parse_quote!({
368+
#[doc(hidden)]
369+
pub _phantom: ::core::marker::PhantomData<& #lifetime ()>,
370+
}));
371+
match_body.extend(quote! {
372+
#input_ident::#variant_ident { .. } => #projected_ident::#variant_ident { _phantom: ::core::marker::PhantomData },
373+
});
374+
} else {
375+
for field in &mut variant.fields {
376+
project_field(field);
377+
}
378+
let (pat, project) = pat_project(&mut variant.fields);
379+
match_body.extend(quote! {
380+
#input_ident::#variant_ident #pat => #projected_ident::#variant_ident #project,
381+
});
382+
}
340383
}
341384
project_body = quote! {
342385
match from {
@@ -353,12 +396,18 @@ fn project_method_impl(
353396
input.generics.split_for_impl();
354397
let (_, projected_generics, _) = projected.generics.split_for_impl();
355398

399+
let method_doc = format!(
400+
"Projects a pinned reference to [`{}`] into a struct of pinned references to its fields.",
401+
input.ident
402+
);
403+
356404
Ok(quote! {
357405
#projected
358406

359407
impl #input_impl_generics #input_ident #input_type_generics #input_where_clause {
408+
#[doc = #method_doc]
360409
#[must_use]
361-
pub fn #method_name<#lifetime>(self: ::core::pin::Pin<& #lifetime #mut_ Self>) -> #projected_ident #projected_generics {
410+
#vis fn #method_name <#lifetime> (self: ::core::pin::Pin<& #lifetime #mut_ Self>) -> #projected_ident #projected_generics {
362411
unsafe {
363412
let from = ::core::pin::Pin::into_inner_unchecked(self);
364413
#project_body
@@ -384,7 +433,12 @@ fn add_lifetime(generics: &mut syn::Generics, prefix: &str) -> proc_macro2::Toke
384433
name = Cow::Owned(format!("{prefix}_{i}"));
385434
};
386435
let quoted_lifetime = quote! {#lifetime};
387-
generics.params.push(syn::GenericParam::Lifetime(syn::LifetimeParam::new(lifetime)));
436+
let pos = generics
437+
.params
438+
.iter()
439+
.position(|p| !matches!(p, syn::GenericParam::Lifetime(_)))
440+
.unwrap_or(generics.params.len());
441+
generics.params.insert(pos, syn::GenericParam::Lifetime(syn::LifetimeParam::new(lifetime)));
388442
quoted_lifetime
389443
}
390444

@@ -749,6 +803,41 @@ fn recursively_pinned_impl(
749803
}
750804
};
751805

806+
let mut input_lifetimes = Vec::new();
807+
let mut input_types_and_consts = Vec::new();
808+
809+
for param in input.generics.params.iter() {
810+
match param {
811+
syn::GenericParam::Lifetime(lt) => {
812+
let lt = &lt.lifetime;
813+
input_lifetimes.push(quote! { #lt });
814+
}
815+
syn::GenericParam::Type(t) => {
816+
let id = &t.ident;
817+
input_types_and_consts.push(quote! { #id });
818+
}
819+
syn::GenericParam::Const(c) => {
820+
let id = &c.ident;
821+
input_types_and_consts.push(quote! { #id });
822+
}
823+
}
824+
}
825+
826+
let (projected_pin_type, projected_ref_type) = if matches!(input.data, syn::Data::Union(_)) {
827+
(quote! { () }, quote! { () })
828+
} else {
829+
let projected_pin_ident = project_pin_ident(&name);
830+
let projected_ref_ident = project_ref_ident(&name);
831+
832+
let mut args = input_lifetimes.clone();
833+
args.push(quote! { 'crubit_proj });
834+
args.extend(input_types_and_consts.clone());
835+
836+
let type_args = quote! { < #(#args),* > };
837+
838+
(quote! { #projected_pin_ident #type_args }, quote! { #projected_ref_ident #type_args })
839+
};
840+
752841
Ok(quote! {
753842
#input
754843
#project_pin_impl
@@ -757,16 +846,15 @@ fn recursively_pinned_impl(
757846
#drop_impl
758847
#unpin_impl
759848

760-
// Introduce a new scope to limit the blast radius of the CtorInitializedFields type.
761-
// This lets us use relatively readable names: while the impl is visible outside the scope,
762-
// type is otherwise not visible.
763849
const _ : () = {
764850
#ctor_initialized_input
765851

766852
unsafe impl #input_impl_generics ::#ctor::RecursivelyPinned for
767853
#name #input_type_generics #input_where_clause
768854
{
769855
type CtorInitializedFields = #ctor_initialized_name #input_type_generics;
856+
type ProjectedPin<'crubit_proj> = #projected_pin_type where Self: 'crubit_proj;
857+
type ProjectedRef<'crubit_proj> = #projected_ref_type where Self: 'crubit_proj;
770858
}
771859
};
772860
})
@@ -809,6 +897,7 @@ mod test {
809897
struct __CrubitCtorS {x: i32}
810898
unsafe impl ::ctor::RecursivelyPinned for S {
811899
type CtorInitializedFields = __CrubitCtorS;
900+
...
812901
}
813902
};
814903
}
@@ -866,6 +955,7 @@ mod test {
866955
}
867956
unsafe impl ::ctor::RecursivelyPinned for E {
868957
type CtorInitializedFields = __CrubitCtorE;
958+
...
869959
}
870960
};
871961
}

0 commit comments

Comments
 (0)