Skip to content

Commit 4df01bb

Browse files
committed
feat: test for inplace pin memory added
1 parent 7e19c8a commit 4df01bb

File tree

1 file changed

+90
-2
lines changed

1 file changed

+90
-2
lines changed

tests/test_update.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def error_run(weights: list[tuple[str, torch.Tensor]]):
8282
try:
8383
trigger_error(socket_paths)
8484
except RuntimeError as e:
85-
assert str(e) == "Failed to update weights due to remote errors"
85+
assert str(e) == "Some workers failed to update weights"
8686

8787

8888
def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue):
@@ -96,7 +96,7 @@ def check(names_to_check: dict[str, bool], weights: list[tuple[str, torch.Tensor
9696
for name, weight in weights:
9797
if name not in named_tensors:
9898
continue
99-
assert (weight == named_tensors[name]).all()
99+
assert (weight == named_tensors[name]).all(), f"Tensor {name} does not match!"
100100
names_to_check[name] = True
101101

102102
def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str, str]]):
@@ -163,6 +163,61 @@ def run(
163163
assert proc.exitcode == 0
164164

165165

166+
def run_with_files(
167+
checker_func: callable,
168+
):
169+
rank = int(os.getenv("RANK"))
170+
ctx = get_context("spawn")
171+
queue = ctx.Queue()
172+
_device_uuid = _get_physical_gpu_id(device_manager, rank)
173+
ps = ParameterServer(auto_pg=True)
174+
_device_uuid = _get_physical_gpu_id(ps.device_manager, rank)
175+
named_tensors = dict(gen_test_tensors(rank))
176+
177+
# Save 1/3 tensors to /dev/shm/ as .safetensors files
178+
# Save 1/3 tensors to ./tmp (disk) as .safetensors files
179+
# Keep 1/3 tensors in memory
180+
import safetensors.torch
181+
182+
files = []
183+
dev_shm_dir = "/dev/shm/checkpoint_engine_tests" # noqa: S108
184+
disk_dir = "/tmp/checkpoint_engine_tests" # noqa: S108
185+
os.makedirs(dev_shm_dir, exist_ok=True)
186+
os.makedirs(disk_dir, exist_ok=True)
187+
tensors_items = list(named_tensors.items())
188+
tensors_in_dev_shm = named_tensors
189+
tensors_in_dev_shm = dict(tensors_items[: len(tensors_items) // 2])
190+
tensors_in_disk = dict(tensors_items[len(tensors_items) // 3 : 2 * len(tensors_items) // 3])
191+
tensors_in_memory = dict(tensors_items[1 * len(tensors_items) // 2 :])
192+
disk_files = [
193+
os.path.join(disk_dir, f"rank{_rank}_checkpoint.safetensors")
194+
for _rank in range(get_world_size())
195+
]
196+
safetensors.torch.save_file(tensors_in_disk, disk_files[rank])
197+
time.sleep(1)
198+
files.append(disk_files[rank])
199+
dev_shm_files = [
200+
os.path.join(dev_shm_dir, f"rank{rank}_checkpoint.safetensors")
201+
for _ in range(get_world_size())
202+
]
203+
safetensors.torch.save_file(tensors_in_dev_shm, dev_shm_files[rank])
204+
time.sleep(1)
205+
files.append(dev_shm_files[rank])
206+
207+
checkpoint_name = "test_with_files"
208+
proc = ctx.Process(target=checker_func, args=(rank, _device_uuid, named_tensors, queue))
209+
proc.start()
210+
ps.register_checkpoint(checkpoint_name, named_tensors=tensors_in_memory, files=files)
211+
ps.gather_metas(checkpoint_name)
212+
ps.update(checkpoint_name, queue.put, ranks=[])
213+
# sleep 3s to wait process group is destroyed
214+
time.sleep(3)
215+
ps.unregister_checkpoint(checkpoint_name)
216+
queue.put(None)
217+
proc.join()
218+
assert proc.exitcode == 0
219+
220+
166221
@pytest.mark.gpu
167222
@pytest.mark.parametrize(
168223
"test_name,rank_list",
@@ -211,6 +266,37 @@ def test_update(test_name: str, rank_list: list[list[int]] | None):
211266
assert result.returncode == 0
212267

213268

269+
@pytest.mark.gpu
270+
def test_update_with_files(test_name: str = "test_with_files"):
271+
world_size = device_manager.device_module.device_count()
272+
assert world_size >= 2, "This test requires at least 2 GPUs."
273+
master_addr = "localhost"
274+
master_port = 25400
275+
cmd = [
276+
"torchrun",
277+
"--nproc_per_node",
278+
str(world_size),
279+
"--master_addr",
280+
master_addr,
281+
"--master_port",
282+
str(master_port),
283+
__file__,
284+
test_name,
285+
"[]",
286+
]
287+
288+
result = subprocess.run( # noqa: S603
289+
cmd,
290+
capture_output=False,
291+
text=True,
292+
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
293+
shell=False,
294+
check=False,
295+
)
296+
297+
assert result.returncode == 0
298+
299+
214300
if __name__ == "__main__":
215301
run_with_pytest = "PYTEST_CURRENT_TEST" in os.environ
216302
if not run_with_pytest:
@@ -230,5 +316,7 @@ def test_update(test_name: str, rank_list: list[list[int]] | None):
230316
expected_exception=RuntimeError,
231317
exception_msg="Failed to update weights due to remote errors",
232318
)
319+
elif test_type == "test_with_files":
320+
run_with_files(checker_proc)
233321
else:
234322
raise ValueError(f"Unknown TEST_TYPE: {test_type}")

0 commit comments

Comments
 (0)