forked from EvolvingLMMs-Lab/lmms-eval
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_model_cleanup.py
More file actions
38 lines (26 loc) · 821 Bytes
/
test_model_cleanup.py
File metadata and controls
38 lines (26 loc) · 821 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from lmms_eval.api.model import lmms
class _DummyLM(lmms):
def loglikelihood(self, requests):
return []
def generate_until(self, requests):
return []
def generate_until_multi_round(self, requests):
return []
class _FakeAccelerator:
def __init__(self, model):
self._models = [model]
self.free_memory_calls = 0
def free_memory(self):
self.free_memory_calls += 1
self._models = []
def test_clean_releases_accelerator_model_references():
lm = _DummyLM()
model = torch.nn.Linear(1, 1)
accelerator = _FakeAccelerator(model)
lm._model = model
lm.accelerator = accelerator
lm.clean()
assert accelerator.free_memory_calls == 1
assert accelerator._models == []
assert not hasattr(lm, "_model")