Skip to content

Commit 08018ab

Browse files
committed
misc: pytest version test
1 parent 7839fc4 commit 08018ab

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

tests/test_error_quit.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22
import random
3+
import subprocess
4+
import pytest
35
import time
46

57
import torch
@@ -86,13 +88,45 @@ def run():
8688
ps.update(checkpoint_name, queue.put, ranks=ranks)
8789
# sleep 3s to wait process group is destroyed
8890
time.sleep(3)
89-
except RuntimeError as e:
91+
except Exception as e:
9092
print(f"[rank{rank}] Caught exception from worker process: {e}")
9193
assert isinstance(e, RuntimeError)
94+
assert "failed to update weights due to remote error(s)" in str(e)
9295
finally:
9396
ps.unregister_checkpoint(checkpoint_name)
9497
queue.put(None)
9598

9699

100+
@pytest.mark.gpu
101+
def test_update():
102+
world_size = torch.cuda.device_count()
103+
assert world_size >= 2, "This test requires at least 2 GPUs."
104+
105+
master_addr = "localhost"
106+
master_port = random.randint(20000, 30000)
107+
108+
cmd = [
109+
"torchrun",
110+
"--nproc_per_node",
111+
str(world_size),
112+
"--master_addr",
113+
master_addr,
114+
"--master_port",
115+
str(master_port),
116+
"tests/test_error_quit.py",
117+
]
118+
119+
result = subprocess.run(
120+
cmd,
121+
capture_output=False,
122+
text=True,
123+
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
124+
shell=False,
125+
check=False,
126+
)
127+
128+
assert result.returncode == 0
129+
130+
97131
if __name__ == "__main__":
98132
run()

0 commit comments

Comments
 (0)