Skip to content

Commit d9ce8b1

Browse files
committed
revert: restore total_slots/sharding in pytree aux_data
1 parent b7cc7fb commit d9ce8b1

1 file changed

Lines changed: 9 additions & 5 deletions

File tree

python/sgl_jax/srt/mem_cache/recurrent_state_pool.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def tree_flatten(self):
194194
tuple(self.linear_recurrent_layer_ids),
195195
self.size,
196196
self.dp_size,
197+
self.total_slots,
197198
self.num_heads,
198199
self.head_dim,
199200
self.num_k_heads,
@@ -205,6 +206,8 @@ def tree_flatten(self):
205206
self.recurrent_partition_axis,
206207
self.conv_partition_axis,
207208
self.data_partition_axis,
209+
self.recurrent_sharding,
210+
self.conv_sharding,
208211
)
209212
return children, aux
210213

@@ -214,6 +217,7 @@ def tree_unflatten(cls, aux_data, children):
214217
linear_recurrent_layer_ids_tup,
215218
size,
216219
dp_size,
220+
total_slots,
217221
num_heads,
218222
head_dim,
219223
num_k_heads,
@@ -225,6 +229,8 @@ def tree_unflatten(cls, aux_data, children):
225229
recurrent_partition_axis,
226230
conv_partition_axis,
227231
data_partition_axis,
232+
recurrent_sharding,
233+
conv_sharding,
228234
) = aux_data
229235
obj = cls.__new__(cls)
230236
obj.linear_recurrent_layer_ids = list(linear_recurrent_layer_ids_tup)
@@ -234,7 +240,7 @@ def tree_unflatten(cls, aux_data, children):
234240
obj.num_linear_recurrent_layers = len(obj.linear_recurrent_layer_ids)
235241
obj.size = size
236242
obj.dp_size = dp_size
237-
obj.total_slots = _ceil_to(size + 1, dp_size)
243+
obj.total_slots = total_slots
238244
obj.num_heads = num_heads
239245
obj.head_dim = head_dim
240246
obj.num_k_heads = num_k_heads
@@ -249,10 +255,8 @@ def tree_unflatten(cls, aux_data, children):
249255
obj.recurrent_partition_axis = recurrent_partition_axis
250256
obj.conv_partition_axis = conv_partition_axis
251257
obj.data_partition_axis = data_partition_axis
252-
obj.recurrent_sharding = NamedSharding(
253-
mesh, P(data_partition_axis, recurrent_partition_axis, None, None)
254-
)
255-
obj.conv_sharding = NamedSharding(mesh, P(data_partition_axis, conv_partition_axis, None))
258+
obj.recurrent_sharding = recurrent_sharding
259+
obj.conv_sharding = conv_sharding
256260
new_recurrent, new_conv = children
257261
obj.recurrent_buffers = list(new_recurrent)
258262
obj.conv_buffers = [list(inner) for inner in new_conv]

0 commit comments

Comments
 (0)