Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions rs_bindings_from_cc/generate_bindings/database/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ memoized::query_group! {
/// Implementation: rs_bindings_from_cc/generate_bindings/has_bindings.rs?q=function:resolve_type_names
fn resolve_type_names(&self, parent: Rc<Record>) -> Result<Rc<HashMap<Rc<str>, ResolvedTypeName>>>;

/// Returns whether the given type ensures that the memory for all publicly accessible
/// fields is initialized.
///
/// Implementation: rs_bindings_from_cc/generate_bindings/lib.rs?q=function:is_default_initialized
fn is_default_initialized(&self, rs_type_kind: RsTypeKind) -> Result<bool>;

#[provided]
/// Returns the generated bindings for the given enum.
///
Expand Down
33 changes: 33 additions & 0 deletions rs_bindings_from_cc/generate_bindings/generate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,39 @@ fn api_func_shape_for_constructor(
match func.params.len() {
0 => panic!("Missing `__this` parameter in a constructor: {:?}", func),
1 => {
// Default constructor.
match func.safety_annotation {
SafetyAnnotation::DisableUnsafe => {
// The default constructor has been marked safe, the user now carries the
// burden of ensuring it's actually safe.
}
SafetyAnnotation::Unsafe => {
errors.add(anyhow!("Rust Default implementations cannot be unsafe."));
return None;
}
SafetyAnnotation::Unannotated => {
// Check if the record is default initialized.
let rs_type_kind = db
.rs_type_kind(CcType {
variant: CcTypeVariant::Decl(record.id),
is_const: false,
unknown_attr: "".into(),
explicit_lifetimes: Vec::new(),
})
.map_err(|e| errors.add(e))
.ok()?;
if !db.is_default_initialized(rs_type_kind).map_err(|e| errors.add(e)).ok()? {
// This type does not init the memory for all publicly accessible fields by default.
let err = if func.is_implicit {
anyhow!("The implicit default constructor will leave some fields uninitialized.")
} else {
anyhow!("The default constructor must be marked CRUBIT_UNSAFE_MARK_SAFE to guarantee that the fields that do not initialized themselves are initialized in the function.")
};
errors.add(err);
return None;
}
}
}
let func_name = make_rs_ident("default");
let impl_kind = ImplKind::new_trait(
TraitName::Default,
Expand Down
74 changes: 73 additions & 1 deletion rs_bindings_from_cc/generate_bindings/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use database::code_snippet::{
use database::db::{BindingsGenerator, CodegenFunctions};
use database::rs_snippet::{
BridgeRsTypeKind, Callable, FnTrait, Mutability, PassingConvention, RsTypeKind, RustPtrKind,
Safety,
Safety, UniformReprTemplateType,
};
use dyn_format::Format;
use error_report::{bail, ErrorReporting, ReportFatalError};
Expand Down Expand Up @@ -394,6 +394,7 @@ pub fn new_database<'db>(
crubit_abi_type,
has_bindings::type_target_restriction,
has_bindings::resolve_type_names,
is_default_initialized,
)
}

Expand Down Expand Up @@ -1510,3 +1511,74 @@ fn make_cpp_type_from_item(
.map_err(|e| anyhow!("Failed to parse C++ name `{cc_name}`: {e}"))?;
Ok(quote! { :: #(#namespace_parts::)* #cpp_type })
}

fn is_default_initialized(db: &BindingsGenerator, rs_type_kind: RsTypeKind) -> Result<bool> {
let result = match rs_type_kind.unalias() {
RsTypeKind::Error { .. } => false,
RsTypeKind::Pointer { kind, .. } => match kind {
RustPtrKind::CcPtr(_) => {
// C++ pointers are not default initialized.
false
}
RustPtrKind::Slice => {
// Slice comes from rs_std::SliceRef, which has a safe default constructor.
true
}
},
RsTypeKind::Reference { .. }
| RsTypeKind::RvalueReference { .. }
| RsTypeKind::FuncPtr { .. }
| RsTypeKind::IncompleteRecord { .. } => false,
RsTypeKind::Record { record, uniform_repr_template_type, owned_ptr_type, .. } => {
// Owned pointer types are transparent wrappers around a pointer, so they have the same
// initialization semantics (no initialization) as a pointer.
if owned_ptr_type.is_some() {
return Ok(false);
}
if let Some(uniform_repr_template_type) = uniform_repr_template_type {
// Exhaustive matching in case we add more.
return match uniform_repr_template_type.as_ref() {
UniformReprTemplateType::StdVector { .. } => Ok(true),
UniformReprTemplateType::StdUniquePtr { .. } => Ok(true),
UniformReprTemplateType::AbslSpan { .. } => Ok(true),
};
}

// TODO(okabayashi): Do we have a default constructor that's marked as
// CRUBIT_UNSAFE_MARK_SAFE?

// If we don't, then we need to check that all fields are default initialized.
for field in &record.fields {
if field.access != AccessSpecifier::Public {
// mapped to [MaybeUninit<u8>; N] so we don't care, go next.
continue;
}

if field.has_in_class_initializer {
// Default initialized, go next.
continue;
}

if db.is_default_initialized(db.rs_type_kind(field.type_.clone())?)? {
// The type is default initialized, go next.
continue;
}

// None of the default initialization checks worked, by itself this record will
// contain publicly accessible uninitialized memory.
return Ok(false);
}

true
}
RsTypeKind::Enum { .. } => false,
RsTypeKind::TypeAlias { .. } => unreachable!("called .unalias() above"),
RsTypeKind::Primitive(_) => false,
RsTypeKind::BridgeType { .. } => bail!("BridgeType should not be default initialized"),
RsTypeKind::ExistingRustType(_existing_rust_type) => {
// I don't really know
false
}
};
Ok(result)
}
3 changes: 2 additions & 1 deletion rs_bindings_from_cc/importers/cxx_record.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,8 @@ std::vector<Field> CXXRecordDeclImporter::ImportFields(
.is_no_unique_address =
field_decl->hasAttr<clang::NoUniqueAddressAttr>(),
.is_bitfield = field_decl->isBitField(),
.is_inheritable = is_inheritable});
.is_inheritable = is_inheritable,
.has_in_class_initializer = field_decl->hasInClassInitializer()});
}
return fields;
}
Expand Down
1 change: 1 addition & 0 deletions rs_bindings_from_cc/importers/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,7 @@ std::optional<IR::Item> FunctionDeclImporter::Import(
.id = ictx_.GenerateItemId(function_decl),
.enclosing_item_id = *std::move(enclosing_item_id),
.lifetime_inputs = std::move(lifetime_inputs),
.is_implicit = function_decl->isImplicit(),
};
}

