Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 33 additions & 18 deletions rs_bindings_from_cc/generate_bindings/database/rs_snippet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,18 +538,14 @@ impl ToTokens for FnTrait {
}
}

/// Information about a dyn callable type.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Callable {
pub backing_type: BackingType,
pub struct CallableSignature {
pub fn_trait: FnTrait,
pub return_type: Rc<RsTypeKind>,
pub param_types: Rc<[RsTypeKind]>,
pub invoker_ident: Ident,
pub manager_ident: Ident,
}

impl Callable {
impl CallableSignature {
/// Returns a `TokenStream` in the shape of `-> Output`, or None if the return type is void.
pub fn rust_return_type_fragment(&self, db: &BindingsGenerator) -> Option<TokenStream> {
if self.return_type.is_void() {
Expand Down Expand Up @@ -581,6 +577,23 @@ impl Callable {
}
}

/// Information about a dyn callable type.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Callable {
pub backing_type: BackingType,
pub sig: CallableSignature,

/// The name of an extern "C" function that knows how to invoke this callable.
/// It is declared in C++ and defined in Rust. It has the signature
/// `extern "C" fn(*mut Box<dyn F>, ...) -> ...`
pub invoker_ident: Ident,

/// The name of an extern "C" function that knows how to delete this callable.
/// It is declared in C++ and defined in Rust. It has the signature
/// `extern "C" fn(FunctionToCall, *mut TypeErasedState, *mut TypeErasedState)`.
pub manager_ident: Ident,
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum BridgeRsTypeKind {
BridgeVoidConverters {
Expand Down Expand Up @@ -668,16 +681,18 @@ impl BridgeRsTypeKind {
let target_identifier = record.owning_target.convert_to_cc_identifier();
BridgeRsTypeKind::Callable(Rc::new(Callable {
backing_type,
fn_trait: match fn_trait {
ir::FnTrait::Fn => FnTrait::Fn,
ir::FnTrait::FnMut => FnTrait::FnMut,
ir::FnTrait::FnOnce => FnTrait::FnOnce,
sig: CallableSignature {
fn_trait: match fn_trait {
ir::FnTrait::Fn => FnTrait::Fn,
ir::FnTrait::FnMut => FnTrait::FnMut,
ir::FnTrait::FnOnce => FnTrait::FnOnce,
},
return_type: Rc::new(db.rs_type_kind(return_type.clone())?),
param_types: param_types
.iter()
.map(|param_type| db.rs_type_kind(param_type.clone()))
.collect::<Result<_>>()?,
},
return_type: Rc::new(db.rs_type_kind(return_type.clone())?),
param_types: param_types
.iter()
.map(|param_type| db.rs_type_kind(param_type.clone()))
.collect::<Result<_>>()?,
invoker_ident: format_ident!(
"__crubit_invoker_{}{}",
record.rs_name.identifier.as_ref(),
Expand Down Expand Up @@ -1785,7 +1800,7 @@ impl RsTypeKind {
}
}
BridgeRsTypeKind::Callable(callable) => {
let callable_spelling = callable.dyn_fn_spelling(&db);
let callable_spelling = callable.sig.dyn_fn_spelling(&db);
quote! { ::alloc::boxed::Box<#callable_spelling> }
}
BridgeRsTypeKind::C9Co { has_reference_param, result_type, .. } => {
Expand Down Expand Up @@ -1944,8 +1959,8 @@ impl<'ty> Iterator for RsTypeKindIter<'ty> {
}
BridgeRsTypeKind::StdString { .. } => {}
BridgeRsTypeKind::Callable(callable) => {
self.todo.push(&callable.return_type);
self.todo.extend(callable.param_types.iter().rev());
self.todo.push(&callable.sig.return_type);
self.todo.extend(callable.sig.param_types.iter().rev());
}
BridgeRsTypeKind::C9Co { result_type, .. } => {
self.todo.push(result_type);
Expand Down
52 changes: 28 additions & 24 deletions rs_bindings_from_cc/generate_bindings/generate_dyn_callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn dyn_callable_crubit_abi_type(
{
bail!("absl::AnyInvocable appears in the C++ API, but CRUBIT_ANY_INVOCABLE_SUPPORT_HEADER is not set. It should be set as a path to a .h file.");
}
let dyn_fn_spelling = callable.dyn_fn_spelling(db);
let dyn_fn_spelling = callable.sig.dyn_fn_spelling(db);

let rust_type_tokens = match callable.backing_type {
BackingType::DynCallable => quote! {
Expand All @@ -37,9 +37,9 @@ pub fn dyn_callable_crubit_abi_type(
};

let on_empty_tokens = {
let rust_return_type_fragment = callable.rust_return_type_fragment(db);
let rust_return_type_fragment = callable.sig.rust_return_type_fragment(db);
let param_type_tokens =
callable.param_types.iter().map(|param_ty| param_ty.to_token_stream(db));
callable.sig.param_types.iter().map(|param_ty| param_ty.to_token_stream(db));

quote! {
::alloc::boxed::Box::new(|#(_: #param_type_tokens),*| #rust_return_type_fragment {
Expand All @@ -65,14 +65,15 @@ pub fn dyn_callable_crubit_abi_type(
}
};

let qualifier = match callable.fn_trait {
let qualifier = match callable.sig.fn_trait {
FnTrait::Fn => quote! { const },
FnTrait::FnMut => quote! {},
FnTrait::FnOnce => quote! { && },
};

let cpp_return_type = cpp_type_name::format_cpp_type(&callable.return_type, db.ir())?;
let cpp_return_type = cpp_type_name::format_cpp_type(&callable.sig.return_type, db.ir())?;
let cpp_param_types = callable
.sig
.param_types
.iter()
.map(|param_ty| cpp_type_name::format_cpp_type(param_ty, db.ir()))
Expand Down Expand Up @@ -145,11 +146,11 @@ fn generate_invoker_function_pointer(
// Even if the callable has all C ABI compatible inputs and outputs, we cannot pass the function
// pointer directly because cfi doesn't recognize Rust function pointers as safe.
let param_idents =
(0..callable.param_types.len()).map(|i| format_ident!("param_{i}")).collect::<Vec<_>>();
(0..callable.sig.param_types.len()).map(|i| format_ident!("param_{i}")).collect::<Vec<_>>();

let mut arg_transforms = quote! {};
let mut arg_exprs = Vec::with_capacity(param_idents.len());
for (i, param_ty) in callable.param_types.iter().enumerate() {
for (i, param_ty) in callable.sig.param_types.iter().enumerate() {
let param_ident = &param_idents[i];

match param_ty.passing_convention() {
Expand Down Expand Up @@ -182,7 +183,7 @@ fn generate_invoker_function_pointer(
}
}

let out_param_arg = match callable.return_type.passing_convention() {
let out_param_arg = match callable.sig.return_type.passing_convention() {
PassingConvention::AbiCompatible
| PassingConvention::Void
| PassingConvention::OwnedPtr => None,
Expand All @@ -193,7 +194,8 @@ fn generate_invoker_function_pointer(
Some(quote! { , out.Get() })
}
PassingConvention::ComposablyBridged => {
let crubit_abi_type = db.crubit_abi_type(RsTypeKind::clone(&callable.return_type))?;
let crubit_abi_type =
db.crubit_abi_type(RsTypeKind::clone(&callable.sig.return_type))?;
let crubit_abi_type_tokens = CrubitAbiTypeToCppTokens(&crubit_abi_type);
arg_transforms.extend(quote! {
unsigned char out[#crubit_abi_type_tokens::kSize];
Expand All @@ -209,7 +211,7 @@ fn generate_invoker_function_pointer(
#invoker_ident(state #(, #arg_exprs)* #out_param_arg);
};

match callable.return_type.passing_convention() {
match callable.sig.return_type.passing_convention() {
PassingConvention::AbiCompatible | PassingConvention::OwnedPtr => {
// Return the result.
invoke_ffi_and_transform_to_cpp = quote! {
Expand All @@ -224,7 +226,8 @@ fn generate_invoker_function_pointer(
});
}
PassingConvention::ComposablyBridged => {
let crubit_abi_type = db.crubit_abi_type(RsTypeKind::clone(&callable.return_type))?;
let crubit_abi_type =
db.crubit_abi_type(RsTypeKind::clone(&callable.sig.return_type))?;
let crubit_abi_type_tokens = CrubitAbiTypeToCppTokens(&crubit_abi_type);
let crubit_abi_type_expr_tokens = CrubitAbiTypeToCppExprTokens(&crubit_abi_type);
invoke_ffi_and_transform_to_cpp.extend(quote! {
Expand Down Expand Up @@ -265,14 +268,15 @@ fn generate_make_cpp_invoker_tokens(
callable: &Callable,
) -> Result<TokenStream> {
let param_idents =
(0..callable.param_types.len()).map(|i| format_ident!("param_{i}")).collect::<Vec<_>>();
let rust_param_types = callable.param_types.iter().map(|param_ty| param_ty.to_token_stream(db));
let rust_return_type_fragment = callable.rust_return_type_fragment(db);
(0..callable.sig.param_types.len()).map(|i| format_ident!("param_{i}")).collect::<Vec<_>>();
let rust_param_types =
callable.sig.param_types.iter().map(|param_ty| param_ty.to_token_stream(db));
let rust_return_type_fragment = callable.sig.rust_return_type_fragment(db);

let mut c_param_types = Vec::with_capacity(callable.param_types.len());
let mut arg_exprs = Vec::with_capacity(callable.param_types.len());
let mut c_param_types = Vec::with_capacity(callable.sig.param_types.len());
let mut arg_exprs = Vec::with_capacity(callable.sig.param_types.len());
// We are the caller
for (i, param_ty) in callable.param_types.iter().enumerate() {
for (i, param_ty) in callable.sig.param_types.iter().enumerate() {
let param_ident = &param_idents[i];

match param_ty.passing_convention() {
Expand Down Expand Up @@ -314,14 +318,14 @@ fn generate_make_cpp_invoker_tokens(
// What the extern "C" function should return.
let mut c_return_type_fragment = None;
// Set c_return_type_fragment, or push an out param, or nothing if void.
match callable.return_type.passing_convention() {
match callable.sig.return_type.passing_convention() {
PassingConvention::AbiCompatible => {
let c_return_type = callable.return_type.to_token_stream(db);
let c_return_type = callable.sig.return_type.to_token_stream(db);
c_return_type_fragment = Some(quote! { -> #c_return_type });
}
PassingConvention::Void => {}
PassingConvention::LayoutCompatible => {
let return_type_tokens = callable.return_type.to_token_stream(db);
let return_type_tokens = callable.sig.return_type.to_token_stream(db);
c_param_types.push(quote! { *mut #return_type_tokens });
arg_exprs.push(quote! { &raw mut out });
}
Expand All @@ -333,7 +337,7 @@ fn generate_make_cpp_invoker_tokens(
bail!("Ctor not supported");
}
PassingConvention::OwnedPtr => {
let c_return_type = callable.return_type.to_token_stream_with_owned_ptr_type(db);
let c_return_type = callable.sig.return_type.to_token_stream_with_owned_ptr_type(db);
c_return_type_fragment = Some(quote! { -> #c_return_type });
}
};
Expand All @@ -342,7 +346,7 @@ fn generate_make_cpp_invoker_tokens(
unsafe { c_invoker(managed.state() #(, #arg_exprs)*) }
};

match callable.return_type.passing_convention() {
match callable.sig.return_type.passing_convention() {
PassingConvention::AbiCompatible => {
// invoke_ffi_and_transform_to_rust is already a trailing expr.
}
Expand All @@ -354,7 +358,7 @@ fn generate_make_cpp_invoker_tokens(
}
}
PassingConvention::ComposablyBridged => {
let crubit_abi_type = db.crubit_abi_type(callable.return_type.as_ref().clone())?;
let crubit_abi_type = db.crubit_abi_type(callable.sig.return_type.as_ref().clone())?;
let crubit_abi_type_tokens = CrubitAbiTypeToRustTokens(&crubit_abi_type);
let crubit_abi_type_expr_tokens = CrubitAbiTypeToRustExprTokens(&crubit_abi_type);
invoke_ffi_and_transform_to_rust = quote! {
Expand All @@ -380,7 +384,7 @@ fn generate_make_cpp_invoker_tokens(
}
}

let dyn_fn_spelling = callable.dyn_fn_spelling(db);
let dyn_fn_spelling = callable.sig.dyn_fn_spelling(db);

Ok(quote! {
|managed: ::any_invocable::ManagedState,
Expand Down
36 changes: 19 additions & 17 deletions rs_bindings_from_cc/generate_bindings/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,11 @@ pub fn generate_bindings_tokens(
// The parameters shall be named `param_0`, `param_1`, etc.
// These names can be reused across different callables, so we reuse the same vec and
// just grow it when we need more Idents than it currently contains.
while callable.param_types.len() > param_idents_buffer.len() {
while callable.sig.param_types.len() > param_idents_buffer.len() {
param_idents_buffer.push(format_ident!("param_{}", param_idents_buffer.len()));
}
// Only take as many filled in names as we need.
let param_idents = &param_idents_buffer[..callable.param_types.len()];
let param_idents = &param_idents_buffer[..callable.sig.param_types.len()];

// If generate_dyn_callable_cpp_thunk fails, skip. We don't need to generate a nice
// error because whoever uses this will also fail and generate an error at the relevant
Expand Down Expand Up @@ -644,7 +644,7 @@ fn rs_type_kind_safety(db: &BindingsGenerator, rs_type_kind: RsTypeKind) -> Safe
}
BridgeRsTypeKind::StdString { .. } => Safety::Safe,
BridgeRsTypeKind::Callable(callable) => {
callable_safety(db, &callable.param_types, &callable.return_type)
callable_safety(db, &callable.sig.param_types, &callable.sig.return_type)
}
BridgeRsTypeKind::C9Co { result_type, .. } => {
// A Co<T> logically produces a T, so it is unsafe iff T is unsafe.
Expand Down Expand Up @@ -1190,10 +1190,11 @@ fn generate_dyn_callable_cpp_thunk(
param_idents: &[Ident],
) -> Option<TokenStream> {
assert!(
param_idents.len() == callable.param_types.len(),
param_idents.len() == callable.sig.param_types.len(),
"param_idents and param_types should have the same length, this is a Crubit bug."
);
let param_types = callable
.sig
.param_types
.iter()
.map(|param_type| -> Option<TokenStream> {
Expand All @@ -1217,16 +1218,16 @@ fn generate_dyn_callable_cpp_thunk(
let out_param_ident;
let out_param_type;
let decl_return_type_tokens;
match callable.return_type.passing_convention() {
match callable.sig.return_type.passing_convention() {
PassingConvention::AbiCompatible => {
out_param_ident = None;
out_param_type = None;
decl_return_type_tokens =
cpp_type_name::format_cpp_type(&callable.return_type, db.ir()).ok()?;
cpp_type_name::format_cpp_type(&callable.sig.return_type, db.ir()).ok()?;
}
PassingConvention::LayoutCompatible => {
let return_type_tokens =
cpp_type_name::format_cpp_type(&callable.return_type, db.ir()).ok()?;
cpp_type_name::format_cpp_type(&callable.sig.return_type, db.ir()).ok()?;
out_param_ident = Some(format_ident!("out"));
out_param_type = Some(quote! { #return_type_tokens* });
decl_return_type_tokens = quote! { void };
Expand Down Expand Up @@ -1293,14 +1294,14 @@ fn generate_dyn_callable_rust_thunk_impl(
param_idents: &[Ident],
) -> Option<TokenStream> {
assert!(
param_idents.len() == callable.param_types.len(),
param_idents.len() == callable.sig.param_types.len(),
"param_idents and param_types should have the same length, this is a Crubit bug."
);
let mut ffi_to_rust_transforms = quote! {};

let param_types_tokens = param_idents
.iter()
.zip(callable.param_types.iter())
.zip(callable.sig.param_types.iter())
.map(|(ident, ty)| -> Option<TokenStream> {
match ty.passing_convention() {
PassingConvention::AbiCompatible => {
Expand Down Expand Up @@ -1328,16 +1329,16 @@ fn generate_dyn_callable_rust_thunk_impl(
})
.collect::<Option<Vec<TokenStream>>>()?;

let unwrapper = match callable.fn_trait {
let unwrapper = match callable.sig.fn_trait {
FnTrait::Fn => quote! { &*f },
FnTrait::FnMut => quote! { &mut *f },
FnTrait::FnOnce => {
// Replace the FnOnce with an empty instance, so it can still be dropped.
// Since it's a ZST, no allocation will be performed, and it can even be forgotten
// without worrying about leaks.
let rust_return_type_fragment = callable.rust_return_type_fragment(db);
let rust_return_type_fragment = callable.sig.rust_return_type_fragment(db);
let param_type_tokens =
callable.param_types.iter().map(|param_ty| param_ty.to_token_stream(db));
callable.sig.param_types.iter().map(|param_ty| param_ty.to_token_stream(db));
quote! {
// SAFETY: f is guaranteed to be valid for reads and writes, is properly aligned,
// and points to a properly initialized value of the correct type.
Expand All @@ -1356,9 +1357,9 @@ fn generate_dyn_callable_rust_thunk_impl(
let return_type_fragment;
let out_param_ident;
let out_param_type;
match callable.return_type.passing_convention() {
match callable.sig.return_type.passing_convention() {
PassingConvention::AbiCompatible => {
let ffi_return_type = callable.return_type.to_token_stream(db);
let ffi_return_type = callable.sig.return_type.to_token_stream(db);
return_type_fragment = Some(quote! { -> #ffi_return_type });
out_param_ident = None;
out_param_type = None;
Expand All @@ -1373,13 +1374,14 @@ fn generate_dyn_callable_rust_thunk_impl(
}
};

let ffi_return_type = callable.return_type.to_token_stream(db);
let ffi_return_type = callable.sig.return_type.to_token_stream(db);
return_type_fragment = None;
out_param_ident = Some(out_ident);
out_param_type = Some(quote! { *mut #ffi_return_type });
}
PassingConvention::ComposablyBridged => {
let crubit_abi_type = db.crubit_abi_type(callable.return_type.as_ref().clone()).ok()?;
let crubit_abi_type =
db.crubit_abi_type(callable.sig.return_type.as_ref().clone()).ok()?;
let crubit_abi_type_expr_tokens = CrubitAbiTypeToRustExprTokens(&crubit_abi_type);
let bridge_buffer_ident = format_ident!("bridge_buffer");
invoke_rust_and_return_to_ffi = quote! {
Expand Down Expand Up @@ -1414,7 +1416,7 @@ fn generate_dyn_callable_rust_thunk_impl(

let param_idents = param_idents.iter().chain(out_param_ident.as_ref());
let param_types_tokens = param_types_tokens.iter().chain(out_param_type.as_ref());
let dyn_fn_spelling = callable.dyn_fn_spelling(db);
let dyn_fn_spelling = callable.sig.dyn_fn_spelling(db);
let invoker_ident = &callable.invoker_ident;
let manager_ident = &callable.manager_ident;

Expand Down