Skip to content

Commit 07f5b24

Browse files
committed
chore: More Ai
1 parent 23013f9 commit 07f5b24

File tree

5 files changed

+759
-19
lines changed

5 files changed

+759
-19
lines changed

src/pybag/bag_writer.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,13 @@ def __init__(
6868
self._chunk_end_time_sec: int | None = None
6969
self._chunk_end_time_nsec: int | None = None
7070
self._chunk_message_counts: dict[int, int] = {}
71+
# Index entries for current chunk: conn_id -> [(time_sec, time_nsec, offset)]
72+
self._chunk_index_entries: dict[int, list[tuple[int, int, int]]] = {}
7173

7274
# Chunk info records (for summary)
7375
self._chunk_infos: list[ChunkInfoRecord] = []
76+
# Index data for all chunks: list of (conn_id, entries) per chunk
77+
self._all_index_data: list[list[tuple[int, list[tuple[int, int, int]]]]] = []
7478

7579
# Write initial file structure
7680
self._write_header()
@@ -216,6 +220,9 @@ def write_message(
216220
# Track message count per connection
217221
self._chunk_message_counts[conn_id] = self._chunk_message_counts.get(conn_id, 0) + 1
218222

223+
# Record the offset within the chunk buffer before writing
224+
msg_offset = self._chunk_buffer.size()
225+
219226
# Write message to chunk buffer
220227
msg_record = MessageDataRecord(
221228
conn=conn_id,
@@ -225,6 +232,11 @@ def write_message(
225232
)
226233
self._chunk_record_writer.write_message_data(msg_record)
227234

235+
# Track index entry for this message
236+
if conn_id not in self._chunk_index_entries:
237+
self._chunk_index_entries[conn_id] = []
238+
self._chunk_index_entries[conn_id].append((time_sec, time_nsec, msg_offset))
239+
228240
# Check if we should flush the chunk
229241
if self._chunk_buffer.size() >= self._chunk_size:
230242
self._flush_chunk()
@@ -254,22 +266,34 @@ def _flush_chunk(self) -> None:
254266
)
255267
self._chunk_infos.append(chunk_info)
256268

269+
# Save index entries for this chunk
270+
chunk_index_data: list[tuple[int, list[tuple[int, int, int]]]] = []
271+
for conn_id, entries in self._chunk_index_entries.items():
272+
chunk_index_data.append((conn_id, list(entries)))
273+
self._all_index_data.append(chunk_index_data)
274+
257275
# Reset chunk state
258276
self._chunk_buffer.clear()
259277
self._chunk_start_time_sec = None
260278
self._chunk_start_time_nsec = None
261279
self._chunk_end_time_sec = None
262280
self._chunk_end_time_nsec = None
263281
self._chunk_message_counts.clear()
282+
self._chunk_index_entries.clear()
264283

265284
def close(self) -> None:
266285
"""Finalize and close the bag file."""
267286
# Flush any remaining chunk data
268287
self._flush_chunk()
269288

270-
# Record the index position
289+
# Record the index position (where index data, connections and chunk infos start)
271290
index_pos = self._record_writer.tell()
272291

292+
# Write INDEX_DATA records for each chunk
293+
for chunk_index_data in self._all_index_data:
294+
for conn_id, entries in chunk_index_data:
295+
self._record_writer.write_index_data(conn_id, entries)
296+
273297
# Write all connection records
274298
for conn in self._connections.values():
275299
self._record_writer.write_connection(conn)
@@ -278,11 +302,13 @@ def close(self) -> None:
278302
for chunk_info in self._chunk_infos:
279303
self._record_writer.write_chunk_info(chunk_info)
280304

281-
# Update the bag header with correct values
282-
# We need to seek back and rewrite it
283-
# For simplicity, we'll just note that proper implementation would
284-
# seek back to header_pos and rewrite with correct values
285-
# This is a limitation of the simple writer approach
305+
# Seek back to the header position and rewrite with correct values
306+
self._writer.seek_from_start(self._header_pos)
307+
self._record_writer.write_bag_header(
308+
index_pos=index_pos,
309+
conn_count=len(self._connections),
310+
chunk_count=len(self._chunk_infos),
311+
)
286312

287313
self._record_writer.close()
288314

src/pybag/io/raw_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def close(self) -> None:
5757
class FileWriter(BaseWriter):
5858
"""Write binary data to a file."""
5959

60-
def __init__(self, file_path: Path | str, mode: str = "wb"):
60+
def __init__(self, file_path: Path | str, mode: str = "w+b"):
6161
self._file_path = Path(file_path).absolute()
6262
self._file = open(self._file_path, mode)
6363

src/pybag/schema/ros1msg.py

Lines changed: 137 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -411,26 +411,151 @@ def parse_schema(self, schema: Message | type[Message]) -> tuple[Schema, dict[st
411411
def compute_md5sum(message_definition: str, msg_type: str) -> str:
412412
"""Compute the MD5 hash for a ROS 1 message definition.
413413
414-
The MD5 sum is computed from the "canonical" form of the message,
415-
which removes comments and normalizes whitespace.
414+
The MD5 sum is computed following the ROS 1 algorithm:
415+
1. Remove comments and normalize whitespace
416+
2. Constants appear first in original order as "type name=value"
417+
3. For builtin types: "type name"
418+
4. For complex types: the MD5 of the nested message replaces the type name
419+
420+
Args:
421+
message_definition: The full message definition text (may include
422+
embedded sub-message definitions separated by 80 '=' characters).
423+
msg_type: The message type name (e.g., 'std_msgs/Header').
424+
425+
Returns:
426+
The 32-character hexadecimal MD5 hash.
427+
"""
428+
# Parse sub-message definitions from the full message definition
429+
sub_msg_defs = _parse_sub_message_definitions(message_definition)
430+
431+
# Get the main message definition (first part before any separator)
432+
main_def = message_definition.split('=' * 80)[0].strip()
433+
434+
# Compute MD5 text for the main message
435+
md5_text = _compute_md5_text(main_def, msg_type, sub_msg_defs)
436+
437+
return hashlib.md5(md5_text.encode('utf-8')).hexdigest()
438+
439+
440+
def _parse_sub_message_definitions(message_definition: str) -> dict[str, str]:
441+
"""Parse embedded sub-message definitions from a full message definition.
442+
443+
Sub-messages are separated by 80 '=' characters and start with 'MSG: type'.
416444
417445
Args:
418446
message_definition: The full message definition text.
447+
448+
Returns:
449+
Dictionary mapping message type to its definition text.
450+
"""
451+
sub_msgs: dict[str, str] = {}
452+
453+
# Split on the 80 '=' separator
454+
parts = message_definition.split('=' * 80)
455+
456+
for part in parts[1:]: # Skip the first part (main message)
457+
part = part.strip()
458+
if not part:
459+
continue
460+
461+
lines = part.split('\n')
462+
first_line = lines[0].strip()
463+
464+
if first_line.startswith('MSG: '):
465+
msg_type = first_line[5:].strip()
466+
# The rest is the message definition
467+
msg_def = '\n'.join(lines[1:]).strip()
468+
sub_msgs[msg_type] = msg_def
469+
470+
return sub_msgs
471+
472+
473+
# ROS 1 builtin types (including time and duration which are special in ROS 1)
474+
_ROS1_BUILTIN_TYPES = {
475+
'bool', 'byte', 'char',
476+
'int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'int64', 'uint64',
477+
'float32', 'float64',
478+
'string',
479+
'time', 'duration',
480+
}
481+
482+
483+
def _is_builtin_type(type_name: str) -> bool:
484+
"""Check if a type is a ROS 1 builtin type."""
485+
# Strip array notation
486+
bare_type = re.sub(r'\[.*\]$', '', type_name)
487+
return bare_type in _ROS1_BUILTIN_TYPES
488+
489+
490+
def _compute_md5_text(
491+
msg_def: str,
492+
msg_type: str,
493+
sub_msg_defs: dict[str, str]
494+
) -> str:
495+
"""Compute the canonical MD5 text for a message definition.
496+
497+
Args:
498+
msg_def: The message definition (just fields, no embedded types).
419499
msg_type: The message type name.
500+
sub_msg_defs: Dictionary of sub-message type -> definition.
420501
421502
Returns:
422-
The 32-character hexadecimal MD5 hash.
503+
The canonical text to hash for MD5 computation.
423504
"""
424-
# Simplified MD5 computation - in practice this should match
425-
# ROS 1's exact algorithm which is more complex
426-
canonical = []
427-
for line in message_definition.split('\n'):
505+
package = msg_type.split('/')[0] if '/' in msg_type else ''
506+
507+
constants: list[str] = []
508+
fields: list[str] = []
509+
510+
for line in msg_def.split('\n'):
428511
# Remove comments
429512
if '#' in line:
430513
line = line[:line.index('#')]
431514
line = line.strip()
432-
if line:
433-
canonical.append(line)
434-
435-
canonical_text = '\n'.join(canonical)
436-
return hashlib.md5(canonical_text.encode('utf-8')).hexdigest()
515+
if not line:
516+
continue
517+
518+
# Parse the line to determine if it's a constant or field
519+
# Constants have the form: TYPE NAME=VALUE
520+
if '=' in line:
521+
# It's a constant
522+
constants.append(line)
523+
else:
524+
# It's a field: TYPE NAME
525+
parts = line.split()
526+
if len(parts) >= 2:
527+
field_type = parts[0]
528+
field_name = parts[1]
529+
530+
# Get the bare type (without array notation) for type checking
531+
bare_type = re.sub(r'\[.*\]$', '', field_type)
532+
533+
if _is_builtin_type(field_type):
534+
# Builtin type: use as-is
535+
fields.append(f"{field_type} {field_name}")
536+
else:
537+
# Complex type: compute its MD5 and use that instead
538+
# Resolve the type name (add package if not specified)
539+
if '/' not in bare_type:
540+
if bare_type == 'Header':
541+
full_type = 'std_msgs/Header'
542+
else:
543+
full_type = f"{package}/{bare_type}"
544+
else:
545+
full_type = bare_type
546+
547+
# Get the sub-message definition
548+
sub_def = sub_msg_defs.get(full_type, '')
549+
if not sub_def and full_type == 'std_msgs/Header':
550+
# Built-in Header definition
551+
sub_def = "uint32 seq\ntime stamp\nstring frame_id"
552+
553+
# Recursively compute MD5 for the sub-message
554+
sub_md5 = _compute_md5_text(sub_def, full_type, sub_msg_defs)
555+
sub_md5_hash = hashlib.md5(sub_md5.encode('utf-8')).hexdigest()
556+
557+
fields.append(f"{sub_md5_hash} {field_name}")
558+
559+
# Combine: constants first, then fields
560+
result_lines = constants + fields
561+
return '\n'.join(result_lines)

0 commit comments

Comments
 (0)