Skip to content

Commit 2ac012e

Browse files
Ali-Tehranifacebook-github-bot
authored andcommitted
Add intra_group_size to topology (meta-pytorch#3696)
Summary: Context --------- Every GB200 node has 2 B200 GPU attached to it, however allows up to 72 B200 connected via NVlink. The planner needs to know how big the intra topology group size is going to be. This causes the `local_world_size` to be different from the `intra_group_size`. Implementation ------------------ - Topology class: - Adds `pod_size`, and uses that to calculate the `intra_group_size` (maximum number of processes linked with high intra bandwidth) to Topology class. If isn't given, then it defaults to local_world_size. - `shard_estimators.py` - The shard estimators now use the `intra_group_size` instead of `local_world_size`, this allows RW/TW/CW to properly account for larger NVlink that comes with the pods. Reviewed By: isururanawaka Differential Revision: D91617887
1 parent 06d0acb commit 2ac012e

File tree

4 files changed

+28
-5
lines changed

4 files changed

+28
-5
lines changed

torchrec/distributed/planner/enumerators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
) -> None:
8989
self._compute_device: str = topology.compute_device
9090
self._world_size: int = topology.world_size
91-
self._local_world_size: int = topology.local_world_size
91+
self._local_world_size: int = topology.intra_group_size
9292
self._batch_size: int = batch_size
9393
self._constraints = constraints
9494
self._sharder_map: Dict[str, ModuleSharder[nn.Module]] = {}

torchrec/distributed/planner/partitioners.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ def partition(
302302
== PartitionByType.DEVICE.value
303303
):
304304
if minheap_devices is None:
305+
# Local world size should be used, since number of GPU per host/CPU
305306
minheap_devices = self._establish_minheap(
306307
devices, storage_constraint.local_world_size
307308
)
@@ -652,11 +653,11 @@ def _cohost_partition(
652653
def _get_host_level_devices(
653654
topology: Topology, all_devices: List[DeviceHardware]
654655
) -> List[List[DeviceHardware]]:
655-
num_hosts: int = topology.world_size // topology.local_world_size
656+
num_hosts: int = topology.world_size // topology.intra_group_size
656657
host_level_devices: List[List[DeviceHardware]] = []
657658
for i in range(num_hosts):
658659
devices_in_host = all_devices[
659-
i * topology.local_world_size : (i + 1) * topology.local_world_size
660+
i * topology.intra_group_size : (i + 1) * topology.intra_group_size
660661
]
661662
host_level_devices.append(devices_in_host)
662663
return host_level_devices

torchrec/distributed/planner/shard_estimators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def estimate(
234234
sharding_type=sharding_option.sharding_type,
235235
batch_sizes=batch_sizes,
236236
world_size=self._topology.world_size,
237-
local_world_size=self._topology.local_world_size,
237+
local_world_size=self._topology.intra_group_size,
238238
input_lengths=sharding_option.input_lengths,
239239
input_data_type_size=input_data_type_size,
240240
table_data_type_size=table_data_type_size,
@@ -1146,7 +1146,7 @@ def estimate(
11461146
shard_sizes=[shard.size for shard in sharding_option.shards],
11471147
batch_sizes=batch_sizes,
11481148
world_size=self._topology.world_size,
1149-
local_world_size=self._topology.local_world_size,
1149+
local_world_size=self._topology.intra_group_size,
11501150
input_lengths=sharding_option.input_lengths,
11511151
num_poolings=num_poolings,
11521152
caching_ratio=caching_ratio if caching_ratio else UVM_CACHING_RATIO,

torchrec/distributed/planner/types.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def __init__(
286286
hbm_cap: Optional[int] = None,
287287
ddr_cap: Optional[int] = None,
288288
local_world_size: Optional[int] = None,
289+
pod_size: Optional[int] = None,
289290
hbm_mem_bw: float = HBM_MEM_BW,
290291
ddr_mem_bw: float = DDR_MEM_BW,
291292
hbm_to_ddr_mem_bw: float = HBM_TO_DDR_MEM_BW,
@@ -310,6 +311,10 @@ def __init__(
310311
"cuda",
311312
"mtia",
312313
], f"unsupported compute device {compute_device}"
314+
if pod_size and pod_size > world_size:
315+
raise ValueError(
316+
f"pod_size={pod_size} cannot be greater than world_size={world_size}"
317+
)
313318

314319
self._compute_device = compute_device
315320
self._world_size = world_size
@@ -343,9 +348,19 @@ def __init__(
343348
)
344349
)
345350

351+
# Local world size is the number of devices (GPUs) in a single node
346352
self._local_world_size: int = (
347353
local_world_size if local_world_size else world_size
348354
)
355+
self._pod_size: int = pod_size
356+
# Maximum numb of devices with high bandwidth interconnect (e.g. NVLink)
357+
# if pod_size isn't given, then assumes local_world_size is maximum group size
358+
self._intra_group_size: int = (
359+
pod_size * self._local_world_size
360+
if pod_size is not None
361+
else self._local_world_size
362+
)
363+
349364
self._hbm_mem_bw = hbm_mem_bw
350365
self._ddr_mem_bw = ddr_mem_bw
351366
self._hbm_to_ddr_mem_bw = hbm_to_ddr_mem_bw
@@ -381,6 +396,11 @@ def world_size(self) -> int:
381396
def local_world_size(self) -> int:
382397
return self._local_world_size
383398

399+
@property
400+
def intra_group_size(self) -> int:
401+
# The largest set of nodes connected with high intra-node bandwidth (e.g. NVLink)
402+
return self._intra_group_size
403+
384404
@property
385405
def hbm_mem_bw(self) -> float:
386406
return self._hbm_mem_bw
@@ -424,6 +444,7 @@ def __repr__(self) -> str:
424444
for idx, device in enumerate(self._devices):
425445
topology_repr += f"\tdevice {idx} {device}\n"
426446
topology_repr += f"local_world_size={self._local_world_size} \n"
447+
topology_repr += f"intra_group_size={self._intra_group_size} \n"
427448
topology_repr += str(self._comms_bandwidths) + "\n"
428449
return topology_repr
429450

@@ -449,6 +470,7 @@ def _hash(self) -> int:
449470
hbms,
450471
ddrs,
451472
self._local_world_size,
473+
self._intra_group_size,
452474
self._hbm_mem_bw,
453475
self._ddr_mem_bw,
454476
self._hbm_to_ddr_mem_bw,

0 commit comments

Comments
 (0)