@@ -94,10 +94,10 @@ def forward( # pyright: ignore[reportIncompatibleMethodOverride]
9494 output_attentions = False
9595 assert not output_attentions
9696
97- # print('🔧 OptimizedLlamaAttention.forward(): received position_ids:', position_ids)
97+ # print(' OptimizedLlamaAttention.forward(): received position_ids:', position_ids)
9898 # if position_ids is not None:
99- # print(f'🔧 position_ids shape: {position_ids.shape}, dtype: {position_ids.dtype}')
100- # print(f'🔧 position_ids content: {position_ids}')
99+ # print(f' position_ids shape: {position_ids.shape}, dtype: {position_ids.dtype}')
100+ # print(f' position_ids content: {position_ids}')
101101
102102 if position_ids is None :
103103 past_seen_tokens = past_key_value [0 ].shape [2 ] if past_key_value is not None else 0
@@ -107,9 +107,9 @@ def forward( # pyright: ignore[reportIncompatibleMethodOverride]
107107 device = hidden_states .device ,
108108 dtype = torch .long
109109 ).unsqueeze (0 ) # pyright: ignore[reportAssignmentType]
110- # print(f'🔧 Generated fallback position_ids: {position_ids}')
110+ # print(f' Generated fallback position_ids: {position_ids}')
111111
112- # print('🔧 Final position_ids before processing:', position_ids)
112+ # print(' Final position_ids before processing:', position_ids)
113113
114114 # Optimized: Avoid .item() CPU-GPU sync by using direct indexing
115115 # Most common case: 2D tensor [batch_size, seq_len]
@@ -124,7 +124,7 @@ def forward( # pyright: ignore[reportIncompatibleMethodOverride]
124124 else :
125125 start_position = 0
126126
127- # print(f'🔧 Extracted start_position: {start_position}')
127+ # print(f' Extracted start_position: {start_position}')
128128
129129 self .temp_hidden_states .val = super (OptimizedLlamaAttention , self ).forward (
130130 hidden_states , cache_read_buf , weight_read_buf , attention_mask , cache_write_buf , start_position , k
@@ -210,6 +210,11 @@ def __init__(self, config: LlamaConfig, layer_id: int, env: ExecutionEnv, policy
210210
211211 # GPU stream management optimization
212212 self ._streams_initialized = False
213+
214+ # Rolling buffer for output_ids to avoid O(prompt_len) copy on each forward
215+ self ._cached_output_ids = None
216+ self ._cached_output_ids_shape = None
217+ self ._output_ids_prompt_initialized = False
213218
214219 # log_mem(f"[LlamaDecoderLayer:{self.layer_id}] before init_all_weights")
215220 self .init_all_weights ()
@@ -406,13 +411,16 @@ def forward(
406411 self ._last_prompt_len = actual_prompt_len
407412 self ._last_gen_len = max_new_tokens
408413
414+ # Reset output_ids prompt flag when task changes
415+ self ._output_ids_prompt_initialized = False
416+
409417 if not self ._is_initialized :
410418 self ._is_initialized = True
411419
412420 # Performance monitoring: record Task rebuild time
413421 if task_rebuild_start is not None :
414422 task_rebuild_time = (time .time () - task_rebuild_start ) * 1000
415- if task_rebuild_time > 1.0 : # 只记录超过1ms的情况
423+ if task_rebuild_time > 1.0 : # Record only when it takes more than 1ms
416424 print (f"[BLOCK_PERF] Layer { self .layer_id } Task rebuild took: { task_rebuild_time :.3f} ms" )
417425
418426 task = self ._cached_task
@@ -424,8 +432,29 @@ def forward(
424432 num_prompts = len (task .inputs )
425433 prompt_len , gen_len = task .prompt_len , task .gen_len
426434
427- self .output_ids = np .ones ((num_prompts , prompt_len + gen_len ), dtype = np .int64 )
428- self .output_ids [:, :prompt_len ] = np .asarray (task .inputs )
435+ # Use rolling buffer to avoid O(prompt_len) copy on each forward
436+ # Only reallocate when shape changes
437+ output_ids_start = time .time ()
438+ target_shape = (num_prompts , prompt_len + gen_len )
439+ if self ._cached_output_ids is None or self ._cached_output_ids_shape != target_shape :
440+ # Shape changed, need to reallocate
441+ self ._cached_output_ids = np .ones (target_shape , dtype = np .int64 )
442+ self ._cached_output_ids_shape = target_shape
443+ self ._output_ids_prompt_initialized = False
444+ if verbose > 0 :
445+ print (f"[OUTPUT_IDS_PERF] Layer { self .layer_id } : Reallocated output_ids with shape { target_shape } " )
446+
447+ # Only copy prompt tokens when necessary (first time or task changed)
448+ if not self ._output_ids_prompt_initialized :
449+ self ._cached_output_ids [:, :prompt_len ] = np .asarray (task .inputs )
450+ self ._output_ids_prompt_initialized = True
451+ if verbose > 0 :
452+ print (f"[OUTPUT_IDS_PERF] Layer { self .layer_id } : Initialized prompt tokens ({ prompt_len } tokens)" )
453+
454+ self .output_ids = self ._cached_output_ids
455+ output_ids_time = (time .time () - output_ids_start ) * 1000
456+ if output_ids_time > 1.0 :
457+ print (f"[OUTPUT_IDS_PERF] Layer { self .layer_id } output_ids setup took: { output_ids_time :.3f} ms" )
429458
430459 # Smart cache clearing - avoid clearing every time
431460 cache_clear_start = time .time ()
@@ -496,10 +525,10 @@ def forward(
496525 if position_ids is not None and position_ids .numel () > 0 :
497526 # Optimized: Avoid .item() sync
498527 current_position = position_ids .flatten ()[0 ]
499- # print(f'🔧 Using actual position from position_ids: {current_position}')
528+ # print(f' Using actual position from position_ids: {current_position}')
500529 else :
501530 current_position = 0
502- # print(f'🔧 No position_ids provided, using fallback position: {current_position}')
531+ # print(f' No position_ids provided, using fallback position: {current_position}')
503532
504533 i = current_position
505534
@@ -613,7 +642,7 @@ def to_torch_tensor(x):
613642 outputs = (hidden_states , past_key_value )
614643 # log_mem(f"[Layer:{self.layer_id}] forward(end) out_shape={hidden_states.shape}")
615644 # Remove empty_cache call from each forward to reduce GPU overhead
616- # torch.cuda.empty_cache() # 这会导致性能问题
645+ # torch.cuda.empty_cache()
617646 return outputs
618647
619648 def load_weight (self , i , j , k , overlap = True ):
@@ -672,7 +701,7 @@ def store_cache(self, i, j, k, overlap=True):
672701 with torch .cuda .stream (self .store_cache_stream ):
673702 self .layers [j ].store_cache (self .cache_home [j ][k ], self .cache_write_buf [j ][k ], i )
674703 # Remove unnecessary synchronization to reduce GPU blocking
675- # torch.cuda.synchronize() # 这会造成性能瓶颈
704+ # torch.cuda.synchronize()
676705 else :
677706 self .layers [j ].store_cache (self .cache_home [j ][k ], self .cache_write_buf [j ][k ], i )
678707
@@ -742,7 +771,7 @@ def compute_layer(self, i, j, k, position_ids=None, generated_tokens_num=0):
742771 if j == 1 :
743772 self .hidden [0 ][j ][k ].val = self .temp_hidden .val
744773
745- # print(f'🔧 compute_layer: i={i}, j={j}, k={k}, received position_ids={position_ids}')
774+ # print(f' compute_layer: i={i}, j={j}, k={k}, received position_ids={position_ids}')
746775
747776 self .layers [j ].forward (hidden_states = self .hidden [0 ][j ][k ],
748777 cache_read_buf = self .cache_read_buf [j ][k ],
@@ -784,10 +813,10 @@ def forward(
784813 seq_length_with_past = seq_length_with_past + past_key_values_length
785814 past_key_value = self ._reorder_cache_from_bloom_to_llama (past_key_value , batch_size , past_key_values_length )
786815
787- # print(f'🔧 WrappedLlamaBlock.forward: received position_ids={position_ids}')
816+ # print(f' WrappedLlamaBlock.forward: received position_ids={position_ids}')
788817 if position_ids is not None :
789818 pass
790- # print(f'🔧 WrappedLlamaBlock.forward: position_ids shape={position_ids.shape}, content={position_ids}')
819+ # print(f' WrappedLlamaBlock.forward: position_ids shape={position_ids.shape}, content={position_ids}')
791820
792821 # print(f"WrappedLlamaBlock, hidden_states: {hidden_states}, seq_length: {seq_length}, past_key_value: {past_key_value}")
793822 # Optimized: Reuse cached attention_mask
0 commit comments