@@ -13,6 +13,7 @@ use std::mem;
1313use std:: ops:: { Deref , DerefMut } ;
1414use std:: sync:: Arc ;
1515
16+ #[ derive( Debug ) ]
1617pub struct MemoryBuffer {
1718 buffer : Mutex < BufferInternal > ,
1819}
@@ -56,7 +57,9 @@ pub struct BufferInternal {
5657 staging_size : i64 ,
5758 flight_size : i64 ,
5859
59- staging : BatchMemoryBlock ,
60+ staging : Vec < Block > ,
61+ batch_boundaries : Vec < usize > , // Track where each batch starts
62+ block_position_index : HashMap < i64 , usize > , // Maps block_id to Vec index
6063
6164 flight : HashMap < u64 , Arc < BatchMemoryBlock > > ,
6265 flight_counter : u64 ,
@@ -68,7 +71,9 @@ impl BufferInternal {
6871 total_size : 0 ,
6972 staging_size : 0 ,
7073 flight_size : 0 ,
71- staging : Default :: default ( ) ,
74+ staging : Vec :: new ( ) ,
75+ batch_boundaries : Vec :: new ( ) ,
76+ block_position_index : HashMap :: new ( ) ,
7277 flight : Default :: default ( ) ,
7378 flight_counter : 0 ,
7479 }
@@ -124,40 +129,16 @@ impl MemoryBuffer {
124129 let mut read_len = 0i64 ;
125130 let mut flight_found = false ;
126131
127- let mut exit = false ;
128- while !exit {
129- exit = true ;
130- {
131- if last_block_id == INVALID_BLOCK_ID {
132- flight_found = true ;
133- }
134- for ( _, batch_block) in buffer. flight . iter ( ) {
135- for blocks in batch_block. iter ( ) {
136- for block in blocks {
137- if !flight_found && block. block_id == last_block_id {
138- flight_found = true ;
139- continue ;
140- }
141- if !flight_found {
142- continue ;
143- }
144- if read_len >= read_bytes_limit_len {
145- break ;
146- }
147- if let Some ( ref expected_task_id) = task_ids {
148- if !expected_task_id. contains ( block. task_attempt_id as u64 ) {
149- continue ;
150- }
151- }
152- read_len += block. length as i64 ;
153- read_result. push ( block) ;
154- }
155- }
156- }
157- }
132+ const FIRST_ATTEMP : u8 = 0 ;
133+ const FALLBACK : u8 = 1 ;
134+ let strategies = [ FIRST_ATTEMP , FALLBACK ] ;
158135
159- {
160- for blocks in buffer. staging . iter ( ) {
136+ for loop_index in strategies {
137+ if last_block_id == INVALID_BLOCK_ID {
138+ flight_found = true ;
139+ }
140+ for ( _, batch_block) in buffer. flight . iter ( ) {
141+ for blocks in batch_block. iter ( ) {
161142 for block in blocks {
162143 if !flight_found && block. block_id == last_block_id {
163144 flight_found = true ;
@@ -180,9 +161,38 @@ impl MemoryBuffer {
180161 }
181162 }
182163
183- if !flight_found {
164+ // Handle staging with Vec + index optimization
165+ let staging_start_idx = if loop_index == FIRST_ATTEMP && !flight_found {
166+ // Try to find position after last_block_id
167+ // Always set flight_found = true for the next searching
184168 flight_found = true ;
185- exit = false ;
169+ if let Some ( & position) = buffer. block_position_index . get ( & last_block_id) {
170+ position + 1
171+ } else {
172+ // Not found in staging, will handle in fallback
173+ continue ;
174+ }
175+ } else {
176+ // Fallback: read from beginning
177+ 0
178+ } ;
179+
180+ for block in & buffer. staging [ staging_start_idx..] {
181+ if read_len >= read_bytes_limit_len {
182+ break ;
183+ }
184+ if let Some ( ref expected_task_id) = task_ids {
185+ if !expected_task_id. contains ( block. task_attempt_id as u64 ) {
186+ continue ;
187+ }
188+ }
189+ read_len += block. length as i64 ;
190+ read_result. push ( block) ;
191+ }
192+
193+ // // If we found data in first attempt, no need for fallback
194+ if flight_found && loop_index == FIRST_ATTEMP {
195+ break ;
186196 }
187197 }
188198
@@ -225,7 +235,33 @@ impl MemoryBuffer {
225235 return Ok ( None ) ;
226236 }
227237
228- let staging: BatchMemoryBlock = { mem:: replace ( & mut buffer. staging , Default :: default ( ) ) } ;
238+ // Reconstruct batches from boundaries
239+ let mut batches = Vec :: new ( ) ;
240+ let mut start = 0 ;
241+ for i in 0 ..buffer. batch_boundaries . len ( ) {
242+ let end = buffer. batch_boundaries [ i] ;
243+ if end >= buffer. staging . len ( ) {
244+ break ;
245+ }
246+
247+ // Find next boundary or use end of staging
248+ let next_boundary = if i + 1 < buffer. batch_boundaries . len ( ) {
249+ buffer. batch_boundaries [ i + 1 ]
250+ } else {
251+ buffer. staging . len ( )
252+ } ;
253+
254+ batches. push ( buffer. staging [ start..next_boundary] . to_vec ( ) ) ;
255+ start = next_boundary;
256+ }
257+
258+ let staging: BatchMemoryBlock = BatchMemoryBlock ( batches) ;
259+
260+ // Clear everything
261+ buffer. staging . clear ( ) ;
262+ buffer. block_position_index . clear ( ) ;
263+ buffer. batch_boundaries . clear ( ) ;
264+
229265 let staging_ref = Arc :: new ( staging) ;
230266 let flight_id = buffer. flight_counter ;
231267
@@ -247,12 +283,20 @@ impl MemoryBuffer {
247283 #[ trace]
248284 pub fn append ( & self , blocks : Vec < Block > , size : u64 ) -> Result < ( ) > {
249285 let mut buffer = self . buffer . lock ( ) ;
250- let mut staging = & mut buffer. staging ;
251- staging. push ( blocks) ;
286+ let current_position = buffer. staging . len ( ) ;
252287
288+ // Record batch boundary
289+ if !blocks. is_empty ( ) {
290+ buffer. batch_boundaries . push ( current_position) ;
291+ }
292+ for ( idx, block) in blocks. into_iter ( ) . enumerate ( ) {
293+ buffer
294+ . block_position_index
295+ . insert ( block. block_id , current_position + idx) ;
296+ buffer. staging . push ( block) ;
297+ }
253298 buffer. staging_size += size as i64 ;
254299 buffer. total_size += size as i64 ;
255-
256300 Ok ( ( ) )
257301 }
258302}
0 commit comments