Skip to content

Commit b7cc7fb

Browse files
committed
refactor: simplify replace_buffer and pytree in RecurrentStatePool
1 parent b19d8f0 commit b7cc7fb

1 file changed

Lines changed: 12 additions & 17 deletions

File tree

python/sgl_jax/srt/mem_cache/recurrent_state_pool.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -165,22 +165,21 @@ def replace_buffer(self, buffers) -> None:
165165

166166
assert len(new_recurrent) == self.num_linear_recurrent_layers
167167
assert len(new_conv) == self.num_linear_recurrent_layers
168-
for layer in range(self.num_linear_recurrent_layers):
169-
assert len(new_conv[layer]) == len(self.conv_buffers[layer])
170168

169+
# tp_size==1 sharding fix: see MHATokenToKVPool.replace_buffer
171170
tp_degenerate = self.mesh.shape.get("tensor", 1) == 1
172171
for layer in range(self.num_linear_recurrent_layers):
173172
buf = new_recurrent[layer]
174-
if tp_degenerate and hasattr(self, "recurrent_sharding"):
173+
if tp_degenerate:
175174
buf = jax.device_put(buf, self.recurrent_sharding)
176175
self.recurrent_buffers[layer] = buf
177176

178-
for layer in range(self.num_linear_recurrent_layers):
177+
assert len(new_conv[layer]) == len(self.conv_buffers[layer])
179178
for i in range(len(new_conv[layer])):
180-
buf = new_conv[layer][i]
181-
if tp_degenerate and hasattr(self, "conv_sharding"):
182-
buf = jax.device_put(buf, self.conv_sharding)
183-
self.conv_buffers[layer][i] = buf
179+
cbuf = new_conv[layer][i]
180+
if tp_degenerate:
181+
cbuf = jax.device_put(cbuf, self.conv_sharding)
182+
self.conv_buffers[layer][i] = cbuf
184183

185184
def clear(self) -> None:
186185
for layer in range(self.num_linear_recurrent_layers):
@@ -195,7 +194,6 @@ def tree_flatten(self):
195194
tuple(self.linear_recurrent_layer_ids),
196195
self.size,
197196
self.dp_size,
198-
self.total_slots,
199197
self.num_heads,
200198
self.head_dim,
201199
self.num_k_heads,
@@ -207,8 +205,6 @@ def tree_flatten(self):
207205
self.recurrent_partition_axis,
208206
self.conv_partition_axis,
209207
self.data_partition_axis,
210-
self.recurrent_sharding,
211-
self.conv_sharding,
212208
)
213209
return children, aux
214210

@@ -218,7 +214,6 @@ def tree_unflatten(cls, aux_data, children):
218214
linear_recurrent_layer_ids_tup,
219215
size,
220216
dp_size,
221-
total_slots,
222217
num_heads,
223218
head_dim,
224219
num_k_heads,
@@ -230,8 +225,6 @@ def tree_unflatten(cls, aux_data, children):
230225
recurrent_partition_axis,
231226
conv_partition_axis,
232227
data_partition_axis,
233-
recurrent_sharding,
234-
conv_sharding,
235228
) = aux_data
236229
obj = cls.__new__(cls)
237230
obj.linear_recurrent_layer_ids = list(linear_recurrent_layer_ids_tup)
@@ -241,7 +234,7 @@ def tree_unflatten(cls, aux_data, children):
241234
obj.num_linear_recurrent_layers = len(obj.linear_recurrent_layer_ids)
242235
obj.size = size
243236
obj.dp_size = dp_size
244-
obj.total_slots = total_slots
237+
obj.total_slots = _ceil_to(size + 1, dp_size)
245238
obj.num_heads = num_heads
246239
obj.head_dim = head_dim
247240
obj.num_k_heads = num_k_heads
@@ -256,8 +249,10 @@ def tree_unflatten(cls, aux_data, children):
256249
obj.recurrent_partition_axis = recurrent_partition_axis
257250
obj.conv_partition_axis = conv_partition_axis
258251
obj.data_partition_axis = data_partition_axis
259-
obj.recurrent_sharding = recurrent_sharding
260-
obj.conv_sharding = conv_sharding
252+
obj.recurrent_sharding = NamedSharding(
253+
mesh, P(data_partition_axis, recurrent_partition_axis, None, None)
254+
)
255+
obj.conv_sharding = NamedSharding(mesh, P(data_partition_axis, conv_partition_axis, None))
261256
new_recurrent, new_conv = children
262257
obj.recurrent_buffers = list(new_recurrent)
263258
obj.conv_buffers = [list(inner) for inner in new_conv]

0 commit comments

Comments
 (0)