Skip to content

Commit 5b3e9da

Browse files
committed
feat: basic test added
1 parent 4c87948 commit 5b3e9da

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

tests/test_pin_memory.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import os
2+
3+
import pytest
4+
import torch
5+
6+
from checkpoint_engine.ps import ParameterServer
7+
8+
9+
def generate_dummy_checkpoint() -> dict[str, torch.Tensor]:
10+
"""
11+
Generate dummy checkpoint data
12+
"""
13+
named_tensors = {
14+
"layer1.weight": torch.randn(1024, 1024),
15+
"layer1.bias": torch.randn(1024),
16+
"layer2.weight": torch.randn(2048, 1024),
17+
"layer2.bias": torch.randn(2048),
18+
}
19+
return named_tensors
20+
21+
22+
@pytest.mark.gpu
23+
def test_register_pin_memory():
24+
os.environ["RANK"] = "0"
25+
os.environ["WORLD_SIZE"] = "1"
26+
ps = ParameterServer()
27+
checkpoint1 = generate_dummy_checkpoint()
28+
checkpoint_shared1 = generate_dummy_checkpoint()
29+
checkpoint2 = generate_dummy_checkpoint()
30+
checkpoint_shared2 = generate_dummy_checkpoint()
31+
ps.register_checkpoint("test_checkpoint1", named_tensors=checkpoint1)
32+
ps.unregister_checkpoint("test_checkpoint1")
33+
assert "test_checkpoint1" not in ps._memory_pool
34+
ps.register_checkpoint(
35+
"test_checkpoint_shared1", named_tensors=checkpoint_shared1, use_shared_memory_pool=True
36+
)
37+
ps.register_checkpoint("test_checkpoint2", named_tensors=checkpoint2)
38+
assert "test_checkpoint_shared1" not in ps._memory_pool
39+
assert "__shared_memory_pool__" in ps._memory_pool
40+
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1"
41+
assert "test_checkpoint2" in ps._memory_pool
42+
ps.register_checkpoint(
43+
"test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True
44+
) # this will fail
45+
assert "test_checkpoint_shared2" not in ps._memory_pool
46+
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1"
47+
ps.unregister_checkpoint("test_checkpoint_shared1")
48+
assert ps._current_shared_memory_pool_user == ""
49+
assert "__shared_memory_pool__" in ps._memory_pool
50+
ps.register_checkpoint(
51+
"test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True
52+
)
53+
assert "test_checkpoint_shared2" not in ps._memory_pool
54+
assert "__shared_memory_pool__" in ps._memory_pool
55+
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared2"
56+
ps.unregister_checkpoint("test_checkpoint1")
57+
assert "test_checkpoint1" not in ps._memory_pool
58+
ps.unregister_checkpoint("test_checkpoint2")
59+
assert "test_checkpoint2" not in ps._memory_pool
60+
ps.unregister_checkpoint("test_checkpoint_shared2")
61+
assert ps._current_shared_memory_pool_user == ""
62+
assert "__shared_memory_pool__" in ps._memory_pool

0 commit comments

Comments
 (0)