@@ -261,48 +261,98 @@ def __call__(
261261
262262 def _native_attention (self , q , k , v , forward_batch : ForwardBatch ):
263263 """Native attention for encoder/cross-attention with T5 position bias."""
264- num_tokens , hidden = q .shape [0 ], q .shape [- 1 ]
265- head_dim = hidden // self .n_heads
266-
267- # Reshape to [heads, tokens, head_dim]
268- def to_heads (x ):
269- n_tok = x .shape [0 ]
270- return jnp .transpose (x .reshape (n_tok , self .n_heads , head_dim ), (1 , 0 , 2 ))
271-
272- q_h , k_h , v_h = to_heads (q ), to_heads (k ), to_heads (v )
273-
274- # Compute scores in float32
275- scores = jnp .einsum ("hqd,hkd->hqk" , q_h .astype (jnp .float32 ), k_h .astype (jnp .float32 ))
264+ hidden = q .shape [- 1 ]
265+ head_dim = self .d_kv # T5 uses d_kv as head dimension, not hidden // n_heads
266+ n_heads = self .n_heads # Capture as local variable for closure
267+ is_cross_attn = self .is_cross_attention # Capture as local variable
268+ has_rel_bias = hasattr (self , "rel_bias" ) # Capture as local variable
269+
270+ # Debug: print dimensions
271+ jax .debug .print (
272+ "UMT5 _native_attention: q.shape={q_shape}, hidden={hidden}, d_kv={d_kv}, n_heads={n_heads}, inner_dim={inner_dim}" ,
273+ q_shape = q .shape ,
274+ hidden = hidden ,
275+ d_kv = head_dim ,
276+ n_heads = n_heads ,
277+ inner_dim = self .inner_dim ,
278+ )
276279
277280 # Get sequence lengths
278281 q_lens = getattr (forward_batch , "extend_seq_lens" , forward_batch .seq_lens )
279282 # Fallback if seq_lens is None: assume single sequence
280283 if q_lens is None :
281284 q_lens = jnp .array ([q .shape [0 ]], dtype = jnp .int32 )
282285
283- # Add position bias for self-attention (T5-specific)
284- if not self .is_cross_attention and hasattr (self , "rel_bias" ):
285- pos_bias = self ._compute_position_bias (q_lens , q .shape [0 ], k .shape [0 ])
286- scores = scores + pos_bias .astype (jnp .float32 )
286+ rel_bias_weight = self .rel_bias .embedding .value if hasattr (self , "rel_bias" ) else None
287287
288- # Apply masking
289288 kv_lens = (
290289 getattr (forward_batch , "encoder_seq_lens" , q_lens )
291290 if self .is_cross_attention
292291 else q_lens
293292 )
294293 is_causal = self .is_decoder and not self .is_cross_attention
295294
296- # Apply block_diagonal_mask
297- scores = _apply_block_diagonal_mask (scores , q_lens , kv_lens , is_causal = is_causal )
295+ # Wrap computation in shard_map for data parallelism
296+ in_specs = (
297+ P ("data" , "tensor" ), # q
298+ P ("data" , "tensor" ), # k
299+ P ("data" , "tensor" ), # v
300+ P ("data" ), # q_lens
301+ P ("data" ), # kv_lens
302+ P (None , "tensor" ), # rel_bias_weight
303+ )
304+ out_specs = P ("data" , "tensor" )
305+
306+ def _compute_attention (q_local , k_local , v_local , q_lens_local , kv_lens_local , rel_weight ):
307+ # Debug: print local shapes inside shard_map
308+ jax .debug .print (
309+ "Inside shard_map: q_local.shape={q_shape}, n_heads={n_heads}, head_dim={head_dim}" ,
310+ q_shape = q_local .shape ,
311+ n_heads = n_heads ,
312+ head_dim = head_dim ,
313+ )
314+ local_n_heads = q_local .shape [- 1 ] // head_dim
315+ local_hidden = q_local .shape [- 1 ]
316+
317+ # Reshape to [heads, tokens, head_dim]
318+ def to_heads (x ):
319+ n_tok = x .shape [0 ]
320+ return jnp .transpose (x .reshape (n_tok , local_n_heads , head_dim ), (1 , 0 , 2 ))
321+
322+ q_h , k_h , v_h = to_heads (q_local ), to_heads (k_local ), to_heads (v_local )
323+
324+ # Compute scores in float32
325+ scores = jnp .einsum ("hqd,hkd->hqk" , q_h .astype (jnp .float32 ), k_h .astype (jnp .float32 ))
326+
327+ # Add position bias for self-attention (T5-specific)
328+ if not is_cross_attn and has_rel_bias :
329+ pos_bias = self ._compute_position_bias (
330+ q_lens_local , q_local .shape [0 ], k_local .shape [0 ], rel_weight
331+ )
332+ scores = scores + pos_bias .astype (jnp .float32 )
333+
334+ # Apply block_diagonal_mask
335+ scores = _apply_block_diagonal_mask (
336+ scores , q_lens_local , kv_lens_local , is_causal = is_causal
337+ )
338+
339+ # Softmax and weighted sum
340+ weights = jax .nn .softmax (scores , axis = - 1 )
341+ out = jnp .einsum ("hqk,hkd->hqd" , weights , v_h .astype (jnp .float32 ))
342+
343+ return jnp .transpose (out , (1 , 0 , 2 )).reshape (q_local .shape [0 ], local_hidden )
298344
299- # Softmax and weighted sum
300- weights = jax .nn .softmax (scores , axis = - 1 )
301- out = jnp .einsum ("hqk,hkd->hqd" , weights , v_h .astype (jnp .float32 ))
345+ result = jax .shard_map (
346+ _compute_attention ,
347+ mesh = self .mesh ,
348+ in_specs = in_specs ,
349+ out_specs = out_specs ,
350+ check_vma = False ,
351+ )(q , k , v , q_lens , kv_lens , rel_bias_weight )
302352
303- return jnp . transpose ( out , ( 1 , 0 , 2 )). reshape ( num_tokens , hidden )
353+ return result
304354
305- def _compute_position_bias (self , seq_lens , q_len , k_len ):
355+ def _compute_position_bias (self , seq_lens , q_len , k_len , rel_weight ):
306356 """Compute T5 position bias [heads, q_len, k_len]."""
307357 starts = jnp .cumsum (seq_lens ) - seq_lens
308358 indicators = jnp .zeros (q_len , dtype = jnp .int32 ).at [starts ].set (1 )
@@ -318,7 +368,8 @@ def _compute_position_bias(self, seq_lens, q_len, k_len):
318368 num_buckets = self .num_buckets ,
319369 max_distance = self .max_distance ,
320370 )
321- return jnp .transpose (self .rel_bias (buckets ), (2 , 0 , 1 ))
371+ bias = rel_weight [buckets ]
372+ return jnp .transpose (bias , (2 , 0 , 1 ))
322373
323374
324375# =============================================================================
0 commit comments