Skip to content

Commit e126a89

Browse files
committed
feat: test_assign_receiver_ranks.py: add unit tests for _assign_receiver_ranks function
1 parent e79ef26 commit e126a89

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

checkpoint_engine/ps.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Callable
1212
from datetime import timedelta
1313
from functools import lru_cache
14-
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
14+
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple, TypeVar
1515

1616
import httpx
1717
import numpy as np
@@ -531,17 +531,25 @@ def _gen_h2d_buckets(
531531
return _assign_receiver_ranks(buckets, actual_local_topo, remote_topo)
532532

533533

534+
T = TypeVar("T")
535+
536+
534537
def _assign_receiver_ranks(
535-
buckets: list[tuple[int, H2DBucket]],
538+
buckets: list[tuple[int, "T"]],
536539
local_topo: dict[str, set[int]],
537540
remote_topo: dict[str, set[int]],
538-
) -> list[tuple[int, int, H2DBucket]]:
541+
) -> list[tuple[int, int, "T"]]:
539542
"""
540543
(owner_rank, bucket) -> (receiver_rank, owner_rank, bucket)
541544
542545
Assign receiver ranks to buckets. If ranks is empty, assign the owner_rank as receiver_rank.
543546
GPU-rdma_device topology will be considered to make full use of the bandwidth.
544547
"""
548+
assert local_topo, "local_topo should not be empty"
549+
assert remote_topo, "remote_topo should not be empty"
550+
if not buckets:
551+
logger.warning("bucket list is empty, no need to assign receiver ranks")
552+
return []
545553
rank_to_rdma_device = {
546554
rank: rdma_device for rdma_device, ranks in remote_topo.items() for rank in ranks
547555
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import os
2+
import sys
3+
4+
5+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
6+
from checkpoint_engine.ps import _assign_receiver_ranks
7+
8+
9+
class TestAssignReceiverRanks:
10+
def test_basic_functionality(self):
11+
buckets = [(i % 8, f"bucket{i}") for i in range(100)]
12+
local_topo = {f"rdma{i}": {i} for i in range(8)}
13+
remote_topo = {f"rdma{i}": {i} for i in range(8)}
14+
15+
result = _assign_receiver_ranks(buckets, local_topo, remote_topo)
16+
17+
assert len(result) == 100
18+
for item in result:
19+
assert len(item) == 3
20+
assert isinstance(item[0], int) # receiver_rank
21+
assert isinstance(item[1], int) # owner_rank
22+
assert isinstance(item[2], str) # bucket
23+
24+
for receiver_rank, owner_rank, bucket in result:
25+
assert receiver_rank in range(8)
26+
assert owner_rank % 8 == receiver_rank
27+
assert bucket in {f"bucket{i}" for i in range(100)}
28+
29+
def test_empty_buckets(self):
30+
buckets = []
31+
local_topo = {"rdma0": {0}}
32+
remote_topo = {"rdma0": {0}}
33+
34+
result = _assign_receiver_ranks(buckets, local_topo, remote_topo)
35+
36+
assert result == []

0 commit comments

Comments
 (0)