Skip to content

Commit 89590fb

Browse files
authored
Add destroy to tests to free memory (#7160)
ZeRO3 requires explicit cleaning in tests when reusing the environment. This PR adds `destroy` calls to the tests to free memory and avoid potential errors due to memory leaks. Signed-off-by: Masahiro Tanaka <[email protected]>
1 parent 1ca83a6 commit 89590fb

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/unit/runtime/zero/test_zero.py

+32
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def test(self, zero_stage):
8383
data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device)
8484

8585
run_unbalanced_gradients(model, data_loader)
86+
model.destroy()
8687

8788

8889
# testing the fix https://github.com/deepspeedai/DeepSpeed/pull/1227
@@ -143,6 +144,8 @@ def forward(self, x, y):
143144
model.backward(loss)
144145
model.step()
145146

147+
model.destroy()
148+
146149

147150
# testing the fix https://github.com/deepspeedai/DeepSpeed/pull/1227
148151
# also reproduces the https://github.com/deepspeedai/DeepSpeed/pull/1372
@@ -243,6 +246,8 @@ def forward(self, x, y):
243246
# float() workaround for torch<1.6
244247
assert torch.allclose(orig_state_dict[name].float(), fp32_state_dict[name].float())
245248

249+
model.destroy()
250+
246251
def test_2_param_groups(self, tmpdir, zero_stage, freeze_params):
247252
# TODO:
248253
# - need to test with multiple param groups
@@ -348,6 +353,8 @@ def forward(self, x, y):
348353
# float() workaround for torch<1.6
349354
assert torch.allclose(orig_state_dict[name].float(), fp32_state_dict[name].float())
350355

356+
model.destroy()
357+
351358

352359
@pytest.mark.parametrize("allgather_bucket_size", [1000, 1001])
353360
class TestIncorectAllgatherBucketSize(DistributedTest):
@@ -821,6 +828,8 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor:
821828
_assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE})
822829
assert not math.isclose(ds_engine.optimizer._global_grad_norm, 0.0)
823830

831+
ds_engine.destroy()
832+
824833

825834
@pytest.mark.parametrize("init_context_manager", [True, False])
826835
@pytest.mark.parametrize("reduce_scatter", [True, False])
@@ -893,6 +902,8 @@ def forward(self, x: Tensor) -> Tensor:
893902

894903
assert torch.allclose(weight_gradient, expected_weight_gradient)
895904

905+
ds_engine.destroy()
906+
896907

897908
@pytest.mark.parametrize("init_context_manager", [True, False])
898909
class TestZero3ParamPartitioningManyParams(DistributedTest):
@@ -977,6 +988,8 @@ def forward(self, x: Tensor) -> Tensor:
977988
for layer_num, activation in enumerate(weight_gradients):
978989
pass
979990

991+
ds_engine.destroy()
992+
980993

981994
class TestZero3InitForParentWeightInitialization(DistributedTest):
982995
world_size = 4
@@ -1197,6 +1210,8 @@ def create_tensor(vals):
11971210
ds_engine.optimizer.step()
11981211
_assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE})
11991212

1213+
ds_engine.destroy()
1214+
12001215

12011216
class TestParamPartitioningSkipInit(DistributedTest):
12021217
world_size = 2
@@ -1274,6 +1289,8 @@ def forward(self, x, y):
12741289
model.backward(loss)
12751290
model.step()
12761291

1292+
model.destroy()
1293+
12771294

12781295
class TestZeroOffloadStage1(DistributedTest):
12791296
world_size = 2
@@ -1311,6 +1328,8 @@ def test(self):
13111328
model.backward(loss)
13121329
model.step()
13131330

1331+
model.destroy()
1332+
13141333

13151334
@pytest.mark.parametrize("return_type", [tuple, list, dict])
13161335
class TestZero3DictFwd(DistributedTest):
@@ -1373,6 +1392,8 @@ def forward(self, x, y):
13731392
model.backward(loss)
13741393
model.step()
13751394

1395+
model.destroy()
1396+
13761397

13771398
@pytest.mark.parametrize("zero_stage", [1, 2, 3])
13781399
class TestZeroAdamOptimizerStepCount(DistributedTest):
@@ -1439,6 +1460,8 @@ def test(self, zero_stage):
14391460
assert all(step == step_counts[0] for step in step_counts)
14401461
assert model.global_steps == step_counts[0]
14411462

1463+
model.destroy()
1464+
14421465

14431466
@pytest.mark.parametrize("zero_stage", [1, 2, 3])
14441467
class TestZeroFrozenWeights(DistributedTest):
@@ -1497,6 +1520,8 @@ def forward(self, x, y):
14971520
model.backward(loss)
14981521
model.step()
14991522

1523+
model.destroy()
1524+
15001525

15011526
@pytest.mark.parametrize("force_ds_optim", [True, False])
15021527
class TestZeroOffloadOptim(DistributedTest):
@@ -1577,6 +1602,8 @@ def test_training_partition_cache(self, training):
15771602
model.empty_partition_cache()
15781603
assert sum([p.numel() for p in model.parameters()]) == 0
15791604

1605+
model.destroy()
1606+
15801607

15811608
@pytest.mark.parametrize("use_client_optimizer", [True, False])
15821609
@pytest.mark.parametrize("empty_weight_group", [True, False])
@@ -1629,6 +1656,8 @@ def test_empty_param_groups(self, dtype, use_client_optimizer, empty_weight_grou
16291656
config=config_dict,
16301657
)
16311658

1659+
model.destroy()
1660+
16321661

16331662
class TestZero3SwitchModes(DistributedTest):
16341663
world_size = 2
@@ -1674,6 +1703,8 @@ def test(self, prefetch_ratio, zero_stage=3):
16741703
for batch in data_loader:
16751704
loss = model(batch[0], batch[1])
16761705

1706+
model.destroy()
1707+
16771708

16781709
# Avoid overwriting client module id
16791710
# https://github.com/deepspeedai/DeepSpeed/issues/6772
@@ -1707,3 +1738,4 @@ def forward(self, x):
17071738
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
17081739
post_init_m_id = model.id
17091740
assert pre_init_m_id == post_init_m_id
1741+
model.destroy()

0 commit comments

Comments
 (0)