1515import jax
1616import jax .numpy as jnp
1717from flax import nnx
18- from jax .sharding import NamedSharding
1918from jax .sharding import PartitionSpec as P
2019from transformers import UMT5Config
2120
@@ -261,48 +260,98 @@ def __call__(
261260
262261 def _native_attention (self , q , k , v , forward_batch : ForwardBatch ):
263262 """Native attention for encoder/cross-attention with T5 position bias."""
264- num_tokens , hidden = q .shape [0 ], q .shape [- 1 ]
265- head_dim = hidden // self .n_heads
266-
267- # Reshape to [heads, tokens, head_dim]
268- def to_heads (x ):
269- n_tok = x .shape [0 ]
270- return jnp .transpose (x .reshape (n_tok , self .n_heads , head_dim ), (1 , 0 , 2 ))
271-
272- q_h , k_h , v_h = to_heads (q ), to_heads (k ), to_heads (v )
273-
274- # Compute scores in float32
275- scores = jnp .einsum ("hqd,hkd->hqk" , q_h .astype (jnp .float32 ), k_h .astype (jnp .float32 ))
263+ hidden = q .shape [- 1 ]
264+ head_dim = self .d_kv # T5 uses d_kv as head dimension, not hidden // n_heads
265+ n_heads = self .n_heads # Capture as local variable for closure
266+ is_cross_attn = self .is_cross_attention # Capture as local variable
267+ has_rel_bias = hasattr (self , "rel_bias" ) # Capture as local variable
268+
269+ # Debug: print dimensions
270+ jax .debug .print (
271+ "UMT5 _native_attention: q.shape={q_shape}, hidden={hidden}, d_kv={d_kv}, n_heads={n_heads}, inner_dim={inner_dim}" ,
272+ q_shape = q .shape ,
273+ hidden = hidden ,
274+ d_kv = head_dim ,
275+ n_heads = n_heads ,
276+ inner_dim = self .inner_dim ,
277+ )
276278
277279 # Get sequence lengths
278280 q_lens = getattr (forward_batch , "extend_seq_lens" , forward_batch .seq_lens )
279281 # Fallback if seq_lens is None: assume single sequence
280282 if q_lens is None :
281283 q_lens = jnp .array ([q .shape [0 ]], dtype = jnp .int32 )
282284
283- # Add position bias for self-attention (T5-specific)
284- if not self .is_cross_attention and hasattr (self , "rel_bias" ):
285- pos_bias = self ._compute_position_bias (q_lens , q .shape [0 ], k .shape [0 ])
286- scores = scores + pos_bias .astype (jnp .float32 )
285+ rel_bias_weight = self .rel_bias .embedding .value if hasattr (self , "rel_bias" ) else None
287286
288- # Apply masking
289287 kv_lens = (
290288 getattr (forward_batch , "encoder_seq_lens" , q_lens )
291289 if self .is_cross_attention
292290 else q_lens
293291 )
294292 is_causal = self .is_decoder and not self .is_cross_attention
295293
296- # Apply block_diagonal_mask
297- scores = _apply_block_diagonal_mask (scores , q_lens , kv_lens , is_causal = is_causal )
294+ # Wrap computation in shard_map for data parallelism
295+ in_specs = (
296+ P ("data" , "tensor" ), # q
297+ P ("data" , "tensor" ), # k
298+ P ("data" , "tensor" ), # v
299+ P ("data" ), # q_lens
300+ P ("data" ), # kv_lens
301+ P (None , "tensor" ), # rel_bias_weight
302+ )
303+ out_specs = P ("data" , "tensor" )
304+
305+ def _compute_attention (q_local , k_local , v_local , q_lens_local , kv_lens_local , rel_weight ):
306+ # Debug: print local shapes inside shard_map
307+ jax .debug .print (
308+ "Inside shard_map: q_local.shape={q_shape}, n_heads={n_heads}, head_dim={head_dim}" ,
309+ q_shape = q_local .shape ,
310+ n_heads = n_heads ,
311+ head_dim = head_dim ,
312+ )
313+ local_n_heads = q_local .shape [- 1 ] // head_dim
314+ local_hidden = q_local .shape [- 1 ]
315+
316+ # Reshape to [heads, tokens, head_dim]
317+ def to_heads (x ):
318+ n_tok = x .shape [0 ]
319+ return jnp .transpose (x .reshape (n_tok , local_n_heads , head_dim ), (1 , 0 , 2 ))
320+
321+ q_h , k_h , v_h = to_heads (q_local ), to_heads (k_local ), to_heads (v_local )
298322
299- # Softmax and weighted sum
300- weights = jax .nn .softmax (scores , axis = - 1 )
301- out = jnp .einsum ("hqk,hkd->hqd" , weights , v_h .astype (jnp .float32 ))
323+ # Compute scores in float32
324+ scores = jnp .einsum ("hqd,hkd->hqk" , q_h .astype (jnp .float32 ), k_h .astype (jnp .float32 ))
302325
303- return jnp .transpose (out , (1 , 0 , 2 )).reshape (num_tokens , hidden )
326+ # Add position bias for self-attention (T5-specific)
327+ if not is_cross_attn and has_rel_bias :
328+ pos_bias = self ._compute_position_bias (
329+ q_lens_local , q_local .shape [0 ], k_local .shape [0 ], rel_weight
330+ )
331+ scores = scores + pos_bias .astype (jnp .float32 )
304332
305- def _compute_position_bias (self , seq_lens , q_len , k_len ):
333+ # Apply block_diagonal_mask
334+ scores = _apply_block_diagonal_mask (
335+ scores , q_lens_local , kv_lens_local , is_causal = is_causal
336+ )
337+
338+ # Softmax and weighted sum
339+ weights = jax .nn .softmax (scores , axis = - 1 )
340+ out = jnp .einsum ("hqk,hkd->hqd" , weights , v_h .astype (jnp .float32 ))
341+
342+ return jnp .transpose (out , (1 , 0 , 2 )).reshape (q_local .shape [0 ], local_hidden )
343+
344+ result = jax .shard_map (
345+ _compute_attention ,
346+ mesh = self .mesh ,
347+ in_specs = in_specs ,
348+ out_specs = out_specs ,
349+ check_vma = False ,
350+ )(q , k , v , q_lens , kv_lens , rel_bias_weight )
351+
352+ return result
353+
354+ def _compute_position_bias (self , seq_lens , q_len , k_len , rel_weight ):
306355 """Compute T5 position bias [heads, q_len, k_len]."""
307356 starts = jnp .cumsum (seq_lens ) - seq_lens
308357 indicators = jnp .zeros (q_len , dtype = jnp .int32 ).at [starts ].set (1 )
@@ -318,7 +367,8 @@ def _compute_position_bias(self, seq_lens, q_len, k_len):
318367 num_buckets = self .num_buckets ,
319368 max_distance = self .max_distance ,
320369 )
321- return jnp .transpose (self .rel_bias (buckets ), (2 , 0 , 1 ))
370+ bias = rel_weight [buckets ]
371+ return jnp .transpose (bias , (2 , 0 , 1 ))
322372
323373
324374# =============================================================================
@@ -467,8 +517,9 @@ def __call__(self, forward_batch: ForwardBatch, token_to_kv_pool=None, logits_me
467517
468518 # Dummy logits for interface compatibility
469519 bs = forward_batch .seq_lens .shape [0 ]
470- dummy = jnp .zeros ((bs , self .config .vocab_size ), dtype = self .dtype )
471- dummy = jax .sharding .reshard (dummy , NamedSharding (self .mesh , P (None , "tensor" )))
520+ dummy = jnp .zeros (
521+ (bs , self .config .vocab_size ), dtype = self .dtype , out_sharding = ("data" , "tensor" )
522+ )
472523 return LogitsProcessorOutput (next_token_logits = dummy , hidden_states = hidden ), [], [], None
473524
474525
0 commit comments