@@ -194,6 +194,7 @@ def tree_flatten(self):
194194 tuple (self .linear_recurrent_layer_ids ),
195195 self .size ,
196196 self .dp_size ,
197+ self .total_slots ,
197198 self .num_heads ,
198199 self .head_dim ,
199200 self .num_k_heads ,
@@ -205,6 +206,8 @@ def tree_flatten(self):
205206 self .recurrent_partition_axis ,
206207 self .conv_partition_axis ,
207208 self .data_partition_axis ,
209+ self .recurrent_sharding ,
210+ self .conv_sharding ,
208211 )
209212 return children , aux
210213
@@ -214,6 +217,7 @@ def tree_unflatten(cls, aux_data, children):
214217 linear_recurrent_layer_ids_tup ,
215218 size ,
216219 dp_size ,
220+ total_slots ,
217221 num_heads ,
218222 head_dim ,
219223 num_k_heads ,
@@ -225,6 +229,8 @@ def tree_unflatten(cls, aux_data, children):
225229 recurrent_partition_axis ,
226230 conv_partition_axis ,
227231 data_partition_axis ,
232+ recurrent_sharding ,
233+ conv_sharding ,
228234 ) = aux_data
229235 obj = cls .__new__ (cls )
230236 obj .linear_recurrent_layer_ids = list (linear_recurrent_layer_ids_tup )
@@ -234,7 +240,7 @@ def tree_unflatten(cls, aux_data, children):
234240 obj .num_linear_recurrent_layers = len (obj .linear_recurrent_layer_ids )
235241 obj .size = size
236242 obj .dp_size = dp_size
237- obj .total_slots = _ceil_to ( size + 1 , dp_size )
243+ obj .total_slots = total_slots
238244 obj .num_heads = num_heads
239245 obj .head_dim = head_dim
240246 obj .num_k_heads = num_k_heads
@@ -249,10 +255,8 @@ def tree_unflatten(cls, aux_data, children):
249255 obj .recurrent_partition_axis = recurrent_partition_axis
250256 obj .conv_partition_axis = conv_partition_axis
251257 obj .data_partition_axis = data_partition_axis
252- obj .recurrent_sharding = NamedSharding (
253- mesh , P (data_partition_axis , recurrent_partition_axis , None , None )
254- )
255- obj .conv_sharding = NamedSharding (mesh , P (data_partition_axis , conv_partition_axis , None ))
258+ obj .recurrent_sharding = recurrent_sharding
259+ obj .conv_sharding = conv_sharding
256260 new_recurrent , new_conv = children
257261 obj .recurrent_buffers = list (new_recurrent )
258262 obj .conv_buffers = [list (inner ) for inner in new_conv ]
0 commit comments