Expand Down
2 changes: 2 additions & 0 deletions rs_bindings_from_cc/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ llvm::json::Value Func::ToJson() const {
{"enclosing_item_id", enclosing_item_id},
{"adl_enclosing_record", adl_enclosing_record},
{"must_bind", must_bind},
{"is_implicit", is_implicit},
};

if (!lifetime_inputs.empty()) {
Expand Down Expand Up @@ -433,6 +434,7 @@ llvm::json::Value Field::ToJson() const {
{"is_no_unique_address", is_no_unique_address},
{"is_bitfield", is_bitfield},
{"is_inheritable", is_inheritable},
{"has_in_class_initializer", has_in_class_initializer},
};
}

Expand Down
5 changes: 5 additions & 0 deletions rs_bindings_from_cc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ struct Func {
bool must_bind = false;
// Lifetime variable names bound by this function.
std::vector<std::string> lifetime_inputs;
bool is_implicit;
};

inline std::ostream& operator<<(std::ostream& o, const Func& f) {
Expand Down Expand Up @@ -502,6 +503,10 @@ struct Field {
bool is_no_unique_address; // True if the field is [[no_unique_address]].
bool is_bitfield; // True if the field is a bitfield.
bool is_inheritable; // True if the field is inheritable.

/// True if the field has an in-class initializer, e.g.
/// struct Foo { int x = 1; };
bool has_in_class_initializer;
};

inline std::ostream& operator<<(std::ostream& o, const Field& f) {
Expand Down
3 changes: 3 additions & 0 deletions rs_bindings_from_cc/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ pub struct Func {
// Lifetime variable names bound by this function.
#[serde(default)]
pub lifetime_inputs: Vec<Rc<str>>,
pub is_implicit: bool,
}

impl GenericItem for Func {
Expand Down Expand Up @@ -1002,6 +1003,8 @@ pub struct Field {
// TODO(kinuko): Consider removing this, it is a duplicate of the same information
// in `Record`.
pub is_inheritable: bool,

pub has_in_class_initializer: bool,
}

#[derive(Debug, PartialEq, Eq, Hash, Clone, Deserialize)]
Expand Down