Skip to content

Commit 30639b6

Browse files
authored
feat: Convert Between MCAP and Bag (#125)
1 parent 6fd1833 commit 30639b6

File tree

12 files changed

+1662
-59
lines changed

12 files changed

+1662
-59
lines changed

benchmarks/writing/test_mcap_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _generate_pybag_odometries(count: int = 1000, seed: int = 0) -> list[Odometr
8383

8484
def _write_with_pybag(path: Path, messages: Iterable) -> None:
8585
writer = McapFileWriter.open(path)
86-
writer.add_channel("/odom", Odometry)
86+
writer.add_channel("/odom", schema=Odometry)
8787
for i, msg in enumerate(messages):
8888
timestamp = int(i * 1_500_000_000)
8989
writer.write_message("/odom", timestamp, msg)

src/pybag/bag_writer.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@
2020
from pybag.encoding.rosmsg import RosMsgEncoder
2121
from pybag.io.raw_writer import BaseWriter, BytesWriter, FileWriter
2222
from pybag.schema.ros1_compiler import compile_ros1_serializer
23-
from pybag.schema.ros1msg import Ros1MsgSchemaEncoder, compute_md5sum
24-
from pybag.types import Message
23+
from pybag.schema.ros1msg import (
24+
Ros1MsgSchemaDecoder,
25+
Ros1MsgSchemaEncoder,
26+
compute_md5sum
27+
)
28+
from pybag.types import Message, SchemaText
2529

2630
logger = logging.getLogger(__name__)
2731

@@ -42,7 +46,7 @@ def __init__(
4246
writer: BaseWriter,
4347
*,
4448
compression: Literal['none', 'bz2'] = 'none',
45-
chunk_size: int = 1024 * 1024, # 1MB default chunk size
49+
chunk_size: int | None = None,
4650
):
4751
"""Initialize the bag writer.
4852
@@ -53,11 +57,12 @@ def __init__(
5357
"""
5458
self._writer = writer
5559
self._compression = compression
56-
self._chunk_size = chunk_size
60+
self._chunk_size = chunk_size or (1024 * 1024) # 1MB
5761
self._record_writer = BagRecordWriter(writer)
5862

59-
# Schema encoder
63+
# Schema encoder and decoder
6064
self._schema_encoder = Ros1MsgSchemaEncoder()
65+
self._schema_decoder = Ros1MsgSchemaDecoder()
6166

6267
# Tracking state
6368
self._next_conn_id = 0
@@ -66,13 +71,17 @@ def __init__(
6671
self._message_types: dict[type[Message], tuple[str, str]] = {} # type -> (msg_def, md5sum)
6772
self._serializers: dict[type[Message], Callable[[Any, Any], None]] = {}
6873

74+
# Pre-compiled serializers for topics with explicit schemas
75+
# Maps topic -> compiled serializer function
76+
self._topic_serializers: dict[str, Callable[[Any, Any], None]] = {}
77+
6978
# Current chunk state
7079
self._chunk_buffer = BytesWriter()
7180
self._chunk_record_writer = BagRecordWriter(self._chunk_buffer)
7281
self._chunk_start_time: int | None = None
7382
self._chunk_end_time: int | None = None
7483
self._chunk_message_counts: dict[int, int] = {}
75-
# Index entries for current chunk: conn_id -> [(time_sec, time_nsec, offset)]
84+
# Index entries for current chunk: conn_id -> [(time, offset)]
7685
self._chunk_index_entries: dict[int, list[tuple[int, int]]] = {}
7786

7887
# Chunk info records (for summary)
@@ -87,7 +96,7 @@ def open(
8796
file_path: str | Path,
8897
*,
8998
compression: Literal['none', 'bz2'] = 'none',
90-
chunk_size: int = 1024 * 1024,
99+
chunk_size: int | None = None,
91100
) -> "BagFileWriter":
92101
"""Create a writer for a file.
93102
@@ -123,11 +132,7 @@ def _write_header(self) -> None:
123132
self._header_pos = self._record_writer.tell()
124133
# Write placeholder header with zeros (will be updated on close)
125134
self._record_writer.write_bag_header(
126-
BagHeaderRecord(
127-
index_pos=0,
128-
conn_count=0,
129-
chunk_count=0,
130-
),
135+
BagHeaderRecord(index_pos=0, conn_count=0, chunk_count=0),
131136
)
132137

133138
def _get_message_info(self, message_type: type[Message]) -> tuple[str, str]:
@@ -161,12 +166,19 @@ def _get_serializer(self, message_type: type[Message]) -> Callable[[Any, Any], N
161166
self._serializers[message_type] = compile_ros1_serializer(schema, sub_schemas)
162167
return self._serializers[message_type]
163168

164-
def add_connection(self, topic: str, message_type: type[Message]) -> int:
169+
def add_connection(
170+
self,
171+
topic: str,
172+
*,
173+
schema: SchemaText | type[Message] | Message,
174+
) -> int:
165175
"""Add a connection (topic) to the bag file.
166176
167177
Args:
168178
topic: The topic name.
169-
message_type: The message type class.
179+
schema: A SchemaText object containing the message type name and
180+
schema definition text, or a message class/instance to
181+
generate the schema from.
170182
171183
Returns:
172184
The connection ID.
@@ -175,22 +187,38 @@ def add_connection(self, topic: str, message_type: type[Message]) -> int:
175187
if topic in self._topics:
176188
return self._topics[topic]
177189

190+
# Convert message class or instance to SchemaText
191+
if isinstance(schema, type) and hasattr(schema, '__msg_name__'):
192+
schema = SchemaText(
193+
name=schema.__msg_name__,
194+
text=self._schema_encoder.encode(schema).decode('utf-8'),
195+
)
196+
elif isinstance(schema, Message):
197+
schema_type = type(schema)
198+
schema = SchemaText(
199+
name=schema_type.__msg_name__,
200+
text=self._schema_encoder.encode(schema_type).decode('utf-8'),
201+
)
202+
178203
conn_id = self._next_conn_id
179204
self._next_conn_id += 1
180205

181-
msg_def, md5sum = self._get_message_info(message_type)
182-
msg_type = message_type.__msg_name__
206+
# Use provided schema text directly
207+
msg_def = schema.text
208+
msg_type_name = schema.name
209+
md5sum = compute_md5sum(msg_def, msg_type_name)
183210

184211
# Build the connection data (connection header fields)
185212
data_buffer = BytesWriter()
186213
# Two topic fields exist (in the record and connection headers).
187214
# This is because messages can be written to the bag file on a topic different
188215
# from where they were originally published
189216
data_buffer.write(self._encode_header_field('topic', topic.encode('utf-8')))
190-
data_buffer.write(self._encode_header_field('type', msg_type.encode('utf-8')))
217+
data_buffer.write(self._encode_header_field('type', msg_type_name.encode('utf-8')))
191218
data_buffer.write(self._encode_header_field('md5sum', md5sum.encode('ascii')))
192219
data_buffer.write(self._encode_header_field('message_definition', msg_def.encode('utf-8')))
193220

221+
# TODO: Add checks to see if previous topic exists
194222
connection = ConnectionRecord(
195223
conn=conn_id,
196224
topic=topic,
@@ -202,6 +230,12 @@ def add_connection(self, topic: str, message_type: type[Message]) -> int:
202230
# Write connection record to current chunk
203231
self._chunk_record_writer.write_connection(connection)
204232

233+
# If explicit schema was provided, compile and store a serializer for this topic
234+
# This allows us to serialize messages without relying on type annotations
235+
parsed_schema, sub_schemas = self._schema_decoder.parse_schema(connection)
236+
serializer = compile_ros1_serializer(parsed_schema, sub_schemas)
237+
self._topic_serializers[topic] = serializer
238+
205239
return conn_id
206240

207241
def write_message(
@@ -212,15 +246,25 @@ def write_message(
212246
) -> None:
213247
"""Write a message to the bag file.
214248
249+
Automatically creates the connection (and schema) if it doesn't exist.
250+
If the connection was pre-registered with add_connection(), uses that schema.
251+
215252
Args:
216253
topic: The topic name.
217254
timestamp: The timestamp in nanoseconds since epoch.
218255
message: The message to write.
219256
"""
220257
message_type = type(message)
221258

222-
# Ensure connection exists
223-
conn_id = self.add_connection(topic, message_type)
259+
# Check if connection already exists (may have been pre-registered)
260+
if topic in self._topics:
261+
conn_id = self._topics[topic]
262+
else:
263+
# Auto-create connection from message type
264+
conn_id = self.add_connection(topic, schema=SchemaText(
265+
name=message_type.__msg_name__,
266+
text=self._schema_encoder.encode(message_type).decode('utf-8'),
267+
))
224268

225269
# Update chunk time bounds
226270
if self._chunk_start_time is None:
@@ -234,7 +278,7 @@ def write_message(
234278
msg_offset = self._chunk_buffer.size()
235279

236280
# Serialize the message
237-
serializer = self._get_serializer(message_type)
281+
serializer = self._topic_serializers[topic]
238282
encoder = RosMsgEncoder()
239283
serializer(encoder, message)
240284
data = encoder.save()

src/pybag/cli/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22

33
from pybag.cli import (
4+
mcap_convert,
45
mcap_filter,
56
mcap_info,
67
mcap_merge,
@@ -23,6 +24,7 @@ def build_parser() -> argparse.ArgumentParser:
2324
subparsers = parser.add_subparsers(dest="command")
2425

2526
# TODO: Have some of entrypoint registration?
27+
mcap_convert.add_parser(subparsers)
2628
mcap_filter.add_parser(subparsers)
2729
mcap_merge.add_parser(subparsers)
2830
mcap_info.add_parser(subparsers)

0 commit comments

Comments
 (0)