Skip to content

Commit a7ccc8d

Browse files
author
thuong
committed
feat: format code
1 parent 9488d97 commit a7ccc8d

File tree

1 file changed

+85
-41
lines changed

1 file changed

+85
-41
lines changed

riffle-server/src/store/mem/buffer.rs

Lines changed: 85 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use std::mem;
1313
use std::ops::{Deref, DerefMut};
1414
use std::sync::Arc;
1515

16+
#[derive(Debug)]
1617
pub struct MemoryBuffer {
1718
buffer: Mutex<BufferInternal>,
1819
}
@@ -56,7 +57,9 @@ pub struct BufferInternal {
5657
staging_size: i64,
5758
flight_size: i64,
5859

59-
staging: BatchMemoryBlock,
60+
staging: Vec<Block>,
61+
batch_boundaries: Vec<usize>, // Track where each batch starts
62+
block_position_index: HashMap<i64, usize>, // Maps block_id to Vec index
6063

6164
flight: HashMap<u64, Arc<BatchMemoryBlock>>,
6265
flight_counter: u64,
@@ -68,7 +71,9 @@ impl BufferInternal {
6871
total_size: 0,
6972
staging_size: 0,
7073
flight_size: 0,
71-
staging: Default::default(),
74+
staging: Vec::new(),
75+
batch_boundaries: Vec::new(),
76+
block_position_index: HashMap::new(),
7277
flight: Default::default(),
7378
flight_counter: 0,
7479
}
@@ -124,40 +129,16 @@ impl MemoryBuffer {
124129
let mut read_len = 0i64;
125130
let mut flight_found = false;
126131

127-
let mut exit = false;
128-
while !exit {
129-
exit = true;
130-
{
131-
if last_block_id == INVALID_BLOCK_ID {
132-
flight_found = true;
133-
}
134-
for (_, batch_block) in buffer.flight.iter() {
135-
for blocks in batch_block.iter() {
136-
for block in blocks {
137-
if !flight_found && block.block_id == last_block_id {
138-
flight_found = true;
139-
continue;
140-
}
141-
if !flight_found {
142-
continue;
143-
}
144-
if read_len >= read_bytes_limit_len {
145-
break;
146-
}
147-
if let Some(ref expected_task_id) = task_ids {
148-
if !expected_task_id.contains(block.task_attempt_id as u64) {
149-
continue;
150-
}
151-
}
152-
read_len += block.length as i64;
153-
read_result.push(block);
154-
}
155-
}
156-
}
157-
}
132+
const FIRST_ATTEMP: u8 = 0;
133+
const FALLBACK: u8 = 1;
134+
let strategies = [FIRST_ATTEMP, FALLBACK];
158135

159-
{
160-
for blocks in buffer.staging.iter() {
136+
for loop_index in strategies {
137+
if last_block_id == INVALID_BLOCK_ID {
138+
flight_found = true;
139+
}
140+
for (_, batch_block) in buffer.flight.iter() {
141+
for blocks in batch_block.iter() {
161142
for block in blocks {
162143
if !flight_found && block.block_id == last_block_id {
163144
flight_found = true;
@@ -180,9 +161,38 @@ impl MemoryBuffer {
180161
}
181162
}
182163

183-
if !flight_found {
164+
// Handle staging with Vec + index optimization
165+
let staging_start_idx = if loop_index == FIRST_ATTEMP && !flight_found {
166+
// Try to find position after last_block_id
167+
// Always set flight_found = true for the next searching
184168
flight_found = true;
185-
exit = false;
169+
if let Some(&position) = buffer.block_position_index.get(&last_block_id) {
170+
position + 1
171+
} else {
172+
// Not found in staging, will handle in fallback
173+
continue;
174+
}
175+
} else {
176+
// Fallback: read from beginning
177+
0
178+
};
179+
180+
for block in &buffer.staging[staging_start_idx..] {
181+
if read_len >= read_bytes_limit_len {
182+
break;
183+
}
184+
if let Some(ref expected_task_id) = task_ids {
185+
if !expected_task_id.contains(block.task_attempt_id as u64) {
186+
continue;
187+
}
188+
}
189+
read_len += block.length as i64;
190+
read_result.push(block);
191+
}
192+
193+
// // If we found data in first attempt, no need for fallback
194+
if flight_found && loop_index == FIRST_ATTEMP {
195+
break;
186196
}
187197
}
188198

@@ -225,7 +235,33 @@ impl MemoryBuffer {
225235
return Ok(None);
226236
}
227237

228-
let staging: BatchMemoryBlock = { mem::replace(&mut buffer.staging, Default::default()) };
238+
// Reconstruct batches from boundaries
239+
let mut batches = Vec::new();
240+
let mut start = 0;
241+
for i in 0..buffer.batch_boundaries.len() {
242+
let end = buffer.batch_boundaries[i];
243+
if end >= buffer.staging.len() {
244+
break;
245+
}
246+
247+
// Find next boundary or use end of staging
248+
let next_boundary = if i + 1 < buffer.batch_boundaries.len() {
249+
buffer.batch_boundaries[i + 1]
250+
} else {
251+
buffer.staging.len()
252+
};
253+
254+
batches.push(buffer.staging[start..next_boundary].to_vec());
255+
start = next_boundary;
256+
}
257+
258+
let staging: BatchMemoryBlock = BatchMemoryBlock(batches);
259+
260+
// Clear everything
261+
buffer.staging.clear();
262+
buffer.block_position_index.clear();
263+
buffer.batch_boundaries.clear();
264+
229265
let staging_ref = Arc::new(staging);
230266
let flight_id = buffer.flight_counter;
231267

@@ -247,12 +283,20 @@ impl MemoryBuffer {
247283
#[trace]
248284
pub fn append(&self, blocks: Vec<Block>, size: u64) -> Result<()> {
249285
let mut buffer = self.buffer.lock();
250-
let mut staging = &mut buffer.staging;
251-
staging.push(blocks);
286+
let current_position = buffer.staging.len();
252287

288+
// Record batch boundary
289+
if !blocks.is_empty() {
290+
buffer.batch_boundaries.push(current_position);
291+
}
292+
for (idx, block) in blocks.into_iter().enumerate() {
293+
buffer
294+
.block_position_index
295+
.insert(block.block_id, current_position + idx);
296+
buffer.staging.push(block);
297+
}
253298
buffer.staging_size += size as i64;
254299
buffer.total_size += size as i64;
255-
256300
Ok(())
257301
}
258302
}

0 commit comments

Comments
 (0)