@@ -55,32 +55,71 @@ def __init__(self,
5555
5656 def _get_number_of_params (self , model ) -> int :
5757 """计算模型中需要梯度的参数数量。"""
58- num_params = sum (p .numel () for p in model .parameters () if p .requires_grad )
58+ """计算模型中需要梯度的参数数量(兼容 DeepSpeed ZeRO-3 分区参数)。"""
59+ num_params = 0
60+ for p in model .parameters ():
61+ if p .requires_grad :
62+ # DeepSpeed ZeRO-3 下参数被分区,p.numel() 只返回分区后的大小,
63+ # 需要用 ds_numel 获取完整参数大小,以匹配 safe_get_full_grad 返回的梯度维度。
64+ if hasattr (p , 'ds_numel' ):
65+ num_params += p .ds_numel
66+ else :
67+ num_params += p .numel ()
5968 if self .accelerator .is_main_process :
6069 logger .info (f"Total number of parameters that require gradients: { num_params } " )
6170 return num_params
6271
63- def _prepare_optimizer_state (self , model , optimizer_state : Dict ) -> (torch .Tensor , torch .Tensor ):
64- """从优化器状态字典中准备 Adam 的一阶和二阶矩估计。"""
72+ def _prepare_optimizer_state (self , model , optimizer_state : Optional [ Dict ] = None ) -> (torch .Tensor , torch .Tensor ):
73+ """从优化器状态中准备 Adam 的一阶和二阶矩估计(兼容 DeepSpeed ZeRO-3) 。"""
6574 avg_list , avg_sq_list = [], []
66- for param in model .parameters ():
67- if param .requires_grad :
68- avg_list .append (optimizer_state [param ]["exp_avg" ].view (- 1 ))
69- avg_sq_list .append (optimizer_state [param ]["exp_avg_sq" ].view (- 1 ))
75+
76+ if self .accelerator .state .deepspeed_plugin is not None :
77+ # DeepSpeed 模式:使用 safe_get_full_optimizer_state 获取完整优化器状态
78+ from deepspeed .utils import safe_get_full_optimizer_state
79+ for param in model .parameters ():
80+ if param .requires_grad :
81+ exp_avg = safe_get_full_optimizer_state (param , "exp_avg" )
82+ exp_avg_sq = safe_get_full_optimizer_state (param , "exp_avg_sq" )
83+ if exp_avg is not None and exp_avg_sq is not None :
84+ avg_list .append (exp_avg .view (- 1 ))
85+ avg_sq_list .append (exp_avg_sq .view (- 1 ))
86+ else :
87+ # 非 DeepSpeed 模式:从传入的 optimizer_state 字典中获取
88+ if optimizer_state is None :
89+ raise ValueError ("optimizer_state must be provided for non-DeepSpeed 'adam' gradient type." )
90+ for param in model .parameters ():
91+ if param .requires_grad :
92+ avg_list .append (optimizer_state [param ]["exp_avg" ].view (- 1 ))
93+ avg_sq_list .append (optimizer_state [param ]["exp_avg_sq" ].view (- 1 ))
7094
7195 avg = torch .cat (avg_list ).to (self .device )
7296 avg_sq = torch .cat (avg_sq_list ).to (self .device )
7397 return avg , avg_sq
7498
7599 def _obtain_gradients (self , model , batch , gradient_type , m : Optional [torch .Tensor ] = None , v : Optional [torch .Tensor ] = None ) -> torch .Tensor :
76100 """根据指定的类型计算单个样本的梯度向量。"""
77- with self .accelerator .no_sync (model ):
101+ # 必须先对当前 batch 做 forward + backward,才能产生对应的梯度
102+ if self .accelerator .state .deepspeed_plugin is not None :
103+ # DeepSpeed 模式:直接调用 model forward/backward
78104 loss = model (** batch ).loss
79- self .accelerator .backward (loss )
80-
81- vectorized_grads = torch .cat (
82- [p .grad .view (- 1 ) for p in model .parameters () if p .grad is not None ]
83- )
105+ model .backward (loss )
106+ # 使用 safe_get_full_grad 获取完整梯度(ZeRO 分区下需要 gather)
107+ from deepspeed .utils import safe_get_full_grad
108+ grads = []
109+ for name , p in model .named_parameters ():
110+ g = safe_get_full_grad (p )
111+ if g is not None :
112+ grads .append (g .contiguous ().view (- 1 ))
113+ vectorized_grads = torch .cat (grads ) if grads else None
114+
115+ else :
116+ # 非 DeepSpeed 模式
117+ with self .accelerator .no_sync (model ):
118+ loss = model (** batch ).loss
119+ self .accelerator .backward (loss )
120+ vectorized_grads = torch .cat (
121+ [p .grad .view (- 1 ) for p in model .parameters () if p .grad is not None ]
122+ )
84123
85124 if gradient_type == "adam" :
86125 if m is None or v is None :
@@ -154,8 +193,9 @@ def _collect_and_save_projected_gradients(self, model, save_dir, dataset_to_use,
154193 # 2) 准备 Adam 状态 (如果需要)
155194 m , v = None , None
156195 if gradient_type == "adam" :
157- if optimizer_state is None :
158- raise ValueError ("optimizer_state must be provided for 'adam' gradient type." )
196+ # DeepSpeed 模式下可通过 safe_get_full_optimizer_state 直接获取,无需传入 optimizer_state
197+ if self .accelerator .state .deepspeed_plugin is None and optimizer_state is None :
198+ raise ValueError ("optimizer_state must be provided for non-DeepSpeed 'adam' gradient type." )
159199 m , v = self ._prepare_optimizer_state (model , optimizer_state )
160200
161201 # 3) 构造 DataLoader
@@ -353,4 +393,4 @@ def select(self, model, step_id: int, num_samples: int, **kwargs) -> List[int]:
353393 dist .broadcast_object_list (obj_list , src = 0 )
354394 selected_indices = obj_list [0 ]
355395
356- return selected_indices
396+ return selected_indices
0 commit comments