Skip to content

Commit bf6b980

Browse files
authored
ZeRO3 handling frozen weights] (#2653)
1 parent 35575bc commit bf6b980

File tree

2 files changed

+85
-11
lines changed

2 files changed

+85
-11
lines changed

deepspeed/runtime/zero/stage3.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,15 @@ def __init__(self,
252252
self.sub_group_size = sub_group_size
253253

254254
self.sub_group_to_group_id = {}
255-
see_memory_usage("Before creating fp16 partitions", force=False)
256-
self._create_fp16_partitions_with_defragmentation()
255+
256+
# Trainable parameters
257+
self.trainable_param_groups = self._get_trainable_parameter_groups()
258+
259+
see_memory_usage("Before creating fp16 partitions", force=True)
260+
self._create_fp16_partitions_with_defragmentation(self.trainable_param_groups)
257261
num_fp16_subgroups = len(self.fp16_partitioned_groups_flat)
258262
see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}",
259-
force=False)
263+
force=True)
260264

261265
# Optimizer tensor swapping
262266
if self.swap_optimizer:
@@ -350,19 +354,28 @@ def __init__(self,
350354
def destroy(self):
351355
self.parameter_offload.destroy()
352356

357+
def _get_trainable_parameter_groups(self):
358+
param_groups = []
359+
for param_group in self.optimizer.param_groups:
360+
trainable_params = {
361+
"params": [p for p in param_group["params"] if p.requires_grad]
362+
}
363+
param_groups.append(trainable_params)
364+
return param_groups
365+
353366
def _setup_for_real_optimizer(self):
354-
see_memory_usage("Before creating fp32 partitions", force=False)
367+
see_memory_usage("Before creating fp32 partitions", force=True)
355368
self._create_fp32_partitions()
356-
see_memory_usage("After creating fp32 partitions", force=False)
369+
see_memory_usage("After creating fp32 partitions", force=True)
357370
dist.barrier()
358371

359372
# To support pipelined optimizer swapping
360373
self._create_next_swappable_fp32_groups()
361374

362-
see_memory_usage("Before initializing optimizer states", force=False)
375+
see_memory_usage("Before initializing optimizer states", force=True)
363376

364377
self.initialize_optimizer_states()
365-
see_memory_usage("After initializing optimizer states", force=False)
378+
see_memory_usage("After initializing optimizer states", force=True)
366379
dist.barrier()
367380

368381
if dist.get_rank() == 0:
@@ -523,7 +536,7 @@ def _create_param_groups_fp16_flat_cpu_memory(self):
523536

524537
aggregate_params_count = 0
525538

526-
for j, param_group in enumerate(self.optimizer.param_groups):
539+
for j, param_group in enumerate(self.trainable_param_groups):
527540
params_in_group = sum([p.partition_numel() for p in param_group['params']])
528541

529542
flat_buffer_size = params_in_group
@@ -552,11 +565,12 @@ def _create_param_groups_fp16_flat_cpu_memory(self):
552565
torch.empty(1,
553566
dtype=self.dtype))
554567

555-
def _create_fp16_partitions_with_defragmentation(self):
568+
def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
556569
dist.barrier()
570+
557571
param_groups: List[List[Parameter]] = tuple(
558572
self._create_fp16_sub_groups(param_group["params"])
559-
for param_group in self.optimizer.param_groups)
573+
for param_group in fp16_param_groups)
560574

561575
# bookkeeping related to param groups
562576
for param_group_idx, param_group in enumerate(param_groups):
@@ -884,7 +898,6 @@ def initialize_optimizer_states(self):
884898
dtype=gradient_dtype,
885899
device=self.device)
886900

887-
timers = self.timers
888901
timer_names = set()
889902

890903
if self.swap_optimizer:
@@ -2122,6 +2135,7 @@ def _get_param_groups(self):
21222135

21232136
def _set_param_groups(self, value):
21242137
self.optimizer.param_groups = value
2138+
self.trainable_param_groups = self._get_trainable_parameter_groups()
21252139

21262140
param_groups = property(_get_param_groups, _set_param_groups)
21272141

tests/unit/runtime/zero/test_zero.py

+60
Original file line numberDiff line numberDiff line change
@@ -1313,3 +1313,63 @@ def test(self, zero_stage):
13131313
state = optimizer.optimizer.state[param]
13141314
step_counts.append(state['step'])
13151315
assert all(step == step_counts[0] for step in step_counts)
1316+
1317+
1318+
class TestZeroFrozenWeights(DistributedTest):
1319+
world_size = 1
1320+
1321+
def test(self):
1322+
config_dict = {
1323+
"train_batch_size": 4,
1324+
"steps_per_print": 1,
1325+
"optimizer": {
1326+
"type": "Adam",
1327+
"params": {
1328+
"lr": 1e-4
1329+
}
1330+
},
1331+
"fp16": {
1332+
"enabled": True
1333+
},
1334+
"zero_optimization": {
1335+
"stage": 3
1336+
}
1337+
}
1338+
hidden_dim = 10
1339+
1340+
class MyModel(torch.nn.Module):
1341+
def __init__(self, hidden_dim):
1342+
super(MyModel, self).__init__()
1343+
self.l1 = torch.nn.Linear(hidden_dim, hidden_dim)
1344+
self.l2 = torch.nn.Linear(hidden_dim, hidden_dim)
1345+
self.act = torch.nn.ReLU()
1346+
self.cel = torch.nn.CrossEntropyLoss()
1347+
1348+
# freeze one fc
1349+
self.l2.weight.requires_grad = False
1350+
self.l2.bias.requires_grad = False
1351+
1352+
def forward(self, x, y):
1353+
x = self.l1(x)
1354+
x = self.act(x)
1355+
x = self.l2(x)
1356+
loss = self.cel(x, y)
1357+
val = (x, loss)
1358+
return val
1359+
1360+
with deepspeed.zero.Init(config_dict_or_path=config_dict):
1361+
model = MyModel(hidden_dim)
1362+
1363+
model, _, _, _ = deepspeed.initialize(model=model,
1364+
model_parameters=model.parameters(),
1365+
config=config_dict)
1366+
data_loader = random_dataloader(model=model,
1367+
total_samples=50,
1368+
hidden_dim=hidden_dim,
1369+
device=model.device)
1370+
dist.barrier()
1371+
for n, batch in enumerate(data_loader):
1372+
loss = model(batch[0], batch[1])
1373+
loss = loss[1]
1374+
model.backward(loss)
1375+
model.step()

0 commit comments

Comments
 (0)