Skip to content

Commit b3126a6

Browse files
JiuChen0dannywillowliu-uchiroot
authored
Remove O(prompt_len) prompt copies (#35)
* Add batch inference support and CPU compatibility - Add --batch_size CLI argument for parallel sequence processing - Add conditional CUDA stream creation for CPU-only mode - Add device-aware ExecutionEnv and Policy resource distribution - Fix MPS compatibility on macOS * fix hardcode of model loading and support batch size * Resolving dependency conflicts * docs: refine README setup and usage sections for clarity and correctness * Add batch size related updates * delete ddebug output * delete .id files * fix max token size problem * add prompt * Reduce /dev/shm peak usage during warmup/prefill stage * delete dead code * chore: comment out unused compare_tensors function * delete bitsandbytes quant * support flexgen 4bit quant * clean debug output for server id * add effective throughput * clean up unnecessary files * fix the error of start compute time * Use rolling buffer to avoid O(prompt_len) copy on each forward * The debug I/O issue has been fixed * Use rolling buffer to avoid O(prompt_len) copy on each forward --------- Co-authored-by: Danny Willow Liu <dannywillowliu@uchicago.edu> Co-authored-by: root <root@investorairig80.maas>
1 parent 241bbc3 commit b3126a6

3 files changed

Lines changed: 49 additions & 19 deletions

File tree

src/bloombee/models/llama/block.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@ def forward( # pyright: ignore[reportIncompatibleMethodOverride]
9494
output_attentions = False
9595
assert not output_attentions
9696

97-
# print('🔧 OptimizedLlamaAttention.forward(): received position_ids:', position_ids)
97+
# print(' OptimizedLlamaAttention.forward(): received position_ids:', position_ids)
9898
# if position_ids is not None:
99-
# print(f'🔧 position_ids shape: {position_ids.shape}, dtype: {position_ids.dtype}')
100-
# print(f'🔧 position_ids content: {position_ids}')
99+
# print(f' position_ids shape: {position_ids.shape}, dtype: {position_ids.dtype}')
100+
# print(f' position_ids content: {position_ids}')
101101

102102
if position_ids is None:
103103
past_seen_tokens = past_key_value[0].shape[2] if past_key_value is not None else 0
@@ -107,9 +107,9 @@ def forward( # pyright: ignore[reportIncompatibleMethodOverride]
107107
device=hidden_states.device,
108108
dtype=torch.long
109109
).unsqueeze(0) # pyright: ignore[reportAssignmentType]
110-
# print(f'🔧 Generated fallback position_ids: {position_ids}')
110+
# print(f' Generated fallback position_ids: {position_ids}')
111111

112-
# print('🔧 Final position_ids before processing:', position_ids)
112+
# print(' Final position_ids before processing:', position_ids)
113113

114114
# Optimized: Avoid .item() CPU-GPU sync by using direct indexing
115115
# Most common case: 2D tensor [batch_size, seq_len]
@@ -124,7 +124,7 @@ def forward( # pyright: ignore[reportIncompatibleMethodOverride]
124124
else:
125125
start_position = 0
126126

127-
# print(f'🔧 Extracted start_position: {start_position}')
127+
# print(f' Extracted start_position: {start_position}')
128128

129129
self.temp_hidden_states.val = super(OptimizedLlamaAttention, self).forward(
130130
hidden_states, cache_read_buf, weight_read_buf, attention_mask, cache_write_buf, start_position, k
@@ -210,6 +210,11 @@ def __init__(self, config: LlamaConfig, layer_id: int, env: ExecutionEnv, policy
210210

211211
# GPU stream management optimization
212212
self._streams_initialized = False
213+
214+
# Rolling buffer for output_ids to avoid O(prompt_len) copy on each forward
215+
self._cached_output_ids = None
216+
self._cached_output_ids_shape = None
217+
self._output_ids_prompt_initialized = False
213218

214219
# log_mem(f"[LlamaDecoderLayer:{self.layer_id}] before init_all_weights")
215220
self.init_all_weights()
@@ -406,13 +411,16 @@ def forward(
406411
self._last_prompt_len = actual_prompt_len
407412
self._last_gen_len = max_new_tokens
408413

414+
# Reset output_ids prompt flag when task changes
415+
self._output_ids_prompt_initialized = False
416+
409417
if not self._is_initialized:
410418
self._is_initialized = True
411419

412420
# Performance monitoring: record Task rebuild time
413421
if task_rebuild_start is not None:
414422
task_rebuild_time = (time.time() - task_rebuild_start) * 1000
415-
if task_rebuild_time > 1.0: # 只记录超过1ms的情况
423+
if task_rebuild_time > 1.0: # Record only when it takes more than 1ms
416424
print(f"[BLOCK_PERF] Layer {self.layer_id} Task rebuild took: {task_rebuild_time:.3f}ms")
417425

418426
task = self._cached_task
@@ -424,8 +432,29 @@ def forward(
424432
num_prompts = len(task.inputs)
425433
prompt_len, gen_len = task.prompt_len, task.gen_len
426434

427-
self.output_ids = np.ones((num_prompts, prompt_len + gen_len), dtype=np.int64)
428-
self.output_ids[:, :prompt_len] = np.asarray(task.inputs)
435+
# Use rolling buffer to avoid O(prompt_len) copy on each forward
436+
# Only reallocate when shape changes
437+
output_ids_start = time.time()
438+
target_shape = (num_prompts, prompt_len + gen_len)
439+
if self._cached_output_ids is None or self._cached_output_ids_shape != target_shape:
440+
# Shape changed, need to reallocate
441+
self._cached_output_ids = np.ones(target_shape, dtype=np.int64)
442+
self._cached_output_ids_shape = target_shape
443+
self._output_ids_prompt_initialized = False
444+
if verbose > 0:
445+
print(f"[OUTPUT_IDS_PERF] Layer {self.layer_id}: Reallocated output_ids with shape {target_shape}")
446+
447+
# Only copy prompt tokens when necessary (first time or task changed)
448+
if not self._output_ids_prompt_initialized:
449+
self._cached_output_ids[:, :prompt_len] = np.asarray(task.inputs)
450+
self._output_ids_prompt_initialized = True
451+
if verbose > 0:
452+
print(f"[OUTPUT_IDS_PERF] Layer {self.layer_id}: Initialized prompt tokens ({prompt_len} tokens)")
453+
454+
self.output_ids = self._cached_output_ids
455+
output_ids_time = (time.time() - output_ids_start) * 1000
456+
if output_ids_time > 1.0:
457+
print(f"[OUTPUT_IDS_PERF] Layer {self.layer_id} output_ids setup took: {output_ids_time:.3f}ms")
429458

430459
# Smart cache clearing - avoid clearing every time
431460
cache_clear_start = time.time()
@@ -496,10 +525,10 @@ def forward(
496525
if position_ids is not None and position_ids.numel() > 0:
497526
# Optimized: Avoid .item() sync
498527
current_position = position_ids.flatten()[0]
499-
# print(f'🔧 Using actual position from position_ids: {current_position}')
528+
# print(f' Using actual position from position_ids: {current_position}')
500529
else:
501530
current_position = 0
502-
# print(f'🔧 No position_ids provided, using fallback position: {current_position}')
531+
# print(f' No position_ids provided, using fallback position: {current_position}')
503532

504533
i = current_position
505534

@@ -613,7 +642,7 @@ def to_torch_tensor(x):
613642
outputs = (hidden_states, past_key_value)
614643
# log_mem(f"[Layer:{self.layer_id}] forward(end) out_shape={hidden_states.shape}")
615644
# Remove empty_cache call from each forward to reduce GPU overhead
616-
# torch.cuda.empty_cache() # 这会导致性能问题
645+
# torch.cuda.empty_cache()
617646
return outputs
618647

619648
def load_weight(self, i, j, k, overlap=True):
@@ -672,7 +701,7 @@ def store_cache(self, i, j, k, overlap=True):
672701
with torch.cuda.stream(self.store_cache_stream):
673702
self.layers[j].store_cache(self.cache_home[j][k], self.cache_write_buf[j][k], i)
674703
# Remove unnecessary synchronization to reduce GPU blocking
675-
# torch.cuda.synchronize() # 这会造成性能瓶颈
704+
# torch.cuda.synchronize()
676705
else:
677706
self.layers[j].store_cache(self.cache_home[j][k], self.cache_write_buf[j][k], i)
678707

@@ -742,7 +771,7 @@ def compute_layer(self, i, j, k, position_ids=None, generated_tokens_num=0):
742771
if j == 1:
743772
self.hidden[0][j][k].val = self.temp_hidden.val
744773

745-
# print(f'🔧 compute_layer: i={i}, j={j}, k={k}, received position_ids={position_ids}')
774+
# print(f' compute_layer: i={i}, j={j}, k={k}, received position_ids={position_ids}')
746775

747776
self.layers[j].forward(hidden_states=self.hidden[0][j][k],
748777
cache_read_buf=self.cache_read_buf[j][k],
@@ -784,10 +813,10 @@ def forward(
784813
seq_length_with_past = seq_length_with_past + past_key_values_length
785814
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
786815

787-
# print(f'🔧 WrappedLlamaBlock.forward: received position_ids={position_ids}')
816+
# print(f' WrappedLlamaBlock.forward: received position_ids={position_ids}')
788817
if position_ids is not None:
789818
pass
790-
# print(f'🔧 WrappedLlamaBlock.forward: position_ids shape={position_ids.shape}, content={position_ids}')
819+
# print(f' WrappedLlamaBlock.forward: position_ids shape={position_ids.shape}, content={position_ids}')
791820

792821
# print(f"WrappedLlamaBlock, hidden_states: {hidden_states}, seq_length: {seq_length}, past_key_value: {past_key_value}")
793822
# Optimized: Reuse cached attention_mask

src/bloombee/models/llama/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def prepare_inputs_for_generation(
185185
input_ids = input_ids[:, past_length:]
186186
# print(f" Past length case: {original_shape} -> {input_ids.shape}, kept tokens: {input_ids}")
187187
else:
188-
print(f" No truncation needed: past_length={past_length}, input_ids.shape[1]={input_ids.shape[1]}")
188+
logger.debug(f"No truncation needed: past_length={past_length}, input_ids.shape[1]={input_ids.shape[1]}")
189189

190190
if (
191191
max_cache_length is not None
@@ -205,10 +205,10 @@ def prepare_inputs_for_generation(
205205
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
206206
if inputs_embeds is not None and past_key_values is None:
207207
model_inputs = {"inputs_embeds": inputs_embeds}
208-
print(f" Using inputs_embeds for first generation step")
208+
logger.debug("Using inputs_embeds for first generation step")
209209
else:
210210
model_inputs = {"input_ids": input_ids}
211-
# print(f" Using input_ids: {input_ids}")
211+
# logger.debug(f"Using input_ids: {input_ids}")
212212

213213
model_inputs.update(
214214
{

src/bloombee/server/block_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ async def iterate_rpc_inference(
289289

290290
# Add Cross-GPU Transfer Latency measurement
291291
cross_gpu_start_time = perf_counter()
292+
start_compute_time = perf_counter() # Initialize compute time tracking
292293

293294
# parse deep prompts (optional argument)
294295
has_prompts = prompts is not None and not is_dummy(prompts)

0 commit comments

Comments
 (0)