@@ -165,22 +165,21 @@ def replace_buffer(self, buffers) -> None:
165165
166166 assert len (new_recurrent ) == self .num_linear_recurrent_layers
167167 assert len (new_conv ) == self .num_linear_recurrent_layers
168- for layer in range (self .num_linear_recurrent_layers ):
169- assert len (new_conv [layer ]) == len (self .conv_buffers [layer ])
170168
169+ # tp_size==1 sharding fix: see MHATokenToKVPool.replace_buffer
171170 tp_degenerate = self .mesh .shape .get ("tensor" , 1 ) == 1
172171 for layer in range (self .num_linear_recurrent_layers ):
173172 buf = new_recurrent [layer ]
174- if tp_degenerate and hasattr ( self , "recurrent_sharding" ) :
173+ if tp_degenerate :
175174 buf = jax .device_put (buf , self .recurrent_sharding )
176175 self .recurrent_buffers [layer ] = buf
177176
178- for layer in range (self .num_linear_recurrent_layers ):
177+ assert len ( new_conv [ layer ]) == len (self .conv_buffers [ layer ])
179178 for i in range (len (new_conv [layer ])):
180- buf = new_conv [layer ][i ]
181- if tp_degenerate and hasattr ( self , "conv_sharding" ) :
182- buf = jax .device_put (buf , self .conv_sharding )
183- self .conv_buffers [layer ][i ] = buf
179+ cbuf = new_conv [layer ][i ]
180+ if tp_degenerate :
181+ cbuf = jax .device_put (cbuf , self .conv_sharding )
182+ self .conv_buffers [layer ][i ] = cbuf
184183
185184 def clear (self ) -> None :
186185 for layer in range (self .num_linear_recurrent_layers ):
@@ -195,7 +194,6 @@ def tree_flatten(self):
195194 tuple (self .linear_recurrent_layer_ids ),
196195 self .size ,
197196 self .dp_size ,
198- self .total_slots ,
199197 self .num_heads ,
200198 self .head_dim ,
201199 self .num_k_heads ,
@@ -207,8 +205,6 @@ def tree_flatten(self):
207205 self .recurrent_partition_axis ,
208206 self .conv_partition_axis ,
209207 self .data_partition_axis ,
210- self .recurrent_sharding ,
211- self .conv_sharding ,
212208 )
213209 return children , aux
214210
@@ -218,7 +214,6 @@ def tree_unflatten(cls, aux_data, children):
218214 linear_recurrent_layer_ids_tup ,
219215 size ,
220216 dp_size ,
221- total_slots ,
222217 num_heads ,
223218 head_dim ,
224219 num_k_heads ,
@@ -230,8 +225,6 @@ def tree_unflatten(cls, aux_data, children):
230225 recurrent_partition_axis ,
231226 conv_partition_axis ,
232227 data_partition_axis ,
233- recurrent_sharding ,
234- conv_sharding ,
235228 ) = aux_data
236229 obj = cls .__new__ (cls )
237230 obj .linear_recurrent_layer_ids = list (linear_recurrent_layer_ids_tup )
@@ -241,7 +234,7 @@ def tree_unflatten(cls, aux_data, children):
241234 obj .num_linear_recurrent_layers = len (obj .linear_recurrent_layer_ids )
242235 obj .size = size
243236 obj .dp_size = dp_size
244- obj .total_slots = total_slots
237+ obj .total_slots = _ceil_to ( size + 1 , dp_size )
245238 obj .num_heads = num_heads
246239 obj .head_dim = head_dim
247240 obj .num_k_heads = num_k_heads
@@ -256,8 +249,10 @@ def tree_unflatten(cls, aux_data, children):
256249 obj .recurrent_partition_axis = recurrent_partition_axis
257250 obj .conv_partition_axis = conv_partition_axis
258251 obj .data_partition_axis = data_partition_axis
259- obj .recurrent_sharding = recurrent_sharding
260- obj .conv_sharding = conv_sharding
252+ obj .recurrent_sharding = NamedSharding (
253+ mesh , P (data_partition_axis , recurrent_partition_axis , None , None )
254+ )
255+ obj .conv_sharding = NamedSharding (mesh , P (data_partition_axis , conv_partition_axis , None ))
261256 new_recurrent , new_conv = children
262257 obj .recurrent_buffers = list (new_recurrent )
263258 obj .conv_buffers = [list (inner ) for inner in new_conv ]
0 commit comments