@@ -284,13 +284,19 @@ def update_flattened_params(self, metadatas: list[Dict]) -> None:
284284 named_params = FlattenedTensorBucket (metadata = metadatas , flattened_tensor = flatten_tensor ).reconstruct_tensors ()
285285
286286 patch_vllm_moe_model_weight_loader (self .model_runner .model )
287- # Re-run process_weights_after_loading on FusedMoE layers so the
288- # kernel-format layout is rebuilt after the in-place reload
289- # (workaround for vLLM issue #42821).
290- try :
291- self .model_runner .model .load_weights (weights = list (named_params .items ()))
292- finally :
293- finish_vllm_weight_reload (self .model_runner .model )
287+ self .model_runner .model .load_weights (weights = list (named_params .items ()))
288+
289+ def process_weights_after_loading (self ) -> None :
290+ """Re-run process_weights_after_loading once after ALL weight
291+ buckets have been loaded, so the kernel-format layout is rebuilt
292+ on complete weights rather than partial ones.
293+
294+ Uses vLLM's built-in ``process_weights_after_loading`` when
295+ *model_config* and *target_device* are available (same as verl);
296+ falls back to FusedMoE-only path otherwise.
297+ """
298+ model_config = self .model_runner .model_config
299+ finish_vllm_weight_reload (self .model_runner .model , model_config = model_config , target_device = self .device )
294300
295301 def close_communicator (self ) -> None :
296302 """
@@ -512,12 +518,13 @@ def _broadcast_obj(obj):
512518 if metadata .get ('is_last' ):
513519 break
514520
515- # Re-run process_weights_after_loading on FusedMoE layers so the
516- # kernel-format layout is rebuilt after the in-place reload
517- # (workaround for vLLM issue #42821). Skipped for LoRA sync
518- # because the adapter path doesn't call ``load_weights``.
521+ # Re-run process_weights_after_loading so the kernel-format
522+ # layout is rebuilt after the in-place reload (vLLM issue
523+ # #42821). Skipped for LoRA sync because the adapter path
524+ # doesn't call ``load_weights``.
519525 if not is_lora_sync :
520- finish_vllm_weight_reload (self .model_runner .model )
526+ model_config = self .model_runner .model_config
527+ finish_vllm_weight_reload (self .model_runner .model , model_config = model_config , target_device = self .device )
521528
522529 if is_lora_sync and all_lora_weights :
523530 req_kw = dict (
@@ -698,6 +705,7 @@ def _register_rl_rollout_app(self):
698705 self .app .post ('/update_adapter_flattened_param/' )(self .update_adapter_flattened_param )
699706 self .app .post ('/update_adapter_param/' )(self .update_adapter_param )
700707 self .app .post ('/update_flattened_params/' )(self .update_flattened_params )
708+ self .app .post ('/process_weights_after_loading/' )(self .process_weights_after_loading )
701709 self .app .post ('/reset_prefix_cache/' )(self .reset_prefix_cache )
702710 self .app .post ('/reset_encoder_cache/' )(self .reset_encoder_cache )
703711 self .app .post ('/reset_mm_cache/' )(self .reset_mm_cache )
@@ -926,6 +934,18 @@ async def update_flattened_params(self, request: UpdateFlattenedParamsRequest):
926934
927935 return {'message' : 'Request received, updating flattened parameters' }
928936
937+ async def process_weights_after_loading (self ):
938+ """
939+ Triggers process_weights_after_loading on all workers.
940+ """
941+ kwargs = {'method' : 'process_weights_after_loading' , 'args' : ()}
942+ for connection in self .connections :
943+ connection .send ({'type' : 'call' , 'method' : 'collective_rpc' , 'kwargs' : kwargs })
944+ # Wait for all workers to complete before returning
945+ loop = asyncio .get_running_loop ()
946+ await asyncio .gather (* (loop .run_in_executor (None , connection .recv ) for connection in self .connections ))
947+ return {'message' : 'Weights processed after loading' }
948+
929949 async def reset_prefix_cache (self ):
930950 """
931951 Resets the prefix cache for the model.
0 commit comments