You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For the advanced super pod hardware, such as GB200 NVL 72, the nvlink is cross mulitple nodes. To enjoy these cross node nvlink. the placement_groups created by the RayResourcePool need to be put on specific nodes. This is also can be used in other scaleup technologies clusters
Proposed Design
ray cluster
when start the ray node. need to put some label to the work node. The label contains the topology information
verl ResourcePoolManager
After created the array of placement_group add another scheduler to process the array of placement_group.
The scheduler process the topology label for nodes and get out placement_goups with node id , so that the task will be set to the nodes that can utilize the high performance scaleup technologies
for example, in verl/tree/main/verl/single_controller/ray/base.py
defget_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="cuda"):
ifself.pgsisnotNone:
returnself.pgspg_name_prefix= (
nameifnameelsef"{self.name_prefix}verl_group_{'_'.join([str(count) forcountinself._store])}:"
)
# print(f"pg_name_prefix = {pg_name_prefix}")ifdevice_name=="npu":
device_name="NPU"elifdevice_name=="cuda":
device_name="GPU"bundle= {"CPU": self.max_colocate_count}
ifself.use_gpu:
bundle[device_name] =1ifself.accelerator_typeisnotNone:
bundle[self.accelerator_type] =1e-4pg_scheme= [[bundle.copy() for_inrange(process_count)] forprocess_countinself._store]
lifetime="detached"ifself.detachedelseNonepgs= [
placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix+str(idx), lifetime=lifetime)
foridx, bundlesinenumerate(pg_scheme)
]
""" node_schedule implement the topoloy aware node select. before schedule the pgs is like this [ {"GPU":8},{"GPU":8},{"GPU":8},{"GPU":8} ] after node schedule the pgs is [ {"GPU":8, "node:192.168.1.100":1 }, {"GPU":8, "node:192.168.1.101":1}, {"GPU":8, "node:192.168.1.102":1}, {"GPU":8, "node:192.168.1.103":1} ] These 4 node of 192.168.1.100-103 share a same scaleup domain so that the model can have a better network performance. """pgs=node_schedule(pgs)
ray.get([pg.ready() forpginpgs])
self.pgs=sort_placement_group_by_node_ip(pgs)
returnpgs
The node_schedule add node label to there placement_groups to set each placement_group to specific node.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Motivation
For the advanced super pod hardware, such as GB200 NVL 72, the nvlink is cross mulitple nodes. To enjoy these cross node nvlink. the placement_groups created by the RayResourcePool need to be put on specific nodes. This is also can be used in other scaleup technologies clusters
Proposed Design
ray cluster
when start the ray node. need to put some label to the work node. The label contains the topology information
verl ResourcePoolManager
After created the array of placement_group add another scheduler to process the array of placement_group.
The scheduler process the topology label for nodes and get out placement_goups with node id , so that the task will be set to the nodes that can utilize the high performance scaleup technologies
for example, in verl/tree/main/verl/single_controller/ray/base.py
The
node_scheduleadd node label to there placement_groups to set each placement_group to specific node.Beta Was this translation helpful? Give feedback.
All reactions