File tree Expand file tree Collapse file tree 1 file changed +11
-0
lines changed
Expand file tree Collapse file tree 1 file changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -159,6 +159,12 @@ def update_params(
159159 filter_types : Optional [Tuple [Any , ...]] = None ,
160160 ):
161161 del filter_types
162+ if self .llm is not None :
163+ self .llm .reset_prefix_cache ()
164+ self .llm .collective_rpc ("delete_kv_cache" ) # will free hbm
165+ elif self ._driver is not None :
166+ self ._driver .llm_engine .reset_prefix_cache ()
167+ self ._driver .llm_engine .collective_rpc ("delete_kv_cache" )
162168
163169 if self .to_hf_key_mappings :
164170 # Mapped Weight Sync (e.g. Vanilla -> vLLM)
@@ -192,6 +198,11 @@ def update_params(
192198 reshard_fn = reshard .reshard_pytree ,
193199 )
194200
201+ if self .llm is not None :
202+ self .llm .collective_rpc ("reinitialize_kv_cache" )
203+ elif self ._driver is not None :
204+ self ._driver .llm_engine .collective_rpc ("reinitialize_kv_cache" )
205+
195206 def load_checkpoint (self , path_or_weights : str | jaxtyping .PyTree ):
196207 # TODO(b/434741253): Consider support orbax checkpoint loading
197208 if isinstance (path_or_weights , jaxtyping .PyTree ):
You can’t perform that action at this time.
0 commit comments