1818from levanter .models .lm_model import LmHeadModel
1919
2020
21+ def _get_nnx_key_name (split_key : list [str ]) -> str :
22+ """
23+ Determine the NNX key name from the split Levanter key.
24+ If the key ends in 'bias', append '_bias' to the parameter name.
25+ Otherwise (e.g. 'weight'), use the parameter name directly.
26+ """
27+ key_name = split_key [- 2 ]
28+ if split_key [- 1 ] == "bias" :
29+ key_name = f"{ key_name } _bias"
30+ return key_name
31+
32+
2133def levanter_to_nnx_state (levanter_model : LmHeadModel ) -> dict :
2234 # The format of this state dict is flat like:
2335 # model.layers.0.self_attn.q_proj.weight -> jax array
@@ -46,7 +58,7 @@ def levanter_to_nnx_state(levanter_model: LmHeadModel) -> dict:
4658 # vLLM expects the weights to be padded to the next multiple of 128. I assume this is
4759 # because they want to use Pallas kernels which have this requirement.
4860 if "self_attn" in split_key_without_weight :
49- if "q_proj" in split_key_without_weight :
61+ if "q_proj" in split_key_without_weight and len ( value . shape ) == 4 :
5062 kv_heads , q_heads_per_group , head_size , embed = value .shape
5163 value = value .reshape (kv_heads * q_heads_per_group , head_size , embed )
5264
@@ -67,7 +79,7 @@ def levanter_to_nnx_state(levanter_model: LmHeadModel) -> dict:
6779 # pad 3rd dimension to 128 (e.g., (8, 2048, 64) -> (8, 2048, 128))
6880 value = jnp .pad (value , ((0 , 0 ), (0 , 0 ), (0 , next_multiple_of_128 - head_size )))
6981
70- current [split_key_without_weight [ - 1 ] ] = nnx .Param (value )
82+ current [_get_nnx_key_name ( split_key ) ] = nnx .Param (value )
7183 return nnx .State (nested_state_dict )
7284
7385
@@ -89,31 +101,46 @@ def levanter_state_dict_to_nnx_state_on_cpu(state_dict: dict) -> dict:
89101 current [part ] = {}
90102 current = current [part ]
91103
92- # for q, k, v projections, we need to pad the 2nd dimension to next multiple of 128
93- # vLLM expects the weights to be padded to the next multiple of 128. I assume this is
94- # because they want to use Pallas kernels which have this requirement.
104+ # vLLM requires weights/biases to be padded to the nearest multiple of 128 for Pallas kernels.
95105 if "self_attn" in split_key_without_weight :
106+ is_bias = split_key [- 1 ] == "bias"
107+
108+ # Flatten grouped query heads -> (Total Heads, Head Dim, [Embed]) for vLLM
96109 if "q_proj" in split_key_without_weight :
97- kv_heads , q_heads_per_group , head_size , embed = value .shape
98- value = value .reshape (kv_heads * q_heads_per_group , head_size , embed )
110+ if len (value .shape ) == 4 :
111+ # Weight: (KV, Group, HeadSize, Embed) -> (Heads, HeadSize, Embed)
112+ kv_heads , q_heads_per_group , head_size , embed = value .shape
113+ value = value .reshape (kv_heads * q_heads_per_group , head_size , embed )
114+ elif len (value .shape ) == 3 and is_bias :
115+ # Bias: (KV, Group, HeadSize) -> (Heads, HeadSize)
116+ kv_heads , q_heads_per_group , head_size = value .shape
117+ value = value .reshape (kv_heads * q_heads_per_group , head_size )
99118
119+ # Pad the head dimension (dim 1) for Q/K/V projections
100120 if (
101121 "q_proj" in split_key_without_weight
102122 or "k_proj" in split_key_without_weight
103123 or "v_proj" in split_key_without_weight
104124 ):
105- _heads , head_size , embed = value .shape
106- next_multiple_of_128 = ((head_size + 127 ) // 128 ) * 128
107- if head_size < next_multiple_of_128 :
108- # pad 2nd dimension to 128 (e.g., (8, 64, 2048) -> (8, 128, 2048))
109- value = jnp .pad (value , ((0 , 0 ), (0 , next_multiple_of_128 - head_size ), (0 , 0 )))
125+ pad_axis = 1
126+ if len (value .shape ) >= 2 :
127+ head_size = value .shape [pad_axis ]
128+ next_multiple_of_128 = ((head_size + 127 ) // 128 ) * 128
129+
130+ if head_size < next_multiple_of_128 :
131+ padding = [(0 , 0 )] * len (value .shape )
132+ padding [pad_axis ] = (0 , next_multiple_of_128 - head_size )
133+ value = jnp .pad (value , padding )
134+
135+ # Pad o_proj weights along the head dimension (dim 2)
110136 elif "o_proj" in split_key_without_weight :
111- embed , _heads , head_size = value .shape
112- next_multiple_of_128 = ((head_size + 127 ) // 128 ) * 128
113- if head_size < next_multiple_of_128 :
114- # pad 3rd dimension to 128 (e.g., (8, 2048, 64) -> (8, 2048, 128))
115- value = jnp .pad (value , ((0 , 0 ), (0 , 0 ), (0 , next_multiple_of_128 - head_size )))
137+ # Weight: (Embed, Heads, HeadSize). Skip bias as it is 1D (Embed,) or handled differently.
138+ if not is_bias and len (value .shape ) == 3 :
139+ embed , _heads , head_size = value .shape
140+ next_multiple_of_128 = ((head_size + 127 ) // 128 ) * 128
141+ if head_size < next_multiple_of_128 :
142+ value = jnp .pad (value , ((0 , 0 ), (0 , 0 ), (0 , next_multiple_of_128 - head_size )))
116143
117- current [split_key_without_weight [ - 1 ] ] = nnx .Param (value )
144+ current [_get_nnx_key_name ( split_key ) ] = nnx .Param (value )
118145
119146 return nnx .State (nested_state_dict )
0 commit comments