Skip to content

Commit 75140f0

Browse files
wang2yn84The tunix Authors
authored andcommitted
[Tunix] Clear and reinitialize vLLM KV cache when loading weights.
PiperOrigin-RevId: 875844163
1 parent ea5a0e6 commit 75140f0

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tunix/generate/vllm_sampler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ def update_params(
159159
filter_types: Optional[Tuple[Any, ...]] = None,
160160
):
161161
del filter_types
162+
if self.llm is not None:
163+
self.llm.reset_prefix_cache()
164+
self.llm.collective_rpc("delete_kv_cache") # will free hbm
165+
elif self._driver is not None:
166+
self._driver.llm_engine.reset_prefix_cache()
167+
self._driver.llm_engine.collective_rpc("delete_kv_cache")
162168

163169
if self.to_hf_key_mappings:
164170
# Mapped Weight Sync (e.g. Vanilla -> vLLM)
@@ -192,6 +198,11 @@ def update_params(
192198
reshard_fn=reshard.reshard_pytree,
193199
)
194200

201+
if self.llm is not None:
202+
self.llm.collective_rpc("reinitialize_kv_cache")
203+
elif self._driver is not None:
204+
self._driver.llm_engine.collective_rpc("reinitialize_kv_cache")
205+
195206
def load_checkpoint(self, path_or_weights: str | jaxtyping.PyTree):
196207
# TODO(b/434741253): Consider support orbax checkpoint loading
197208
if isinstance(path_or_weights, jaxtyping.PyTree):

0 commit comments

Comments
 (0)