3333 FSDP2_SUPPORTED = False
3434
3535try :
36+ import torch .distributed .checkpoint as dcp
3637 from torch .distributed .checkpoint .state_dict import (
3738 StateDictOptions ,
39+ get_model_state_dict ,
3840 set_model_state_dict ,
3941 )
4042
@@ -163,8 +165,27 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
163165 )
164166 fsdp_mode = gpc .config .parallel .fsdp .get ("mode" , "v1" )
165167 fsdp_init_method = gpc .config .parallel .fsdp .get ("init_method" , "cuda" )
168+
169+ if gpc .is_using_parallel_mode (ParallelMode .EXPERT ):
170+ assert gpc .get_world_size (ParallelMode .EXPERT_DATA ) * gpc .get_world_size (ParallelMode .EXPERT ) == gpc .get_world_size (ParallelMode .GLOBAL )
166171
167172 if fsdp_mode == "v1" :
173+ ignored_mod = []
174+ if gpc .is_using_parallel_mode (ParallelMode .EXPERT ):
175+ for layer_id , layer in enumerate (model .model .layers ):
176+ if layer_id >= gpc .config .model .first_k_dense_replace :
177+ layer .feed_forward .moe_layer .experts = FSDP (
178+ layer .feed_forward .moe_layer .experts ,
179+ process_group = gpc .get_group (ParallelMode .EXPERT_DATA ),
180+ sharding_strategy = ShardingStrategy .FULL_SHARD ,
181+ sync_module_states = fsdp_init_method != "cuda" , # sync model paramters
182+ forward_prefetch = True ,
183+ backward_prefetch = BackwardPrefetch .BACKWARD_PRE ,
184+ limit_all_gathers = True ,
185+ use_orig_params = True ,
186+ device_id = None if fsdp_init_method == "cuda" else get_current_device (), # needed for sync_module_states
187+ )
188+ ignored_mod .append (layer .feed_forward .moe_layer .experts )
168189 model = FSDP (
169190 module = model ,
170191 process_group = gpc .get_group (ParallelMode .GLOBAL ),
@@ -176,6 +197,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
176197 limit_all_gathers = True ,
177198 use_orig_params = True ,
178199 device_id = None if fsdp_init_method == "cuda" else get_current_device (), # needed for sync_module_states
200+ ignored_modules = ignored_mod ,
179201 )
180202 # For FSDP v1, to get ckpt resuming work normally, we do dummy forward.
181203 # This hack is needed due to FSDP v1 lazy initialization in model construction.
@@ -196,7 +218,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
196218 else :
197219 raise ValueError (f"Unsupported FSDP mode: { fsdp_mode } " )
198220
199- if is_using_hf () and not gpc .config .ckpt .get ("auto_resume" , False ):
221+ if not gpc .config .ckpt .get ("auto_resume" , False ):
200222 load_ckpt_info = gpc .config .ckpt .load_ckpt_info
201223 load_ckpt_path = load_ckpt_info .get ("path" , None )
202224 load_ckpt_content = load_ckpt_info .get ("content" , [])
@@ -205,19 +227,25 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
205227 "model" ,
206228 ), "If auto_resume=False and checkpoint path is given, only model can be loaded"
207229 if DCP_SUPPORTED :
208- hf = gpc .config .hf
209- mod = LazyObject (hf .mod , hf .mod_cls )
210- mod = mod .build ()
211- state_dict = mod .from_pretrained (
212- pretrained_model_name_or_path = load_ckpt_path , use_safetensors = True
213- ).state_dict ()
214- state_dict = {f"model.{ key } " : state_dict [key ].clone ().detach () for key in state_dict }
215- set_model_state_dict (
216- model = model , model_state_dict = state_dict , options = StateDictOptions (full_state_dict = True )
217- )
230+ if is_using_hf ():
231+ hf = gpc .config .hf
232+ mod = LazyObject (hf .mod , hf .mod_cls )
233+ mod = mod .build ()
234+ state_dict = mod .from_pretrained (
235+ pretrained_model_name_or_path = load_ckpt_path , use_safetensors = True
236+ ).state_dict ()
237+ state_dict = {f"model.{ key } " : state_dict [key ].clone ().detach () for key in state_dict }
238+ set_model_state_dict (
239+ model = model , model_state_dict = state_dict , options = StateDictOptions (full_state_dict = True )
240+ )
241+ else :
242+ state_dict = get_model_state_dict (model = model )
243+ state_dict = {key : state_dict [key ].clone ().detach () for key in state_dict }
244+ dcp .load (state_dict = state_dict , checkpoint_id = load_ckpt_path )
245+ set_model_state_dict (model = model , model_state_dict = state_dict )
218246 del state_dict
219247 internlm_accelerator .empty_cache ()
220248 else :
221249 raise RuntimeError ("DCP is not supported in this version of PyTorch." )
222250
223- return model
251+ return model
0 commit comments