Skip to content

Commit e406aab

Browse files
committed
chore: Tidy
1 parent f9c599a commit e406aab

File tree

8 files changed

+135
-141
lines changed

8 files changed

+135
-141
lines changed

src/pybag/bag/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
BagRecordType,
1212
ChunkInfoRecord,
1313
ChunkRecord,
14+
ConnectionHeader,
1415
ConnectionRecord,
1516
IndexDataRecord,
1617
MessageDataRecord
@@ -22,6 +23,7 @@
2223
'BagRecordType',
2324
'ChunkInfoRecord',
2425
'ChunkRecord',
26+
'ConnectionHeader',
2527
'ConnectionRecord',
2628
'IndexDataRecord',
2729
'MessageDataRecord',

src/pybag/bag/record_writer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def _write_record(
7676
result += self._writer.write(data)
7777
return result
7878

79+
# TODO: Make API consistent (i.e. take BadHeader object)
7980
def write_bag_header(
8081
self,
8182
index_pos: int,

src/pybag/bag_reader.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515
ChunkInfoRecord,
1616
ChunkRecord,
1717
ConnectionRecord,
18-
IndexDataRecord,
1918
MessageDataRecord
2019
)
21-
from pybag.encoding.rosmsg import RosmsgDecoder
20+
from pybag.encoding.rosmsg import RosMsgDecoder
2221
from pybag.io.raw_reader import BaseReader, BytesReader, FileReader
2322
from pybag.schema.ros1_compiler import compile_ros1_schema
2423
from pybag.schema.ros1msg import Ros1MsgSchemaDecoder
2524

2625
logger = logging.getLogger(__name__)
2726

27+
# TODO: Do not load all messages at once
2828

2929
@dataclass(slots=True)
3030
class DecodedMessage:
@@ -59,7 +59,6 @@ def __init__(self, reader: BaseReader):
5959
self._bag_header: BagHeaderRecord | None = None
6060
self._connections: dict[int, ConnectionRecord] = {}
6161
self._chunk_infos: list[ChunkInfoRecord] = []
62-
self._message_count: int = 0
6362

6463
# Schema decoder for message definitions
6564
self._schema_decoder = Ros1MsgSchemaDecoder()
@@ -99,12 +98,12 @@ def _parse_file(self) -> None:
9998
# Parse version
10099
self._version = BagRecordParser.parse_version(self._reader)
101100
if self._version != '2.0':
102-
raise MalformedBag(f'Unsupported bag version: {self._version}')
101+
raise MalformedBag(f'Unsupported bag version: {self._version} (must be 2.0)')
103102

104103
# Parse bag header
105104
result = BagRecordParser.parse_record(self._reader)
106105
if result is None or result[0] != BagRecordType.BAG_HEADER:
107-
raise MalformedBag('Expected bag header record')
106+
raise MalformedBag(f'Expected bag header record, got {result}')
108107
self._bag_header = result[1]
109108

110109
# Seek to index section
@@ -122,7 +121,6 @@ def _parse_file(self) -> None:
122121
self._connections[record.conn] = record
123122
elif op == BagRecordType.CHUNK_INFO:
124123
self._chunk_infos.append(record)
125-
self._message_count += record.count
126124

127125
@property
128126
def version(self) -> str:
@@ -160,16 +158,12 @@ def get_message_count(self, topic: str) -> int:
160158
@property
161159
def start_time(self) -> int:
162160
"""Get the start time of the bag file in nanoseconds since epoch."""
163-
if not self._chunk_infos:
164-
return 0
165-
return min(ci.start_time_ns for ci in self._chunk_infos)
161+
return min([ci.start_time for ci in self._chunk_infos], default=0)
166162

167163
@property
168164
def end_time(self) -> int:
169165
"""Get the end time of the bag file in nanoseconds since epoch."""
170-
if not self._chunk_infos:
171-
return 0
172-
return max(ci.end_time_ns for ci in self._chunk_infos)
166+
return max([ci.end_time for ci in self._chunk_infos], default=0)
173167

174168
def _expand_topics(self, topic: str | list[str]) -> list[str]:
175169
"""Expand topic patterns to list of concrete topic names.
@@ -182,6 +176,7 @@ def _expand_topics(self, topic: str | list[str]) -> list[str]:
182176
"""
183177
available_topics = self.get_topics()
184178
topic_patterns = [topic] if isinstance(topic, str) else topic
179+
185180
matched_topics = set()
186181
for pattern in topic_patterns:
187182
matches = fnmatch.filter(available_topics, pattern)
@@ -218,8 +213,7 @@ def _deserialize_message(
218213
The deserialized message object.
219214
"""
220215
deserializer = self._get_deserializer(conn.conn)
221-
decoder = RosmsgDecoder(msg.data)
222-
return deserializer(decoder)
216+
return deserializer(RosMsgDecoder(msg.data))
223217

224218
def messages(
225219
self,
@@ -256,21 +250,22 @@ def messages(
256250
conn_ids_to_topics[conn.conn] = conn.topic
257251

258252
if not conn_ids_to_topics:
253+
logging.warning("No matching topics found")
259254
return
260255

261256
# Sort chunk infos by start time if needed
262257
chunk_infos = self._chunk_infos
263258
if in_log_time_order:
264-
chunk_infos = sorted(chunk_infos, key=lambda ci: ci.start_time_ns)
259+
chunk_infos = sorted(chunk_infos, key=lambda ci: ci.start_time)
265260

266261
# Collect messages (optionally sorted by time)
267262
all_messages: list[tuple[int, MessageDataRecord, ConnectionRecord]] = []
268263

269264
for chunk_info in chunk_infos:
270265
# Skip chunks outside the time range
271-
if start_time is not None and chunk_info.end_time_ns < start_time:
266+
if start_time is not None and chunk_info.end_time < start_time:
272267
continue
273-
if end_time is not None and chunk_info.start_time_ns > end_time:
268+
if end_time is not None and chunk_info.start_time > end_time:
274269
continue
275270

276271
# Check if this chunk has any relevant connections
@@ -285,7 +280,7 @@ def messages(
285280
self._reader.seek_from_start(chunk_info.chunk_pos)
286281
result = BagRecordParser.parse_record(self._reader)
287282
if result is None or result[0] != BagRecordType.CHUNK:
288-
logger.warning(f'Expected chunk at position {chunk_info.chunk_pos}')
283+
logger.warning(f'Expected chunk at position {chunk_info.chunk_pos}, got {result}')
289284
continue
290285

291286
chunk: ChunkRecord = result[1]
@@ -303,7 +298,7 @@ def messages(
303298
continue
304299

305300
# Time filtering
306-
log_time = msg.time_ns
301+
log_time = msg.time
307302
if start_time is not None and log_time < start_time:
308303
continue
309304
if end_time is not None and log_time > end_time:
@@ -319,10 +314,11 @@ def messages(
319314
# Yield decoded messages
320315
for log_time, msg, conn in all_messages:
321316
decoded_data = self._deserialize_message(msg, conn)
317+
conn_header = conn.connection_header
322318
decoded = DecodedMessage(
323319
connection_id=msg.conn,
324320
topic=conn.topic,
325-
msg_type=conn.msg_type,
321+
msg_type=conn_header.type,
326322
log_time=log_time,
327323
data=decoded_data,
328324
)

0 commit comments

Comments
 (0)