@@ -48,38 +48,54 @@ def __getattr__(self, item):
4848 return getattr (self .bag , item )
4949
5050 def _get_entries (self , connections = None , start_time = None , end_time = None ):
51+ all_ranges = None
52+ if start_time is not None and isinstance (start_time , TimeRanges ):
53+ all_ranges = copy .copy (start_time )
54+ if end_time is not None and isinstance (end_time , TimeRanges ):
55+ if all_ranges is None :
56+ all_ranges = copy .copy (end_time )
57+ else :
58+ all_ranges .append (end_time .ranges )
59+
5160 for entry in heapq .merge (* self .bag ._get_indexes (connections ), key = lambda x : x .time .to_nsec ()): # noqa
52- if start_time is not None :
53- if isinstance (start_time , TimeRanges ):
54- time_ranges = start_time
55- if time_ranges > entry .time :
56- continue
57- elif time_ranges < entry .time :
58- return
59- elif entry .time not in time_ranges :
60- continue
61- elif entry .time < start_time :
61+ if all_ranges is not None :
62+ if all_ranges > entry .time :
6263 continue
63- if end_time is not None and entry .time > end_time :
64- return
64+ elif all_ranges < entry .time :
65+ return
66+ elif entry .time not in all_ranges :
67+ continue
68+ else :
69+ if start_time is not None and isinstance (start_time , genpy .Time ) and entry .time < start_time :
70+ continue
71+ if end_time is not None and isinstance (end_time , genpy .Time ) and entry .time > end_time :
72+ return
6573 yield entry
6674
6775 def _get_entries_reverse (self , connections = None , start_time = None , end_time = None ):
76+ all_ranges = None
77+ if start_time is not None and isinstance (start_time , TimeRanges ):
78+ all_ranges = copy .copy (start_time )
79+ if end_time is not None and isinstance (end_time , TimeRanges ):
80+ if all_ranges is None :
81+ all_ranges = copy .copy (end_time )
82+ else :
83+ all_ranges .append (end_time .ranges )
84+
6885 for entry in heapq .merge (* (reversed (index ) for index in self ._get_indexes (connections )),
6986 key = lambda x : x .time .to_nsec (), reverse = True ):
70- if start_time is not None :
71- if isinstance (start_time , TimeRanges ):
72- time_ranges = start_time
73- if time_ranges > entry .time :
74- return
75- elif time_ranges < entry .time :
76- continue
77- elif entry .time not in time_ranges :
78- continue
79- elif entry .time < start_time :
87+ if all_ranges is not None :
88+ if all_ranges > entry .time :
8089 return
81- if end_time is not None and entry .time > end_time :
82- continue
90+ elif all_ranges < entry .time :
91+ continue
92+ elif entry .time not in all_ranges :
93+ continue
94+ else :
95+ if start_time is not None and isinstance (start_time , genpy .Time ) and entry .time < start_time :
96+ return
97+ if end_time is not None and isinstance (end_time , genpy .Time ) and entry .time > end_time :
98+ continue
8399 yield entry
84100
85101
@@ -91,12 +107,24 @@ def __init__(self,
91107 mode = 'r' , # type: STRING_TYPE
92108 compression = rosbag .Compression .NONE , # type: rosbag.Compression
93109 options = None , # type: Dict[STRING_TYPE, Any]
94- skip_index = False # type: bool
110+ skip_index = False , # type: bool
111+ limit_to_first_bag = False , # type: bool
95112 ):
113+ """
114+ :param bag_files: The paths to bags to open (either a sequence of colon-separated string of paths).
115+ :param mode: Open mode (r/w/a).
116+ :param compression: Compression (used for write mode).
117+ :param options: Bag options (compression, chunk threshold, ...).
118+ :param skip_index: Whether index should be read right away. Otherwise, call read_index() when you need it.
119+ :param limit_to_first_bag: If True, the multibag will report its start and end to be equal to the
120+ first open bag. If False, the start and end correspond to the earliest and latest
121+ stamp in all bags.
122+ """
96123
97124 if isinstance (bag_files , STRING_TYPE ):
98125 bag_files = bag_files .split (os .path .pathsep )
99126 self .bags = [self .open_bag (b , compression , mode , options , skip_index ) for b in bag_files ]
127+ self ._limit_to_first_bag = limit_to_first_bag
100128
101129 def open_bag (self , b , compression , mode , options , skip_index ):
102130 return rosbag .Bag (b , mode , compression , options = options , skip_index = skip_index )
@@ -115,23 +143,33 @@ def __del__(self):
115143
116144 @property
117145 def size (self ):
146+ if self ._limit_to_first_bag :
147+ return self .bags [0 ].size
118148 return sum (b .size for b in self .bags )
119149
120150 def close (self ):
121151 if hasattr (self , 'bags' ):
122152 for b in self .bags :
123153 b .close ()
124154
125- def get_message_count (self , topic_filters = None , start_time = None , end_time = None ):
126- # type: (Optional[Sequence[STRING_TYPE]], Optional[Union[genpy.Time, TimeRanges]], Optional[genpy.Time]) -> int
155+ def get_message_count (self ,
156+ topic_filters = None , # type: Optional[Sequence[STRING_TYPE]]
157+ start_time = None , # type: Optional[Union[genpy.Time, TimeRanges]]
158+ end_time = None # type: Optional[Union[genpy.Time, TimeRanges]]
159+ ):
160+ # type: (...) -> int
127161 connections = dict (self ._get_connections (topic_filters , with_bag = True ))
128162 entries = self ._get_entries (connections , start_time , end_time )
129163 return sum (1 for _ in entries )
130164
131165 def get_start_time (self ): # type: () -> float
166+ if self ._limit_to_first_bag :
167+ return self .bags [0 ].get_start_time ()
132168 return min (b .get_start_time () for b in self .bags )
133169
134170 def get_end_time (self ): # type: () -> float
171+ if self ._limit_to_first_bag :
172+ return self .bags [0 ].get_end_time ()
135173 return max (b .get_end_time () for b in self .bags )
136174
137175 def read_index (self ):
@@ -151,7 +189,7 @@ def _get_connections(self, topics=None, connection_filter=None, with_bag=False):
151189 def _get_entries (self ,
152190 connections = None , # type: Optional[Dict[rosbag.Bag, Iterable[ConnectionInfo]]]
153191 start_time = None , # type: Optional[Union[genpy.Time, TimeRanges]]
154- end_time = None , # type: Optional[genpy.Time]
192+ end_time = None , # type: Optional[Union[ genpy.Time, TimeRanges] ]
155193 ):
156194 # type: (...) -> Iterator[Tuple[rosbag.Bag, ConnectionEntry, ConnectionInfo]]
157195 all_indexes = []
@@ -161,6 +199,11 @@ def _get_entries(self,
161199 for conn , index in zip (conns , indexes ):
162200 all_indexes .append ([(bag , entry , conn ) for entry in index ])
163201
202+ if start_time is None and self ._limit_to_first_bag :
203+ start_time = genpy .Time (self .get_start_time ())
204+ if end_time is None and self ._limit_to_first_bag :
205+ end_time = genpy .Time (self .get_end_time ())
206+
164207 time_ranges = None
165208 if start_time is not None and isinstance (start_time , TimeRanges ):
166209 time_ranges = {}
@@ -170,7 +213,20 @@ def _get_entries(self,
170213 time_ranges [bag ] = time_range
171214 start_time = None # Simplify the check in the next loop
172215
216+ extra_time_ranges = None
217+ if end_time is not None and isinstance (end_time , TimeRanges ):
218+ extra_time_ranges = {}
219+ for bag in self .bags :
220+ time_range = copy .copy (end_time )
221+ time_range .set_base_time (genpy .Time (bag .get_start_time ()))
222+ extra_time_ranges [bag ] = time_range
223+ end_time = None # Simplify the check in the next loop
224+
173225 for bag , entry , conn in heapq .merge (* all_indexes , key = lambda x : x [1 ].time .to_nsec ()):
226+ if extra_time_ranges is not None and entry .time in extra_time_ranges [bag ]:
227+ yield bag , entry , conn
228+ continue
229+
174230 if start_time is not None and entry .time < start_time :
175231 continue
176232 if end_time is not None and entry .time > end_time :
@@ -189,7 +245,7 @@ def _get_entries(self,
189245 def _read_messages (self ,
190246 topics = None , # type: Optional[Sequence[STRING_TYPE]]
191247 start_time = None , # type: Optional[Union[genpy.Time, TimeRanges]]
192- end_time = None , # type: Optional[genpy.Time]
248+ end_time = None , # type: Optional[Union[ genpy.Time, TimeRanges] ]
193249 connection_filter = None , # type: Optional[ConnectionFilter]
194250 raw = False , # type: bool
195251 return_connection_header = False , # type: bool
@@ -204,7 +260,7 @@ def _read_messages(self,
204260 def read_messages (self ,
205261 topics = None , # type: Optional[Sequence[STRING_TYPE]]
206262 start_time = None , # type: Optional[Union[genpy.Time, TimeRanges]]
207- end_time = None , # type: Optional[genpy.Time]
263+ end_time = None , # type: Optional[Union[ genpy.Time, TimeRanges] ]
208264 connection_filter = None , # type: Optional[ConnectionFilter]
209265 raw = False , # type: bool
210266 return_connection_header = False , # type: bool
0 commit comments