2020from pybag .encoding .rosmsg import RosMsgEncoder
2121from pybag .io .raw_writer import BaseWriter , BytesWriter , FileWriter
2222from pybag .schema .ros1_compiler import compile_ros1_serializer
23- from pybag .schema .ros1msg import Ros1MsgSchemaEncoder , compute_md5sum
24- from pybag .types import Message
23+ from pybag .schema .ros1msg import (
24+ Ros1MsgSchemaDecoder ,
25+ Ros1MsgSchemaEncoder ,
26+ compute_md5sum
27+ )
28+ from pybag .types import Message , SchemaText
2529
2630logger = logging .getLogger (__name__ )
2731
@@ -42,7 +46,7 @@ def __init__(
4246 writer : BaseWriter ,
4347 * ,
4448 compression : Literal ['none' , 'bz2' ] = 'none' ,
45- chunk_size : int = 1024 * 1024 , # 1MB default chunk size
49+ chunk_size : int | None = None ,
4650 ):
4751 """Initialize the bag writer.
4852
@@ -53,11 +57,12 @@ def __init__(
5357 """
5458 self ._writer = writer
5559 self ._compression = compression
56- self ._chunk_size = chunk_size
60+ self ._chunk_size = chunk_size or ( 1024 * 1024 ) # 1MB
5761 self ._record_writer = BagRecordWriter (writer )
5862
59- # Schema encoder
63+ # Schema encoder and decoder
6064 self ._schema_encoder = Ros1MsgSchemaEncoder ()
65+ self ._schema_decoder = Ros1MsgSchemaDecoder ()
6166
6267 # Tracking state
6368 self ._next_conn_id = 0
@@ -66,13 +71,17 @@ def __init__(
6671 self ._message_types : dict [type [Message ], tuple [str , str ]] = {} # type -> (msg_def, md5sum)
6772 self ._serializers : dict [type [Message ], Callable [[Any , Any ], None ]] = {}
6873
74+ # Pre-compiled serializers for topics with explicit schemas
75+ # Maps topic -> compiled serializer function
76+ self ._topic_serializers : dict [str , Callable [[Any , Any ], None ]] = {}
77+
6978 # Current chunk state
7079 self ._chunk_buffer = BytesWriter ()
7180 self ._chunk_record_writer = BagRecordWriter (self ._chunk_buffer )
7281 self ._chunk_start_time : int | None = None
7382 self ._chunk_end_time : int | None = None
7483 self ._chunk_message_counts : dict [int , int ] = {}
75- # Index entries for current chunk: conn_id -> [(time_sec, time_nsec , offset)]
84+ # Index entries for current chunk: conn_id -> [(time , offset)]
7685 self ._chunk_index_entries : dict [int , list [tuple [int , int ]]] = {}
7786
7887 # Chunk info records (for summary)
@@ -87,7 +96,7 @@ def open(
8796 file_path : str | Path ,
8897 * ,
8998 compression : Literal ['none' , 'bz2' ] = 'none' ,
90- chunk_size : int = 1024 * 1024 ,
99+ chunk_size : int | None = None ,
91100 ) -> "BagFileWriter" :
92101 """Create a writer for a file.
93102
@@ -123,11 +132,7 @@ def _write_header(self) -> None:
123132 self ._header_pos = self ._record_writer .tell ()
124133 # Write placeholder header with zeros (will be updated on close)
125134 self ._record_writer .write_bag_header (
126- BagHeaderRecord (
127- index_pos = 0 ,
128- conn_count = 0 ,
129- chunk_count = 0 ,
130- ),
135+ BagHeaderRecord (index_pos = 0 , conn_count = 0 , chunk_count = 0 ),
131136 )
132137
133138 def _get_message_info (self , message_type : type [Message ]) -> tuple [str , str ]:
@@ -161,12 +166,19 @@ def _get_serializer(self, message_type: type[Message]) -> Callable[[Any, Any], N
161166 self ._serializers [message_type ] = compile_ros1_serializer (schema , sub_schemas )
162167 return self ._serializers [message_type ]
163168
164- def add_connection (self , topic : str , message_type : type [Message ]) -> int :
169+ def add_connection (
170+ self ,
171+ topic : str ,
172+ * ,
173+ schema : SchemaText | type [Message ] | Message ,
174+ ) -> int :
165175 """Add a connection (topic) to the bag file.
166176
167177 Args:
168178 topic: The topic name.
169- message_type: The message type class.
179+ schema: A SchemaText object containing the message type name and
180+ schema definition text, or a message class/instance to
181+ generate the schema from.
170182
171183 Returns:
172184 The connection ID.
@@ -175,22 +187,38 @@ def add_connection(self, topic: str, message_type: type[Message]) -> int:
175187 if topic in self ._topics :
176188 return self ._topics [topic ]
177189
190+ # Convert message class or instance to SchemaText
191+ if isinstance (schema , type ) and hasattr (schema , '__msg_name__' ):
192+ schema = SchemaText (
193+ name = schema .__msg_name__ ,
194+ text = self ._schema_encoder .encode (schema ).decode ('utf-8' ),
195+ )
196+ elif isinstance (schema , Message ):
197+ schema_type = type (schema )
198+ schema = SchemaText (
199+ name = schema_type .__msg_name__ ,
200+ text = self ._schema_encoder .encode (schema_type ).decode ('utf-8' ),
201+ )
202+
178203 conn_id = self ._next_conn_id
179204 self ._next_conn_id += 1
180205
181- msg_def , md5sum = self ._get_message_info (message_type )
182- msg_type = message_type .__msg_name__
206+ # Use provided schema text directly
207+ msg_def = schema .text
208+ msg_type_name = schema .name
209+ md5sum = compute_md5sum (msg_def , msg_type_name )
183210
184211 # Build the connection data (connection header fields)
185212 data_buffer = BytesWriter ()
186213 # Two topic fields exist (in the record and connection headers).
187214 # This is because messages can be written to the bag file on a topic different
188215 # from where they were originally published
189216 data_buffer .write (self ._encode_header_field ('topic' , topic .encode ('utf-8' )))
190- data_buffer .write (self ._encode_header_field ('type' , msg_type .encode ('utf-8' )))
217+ data_buffer .write (self ._encode_header_field ('type' , msg_type_name .encode ('utf-8' )))
191218 data_buffer .write (self ._encode_header_field ('md5sum' , md5sum .encode ('ascii' )))
192219 data_buffer .write (self ._encode_header_field ('message_definition' , msg_def .encode ('utf-8' )))
193220
221+ # TODO: Add checks to see if previous topic exists
194222 connection = ConnectionRecord (
195223 conn = conn_id ,
196224 topic = topic ,
@@ -202,6 +230,12 @@ def add_connection(self, topic: str, message_type: type[Message]) -> int:
202230 # Write connection record to current chunk
203231 self ._chunk_record_writer .write_connection (connection )
204232
233+ # If explicit schema was provided, compile and store a serializer for this topic
234+ # This allows us to serialize messages without relying on type annotations
235+ parsed_schema , sub_schemas = self ._schema_decoder .parse_schema (connection )
236+ serializer = compile_ros1_serializer (parsed_schema , sub_schemas )
237+ self ._topic_serializers [topic ] = serializer
238+
205239 return conn_id
206240
207241 def write_message (
@@ -212,15 +246,25 @@ def write_message(
212246 ) -> None :
213247 """Write a message to the bag file.
214248
249+ Automatically creates the connection (and schema) if it doesn't exist.
250+ If the connection was pre-registered with add_connection(), uses that schema.
251+
215252 Args:
216253 topic: The topic name.
217254 timestamp: The timestamp in nanoseconds since epoch.
218255 message: The message to write.
219256 """
220257 message_type = type (message )
221258
222- # Ensure connection exists
223- conn_id = self .add_connection (topic , message_type )
259+ # Check if connection already exists (may have been pre-registered)
260+ if topic in self ._topics :
261+ conn_id = self ._topics [topic ]
262+ else :
263+ # Auto-create connection from message type
264+ conn_id = self .add_connection (topic , schema = SchemaText (
265+ name = message_type .__msg_name__ ,
266+ text = self ._schema_encoder .encode (message_type ).decode ('utf-8' ),
267+ ))
224268
225269 # Update chunk time bounds
226270 if self ._chunk_start_time is None :
@@ -234,7 +278,7 @@ def write_message(
234278 msg_offset = self ._chunk_buffer .size ()
235279
236280 # Serialize the message
237- serializer = self ._get_serializer ( message_type )
281+ serializer = self ._topic_serializers [ topic ]
238282 encoder = RosMsgEncoder ()
239283 serializer (encoder , message )
240284 data = encoder .save ()
0 commit comments