@@ -208,7 +208,7 @@ def __init__(
208208 max_seq_len = config .max_position_embeddings ,
209209 base = config .rope_theta ,
210210 interleaved = config .rope_interleaved ,
211- seq_len_scaling_factor = config .rope_seq_len_scaling_factor ,
211+ seq_len_scaling_factor = config .rope_seq_len_interpolation_factor ,
212212 fused = config ._fused_rotary_emb ,
213213 )
214214 self .attention = CoreAttention (config , tp_pg , cp_pg , layer_idx )
@@ -238,28 +238,28 @@ def forward(
238238
239239 if self ._use_qkv_packed :
240240 attn_output = self ._forward_packed (qkv , seq_length , position_ids , cu_seqlens )
241- # else:
242- # q, k, v = qkv.split(
243- # [self.local_q_size, self.local_kv_size, self.local_kv_size], dim=-1
244- # ) # [batch_size*seq_length, q_size], [batch_size*seq_length, kv_size]
245- # q = q.view(-1, self.local_num_heads, self.head_dim) # [b*s, num_heads, head_dim]
246- # k = k.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim]
247- # v = v.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim]
248- # if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0:
249- # rotary_pos_emb = self.rotary_emb(
250- # position_ids=position_ids if not self.simple_causal_mask else None, seq_length=seq_length
251- # ) # [b*s, dim] or [seq_length, dim]
252- # q = self.rotary_emb.apply_rotary_pos_emb(
253- # q, rotary_pos_emb, seq_length=seq_length
254- # ) # [b*s, num_heads, head_dim]
255- # k = self.rotary_emb.apply_rotary_pos_emb(
256- # k, rotary_pos_emb, seq_length=seq_length
257- # ) # [b*s, num_kv_heads, head_dim]
258- # else:
259- # log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0)
260- # attn_output = self.attention(
261- # q, k, v, position_ids=position_ids, seq_length=seq_length, cu_seqlens=cu_seqlens
262- # )
241+ else :
242+ q , k , v = qkv .split (
243+ [self .local_q_size , self .local_kv_size , self .local_kv_size ], dim = - 1
244+ ) # [batch_size*seq_length, q_size], [batch_size*seq_length, kv_size]
245+ q = q .view (- 1 , self .local_num_heads , self .head_dim ) # [b*s, num_heads, head_dim]
246+ k = k .view (- 1 , self .local_num_kv_heads , self .head_dim ) # [b*s, num_kv_heads, head_dim]
247+ v = v .view (- 1 , self .local_num_kv_heads , self .head_dim ) # [b*s, num_kv_heads, head_dim]
248+ if self .config .no_rope_layer is None or (self .layer_idx + 1 ) % self .config .no_rope_layer != 0 :
249+ rotary_pos_emb = self .rotary_emb (
250+ position_ids = position_ids if not self .simple_causal_mask else None , seq_length = seq_length
251+ ) # [b*s, dim] or [seq_length, dim]
252+ q = self .rotary_emb .apply_rotary_pos_emb (
253+ q , rotary_pos_emb , seq_length = seq_length
254+ ) # [b*s, num_heads, head_dim]
255+ k = self .rotary_emb .apply_rotary_pos_emb (
256+ k , rotary_pos_emb , seq_length = seq_length
257+ ) # [b*s, num_kv_heads, head_dim]
258+ else :
259+ log_rank (f"skipping rotary for layer { self .layer_idx + 1 } " , logger = logger , level = logging .DEBUG , rank = 0 )
260+ attn_output = self .attention (
261+ q , k , v , position_ids = position_ids , seq_length = seq_length , cu_seqlens = cu_seqlens
262+ )
263263 output = self .o_proj (attn_output )
264264 # Return original position_ids shape
265265 return {"hidden_states" : output , "position_ids" : position_ids .view (- 1 , seq_length )}
0 commit comments