@@ -524,14 +524,31 @@ def _apply_transpose(
524524
525525
526526def _align_shape (
527- val : jnp .ndarray , tgt_shape : Tuple [int , ...], src_key : str
527+ val : jnp .ndarray ,
528+ tgt_shape : Tuple [int , ...],
529+ src_key : str ,
530+ rollout_engine : Optional [str ] = None ,
531+ ** kwargs ,
528532) -> jnp .ndarray :
529533 """Align source value shape to target shape through padding or repeating.
530534
535+ This function attempts to align the shape of a source JAX array (`val`) to a
536+ target shape (`tgt_shape`). It supports alignment by:
537+ 1. Reshaping: If the product of dimensions matches, especially for attention
538+ biases and projections.
539+ 2. Padding/Repeating: For attention-related weights, it can pad the head
540+ dimension or repeat along the number of heads dimension.
541+ 3. Special Handling: Includes specific logic for 1-D KV biases in
542+ 'sglang_jax' rollout.
543+
531544 Args:
532545 val: Source value.
533546 tgt_shape: Target shape.
534547 src_key: Source key for error messages.
548+ rollout_engine: Optional string indicating the rollout engine, used for
549+ special-casing certain alignments (e.g., 'sglang_jax').
550+ **kwargs: Additional keyword arguments, potentially containing metadata
551+ like 'num_kv_heads' and 'head_dim' for specific alignment logic.
535552
536553 Returns:
537554 Shape-aligned value.
@@ -596,6 +613,27 @@ def _align_shape(
596613 raise ShapeMismatchError (
597614 f'Rank mismatch for { src_key } : { val .shape } vs { tgt_shape } '
598615 )
616+ elif rollout_engine == 'sglang_jax' and re .compile (
617+ r'layers\..*\.attn\.(k|v)_bias'
618+ ).match (src_key ):
619+ logging .debug (
620+ 'Handling 1-D KV bias for %s in SGLangJAX rollout.' , src_key
621+ )
622+ assert tgt_shape [0 ] > val .shape [0 ] and tgt_shape [0 ] % val .shape [0 ] == 0 , (
623+ f'Unexpected attention bias shape: { val .shape } and target shape:'
624+ f' { tgt_shape } '
625+ )
626+ repeat_factor = tgt_shape [0 ] // val .shape [0 ]
627+ logging .debug (
628+ 'Replicating 1-D KV bias on %s: %s -> %s (repeat x%d per head)' ,
629+ src_key ,
630+ val .shape ,
631+ tgt_shape ,
632+ repeat_factor ,
633+ )
634+ val_2d = jnp .reshape (val , (kwargs ['num_kv_heads' ], kwargs ['head_dim' ]))
635+ val_2d = jnp .repeat (val_2d , repeat_factor , axis = 0 )
636+ return jnp .reshape (val_2d , tgt_shape )
599637
600638 attention_patterns = [
601639 r'.*(q|k|v|o)_proj.*' ,
@@ -680,6 +718,7 @@ def transfer_state_with_mappings(
680718 transpose_keys = None ,
681719 reshard_fn = None ,
682720 rollout_engine = None ,
721+ ** kwargs ,
683722):
684723 """Transfer state using mappings, with optional transpose and shard logic.
685724
@@ -695,6 +734,8 @@ def transfer_state_with_mappings(
695734 transpose_keys: A dictionary defining which keys to transpose and the
696735 corresponding axes to transpose.
697736 reshard_fn: A function to shard the value.
737+ rollout_engine: The name of the rollout engine being used.
738+ **kwargs: Additional keyword arguments.
698739
699740 Returns:
700741 The target state with the transferred values.
@@ -722,7 +763,7 @@ def transfer_state_with_mappings(
722763 unscanned_src_to_tgt_flat = _unroll_scanned_layers (src_state , src_to_tgt_map )
723764
724765 # Transfer values with transformations
725- for (flat_src_key , tgt_key ), (
766+ for (flat_src_key , _ ), (
726767 val ,
727768 tgt_param ,
728769 ) in unscanned_src_to_tgt_flat .items ():
@@ -734,7 +775,9 @@ def transfer_state_with_mappings(
734775 val = key_mapping_hook_fns [flat_src_key ](val )
735776
736777 # Align shapes (padding/repeating as needed)
737- val = _align_shape (val , tgt_param .value .shape , flat_src_key )
778+ val = _align_shape (
779+ val , tgt_param .value .shape , flat_src_key , rollout_engine , ** kwargs
780+ )
738781
739782 # Cast to target dtype
740783 val = _apply_dtype_cast (val , tgt_param .value .dtype , flat_src_key )
0 commit comments