From 005b464deb9a98cbd78c3b3930a033773af5f33d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20Jos=C3=A9=20Pereira?= Date: Tue, 2 Dec 2025 19:03:14 -0300 Subject: [PATCH] Improve generation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Patrick José Pereira --- crates/generators/codegen-core/src/lib.rs | 3 + .../codegen-core/src/type_registry.rs | 254 +++++ crates/generators/cpp/src/lib.rs | 1011 ++++++++++++----- crates/generators/python/src/lib.rs | 814 ++++++++++--- crates/generators/rust/src/lib.rs | 249 +--- 5 files changed, 1691 insertions(+), 640 deletions(-) create mode 100644 crates/generators/codegen-core/src/type_registry.rs diff --git a/crates/generators/codegen-core/src/lib.rs b/crates/generators/codegen-core/src/lib.rs index 6675898..7b72781 100644 --- a/crates/generators/codegen-core/src/lib.rs +++ b/crates/generators/codegen-core/src/lib.rs @@ -2,6 +2,9 @@ use blueberry_ast::{ Annotation, AnnotationParam, Commented, ConstValue, Definition, MessageDef, Type, }; +pub mod type_registry; +pub use type_registry::{ResolvedMember, TypeRegistry, map_builtin_ident}; + pub const DEFAULT_MODULE_KEY: u16 = 0x4242; pub const MESSAGE_HEADER_SIZE: usize = 8; diff --git a/crates/generators/codegen-core/src/type_registry.rs b/crates/generators/codegen-core/src/type_registry.rs new file mode 100644 index 0000000..ff4f6b6 --- /dev/null +++ b/crates/generators/codegen-core/src/type_registry.rs @@ -0,0 +1,254 @@ +use std::collections::{HashMap, HashSet}; + +use blueberry_ast::{Definition, MessageDef, StructDef, Type}; + +#[derive(Clone)] +pub struct TypedefInfo { + pub ty: Type, + pub scope: Vec, +} + +#[derive(Clone)] +pub struct StructInfo { + pub def: StructDef, + pub scope: Vec, +} + +#[derive(Clone)] +pub struct MessageInfo { + pub def: MessageDef, + pub scope: Vec, +} + +#[derive(Clone)] +pub struct ResolvedMember { + pub name: String, + pub ty: Type, + pub comments: Vec, +} + +#[derive(Default)] +pub struct TypeRegistry { + typedefs: HashMap, TypedefInfo>, + structs: HashMap, StructInfo>, + messages: HashMap, MessageInfo>, + enums: HashMap, Type>, +} + +impl TypeRegistry { + pub fn new(definitions: &[Definition]) -> Self { + let mut registry = TypeRegistry::default(); + let mut scope = Vec::new(); + registry.collect(definitions, &mut scope); + registry + } + + pub fn resolve_type(&self, ty: &Type, scope: &[String]) -> Type { + match ty { + Type::Sequence { element_type, size } => Type::Sequence { + element_type: Box::new(self.resolve_type(element_type, scope)), + size: *size, + }, + Type::Array { + element_type, + dimensions, + } => { + let mut resolved = self.resolve_type(element_type, scope); + for &dim in dimensions.iter().rev() { + resolved = Type::Sequence { + element_type: Box::new(resolved), + size: Some(dim), + }; + } + resolved + } + Type::ScopedName(name) => { + if let [single] = name.as_slice() + && let Some(mapped) = map_builtin_ident(single) + { + return mapped; + } + if let Some(path) = self.resolve_typedef(name, scope) { + let info = self.typedefs.get(&path).expect("typedef info missing"); + self.resolve_type(&info.ty, &info.scope) + } else if let Some(path) = self.resolve_struct(name, scope) { + Type::ScopedName(path) + } else if let Some(path) = self.resolve_message(name, scope) { + Type::ScopedName(path) + } else if let Some(path) = self.resolve_enum(name, scope) { + Type::ScopedName(path) + } else { + Type::ScopedName(name.clone()) + } + } + other => other.clone(), + } + } + + pub fn collect_struct_members(&self, path: &[String]) -> Vec { + let info = self + .structs + .get(path) + .unwrap_or_else(|| panic!("missing struct {:?}", path)); + let mut members = Vec::new(); + if let Some(base) = &info.def.base + && let Some(base_path) = self.resolve_struct(base, &info.scope) + { + members.extend(self.collect_struct_members(&base_path)); + } + for member in &info.def.members { + let ty = self.resolve_type(&member.node.type_, &info.scope); + members.push(ResolvedMember { + name: member.node.name.clone(), + ty, + comments: member.comments.clone(), + }); + } + members + } + + pub fn collect_message_members(&self, path: &[String]) -> Vec { + let info = self + .messages + .get(path) + .unwrap_or_else(|| panic!("missing message {:?}", path)); + let mut members = Vec::new(); + if let Some(base) = &info.def.base + && let Some(base_path) = self.resolve_message(base, &info.scope) + { + members.extend(self.collect_message_members(&base_path)); + } + for member in &info.def.members { + let ty = self.resolve_type(&member.node.type_, &info.scope); + members.push(ResolvedMember { + name: member.node.name.clone(), + ty, + comments: member.comments.clone(), + }); + } + members + } + + pub fn struct_paths(&self) -> Vec> { + self.structs.keys().cloned().collect() + } + + pub fn enum_base(&self, path: &[String]) -> Option<&Type> { + self.enums.get(path) + } + + pub fn resolve_typedef(&self, name: &[String], scope: &[String]) -> Option> { + self.resolve_path(name, scope, self.typedefs.keys()) + } + + pub fn resolve_struct(&self, name: &[String], scope: &[String]) -> Option> { + self.resolve_path(name, scope, self.structs.keys()) + } + + pub fn resolve_message(&self, name: &[String], scope: &[String]) -> Option> { + self.resolve_path(name, scope, self.messages.keys()) + } + + pub fn resolve_enum(&self, name: &[String], scope: &[String]) -> Option> { + self.resolve_path(name, scope, self.enums.keys()) + } + + fn collect(&mut self, defs: &[Definition], scope: &mut Vec) { + for def in defs { + match def { + Definition::ModuleDef(module) => { + scope.push(module.node.name.clone()); + self.collect(&module.node.definitions, scope); + scope.pop(); + } + Definition::TypeDef(typedef) => { + let mut path = scope.clone(); + path.push(typedef.node.name.clone()); + self.typedefs.insert( + path, + TypedefInfo { + ty: typedef.node.base_type.clone(), + scope: scope.clone(), + }, + ); + } + Definition::StructDef(struct_def) => { + let mut path = scope.clone(); + path.push(struct_def.node.name.clone()); + self.structs.insert( + path, + StructInfo { + def: struct_def.node.clone(), + scope: scope.clone(), + }, + ); + } + Definition::MessageDef(message_def) => { + let mut path = scope.clone(); + path.push(message_def.node.name.clone()); + self.messages.insert( + path, + MessageInfo { + def: message_def.node.clone(), + scope: scope.clone(), + }, + ); + } + Definition::EnumDef(enum_def) => { + let mut path = scope.clone(); + path.push(enum_def.node.name.clone()); + let repr = enum_def + .node + .base_type + .clone() + .unwrap_or(Type::UnsignedLong); + self.enums.insert(path, repr); + } + Definition::ConstDef(_) | Definition::ImportDef(_) => {} + } + } + } + + fn resolve_path<'a, I>( + &self, + name: &[String], + scope: &[String], + entries: I, + ) -> Option> + where + I: IntoIterator>, + { + let paths: Vec> = entries.into_iter().cloned().collect(); + let lookup: HashSet> = paths.iter().cloned().collect(); + for prefix in (0..=scope.len()).rev() { + let mut candidate = scope[..prefix].to_vec(); + candidate.extend_from_slice(name); + if lookup.contains(&candidate) { + return Some(candidate); + } + } + self.resolve_by_suffix(name, &paths) + } + + fn resolve_by_suffix(&self, name: &[String], paths: &[Vec]) -> Option> { + let matches: Vec<&Vec> = paths.iter().filter(|path| path.ends_with(name)).collect(); + if matches.len() == 1 { + return Some(matches[0].clone()); + } + None + } +} + +pub fn map_builtin_ident(name: &str) -> Option { + match name { + "int8" => Some(Type::Octet), + "int16" => Some(Type::Short), + "int32" => Some(Type::Long), + "int64" => Some(Type::LongLong), + "uint8" => Some(Type::Octet), + "uint16" => Some(Type::UnsignedShort), + "uint32" => Some(Type::UnsignedLong), + "uint64" => Some(Type::UnsignedLongLong), + _ => None, + } +} diff --git a/crates/generators/cpp/src/lib.rs b/crates/generators/cpp/src/lib.rs index 5f57026..71e2afa 100644 --- a/crates/generators/cpp/src/lib.rs +++ b/crates/generators/cpp/src/lib.rs @@ -1,227 +1,528 @@ -use blueberry_ast::Definition; -use blueberry_codegen_core::{ - CodegenError, GeneratedFile, MessageSpec, PrimitiveType, class_name, collect_messages, - quoted_string, +use blueberry_ast::{ + Annotation, AnnotationParam, Commented, ConstDef, ConstValue, Definition, EnumDef, MessageDef, + StructDef, Type, TypeDef, }; +use blueberry_codegen_core::{CodegenError, DEFAULT_MODULE_KEY, GeneratedFile, TypeRegistry}; use genco::lang::c::Tokens; use genco::quote; const OUTPUT_PATH: &str = "cpp/messages.hpp"; pub fn generate(definitions: &[Definition]) -> Result, CodegenError> { - let messages = collect_messages(definitions)?; - if messages.is_empty() { - return Ok(Vec::new()); - } - let generator = CppGenerator::new(&messages); - let contents = generator.render(); + let generator = CppGenerator::new(definitions); + let contents = generator.render(definitions)?; + Ok(vec![GeneratedFile { path: OUTPUT_PATH.to_string(), contents, }]) } -struct CppGenerator<'a> { - messages: &'a [MessageSpec], +struct CppGenerator { + registry: TypeRegistry, } -impl<'a> CppGenerator<'a> { - fn new(messages: &'a [MessageSpec]) -> Self { - Self { messages } +impl CppGenerator { + fn new(definitions: &[Definition]) -> Self { + Self { + registry: TypeRegistry::new(definitions), + } } - fn render(&self) -> String { - let message_blocks: Vec = - self.messages.iter().map(|m| self.emit_message(m)).collect(); - + fn render(&self, definitions: &[Definition]) -> Result { + let defs = self.emit_definitions(definitions, &mut Vec::new(), DEFAULT_MODULE_KEY)?; let helpers = helpers_tokens(); - let header = message_header(); let tokens: Tokens = quote! { #pragma once #include + #include #include #include #include #include #include #include + #include #include + #include - namespace blueberry::messages { + namespace blueberry_generated { $helpers - $header - - $(for block in &message_blocks => - $block + $(for def in defs => + $def ) - } // namespace blueberry::messages + } // namespace blueberry_generated }; - tokens.to_file_string().expect("render cpp output") + Ok(tokens.to_file_string().expect("render cpp output")) } - fn emit_message(&self, message: &MessageSpec) -> Tokens { - let class_ident = class_name(&message.scope, &message.name); - let class_ident_ref = class_ident.as_str(); - let topic_literal = quoted_string(&message.topic); - let module_key = message.module_key; - let message_key = message.message_key; - let struct_size = message.field_payload_size(); - let padded_size = message.padded_payload_size(); - let payload_words = message.payload_words(); - - let fields = &message.fields; - let ctor_params: Tokens = quote! { - $(for field in fields join (, ) => - $(cpp_field_type(field.primitive)) $(field.name.clone()) = $(cpp_default(field.primitive)) - ) - }; - let ctor_inits: Tokens = quote! { - $(for field in fields join (, ) => - $(field.name.clone())_($(field.name.clone())) - ) - }; + fn emit_definitions( + &self, + defs: &[Definition], + scope: &mut Vec, + module_key: u16, + ) -> Result, CodegenError> { + let mut out = Vec::new(); + let mut message_keys = MessageKeyGen::default(); + + for def in defs { + match def { + Definition::ModuleDef(module) => { + let module_key = + annotation_u16(&module.annotations, "module_key").unwrap_or(module_key); + scope.push(module.node.name.clone()); + let nested = + self.emit_definitions(&module.node.definitions, scope, module_key)?; + scope.pop(); + + let name = &module.node.name; + out.push(quote! { + namespace $name { + $(for item in nested => + $item + + ) + } // namespace $name + }); + } + Definition::EnumDef(enum_def) => { + out.push(self.emit_enum(enum_def, scope)); + out.push(quote! { $['\n'] }); + } + Definition::StructDef(struct_def) => { + out.push(self.emit_struct(struct_def, scope)); + out.push(quote! { $['\n'] }); + } + Definition::MessageDef(message_def) => { + let message_key = annotation_u16(&message_def.annotations, "message_key") + .unwrap_or_else(|| message_keys.next()); + out.push(self.emit_message(message_def, scope, module_key, message_key)?); + out.push(quote! { $['\n'] }); + } + Definition::ConstDef(const_def) => { + out.push(self.emit_const(const_def, scope)); + } + Definition::TypeDef(typedef_def) => { + out.push(self.emit_typedef(typedef_def, scope)); + } + Definition::ImportDef(_) => {} + } + } - let accessors: Tokens = quote! { - $(for field in fields => - [[nodiscard]] $(cpp_field_type(field.primitive)) $(field.name.clone())() const noexcept { return $(field.name.clone())_; } - void set_$(field.name.clone())($(cpp_field_type(field.primitive)) value) noexcept { $(field.name.clone())_ = value; } - ) - }; + Ok(out) + } - let serialize_body: Tokens = quote! { - std::array payload{}; - std::size_t offset = 0; - $(for field in fields => - $(cpp_writer(field.primitive))($(field.name.clone())_, payload.data() + offset); - offset += $(field.primitive.size()); - ) - return payload; - }; + fn emit_enum(&self, enum_def: &Commented, scope: &[String]) -> Tokens { + let scoped = &enum_def.node.name; + let base = enum_def + .node + .base_type + .clone() + .unwrap_or(Type::UnsignedLong); + let base_ty = self.cpp_type(&base, scope); + let members: Vec = enum_def + .node + .enumerators + .iter() + .map(|member| { + if let Some(ConstValue::Integer(value)) = &member.value { + quote!( $(member.name.clone()) = $(value.value) ) + } else { + quote!( $(member.name.clone()) ) + } + }) + .collect(); + + quote! { + enum class $scoped : $base_ty { + $(for m in members join (, ) => $m) + }; + } + } - let deserialize_body: Tokens = if fields.is_empty() { - quote! { return $class_ident_ref(); } - } else { - let values: Tokens = quote! { - $(for field in fields => - auto value_$(field.name.clone()) = $(cpp_reader(field.primitive))(payload.data() + offset); - offset += $(field.primitive.size()); + fn emit_struct(&self, struct_def: &Commented, scope: &[String]) -> Tokens { + let mut path = scope.to_vec(); + path.push(struct_def.node.name.clone()); + let members = self.registry.collect_struct_members(&path); + let fields: Vec = members + .iter() + .map(|member| { + let ty = self.cpp_type(&member.ty, scope); + let name = &member.name; + quote!( $ty $name{}; ) + }) + .collect(); + let serialize_fields: Vec = members + .iter() + .map(|member| self.serialize_value(&member.name, &member.ty, scope, quote!(writer))) + .collect(); + let deserialize_fields: Vec = members + .iter() + .map(|member| { + let var_name = format!("value_{}", member.name); + let read_expr = self.deserialize_value(&member.ty, scope, quote!(reader)); + quote! { + auto $var_name = $read_expr; + } + }) + .collect(); + let assignments: Vec = members + .iter() + .map(|member| { + let var_name = format!("value_{}", member.name); + let field = &member.name; + quote!( result.$field = $var_name; ) + }) + .collect(); + + let name = &struct_def.node.name; + quote! { + struct $name { + $( + for field in &fields join ($['\r']) => $field ) + + [[nodiscard]] std::vector to_payload() const { + CdrWriter writer; + serialize(writer); + return writer.buffer; + } + + void serialize(CdrWriter &writer) const { + $(for s in &serialize_fields => $s$['\r']) + } + + [[nodiscard]] static $name from_payload(std::span payload) { + CdrReader reader(payload); + return read(reader); + } + + [[nodiscard]] static $name read(CdrReader &reader) { + $(for d in &deserialize_fields => $d$['\r']) + $name result{}; + $(for a in &assignments => $a$['\r']) + return result; + } }; - let args: Tokens = - quote! { $(for field in fields join (, ) => value_$(field.name.clone())) }; - quote! { - std::size_t offset = 0; - $values - return $class_ident_ref($args); - } - }; + } + } - let members: Tokens = quote! { - $(for field in fields => - $(cpp_field_type(field.primitive)) $(field.name.clone())_{};$['\n'] - ) - }; + fn emit_message( + &self, + message_def: &Commented, + scope: &[String], + module_key: u16, + message_key: u16, + ) -> Result { + let topic = annotation_string(&message_def.annotations, "topic").ok_or( + CodegenError::MissingTopic { + message: scoped_name(scope, &message_def.node.name), + }, + )?; + let mut path = scope.to_vec(); + path.push(message_def.node.name.clone()); + let members = self.registry.collect_message_members(&path); + let fields: Vec = members + .iter() + .map(|member| { + let ty = self.cpp_type(&member.ty, scope); + let name = &member.name; + quote!( $ty $name{}; ) + }) + .collect(); + let serialize_fields: Vec = members + .iter() + .map(|member| self.serialize_value(&member.name, &member.ty, scope, quote!(writer))) + .collect(); + let deserialize_fields: Vec = members + .iter() + .map(|member| { + let var_name = format!("value_{}", member.name); + let read_expr = self.deserialize_value(&member.ty, scope, quote!(reader)); + quote! { + auto $var_name = $read_expr; + } + }) + .collect(); + let assignments: Vec = members + .iter() + .map(|member| { + let var_name = format!("value_{}", member.name); + let field = &member.name; + quote!( result.$field = $var_name; ) + }) + .collect(); + + let name = &message_def.node.name; + let mut schema = scope.join("."); + if !schema.is_empty() { + schema.push('.'); + } + schema.push_str(name); + let topic_literal = quoted_string(&topic); + let schema_literal = quoted_string(&schema); - quote! { - class $class_ident_ref { - public: + Ok(quote! { + struct $name { static constexpr std::uint16_t kModuleKey = $module_key; static constexpr std::uint16_t kMessageKey = $message_key; - static constexpr std::size_t kStructSize = $struct_size; - static constexpr std::size_t kPaddedSize = $padded_size; - static constexpr std::size_t kPayloadWords = $payload_words; static constexpr std::string_view kTopicTemplate = $topic_literal; + static constexpr std::string_view kSchema = $schema_literal; - $class_ident_ref() = default; - $(if !fields.is_empty() { - explicit $class_ident_ref($ctor_params) : $ctor_inits {} - }) + $( + for field in &fields => + $field + ) - $accessors + [[nodiscard]] std::vector to_payload() const { + CdrWriter writer; + serialize(writer); + return writer.buffer; + } - [[nodiscard]] std::array Serialize() const { - $serialize_body + void serialize(CdrWriter &writer) const { + $(for s in &serialize_fields => $s$['\r']) } - [[nodiscard]] static $class_ident_ref Deserialize(std::span payload) { - if (payload.size() < kStructSize) { throw std::runtime_error("payload shorter than struct"); } - $deserialize_body + [[nodiscard]] static $name from_payload(std::span payload) { + CdrReader reader(payload); + return read(reader); } - [[nodiscard]] static $class_ident_ref FromFrame(std::span frame) { - auto header = MessageHeader::Parse(frame); - const auto payload_len = static_cast(header.payload_words) * 4u; - if (frame.size() < MessageHeader::kSize + payload_len) { throw std::runtime_error("frame shorter than payload"); } - auto payload = frame.subspan(MessageHeader::kSize, payload_len); - return Deserialize(payload); + [[nodiscard]] static $name read(CdrReader &reader) { + $(for d in &deserialize_fields => $d$['\r']) + + $name result{}; + $(for a in &assignments join ($['\r']) => $a) + return result; } - template - void Publish(Session &session, std::string_view device_type, std::string_view nid) const { - auto payload = Serialize(); - MessageHeader header{static_cast(kPayloadWords), 0, kModuleKey, kMessageKey}; - auto header_bytes = header.Pack(); - std::array frame{}; - std::memcpy(frame.data(), header_bytes.data(), header_bytes.size()); - if constexpr (kPaddedSize > 0) { std::memcpy(frame.data() + MessageHeader::kSize, payload.data(), payload.size()); } - session.put(FormatTopic(device_type, nid), std::span(frame.data(), frame.size())); + [[nodiscard]] static std::string topic() { + return format_topic(std::string(kTopicTemplate)); } + }; + }) + } + + fn emit_const(&self, const_def: &Commented, scope: &[String]) -> Tokens { + let name = &const_def.node.name; + let ty = self.cpp_type(&const_def.node.const_type, scope); + let value = const_literal(&const_def.node.value); + quote!( inline constexpr $ty $name = $value; ) + } + + fn emit_typedef(&self, typedef_def: &Commented, scope: &[String]) -> Tokens { + let name = &typedef_def.node.name; + let base = self.cpp_type(&typedef_def.node.base_type, scope); + quote!( using $name = $base; ) + } - template - static void Subscribe(Session &session, std::string_view device_type, std::string_view nid, Callback &&callback) { - session.subscribe(FormatTopic(device_type, nid), [cb = std::forward(callback)](std::span frame) { cb(FromFrame(frame)); }); + fn cpp_type(&self, ty: &Type, scope: &[String]) -> Tokens { + match self.registry.resolve_type(ty, scope) { + Type::Float => quote!(float), + Type::Double => quote!(double), + Type::LongDouble => quote!(long double), + Type::Long => quote!(std::int32_t), + Type::UnsignedLong => quote!(std::uint32_t), + Type::LongLong => quote!(std::int64_t), + Type::UnsignedLongLong => quote!(std::uint64_t), + Type::Short => quote!(std::int16_t), + Type::UnsignedShort => quote!(std::uint16_t), + Type::Octet => quote!(std::uint8_t), + Type::Boolean => quote!(bool), + Type::Char => quote!(char), + Type::String { .. } => quote!(std::string), + Type::Sequence { element_type, .. } => { + let inner = self.cpp_type(&element_type, scope); + quote!(std::vector<$inner>) + } + Type::Array { + element_type, + dimensions, + } => { + let mut elem = *element_type; + for _ in dimensions { + elem = Type::Sequence { + element_type: Box::new(elem), + size: None, + }; } + self.cpp_type(&elem, scope) + } + Type::ScopedName(path) => { + let scoped = absolute_path(&path); + quote!( $scoped ) + } + Type::WString | Type::WChar => quote!(std::u16string), + } + } - private: - static std::string FormatTopic(std::string_view device_type, std::string_view nid) { - std::string topic = std::string(kTopicTemplate); - replace_placeholder(topic, "{{device_type}}", device_type); - replace_placeholder(topic, "{{nid}}", nid); - return topic; + fn serialize_value(&self, name: &str, ty: &Type, scope: &[String], writer: Tokens) -> Tokens { + match self.registry.resolve_type(ty, scope) { + Type::Boolean => quote!( $writer.write_bool($name); ), + Type::Char => quote!( $writer.write_char($name); ), + Type::Octet => quote!( $writer.write_u8($name); ), + Type::Short => quote!( $writer.write_i16($name); ), + Type::UnsignedShort => quote!( $writer.write_u16($name); ), + Type::Long => quote!( $writer.write_i32($name); ), + Type::UnsignedLong => quote!( $writer.write_u32($name); ), + Type::LongLong => quote!( $writer.write_i64($name); ), + Type::UnsignedLongLong => quote!( $writer.write_u64($name); ), + Type::Float => quote!( $writer.write_f32($name); ), + Type::Double => quote!( $writer.write_f64($name); ), + Type::LongDouble => quote!( $writer.write_f64(static_cast($name)); ), + Type::String { .. } => quote!( $writer.write_string($name); ), + Type::Sequence { element_type, .. } => { + let inner = self.serialize_value("item", &element_type, scope, quote!(writer)); + quote! { + $writer.write_sequence($name, [](CdrWriter &writer, const auto &item) { + $inner + }); + } + } + Type::Array { + element_type, + dimensions, + } => { + let mut elem = *element_type; + for _ in dimensions { + elem = Type::Sequence { + element_type: Box::new(elem), + size: None, + }; + } + self.serialize_value(name, &elem, scope, writer) + } + Type::ScopedName(path) => { + if let Some(base) = self.registry.enum_base(&path) { + let base_ty = self.cpp_type(base, scope); + let write_fn = match base { + Type::UnsignedShort => quote!(write_u16), + Type::Short => quote!(write_i16), + Type::UnsignedLong => quote!(write_u32), + Type::Long => quote!(write_i32), + Type::UnsignedLongLong => quote!(write_u64), + Type::LongLong => quote!(write_i64), + Type::Octet => quote!(write_u8), + Type::Boolean => quote!(write_bool), + Type::Char => quote!(write_char), + Type::Float => quote!(write_f32), + Type::Double => quote!(write_f64), + _ => quote!(write_u32), + }; + quote! { + $writer.$write_fn(static_cast<$base_ty>($name)); + } + } else { + quote!( $name.serialize($writer); ) } + } + Type::WString | Type::WChar => quote!( /* unsupported wide string */ ), + } + } - $members - }; + fn deserialize_value(&self, ty: &Type, scope: &[String], reader: Tokens) -> Tokens { + match self.registry.resolve_type(ty, scope) { + Type::Boolean => quote!( $reader.read_bool() ), + Type::Char => quote!( $reader.read_char() ), + Type::Octet => quote!( $reader.read_u8() ), + Type::Short => quote!( $reader.read_i16() ), + Type::UnsignedShort => quote!( $reader.read_u16() ), + Type::Long => quote!( $reader.read_i32() ), + Type::UnsignedLong => quote!( $reader.read_u32() ), + Type::LongLong => quote!( $reader.read_i64() ), + Type::UnsignedLongLong => quote!( $reader.read_u64() ), + Type::Float => quote!( $reader.read_f32() ), + Type::Double => quote!( $reader.read_f64() ), + Type::LongDouble => quote!( static_cast($reader.read_f64()) ), + Type::String { .. } => quote!( $reader.read_string() ), + Type::Sequence { element_type, .. } => { + let inner = self.deserialize_value(&element_type, scope, quote!(reader)); + quote! { + $reader.read_sequence([&](CdrReader &reader) { + return $inner; + }) + } + } + Type::Array { + element_type, + dimensions, + } => { + let mut elem = *element_type; + for _ in dimensions { + elem = Type::Sequence { + element_type: Box::new(elem), + size: None, + }; + } + self.deserialize_value(&elem, scope, reader) + } + Type::ScopedName(path) => { + if let Some(base) = self.registry.enum_base(&path) { + let read_expr = match base { + Type::UnsignedShort => quote!(read_u16), + Type::Short => quote!(read_i16), + Type::UnsignedLong => quote!(read_u32), + Type::Long => quote!(read_i32), + Type::UnsignedLongLong => quote!(read_u64), + Type::LongLong => quote!(read_i64), + Type::Octet => quote!(read_u8), + Type::Boolean => quote!(read_bool), + Type::Char => quote!(read_char), + Type::Float => quote!(read_f32), + Type::Double => quote!(read_f64), + _ => quote!(read_u32), + }; + let scoped = absolute_path(&path); + quote! { + static_cast<$scoped>($reader.$read_expr()) + } + } else { + let scoped = absolute_path(&path); + quote!( $scoped::read($reader) ) + } + } + Type::WString | Type::WChar => quote!(std::u16string()), } } } fn helpers_tokens() -> Tokens { quote! { - inline void write_u16(std::uint16_t value, std::uint8_t *out) { + inline void write_u16_raw(std::uint16_t value, std::uint8_t *out) { out[0] = static_cast(value & 0xffu); out[1] = static_cast((value >> 8) & 0xffu); } - inline std::uint16_t read_u16(const std::uint8_t *in) { + inline std::uint16_t read_u16_raw(const std::uint8_t *in) { return static_cast(static_cast(in[0]) | static_cast(in[1] << 8)); } - inline void write_u32(std::uint32_t value, std::uint8_t *out) { + inline void write_u32_raw(std::uint32_t value, std::uint8_t *out) { for (int i = 0; i < 4; ++i) { out[i] = static_cast((value >> (8 * i)) & 0xffu); } } - inline std::uint32_t read_u32(const std::uint8_t *in) { + inline std::uint32_t read_u32_raw(const std::uint8_t *in) { return static_cast(in[0]) | (static_cast(in[1]) << 8) | (static_cast(in[2]) << 16) | (static_cast(in[3]) << 24); } - inline void write_u64(std::uint64_t value, std::uint8_t *out) { + inline void write_u64_raw(std::uint64_t value, std::uint8_t *out) { for (int i = 0; i < 8; ++i) { out[i] = static_cast((value >> (8 * i)) & 0xffu); } } - inline std::uint64_t read_u64(const std::uint8_t *in) { + inline std::uint64_t read_u64_raw(const std::uint8_t *in) { std::uint64_t result = 0; for (int i = 0; i < 8; ++i) { result |= static_cast(in[i]) << (8 * i); @@ -229,157 +530,343 @@ fn helpers_tokens() -> Tokens { return result; } - inline void write_bool(bool value, std::uint8_t *out) { out[0] = value ? 1u : 0u; } - inline bool read_bool(const std::uint8_t *in) { return in[0] != 0; } - inline void write_char(char value, std::uint8_t *out) { out[0] = static_cast(value); } - inline char read_char(const std::uint8_t *in) { return static_cast(in[0]); } - inline void write_u8(std::uint8_t value, std::uint8_t *out) { out[0] = value; } - inline std::uint8_t read_u8(const std::uint8_t *in) { return in[0]; } - inline void write_i16(std::int16_t value, std::uint8_t *out) { write_u16(static_cast(value), out); } - inline std::int16_t read_i16(const std::uint8_t *in) { return static_cast(read_u16(in)); } - inline void write_u16_val(std::uint16_t value, std::uint8_t *out) { write_u16(value, out); } - inline std::uint16_t read_u16_val(const std::uint8_t *in) { return read_u16(in); } - inline void write_i32(std::int32_t value, std::uint8_t *out) { write_u32(static_cast(value), out); } - inline std::int32_t read_i32(const std::uint8_t *in) { return static_cast(read_u32(in)); } - inline void write_u32_val(std::uint32_t value, std::uint8_t *out) { write_u32(value, out); } - inline std::uint32_t read_u32_val(const std::uint8_t *in) { return read_u32(in); } - inline void write_i64(std::int64_t value, std::uint8_t *out) { write_u64(static_cast(value), out); } - inline std::int64_t read_i64(const std::uint8_t *in) { return static_cast(read_u64(in)); } - inline void write_u64_val(std::uint64_t value, std::uint8_t *out) { write_u64(value, out); } - inline std::uint64_t read_u64_val(const std::uint8_t *in) { return read_u64(in); } - - inline void write_f32(float value, std::uint8_t *out) { - std::uint32_t bits; - std::memcpy(&bits, &value, sizeof(bits)); - write_u32(bits, out); - } + struct CdrWriter { + std::vector buffer; - inline float read_f32(const std::uint8_t *in) { - std::uint32_t bits = read_u32(in); - float value; - std::memcpy(&value, &bits, sizeof(value)); - return value; - } + void align(std::size_t alignment) { + if (alignment <= 1) { return; } + const auto pad = (alignment - (buffer.size() % alignment)) % alignment; + buffer.insert(buffer.end(), pad, 0); + } - inline void write_f64(double value, std::uint8_t *out) { - std::uint64_t bits; - std::memcpy(&bits, &value, sizeof(bits)); - write_u64(bits, out); - } + void write_bool(bool value) { + align(1); + buffer.push_back(value ? 1u : 0u); + } - inline double read_f64(const std::uint8_t *in) { - std::uint64_t bits = read_u64(in); - double value; - std::memcpy(&value, &bits, sizeof(value)); - return value; - } + void write_char(char value) { + align(1); + buffer.push_back(static_cast(value)); + } - inline void replace_placeholder(std::string &topic, std::string_view placeholder, std::string_view value) { - const std::string needle(placeholder); - const std::string replacement(value); - std::size_t pos = topic.find(needle); - while (pos != std::string::npos) { - topic.replace(pos, needle.size(), replacement); - pos = topic.find(needle, pos + replacement.size()); + void write_u8(std::uint8_t value) { + align(1); + buffer.push_back(value); } - } - } -} -fn message_header() -> Tokens { - quote! { - struct MessageHeader { - std::uint16_t payload_words{}; - std::uint16_t flags{}; - std::uint16_t module_key{}; - std::uint16_t message_key{}; - - static constexpr std::size_t kSize = 8; - - [[nodiscard]] std::array Pack() const { - std::array data{}; - write_u16(payload_words, data.data()); - write_u16(flags, data.data() + 2); - write_u16(module_key, data.data() + 4); - write_u16(message_key, data.data() + 6); - return data; - } - - [[nodiscard]] static MessageHeader Parse(std::span frame) { - if (frame.size() < static_cast(kSize)) { - throw std::runtime_error("frame shorter than header"); + void write_i16(std::int16_t value) { + align(2); + std::array bytes{}; + write_u16_raw(static_cast(value), bytes.data()); + buffer.insert(buffer.end(), bytes.begin(), bytes.end()); + } + + void write_u16(std::uint16_t value) { + align(2); + std::array bytes{}; + write_u16_raw(value, bytes.data()); + buffer.insert(buffer.end(), bytes.begin(), bytes.end()); + } + + void write_i32(std::int32_t value) { + align(4); + std::array bytes{}; + write_u32_raw(static_cast(value), bytes.data()); + buffer.insert(buffer.end(), bytes.begin(), bytes.end()); + } + + void write_u32(std::uint32_t value) { + align(4); + std::array bytes{}; + write_u32_raw(value, bytes.data()); + buffer.insert(buffer.end(), bytes.begin(), bytes.end()); + } + + void write_i64(std::int64_t value) { + align(8); + std::array bytes{}; + write_u64_raw(static_cast(value), bytes.data()); + buffer.insert(buffer.end(), bytes.begin(), bytes.end()); + } + + void write_u64(std::uint64_t value) { + align(8); + std::array bytes{}; + write_u64_raw(value, bytes.data()); + buffer.insert(buffer.end(), bytes.begin(), bytes.end()); + } + + void write_f32(float value) { + align(4); + std::array bytes{}; + std::uint32_t bits; + std::memcpy(&bits, &value, sizeof(bits)); + write_u32_raw(bits, bytes.data()); + buffer.insert(buffer.end(), bytes.begin(), bytes.end()); + } + + void write_f64(double value) { + align(8); + std::array bytes{}; + std::uint64_t bits; + std::memcpy(&bits, &value, sizeof(bits)); + write_u64_raw(bits, bytes.data()); + buffer.insert(buffer.end(), bytes.begin(), bytes.end()); + } + + void write_string(std::string_view value) { + align(4); + const auto len = static_cast(value.size() + 1); + write_u32(len); + buffer.insert(buffer.end(), value.begin(), value.end()); + buffer.push_back(0); + } + + template + void write_sequence(const std::vector &values, WriteFn &&write_fn) { + align(4); + write_u32(static_cast(values.size())); + for (const auto &item : values) { + write_fn(*this, item); + } + } + }; + + struct CdrReader { + std::span data; + std::size_t offset{0}; + + explicit CdrReader(std::span d) : data(d) {} + + void align(std::size_t alignment) { + if (alignment <= 1) { return; } + const auto pad = (alignment - (offset % alignment)) % alignment; + if (offset + pad > data.size()) { throw std::runtime_error("cursor out of range"); } + offset += pad; + } + + std::span take(std::size_t size) { + if (offset + size > data.size()) { throw std::runtime_error("cursor out of range"); } + auto slice = data.subspan(offset, size); + offset += size; + return slice; + } + + bool read_bool() { + align(1); + auto bytes = take(1); + return bytes[0] != 0; + } + + char read_char() { + align(1); + auto bytes = take(1); + return static_cast(bytes[0]); + } + + std::uint8_t read_u8() { + align(1); + auto bytes = take(1); + return bytes[0]; + } + + std::int16_t read_i16() { + align(2); + auto bytes = take(2); + return static_cast(read_u16_raw(bytes.data())); + } + + std::uint16_t read_u16() { + align(2); + auto bytes = take(2); + return read_u16_raw(bytes.data()); + } + + std::int32_t read_i32() { + align(4); + auto bytes = take(4); + return static_cast(read_u32_raw(bytes.data())); + } + + std::uint32_t read_u32() { + align(4); + auto bytes = take(4); + return read_u32_raw(bytes.data()); + } + + std::int64_t read_i64() { + align(8); + auto bytes = take(8); + return static_cast(read_u64_raw(bytes.data())); + } + + std::uint64_t read_u64() { + align(8); + auto bytes = take(8); + return read_u64_raw(bytes.data()); + } + + float read_f32() { + align(4); + auto bytes = take(4); + std::uint32_t bits = read_u32_raw(bytes.data()); + float value; + std::memcpy(&value, &bits, sizeof(value)); + return value; + } + + double read_f64() { + align(8); + auto bytes = take(8); + std::uint64_t bits = read_u64_raw(bytes.data()); + double value; + std::memcpy(&value, &bits, sizeof(value)); + return value; + } + + std::string read_string() { + align(4); + auto len = read_u32(); + if (len == 0) { return std::string(); } + auto bytes = take(len); + if (bytes[len - 1] != 0) { throw std::runtime_error("CDR string missing null terminator"); } + return std::string(reinterpret_cast(bytes.data()), len - 1); + } + + template + auto read_sequence(ReadFn &&read_fn) { + align(4); + auto len = read_u32(); + using Elem = std::invoke_result_t; + std::vector values; + values.reserve(len); + for (std::uint32_t i = 0; i < len; ++i) { + values.push_back(read_fn(*this)); } - MessageHeader header{}; - header.payload_words = read_u16(frame.data()); - header.flags = read_u16(frame.data() + 2); - header.module_key = read_u16(frame.data() + 4); - header.message_key = read_u16(frame.data() + 6); - return header; + return values; } }; + + inline std::string format_topic(std::string_view template_) { + return std::string(template_); + } } } -fn cpp_field_type(primitive: PrimitiveType) -> Tokens { - match primitive { - PrimitiveType::Bool => quote!(bool), - PrimitiveType::Char => quote!(char), - PrimitiveType::Octet => quote!(std::uint8_t), - PrimitiveType::I16 => quote!(std::int16_t), - PrimitiveType::U16 => quote!(std::uint16_t), - PrimitiveType::I32 => quote!(std::int32_t), - PrimitiveType::U32 => quote!(std::uint32_t), - PrimitiveType::I64 => quote!(std::int64_t), - PrimitiveType::U64 => quote!(std::uint64_t), - PrimitiveType::F32 => quote!(float), - PrimitiveType::F64 => quote!(double), +fn const_literal(value: &ConstValue) -> Tokens { + match value { + ConstValue::Integer(lit) => quote!( $(lit.value) ), + ConstValue::Float(v) => quote!( $(format!("{v}")) ), + ConstValue::Fixed(f) => { + let mut digits = f.digits.clone(); + if f.scale > 0 { + let point = digits.len().saturating_sub(f.scale as usize); + digits.insert(point, '.'); + } + if f.negative { + digits.insert(0, '-'); + } + quote!( $(digits) ) + } + ConstValue::Binary(bin) => { + let value = bin.to_i64(); + quote!( $value ) + } + ConstValue::String(s) => quote!( $(quoted_string(s)) ), + ConstValue::Boolean(value) => { + if *value { + quote!(true) + } else { + quote!(false) + } + } + ConstValue::Char(ch) => quote!( $(format!("'{}'", ch)) ), + ConstValue::ScopedName(path) => quote!( $(path.join("::")) ), + ConstValue::UnaryOp { .. } | ConstValue::BinaryOp { .. } => quote!(0), } } -fn cpp_default(primitive: PrimitiveType) -> Tokens { - match primitive { - PrimitiveType::Bool => quote!(false), - PrimitiveType::Char => quote!('\0'), - PrimitiveType::Octet - | PrimitiveType::I16 - | PrimitiveType::U16 - | PrimitiveType::I32 - | PrimitiveType::U32 - | PrimitiveType::I64 - | PrimitiveType::U64 => quote!(0), - PrimitiveType::F32 => quote!(0.0F), - PrimitiveType::F64 => quote!(0.0), +fn annotation_value<'a>(annotations: &'a [Annotation], name: &str) -> Option<&'a ConstValue> { + annotations + .iter() + .find(|annotation| { + annotation + .name + .last() + .map(|segment| segment.eq_ignore_ascii_case(name)) + .unwrap_or(false) + }) + .and_then(|annotation| { + annotation + .params + .iter() + .map(|param| match param { + AnnotationParam::Named { name, value } + if name.eq_ignore_ascii_case("value") => + { + value + } + AnnotationParam::Positional(value) => value, + AnnotationParam::Named { value, .. } => value, + }) + .next() + }) +} + +fn annotation_string(annotations: &[Annotation], name: &str) -> Option { + annotation_value(annotations, name).and_then(|value| match value { + ConstValue::String(value) => Some(value.clone()), + _ => None, + }) +} + +fn annotation_u16(annotations: &[Annotation], name: &str) -> Option { + annotation_value(annotations, name).and_then(|value| match value { + ConstValue::Integer(lit) if (0..=u16::MAX as i64).contains(&lit.value) => { + Some(lit.value as u16) + } + _ => None, + }) +} + +fn scoped_name(scope: &[String], name: &str) -> String { + if scope.is_empty() { + name.to_string() + } else { + format!("{}::{}", scope.join("::"), name) } } -fn cpp_writer(primitive: PrimitiveType) -> Tokens { - match primitive { - PrimitiveType::Bool => quote!(write_bool), - PrimitiveType::Char => quote!(write_char), - PrimitiveType::Octet => quote!(write_u8), - PrimitiveType::I16 => quote!(write_i16), - PrimitiveType::U16 => quote!(write_u16_val), - PrimitiveType::I32 => quote!(write_i32), - PrimitiveType::U32 => quote!(write_u32_val), - PrimitiveType::I64 => quote!(write_i64), - PrimitiveType::U64 => quote!(write_u64_val), - PrimitiveType::F32 => quote!(write_f32), - PrimitiveType::F64 => quote!(write_f64), +fn absolute_path(path: &[String]) -> String { + format!("::blueberry_generated::{}", path.join("::")) +} + +fn quoted_string(value: &str) -> String { + let mut escaped = String::new(); + for ch in value.chars() { + match ch { + '"' => escaped.push_str("\\\""), + '\\' => escaped.push_str("\\\\"), + '\n' => escaped.push_str("\\n"), + '\r' => escaped.push_str("\\r"), + '\t' => escaped.push_str("\\t"), + _ => escaped.push(ch), + } } + format!("\"{escaped}\"") } -fn cpp_reader(primitive: PrimitiveType) -> Tokens { - match primitive { - PrimitiveType::Bool => quote!(read_bool), - PrimitiveType::Char => quote!(read_char), - PrimitiveType::Octet => quote!(read_u8), - PrimitiveType::I16 => quote!(read_i16), - PrimitiveType::U16 => quote!(read_u16), - PrimitiveType::I32 => quote!(read_i32), - PrimitiveType::U32 => quote!(read_u32), - PrimitiveType::I64 => quote!(read_i64), - PrimitiveType::U64 => quote!(read_u64), - PrimitiveType::F32 => quote!(read_f32), - PrimitiveType::F64 => quote!(read_f64), +#[derive(Default)] +struct MessageKeyGen { + next: u16, +} + +impl MessageKeyGen { + fn next(&mut self) -> u16 { + self.next = self.next.wrapping_add(1); + if self.next == 0 { + self.next = 1; + } + self.next } } diff --git a/crates/generators/python/src/lib.rs b/crates/generators/python/src/lib.rs index ce3555e..0b63dd7 100644 --- a/crates/generators/python/src/lib.rs +++ b/crates/generators/python/src/lib.rs @@ -1,167 +1,549 @@ -use blueberry_ast::Definition; -use blueberry_codegen_core::{ - CodegenError, GeneratedFile, MessageSpec, PrimitiveType, class_name, collect_messages, - quoted_string, +use blueberry_ast::{ + Annotation, AnnotationParam, Commented, ConstDef, ConstValue, Definition, EnumDef, MessageDef, + StructDef, Type, TypeDef, }; +use blueberry_codegen_core::{CodegenError, GeneratedFile, TypeRegistry}; use genco::lang::python::Tokens; use genco::quote; const OUTPUT_PATH: &str = "python/messages.py"; pub fn generate(definitions: &[Definition]) -> Result, CodegenError> { - let messages = collect_messages(definitions)?; - if messages.is_empty() { - return Ok(Vec::new()); - } - let generator = PythonGenerator::new(&messages); - let contents = generator.render(); + let generator = PythonGenerator::new(definitions); + let contents = generator.render(definitions)?; Ok(vec![GeneratedFile { path: OUTPUT_PATH.to_string(), contents, }]) } -struct PythonGenerator<'a> { - messages: &'a [MessageSpec], +struct PythonGenerator { + registry: TypeRegistry, } -impl<'a> PythonGenerator<'a> { - fn new(messages: &'a [MessageSpec]) -> Self { - Self { messages } +impl PythonGenerator { + fn new(definitions: &[Definition]) -> Self { + Self { + registry: TypeRegistry::new(definitions), + } } - fn render(&self) -> String { - let classes: Vec = self.messages.iter().map(|m| self.emit_message(m)).collect(); - let header = message_header(); + fn render(&self, definitions: &[Definition]) -> Result { + let defs = self.emit_definitions(definitions, &mut Vec::new())?; + let helpers = helpers_tokens(); let tokens: Tokens = quote! { # Auto-generated Blueberry bindings from __future__ import annotations from dataclasses import dataclass + import enum import struct - from typing import ClassVar + from typing import ClassVar, Callable, List - $header + $helpers - $(for class_def in &classes => - $class_def + $(for def in defs => + $def ) }; - tokens.to_file_string().expect("render python output") + Ok(tokens.to_file_string().expect("render python output")) } - fn emit_message(&self, message: &MessageSpec) -> Tokens { - let class_ident = class_name(&message.scope, &message.name); - let topic_literal = quoted_string(&message.topic); - let struct_format = quoted_string(&format!( - "<{}", - message - .fields - .iter() - .map(|field| python_struct_code(field.primitive)) - .collect::() - )); - let module_key = message.module_key; - let message_key = message.message_key; - let struct_size = message.field_payload_size(); - let padded_size = message.padded_payload_size(); - let payload_words = message.payload_words(); - let field_defs: Tokens = quote! { - $(for field in &message.fields => - $(field.name.clone()): $(python_type_hint(field.primitive))$['\n'] + fn emit_definitions( + &self, + defs: &[Definition], + scope: &mut Vec, + ) -> Result, CodegenError> { + let mut out = Vec::new(); + let mut message_keys = MessageKeyGen::default(); + + for def in defs { + match def { + Definition::ModuleDef(module) => { + scope.push(module.node.name.clone()); + let nested = self.emit_definitions(&module.node.definitions, scope)?; + out.extend(nested); + scope.pop(); + } + Definition::EnumDef(enum_def) => { + out.push(quote! { $['\n'] }); + out.push(self.emit_enum(enum_def, scope)); + } + Definition::StructDef(struct_def) => { + out.push(quote! { $['\n'] }); + out.push(self.emit_struct(struct_def, scope)); + } + Definition::MessageDef(message_def) => { + let module_key = + annotation_u16(&message_def.annotations, "module_key").unwrap_or(0x4242); + let message_key = annotation_u16(&message_def.annotations, "message_key") + .unwrap_or_else(|| message_keys.next()); + out.push(quote! { $['\n'] }); + out.push(self.emit_message(message_def, scope, module_key, message_key)?); + } + Definition::ConstDef(const_def) => { + out.push(quote! { $['\r'] }); + out.push(self.emit_const(const_def, scope)); + } + Definition::TypeDef(typedef_def) => { + out.push(quote! { $['\r'] }); + out.push(self.emit_typedef(typedef_def, scope)); + } + Definition::ImportDef(_) => {} + } + } + + Ok(out) + } + + fn emit_enum(&self, enum_def: &Commented, scope: &[String]) -> Tokens { + let name = class_name(scope, &enum_def.node.name); + let members: Vec = enum_def + .node + .enumerators + .iter() + .enumerate() + .map(|(idx, member)| { + let value = member + .value + .as_ref() + .and_then(|v| match v { + ConstValue::Integer(lit) => Some(lit.value), + ConstValue::Binary(bin) => Some(bin.to_i64()), + _ => None, + }) + .unwrap_or(idx as i64); + quote!( $(member.name.clone()) = $(value) ) + }) + .collect(); + + quote! { + class $name(enum.IntEnum): + $(for m in members => + $m + ) + } + } + + fn emit_struct(&self, struct_def: &Commented, scope: &[String]) -> Tokens { + let mut path = scope.to_vec(); + path.push(struct_def.node.name.clone()); + let members = self.registry.collect_struct_members(&path); + let name = class_name(scope, &struct_def.node.name); + let name_ref = name.as_str(); + let fields: Tokens = quote! { + $(for member in &members => + $(member.name.clone()): $(python_type_hint(&member.ty, scope, &self.registry))$['\r'] ) }; - - let payload_methods = self.payload_tokens(message, &class_ident); - let frame_helpers = frame_helper_tokens(&class_ident); - let publish_helpers = publish_helper_tokens(&class_ident); + let serialize_fields: Vec = members + .iter() + .map(|member| self.serialize_value(&member.name, &member.ty, scope)) + .collect(); + let deserialize_fields: Vec = members + .iter() + .map(|member| { + let var_name = format!("value_{}", member.name); + let read_expr = self.deserialize_value(&member.ty, scope); + quote!( $var_name = $read_expr ) + }) + .collect(); + let assignments: Vec = members + .iter() + .map(|member| { + let var_name = format!("value_{}", member.name); + let field = &member.name; + quote!( $field=$var_name ) + }) + .collect(); quote! { @dataclass - class $class_ident: - topic_template: ClassVar[str] = $topic_literal - module_key: ClassVar[int] = $module_key - message_key: ClassVar[int] = $message_key - STRUCT_FORMAT: ClassVar[str] = $struct_format - STRUCT_SIZE: ClassVar[int] = $struct_size - PADDED_SIZE: ClassVar[int] = $padded_size - PAYLOAD_WORDS: ClassVar[int] = $payload_words + class $name_ref: + $fields - $field_defs + def to_payload(self) -> bytes: + writer = CdrWriter() + self.serialize(writer) + return bytes(writer.buffer) - $payload_methods + def serialize(self, writer: "CdrWriter") -> None: + $(for s in &serialize_fields =>$s$['\r']) - $frame_helpers + @classmethod + def from_payload(cls, payload: bytes) -> $name_ref: + reader = CdrReader(payload) + return cls.read(reader) - $publish_helpers + @classmethod + def read(cls, reader: "CdrReader") -> $name_ref: + $(for d in &deserialize_fields => $d$['\r']) + return cls($(for a in &assignments join (, ) => $a)) } } - fn payload_tokens(&self, message: &MessageSpec, class_ident: &str) -> Tokens { - if message.fields.is_empty() { - return quote! { + fn emit_message( + &self, + message_def: &Commented, + scope: &[String], + module_key: u16, + message_key: u16, + ) -> Result { + let topic = annotation_string(&message_def.annotations, "topic").ok_or( + CodegenError::MissingTopic { + message: scoped_name(scope, &message_def.node.name), + }, + )?; + let mut path = scope.to_vec(); + path.push(message_def.node.name.clone()); + let members = self.registry.collect_message_members(&path); + let name = class_name(scope, &message_def.node.name); + let name_ref = name.as_str(); + let fields: Tokens = quote! { + $(for member in &members => + $(member.name.clone()): $(python_type_hint(&member.ty, scope, &self.registry))$['\r'] + ) + }; + let serialize_fields: Vec = members + .iter() + .map(|member| self.serialize_value(&member.name, &member.ty, scope)) + .collect(); + let deserialize_fields: Vec = members + .iter() + .map(|member| { + let var_name = format!("value_{}", member.name); + let read_expr = self.deserialize_value(&member.ty, scope); + quote!( $var_name = $read_expr ) + }) + .collect(); + let assignments: Vec = members + .iter() + .map(|member| { + let var_name = format!("value_{}", member.name); + let field = &member.name; + quote!( $field=$var_name ) + }) + .collect(); + let topic_literal = quoted_string(&topic); + + Ok(quote! { + @dataclass + class $name_ref: + topic_template: ClassVar[str] = $topic_literal + module_key: ClassVar[int] = $module_key + message_key: ClassVar[int] = $message_key + + $fields + def to_payload(self) -> bytes: - return b"".ljust(self.PADDED_SIZE, b"\x00") + writer = CdrWriter() + self.serialize(writer) + return bytes(writer.buffer) + + def serialize(self, writer: "CdrWriter") -> None: + $(for s in &serialize_fields => $s$['\r']) @classmethod - def from_payload(cls, payload: bytes) -> $class_ident: - return cls() - }; - } + def from_payload(cls, payload: bytes) -> $name_ref: + reader = CdrReader(payload) + return cls.read(reader) - let pack_args: Tokens = - quote! { $(for field in &message.fields join (, ) => self.$(field.name.clone())) }; - let value_inits: Tokens = quote! { - $(for (idx, field) in message.fields.iter().enumerate() join (, ) => - $(field.name.clone()) = values[$idx] - ) - }; + @classmethod + def read(cls, reader: "CdrReader") -> $name_ref: + $(for d in &deserialize_fields => $d$['\r']) + return cls($(for a in &assignments join (, ) => $a)) - quote! { - def to_payload(self) -> bytes: - payload = struct.pack(self.STRUCT_FORMAT, $pack_args) - if len(payload) < self.PADDED_SIZE: - payload += b"\x00" * (self.PADDED_SIZE - len(payload)) - return payload + @classmethod + def from_frame(cls, frame: bytes) -> $name_ref: + header = MessageHeader.parse(frame) + payload_len = header.payload_words * 4 + payload = frame[MessageHeader.HEADER_LEN:MessageHeader.HEADER_LEN + payload_len] + return cls.from_payload(payload) - @classmethod - def from_payload(cls, payload: bytes) -> $class_ident: - values = struct.unpack(cls.STRUCT_FORMAT, payload[: cls.STRUCT_SIZE]) - return cls($value_inits) - } + @classmethod + def topic(cls, device_type: str, nid: str) -> str: + return cls.topic_template.format(device_type=device_type, nid=nid) + + @staticmethod + def subscribe(session, device_type: str, nid: str, callback) -> None: + topic = $name_ref.topic(device_type, nid) + session.subscribe(topic, lambda frame: callback($name_ref.from_frame(frame))) + + def publish(self, session, device_type: str, nid: str) -> None: + topic = self.topic(device_type, nid) + payload = self.to_payload() + header = MessageHeader(CdrWriter.payload_words_for(payload), 0, self.module_key, self.message_key) + session.put(topic, header.pack() + payload) + + @staticmethod + def payload_words(payload: bytes) -> int: + return CdrWriter.payload_words_for(payload) + }) } -} -fn frame_helper_tokens(class_ident: &str) -> Tokens { - quote! { - @classmethod - def from_frame(cls, frame: bytes) -> $class_ident: - header = MessageHeader.parse(frame) - payload_len = header.payload_words * 4 - payload = frame[MessageHeader.HEADER_LEN:MessageHeader.HEADER_LEN + payload_len] - return cls.from_payload(payload) + fn emit_const(&self, const_def: &Commented, scope: &[String]) -> Tokens { + let name = class_name(scope, &const_def.node.name); + let value = const_literal(&const_def.node.value); + quote!( $name = $value ) } -} -fn publish_helper_tokens(class_ident: &str) -> Tokens { - quote! { - @staticmethod - def subscribe(session, device_type: str, nid: str, callback) -> None: - topic = $class_ident.topic_template.format(device_type=device_type, nid=nid) - session.subscribe(topic, lambda frame: callback($class_ident.from_frame(frame))) + fn emit_typedef(&self, typedef_def: &Commented, scope: &[String]) -> Tokens { + let name = class_name(scope, &typedef_def.node.name); + let base = python_type_hint(&typedef_def.node.base_type, scope, &self.registry); + quote!( $name = $base ) + } - def publish(self, session, device_type: str, nid: str) -> None: - topic = self.topic_template.format(device_type=device_type, nid=nid) - payload = self.to_payload() - header = MessageHeader(self.PAYLOAD_WORDS, 0, self.module_key, self.message_key) - session.put(topic, header.pack() + payload) + fn serialize_value(&self, name: &str, ty: &Type, scope: &[String]) -> Tokens { + match self.registry.resolve_type(ty, scope) { + Type::Boolean => quote!( writer.write_bool(self.$name) ), + Type::Char => quote!( writer.write_char(self.$name) ), + Type::Octet => quote!( writer.write_u8(self.$name) ), + Type::Short => quote!( writer.write_i16(self.$name) ), + Type::UnsignedShort => quote!( writer.write_u16(self.$name) ), + Type::Long => quote!( writer.write_i32(self.$name) ), + Type::UnsignedLong => quote!( writer.write_u32(self.$name) ), + Type::LongLong => quote!( writer.write_i64(self.$name) ), + Type::UnsignedLongLong => quote!( writer.write_u64(self.$name) ), + Type::Float => quote!( writer.write_f32(self.$name) ), + Type::Double => quote!( writer.write_f64(self.$name) ), + Type::String { .. } => quote!( writer.write_string(self.$name) ), + Type::Sequence { element_type, .. } => { + let inner = self.serialize_value("item", &element_type, scope); + quote! { + writer.write_sequence(self.$name, lambda w, item: ($inner)) + } + } + Type::Array { + element_type, + dimensions, + } => { + let mut elem = *element_type; + for _ in dimensions { + elem = Type::Sequence { + element_type: Box::new(elem), + size: None, + }; + } + self.serialize_value(name, &elem, scope) + } + Type::ScopedName(path) => { + if let Some(base) = self.registry.enum_base(&path) { + let write_fn = writer_for(base); + quote!( writer.$write_fn(int(self.$name)) ) + } else { + quote!( self.$name.serialize(writer) ) + } + } + Type::WString | Type::WChar | Type::LongDouble => quote!(None), + } + } + + fn deserialize_value(&self, ty: &Type, scope: &[String]) -> Tokens { + match self.registry.resolve_type(ty, scope) { + Type::Boolean => quote!(reader.read_bool()), + Type::Char => quote!(reader.read_char()), + Type::Octet => quote!(reader.read_u8()), + Type::Short => quote!(reader.read_i16()), + Type::UnsignedShort => quote!(reader.read_u16()), + Type::Long => quote!(reader.read_i32()), + Type::UnsignedLong => quote!(reader.read_u32()), + Type::LongLong => quote!(reader.read_i64()), + Type::UnsignedLongLong => quote!(reader.read_u64()), + Type::Float => quote!(reader.read_f32()), + Type::Double => quote!(reader.read_f64()), + Type::String { .. } => quote!(reader.read_string()), + Type::Sequence { element_type, .. } => { + let inner = self.deserialize_value(&element_type, scope); + quote! { + reader.read_sequence(lambda r: ($inner)) + } + } + Type::Array { + element_type, + dimensions, + } => { + let mut elem = *element_type; + for _ in dimensions { + elem = Type::Sequence { + element_type: Box::new(elem), + size: None, + }; + } + self.deserialize_value(&elem, scope) + } + Type::ScopedName(path) => { + if let Some(base) = self.registry.enum_base(&path) { + let read_fn = reader_for(base); + let scoped = class_name_path(&path); + quote!( $scoped(reader.$read_fn()) ) + } else { + let scoped = class_name_path(&path); + quote!( $scoped.read(reader) ) + } + } + Type::WString | Type::WChar | Type::LongDouble => quote!(None), + } } } -fn message_header() -> Tokens { +fn helpers_tokens() -> Tokens { quote! { + class CdrWriter: + def __init__(self) -> None: + self.buffer: bytearray = bytearray() + + @staticmethod + def payload_words_for(payload: bytes) -> int: + return (len(payload) + 3) // 4 + + def align(self, alignment: int) -> None: + if alignment <= 1: + return + pad = (alignment - (len(self.buffer) % alignment)) % alignment + if pad: + self.buffer.extend(b"\x00" * pad) + + def write_bool(self, value: bool) -> None: + self.align(1) + self.buffer.append(1 if value else 0) + + def write_char(self, value: str) -> None: + self.align(1) + b = value.encode("latin1") if value else b"\x00" + self.buffer.extend(b[:1].ljust(1, b"\x00")) + + def write_u8(self, value: int) -> None: + self.align(1) + self.buffer.extend(struct.pack(" None: + self.align(2) + self.buffer.extend(struct.pack(" None: + self.align(2) + self.buffer.extend(struct.pack(" None: + self.align(4) + self.buffer.extend(struct.pack(" None: + self.align(4) + self.buffer.extend(struct.pack(" None: + self.align(8) + self.buffer.extend(struct.pack(" None: + self.align(8) + self.buffer.extend(struct.pack(" None: + self.align(4) + self.buffer.extend(struct.pack(" None: + self.align(8) + self.buffer.extend(struct.pack(" None: + self.align(4) + data = value.encode("utf-8") + b"\x00" + self.write_u32(len(data)) + self.buffer.extend(data) + + def write_sequence(self, values, write_fn: Callable[["CdrWriter", object], None]) -> None: + self.align(4) + self.write_u32(len(values)) + for item in values: + write_fn(self, item) + + class CdrReader: + def __init__(self, data: bytes) -> None: + self.data = data + self.offset = 0 + + def align(self, alignment: int) -> None: + if alignment <= 1: + return + pad = (alignment - (self.offset % alignment)) % alignment + if self.offset + pad > len(self.data): + raise RuntimeError("cursor out of range") + self.offset += pad + + def take(self, size: int) -> bytes: + if self.offset + size > len(self.data): + raise RuntimeError("cursor out of range") + start = self.offset + self.offset += size + return self.data[start:self.offset] + + def read_bool(self) -> bool: + self.align(1) + return bool(self.take(1)[0]) + + def read_char(self) -> str: + self.align(1) + return self.take(1).decode("latin1") + + def read_u8(self) -> int: + self.align(1) + return struct.unpack(" int: + self.align(2) + return struct.unpack(" int: + self.align(2) + return struct.unpack(" int: + self.align(4) + return struct.unpack(" int: + self.align(4) + return struct.unpack(" int: + self.align(8) + return struct.unpack(" int: + self.align(8) + return struct.unpack(" float: + self.align(4) + return struct.unpack(" float: + self.align(8) + return struct.unpack(" str: + self.align(4) + length = self.read_u32() + if length == 0: + return "" + data = self.take(length) + if data[-1] != 0: + raise RuntimeError("CDR string missing null terminator") + return data[:-1].decode("utf-8") + + def read_sequence(self, read_fn: Callable[["CdrReader"], object]): + self.align(4) + length = self.read_u32() + values = [] + for _ in range(length): + values.append(read_fn(self)) + return values + @dataclass class MessageHeader: payload_words: int @@ -181,31 +563,199 @@ fn message_header() -> Tokens { } } -fn python_type_hint(primitive: PrimitiveType) -> Tokens { - match primitive { - PrimitiveType::Bool => quote!(bool), - PrimitiveType::Char => quote!(str), - PrimitiveType::Octet - | PrimitiveType::I16 - | PrimitiveType::U16 - | PrimitiveType::I32 - | PrimitiveType::U32 - | PrimitiveType::I64 - | PrimitiveType::U64 => quote!(int), - PrimitiveType::F32 | PrimitiveType::F64 => quote!(float), - } -} - -fn python_struct_code(primitive: PrimitiveType) -> &'static str { - match primitive { - PrimitiveType::Bool | PrimitiveType::Octet | PrimitiveType::Char => "B", - PrimitiveType::I16 => "h", - PrimitiveType::U16 => "H", - PrimitiveType::I32 => "i", - PrimitiveType::U32 => "I", - PrimitiveType::I64 => "q", - PrimitiveType::U64 => "Q", - PrimitiveType::F32 => "f", - PrimitiveType::F64 => "d", +fn writer_for(ty: &Type) -> Tokens { + match ty { + Type::UnsignedShort => quote!(write_u16), + Type::Short => quote!(write_i16), + Type::UnsignedLong => quote!(write_u32), + Type::Long => quote!(write_i32), + Type::UnsignedLongLong => quote!(write_u64), + Type::LongLong => quote!(write_i64), + Type::Octet => quote!(write_u8), + Type::Boolean => quote!(write_bool), + Type::Char => quote!(write_char), + Type::Float => quote!(write_f32), + Type::Double => quote!(write_f64), + _ => quote!(write_u32), + } +} + +fn reader_for(ty: &Type) -> Tokens { + match ty { + Type::UnsignedShort => quote!(read_u16), + Type::Short => quote!(read_i16), + Type::UnsignedLong => quote!(read_u32), + Type::Long => quote!(read_i32), + Type::UnsignedLongLong => quote!(read_u64), + Type::LongLong => quote!(read_i64), + Type::Octet => quote!(read_u8), + Type::Boolean => quote!(read_bool), + Type::Char => quote!(read_char), + Type::Float => quote!(read_f32), + Type::Double => quote!(read_f64), + _ => quote!(read_u32), + } +} + +fn annotation_value<'a>(annotations: &'a [Annotation], name: &str) -> Option<&'a ConstValue> { + annotations + .iter() + .find(|annotation| { + annotation + .name + .last() + .map(|segment| segment.eq_ignore_ascii_case(name)) + .unwrap_or(false) + }) + .and_then(|annotation| { + annotation + .params + .iter() + .map(|param| match param { + AnnotationParam::Named { name, value } + if name.eq_ignore_ascii_case("value") => + { + Some(value) + } + AnnotationParam::Positional(value) => Some(value), + AnnotationParam::Named { value, .. } => Some(value), + }) + .next() + .flatten() + }) +} + +fn annotation_string(annotations: &[Annotation], name: &str) -> Option { + annotation_value(annotations, name).and_then(|value| match value { + ConstValue::String(value) => Some(value.clone()), + _ => None, + }) +} + +fn annotation_u16(annotations: &[Annotation], name: &str) -> Option { + annotation_value(annotations, name).and_then(|value| match value { + ConstValue::Integer(lit) if (0..=u16::MAX as i64).contains(&lit.value) => { + Some(lit.value as u16) + } + _ => None, + }) +} + +fn class_name(scope: &[String], name: &str) -> String { + let mut parts = scope.to_vec(); + parts.push(name.to_string()); + parts.join("_") +} + +fn class_name_path(path: &[String]) -> String { + path.join("_") +} + +fn scoped_name(scope: &[String], name: &str) -> String { + if scope.is_empty() { + name.to_string() + } else { + format!("{}::{}", scope.join("::"), name) + } +} + +fn quoted_string(value: &str) -> String { + let mut escaped = String::new(); + for ch in value.chars() { + match ch { + '"' => escaped.push_str("\\\""), + '\\' => escaped.push_str("\\\\"), + '\n' => escaped.push_str("\\n"), + '\r' => escaped.push_str("\\r"), + '\t' => escaped.push_str("\\t"), + _ => escaped.push(ch), + } + } + format!("\"{escaped}\"") +} + +fn python_type_hint(ty: &Type, scope: &[String], registry: &TypeRegistry) -> Tokens { + match registry.resolve_type(ty, scope) { + Type::Boolean => quote!(bool), + Type::Char => quote!(str), + Type::Octet + | Type::Short + | Type::UnsignedShort + | Type::Long + | Type::UnsignedLong + | Type::LongLong + | Type::UnsignedLongLong => quote!(int), + Type::Float | Type::Double => quote!(float), + Type::String { .. } => quote!(str), + Type::Sequence { element_type, .. } => { + let inner = python_type_hint(&element_type, scope, registry); + quote!(List[$inner]) + } + Type::Array { + element_type, + dimensions, + } => { + let mut elem = *element_type; + for _ in dimensions { + elem = Type::Sequence { + element_type: Box::new(elem), + size: None, + }; + } + python_type_hint(&elem, scope, registry) + } + Type::ScopedName(path) => { + let ident = class_name_path(&path); + quote!( $ident ) + } + Type::WString | Type::WChar | Type::LongDouble => quote!(object), + } +} + +fn const_literal(value: &ConstValue) -> Tokens { + match value { + ConstValue::Integer(lit) => quote!( $(lit.value) ), + ConstValue::Float(v) => quote!( $(format!("{v}")) ), + ConstValue::Fixed(f) => { + let mut digits = f.digits.clone(); + if f.scale > 0 { + let point = digits.len().saturating_sub(f.scale as usize); + digits.insert(point, '.'); + } + if f.negative { + digits.insert(0, '-'); + } + quote!( $(digits) ) + } + ConstValue::Binary(bin) => { + let value = bin.to_i64(); + quote!( $value ) + } + ConstValue::String(s) => quote!( $(quoted_string(s)) ), + ConstValue::Boolean(value) => { + if *value { + quote!(True) + } else { + quote!(False) + } + } + ConstValue::Char(ch) => quote!( $(format!("'{}'", ch)) ), + ConstValue::ScopedName(path) => quote!( $(path.join("::")) ), + ConstValue::UnaryOp { .. } | ConstValue::BinaryOp { .. } => quote!(0), + } +} + +#[derive(Default)] +struct MessageKeyGen { + next: u16, +} + +impl MessageKeyGen { + fn next(&mut self) -> u16 { + self.next = self.next.wrapping_add(1); + if self.next == 0 { + self.next = 1; + } + self.next } } diff --git a/crates/generators/rust/src/lib.rs b/crates/generators/rust/src/lib.rs index ecfcfce..a785f63 100644 --- a/crates/generators/rust/src/lib.rs +++ b/crates/generators/rust/src/lib.rs @@ -1,10 +1,10 @@ -use std::{collections::HashMap, fs, io, path::Path}; +use std::{fs, io, path::Path}; use blueberry_ast::{ Annotation, AnnotationParam, Commented, ConstDef, ConstValue, Definition, EnumDef, MessageDef, ModuleDef, StructDef, Type, TypeDef, }; -use blueberry_codegen_core::{CodegenError, GeneratedFile}; +use blueberry_codegen_core::{CodegenError, GeneratedFile, ResolvedMember, TypeRegistry}; use genco::lang::rust::Tokens; use genco::quote; @@ -97,7 +97,7 @@ impl RustGenerator { } fn emit_tests(&self, _definitions: &[Definition]) -> Tokens { - let mut struct_paths: Vec> = self.registry.structs.keys().cloned().collect(); + let mut struct_paths: Vec> = self.registry.struct_paths(); if struct_paths.is_empty() { return Tokens::new(); } @@ -669,249 +669,6 @@ impl RustGenerator { } } -#[derive(Clone)] -struct ResolvedMember { - name: String, - ty: Type, - comments: Vec, -} - -#[derive(Clone)] -struct TypedefInfo { - ty: Type, - scope: Vec, -} - -#[derive(Clone)] -struct StructInfo { - def: StructDef, - scope: Vec, -} - -#[derive(Clone)] -struct MessageInfo { - def: MessageDef, - scope: Vec, -} - -#[derive(Default)] -struct TypeRegistry { - typedefs: HashMap, TypedefInfo>, - structs: HashMap, StructInfo>, - messages: HashMap, MessageInfo>, - enums: HashMap, Type>, -} - -impl TypeRegistry { - fn new(definitions: &[Definition]) -> Self { - let mut registry = TypeRegistry::default(); - let mut scope = Vec::new(); - registry.collect(definitions, &mut scope); - registry - } - - fn collect(&mut self, defs: &[Definition], scope: &mut Vec) { - for def in defs { - match def { - Definition::ModuleDef(module) => { - scope.push(module.node.name.clone()); - self.collect(&module.node.definitions, scope); - scope.pop(); - } - Definition::TypeDef(typedef) => { - let mut path = scope.clone(); - path.push(typedef.node.name.clone()); - self.typedefs.insert( - path, - TypedefInfo { - ty: typedef.node.base_type.clone(), - scope: scope.clone(), - }, - ); - } - Definition::StructDef(struct_def) => { - let mut path = scope.clone(); - path.push(struct_def.node.name.clone()); - self.structs.insert( - path, - StructInfo { - def: struct_def.node.clone(), - scope: scope.clone(), - }, - ); - } - Definition::MessageDef(message_def) => { - let mut path = scope.clone(); - path.push(message_def.node.name.clone()); - self.messages.insert( - path, - MessageInfo { - def: message_def.node.clone(), - scope: scope.clone(), - }, - ); - } - Definition::EnumDef(enum_def) => { - let mut path = scope.clone(); - path.push(enum_def.node.name.clone()); - let repr = enum_def - .node - .base_type - .clone() - .unwrap_or(Type::UnsignedLong); - self.enums.insert(path, repr); - } - Definition::ConstDef(_) | Definition::ImportDef(_) => {} - } - } - } - - fn resolve_type(&self, ty: &Type, scope: &[String]) -> Type { - match ty { - Type::Sequence { element_type, size } => Type::Sequence { - element_type: Box::new(self.resolve_type(element_type, scope)), - size: *size, - }, - Type::Array { - element_type, - dimensions, - } => { - let mut resolved = self.resolve_type(element_type, scope); - for &dim in dimensions.iter().rev() { - resolved = Type::Sequence { - element_type: Box::new(resolved), - size: Some(dim), - }; - } - resolved - } - Type::ScopedName(name) => { - if let [single] = name.as_slice() - && let Some(mapped) = map_builtin_ident(single) - { - return mapped; - } - if let Some(path) = self.resolve_typedef(name, scope) { - let info = self.typedefs.get(&path).expect("typedef info missing"); - self.resolve_type(&info.ty, &info.scope) - } else if let Some(path) = self.resolve_struct(name, scope) { - Type::ScopedName(path) - } else if let Some(path) = self.resolve_message(name, scope) { - Type::ScopedName(path) - } else if let Some(path) = self.resolve_enum(name, scope) { - Type::ScopedName(path) - } else { - Type::ScopedName(name.clone()) - } - } - other => other.clone(), - } - } - - fn collect_struct_members(&self, path: &[String]) -> Vec { - let info = self - .structs - .get(path) - .unwrap_or_else(|| panic!("missing struct {:?}", path)); - let mut members = Vec::new(); - if let Some(base) = &info.def.base - && let Some(base_path) = self.resolve_struct(base, &info.scope) - { - members.extend(self.collect_struct_members(&base_path)); - } - for member in &info.def.members { - let ty = self.resolve_type(&member.node.type_, &info.scope); - members.push(ResolvedMember { - name: member.node.name.clone(), - ty, - comments: member.comments.clone(), - }); - } - members - } - - fn collect_message_members(&self, path: &[String]) -> Vec { - let info = self - .messages - .get(path) - .unwrap_or_else(|| panic!("missing message {:?}", path)); - let mut members = Vec::new(); - if let Some(base) = &info.def.base - && let Some(base_path) = self.resolve_message(base, &info.scope) - { - members.extend(self.collect_message_members(&base_path)); - } - for member in &info.def.members { - let ty = self.resolve_type(&member.node.type_, &info.scope); - members.push(ResolvedMember { - name: member.node.name.clone(), - ty, - comments: member.comments.clone(), - }); - } - members - } - - fn resolve_typedef(&self, name: &[String], scope: &[String]) -> Option> { - self.resolve_path(name, scope, self.typedefs.keys()) - } - - fn resolve_struct(&self, name: &[String], scope: &[String]) -> Option> { - self.resolve_path(name, scope, self.structs.keys()) - } - - fn resolve_message(&self, name: &[String], scope: &[String]) -> Option> { - self.resolve_path(name, scope, self.messages.keys()) - } - - fn resolve_enum(&self, name: &[String], scope: &[String]) -> Option> { - self.resolve_path(name, scope, self.enums.keys()) - } - - fn resolve_path<'a, I>( - &self, - name: &[String], - scope: &[String], - entries: I, - ) -> Option> - where - I: IntoIterator>, - { - let paths: Vec> = entries.into_iter().cloned().collect(); - let lookup: std::collections::HashSet> = paths.iter().cloned().collect(); - for prefix in (0..=scope.len()).rev() { - let mut candidate = scope[..prefix].to_vec(); - candidate.extend_from_slice(name); - if lookup.contains(&candidate) { - return Some(candidate); - } - } - self.resolve_by_suffix(name, &paths) - } - - fn resolve_by_suffix(&self, name: &[String], paths: &[Vec]) -> Option> { - let matches: Vec<&Vec> = paths.iter().filter(|path| path.ends_with(name)).collect(); - if matches.len() == 1 { - return Some(matches[0].clone()); - } - None - } -} - -fn map_builtin_ident(name: &str) -> Option { - match name { - "int8" => Some(Type::Octet), - "int16" => Some(Type::Short), - "int32" => Some(Type::Long), - "int64" => Some(Type::LongLong), - "uint8" => Some(Type::Octet), - "uint16" => Some(Type::UnsignedShort), - "uint32" => Some(Type::UnsignedLong), - "uint64" => Some(Type::UnsignedLongLong), - _ => None, - } -} - #[cfg(test)] mod tests { use super::*;