2020import logging
2121import os
2222import warnings
23+ from collections .abc import Callable
2324from dataclasses import asdict
24- from typing import Any , Optional
25+ from typing import Any , Generator , Optional
2526
2627import numpy as np
2728import psutil
@@ -577,6 +578,25 @@ def _build_model_optimizer(
577578
578579 return actor_module_fsdp , actor_optimizer , actor_lr_scheduler , actor_model_config
579580
581+ def update_weighs_by_checkpoint_engine (
582+ self ,
583+ weights : Generator [tuple [str , torch .Tensor ], None , None ],
584+ req_func : Callable [[list [tuple [str , str ]]], None ]
585+ ):
586+ named_tensors = {}
587+ for tensor_idx , (name , tensor ) in enumerate (weights ):
588+ if tensor_idx % self .world_size == self .rank :
589+ named_tensors [name ] = tensor
590+
591+ checkpoint_name = f"checkpoint_engine"
592+ self .parameter_server .register_checkpoint (checkpoint_name , named_tensors = named_tensors )
593+ named_tensors = {}
594+ dist .barrier ()
595+ self .parameter_server .gather_metas (checkpoint_name )
596+ self .parameter_server .update (checkpoint_name , req_func )
597+ self .parameter_server .unregister_checkpoint (checkpoint_name )
598+
599+
580600 def _build_rollout (self , trust_remote_code = False ):
581601 from torch .distributed .device_mesh import init_device_mesh
582602
@@ -588,10 +608,10 @@ def _build_rollout(self, trust_remote_code=False):
588608 # 2. build rollout device mesh
589609 infer_tp = self .config .rollout .tensor_model_parallel_size * self .config .rollout .data_parallel_size
590610 infer_pp = self .config .rollout .pipeline_model_parallel_size
591- infer_world_size = infer_tp * infer_pp
592- dp = self .world_size // infer_world_size
593- assert self .world_size % infer_world_size == 0 , (
594- f"rollout world_size: { self .world_size } is not divisible by infer_world_size: { infer_world_size } "
611+ self . infer_world_size = infer_tp * infer_pp
612+ dp = self .world_size // self . infer_world_size
613+ assert self .world_size % self . infer_world_size == 0 , (
614+ f"rollout world_size: { self .world_size } is not divisible by infer_world_size: { self . infer_world_size } "
595615 )
596616 rollout_device_mesh = init_device_mesh (
597617 device_name , mesh_shape = (dp , infer_tp , infer_pp ), mesh_dim_names = ["dp" , "infer_tp" , "infer_pp" ]
@@ -700,10 +720,14 @@ async def rollout_mode(self):
700720
701721 set_expandable_segments (False )
702722
723+ if self .config .rollout .enable_checkpoint_engine :
724+ device = "cpu"
725+ else :
726+ device = get_device_id () # used when fsdp2 set cpu_offload_policy
727+
703728 if peft_config is not None and self .base_sync_done :
704729 per_tensor_param = params .items () if isinstance (params , dict ) else params # Fixed: handle dict case
705730 else :
706- device = get_device_id () # used when fsdp2 set cpu_offload_policy
707731 per_tensor_param = (
708732 (name , param .to (device , non_blocking = True ).full_tensor () if isinstance (param , DTensor ) else param )
709733 for name , param in params .items ()
@@ -718,10 +742,18 @@ async def rollout_mode(self):
718742 (name , param .to (device , non_blocking = True ).full_tensor () if isinstance (param , DTensor ) else param )
719743 for name , param in base_model_params .items ()
720744 )
721- await self .rollout .update_weights (per_tensor_base_params , base_sync_done = False )
745+ if self .config .rollout .enable_checkpoint_engine :
746+ req_func = await self .rollout .checkpoint_engine_req_func (self .infer_world_size )
747+ self .update_weighs_by_checkpoint_engine (per_tensor_param , req_func )
748+ else :
749+ await self .rollout .update_weights (per_tensor_base_params , base_sync_done = False )
722750 del base_model_params , per_tensor_base_params
723-
724- await self .rollout .update_weights (per_tensor_param , peft_config = peft_config , base_sync_done = self .base_sync_done )
751+
752+ if self .config .rollout .enable_checkpoint_engine :
753+ req_func = await self .rollout .checkpoint_engine_req_func (self .infer_world_size )
754+ self .update_weighs_by_checkpoint_engine (per_tensor_param , req_func )
755+ else :
756+ await self .rollout .update_weights (per_tensor_param , peft_config = peft_config , base_sync_done = self .base_sync_done )
725757 log_gpu_memory_usage ("After update_weights" , logger = logger )
726758 del params , per_tensor_param
727759 aggressive_empty_cache (force_sync = True )
@@ -863,6 +895,11 @@ def init_model(self):
863895 checkpoint_config = checkpoint_contents ,
864896 )
865897
898+ if self .config .rollout .enable_checkpoint_engine :
899+ from checkpoint_engine .ps import ParameterServer
900+
901+ self .parameter_server = ParameterServer (auto_pg = False )
902+
866903 @register (dispatch_mode = make_nd_compute_dataproto_dispatch_fn (mesh_name = "actor" ))
867904 @DistProfiler .annotate (color = "red" , role = "actor_update" )
868905 def update_actor (self , data : DataProto ):
0 commit comments