Skip to content

Commit 90701aa

Browse files
committed
Release accelerator model refs during cleanup
1 parent d6cc2b5 commit 90701aa

2 files changed

Lines changed: 42 additions & 0 deletions

File tree

lmms_eval/api/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ def set_cache_hook(self, cache_hook) -> None:
164164
self.cache_hook = cache_hook
165165

166166
def clean(self):
167+
accelerator = getattr(self, "accelerator", None)
168+
if accelerator is not None and hasattr(accelerator, "free_memory"):
169+
accelerator.free_memory()
170+
167171
for attr_name in list(vars(self)):
168172
attr_value = getattr(self, attr_name)
169173
if isinstance(attr_value, nn.Module):

test/models/test_model_cleanup.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
3+
from lmms_eval.api.model import lmms
4+
5+
6+
class _DummyLM(lmms):
7+
def loglikelihood(self, requests):
8+
return []
9+
10+
def generate_until(self, requests):
11+
return []
12+
13+
def generate_until_multi_round(self, requests):
14+
return []
15+
16+
17+
class _FakeAccelerator:
18+
def __init__(self, model):
19+
self._models = [model]
20+
self.free_memory_calls = 0
21+
22+
def free_memory(self):
23+
self.free_memory_calls += 1
24+
self._models = []
25+
26+
27+
def test_clean_releases_accelerator_model_references():
28+
lm = _DummyLM()
29+
model = torch.nn.Linear(1, 1)
30+
accelerator = _FakeAccelerator(model)
31+
lm._model = model
32+
lm.accelerator = accelerator
33+
34+
lm.clean()
35+
36+
assert accelerator.free_memory_calls == 1
37+
assert accelerator._models == []
38+
assert not hasattr(lm, "_model")

0 commit comments

Comments
 (0)