Skip to content

Commit f20102a

Browse files
committed
add support to DeepSpeed Zero3
1 parent e902430 commit f20102a

1 file changed

Lines changed: 56 additions & 16 deletions

File tree

src/dataflex/train/selector/less_selector.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,32 +55,71 @@ def __init__(self,
5555

5656
def _get_number_of_params(self, model) -> int:
5757
"""计算模型中需要梯度的参数数量。"""
58-
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
58+
"""计算模型中需要梯度的参数数量(兼容 DeepSpeed ZeRO-3 分区参数)。"""
59+
num_params = 0
60+
for p in model.parameters():
61+
if p.requires_grad:
62+
# DeepSpeed ZeRO-3 下参数被分区,p.numel() 只返回分区后的大小,
63+
# 需要用 ds_numel 获取完整参数大小,以匹配 safe_get_full_grad 返回的梯度维度。
64+
if hasattr(p, 'ds_numel'):
65+
num_params += p.ds_numel
66+
else:
67+
num_params += p.numel()
5968
if self.accelerator.is_main_process:
6069
logger.info(f"Total number of parameters that require gradients: {num_params}")
6170
return num_params
6271

63-
def _prepare_optimizer_state(self, model, optimizer_state: Dict) -> (torch.Tensor, torch.Tensor):
64-
"""从优化器状态字典中准备 Adam 的一阶和二阶矩估计。"""
72+
def _prepare_optimizer_state(self, model, optimizer_state: Optional[Dict] = None) -> (torch.Tensor, torch.Tensor):
73+
"""从优化器状态中准备 Adam 的一阶和二阶矩估计(兼容 DeepSpeed ZeRO-3)。"""
6574
avg_list, avg_sq_list = [], []
66-
for param in model.parameters():
67-
if param.requires_grad:
68-
avg_list.append(optimizer_state[param]["exp_avg"].view(-1))
69-
avg_sq_list.append(optimizer_state[param]["exp_avg_sq"].view(-1))
75+
76+
if self.accelerator.state.deepspeed_plugin is not None:
77+
# DeepSpeed 模式:使用 safe_get_full_optimizer_state 获取完整优化器状态
78+
from deepspeed.utils import safe_get_full_optimizer_state
79+
for param in model.parameters():
80+
if param.requires_grad:
81+
exp_avg = safe_get_full_optimizer_state(param, "exp_avg")
82+
exp_avg_sq = safe_get_full_optimizer_state(param, "exp_avg_sq")
83+
if exp_avg is not None and exp_avg_sq is not None:
84+
avg_list.append(exp_avg.view(-1))
85+
avg_sq_list.append(exp_avg_sq.view(-1))
86+
else:
87+
# 非 DeepSpeed 模式:从传入的 optimizer_state 字典中获取
88+
if optimizer_state is None:
89+
raise ValueError("optimizer_state must be provided for non-DeepSpeed 'adam' gradient type.")
90+
for param in model.parameters():
91+
if param.requires_grad:
92+
avg_list.append(optimizer_state[param]["exp_avg"].view(-1))
93+
avg_sq_list.append(optimizer_state[param]["exp_avg_sq"].view(-1))
7094

7195
avg = torch.cat(avg_list).to(self.device)
7296
avg_sq = torch.cat(avg_sq_list).to(self.device)
7397
return avg, avg_sq
7498

7599
def _obtain_gradients(self, model, batch, gradient_type, m: Optional[torch.Tensor] = None, v: Optional[torch.Tensor] = None) -> torch.Tensor:
76100
"""根据指定的类型计算单个样本的梯度向量。"""
77-
with self.accelerator.no_sync(model):
101+
# 必须先对当前 batch 做 forward + backward,才能产生对应的梯度
102+
if self.accelerator.state.deepspeed_plugin is not None:
103+
# DeepSpeed 模式:直接调用 model forward/backward
78104
loss = model(**batch).loss
79-
self.accelerator.backward(loss)
80-
81-
vectorized_grads = torch.cat(
82-
[p.grad.view(-1) for p in model.parameters() if p.grad is not None]
83-
)
105+
model.backward(loss)
106+
# 使用 safe_get_full_grad 获取完整梯度(ZeRO 分区下需要 gather)
107+
from deepspeed.utils import safe_get_full_grad
108+
grads = []
109+
for name, p in model.named_parameters():
110+
g = safe_get_full_grad(p)
111+
if g is not None:
112+
grads.append(g.contiguous().view(-1))
113+
vectorized_grads = torch.cat(grads) if grads else None
114+
115+
else:
116+
# 非 DeepSpeed 模式
117+
with self.accelerator.no_sync(model):
118+
loss = model(**batch).loss
119+
self.accelerator.backward(loss)
120+
vectorized_grads = torch.cat(
121+
[p.grad.view(-1) for p in model.parameters() if p.grad is not None]
122+
)
84123

85124
if gradient_type == "adam":
86125
if m is None or v is None:
@@ -154,8 +193,9 @@ def _collect_and_save_projected_gradients(self, model, save_dir, dataset_to_use,
154193
# 2) 准备 Adam 状态 (如果需要)
155194
m, v = None, None
156195
if gradient_type == "adam":
157-
if optimizer_state is None:
158-
raise ValueError("optimizer_state must be provided for 'adam' gradient type.")
196+
# DeepSpeed 模式下可通过 safe_get_full_optimizer_state 直接获取,无需传入 optimizer_state
197+
if self.accelerator.state.deepspeed_plugin is None and optimizer_state is None:
198+
raise ValueError("optimizer_state must be provided for non-DeepSpeed 'adam' gradient type.")
159199
m, v = self._prepare_optimizer_state(model, optimizer_state)
160200

161201
# 3) 构造 DataLoader
@@ -353,4 +393,4 @@ def select(self, model, step_id: int, num_samples: int, **kwargs) -> List[int]:
353393
dist.broadcast_object_list(obj_list, src=0)
354394
selected_indices = obj_list[0]
355395

356-
return selected_indices
396+
return selected_indices

0 commit comments

Comments
 (0)